[Algorithm] Optimal binary Search (최적 이진 탐색 트리)

2023. 5. 20. 22:59

오늘은 Optimal Binary Search Tree ( 최적 이진 탐색 트리 ) 에 대해 정리해보았다.

 

Binary Search Tree

Binary Search Tree는 Binary Tree의 성질을 만족하면서 Ordered 되어 있는 상태이다.

BST의 성질은 다음과 같다.

왼쪽 자식 노드는 부모 자식의 값보다 작다.
오른쪽 자식 노드는 부모 자식의 값보다 크다.

BST는 탐색시간을 줄여준다는 것에서 굉장히 큰 의미가 있다. 

 

Optimal BST (최적 이진 탐색 트리)

1. 기본 개념

오늘 할 Optimal BST는 BST 중 평균 탐색 시간이 가장 작은 Tree를 의미한다.

여기서 '평균 탐색시간'이란 무엇일까?

이때까지 우리는 노드를 탐색할 확률이 같다고 가정하고 Tree를 만들었다. (root 노드와 leaf 노드의 탐색 횟수가 같을 때)
그럼 탐색 확률이 다를 때는 어떨까?

 위의 Tree가 BST라고 생각해보자. 
각 노드를 탐색하는 횟수는 다음과 같다.

node A B C D
횟수 2 1 2 3

A를 찾으려면 B -> A로 이동해야하기 때문에 총 2번을 비교해야해서 2이다. 

 

case 1) 

node A B C D
P(탐색 확률) 0.2 0.5 0.1 0.2

A의 탐색시간 : 2 * 0.2 (탐색 횟수 * 확률)
B의 탐색시간 : 1 * 0.5
C의 탐색시간 : 2 * 0.1
D의 탐색시간 : 3 * 0.2

=> 전체 평균 탐색시간 = 0.4 + 0.5 + 0.2 + 0.6 = 1.7

 

case 2)

node A B C D
P(탐색 확률) 0.5 0.2 0.1 0.2

A의 탐색시간 : 2 * 0.5
B의 탐색시간 : 1 * 0.2
C의 탐색시간 : 2 * 0.1
D의 탐색시간 : 3 * 0.2

=> 전체 평균 탐색시간 = 1 + 0.2 + 0.2 + 0.6 = 2

 

case 1과 case2는 같은 TREE를 사용하지만,  탐색 확률에 따라 평균 탐색시간이 달라지는 모습을 볼 수 있다. 
그래서 case 2 보다 case 1 이 더 최적의 BST라고 할 수 있다.

Optimal Binary Search Tree는 각 노드로 만들 수 있는 모든 BST 경우의 수 중 가장 평균 탐색시간이 짧은 것을 의미한다.

 

2. 점화식 구하기

Optimal BST를 어떻게 구할 수 있을까?

Tree의 가장 기본 성질을 생각해보자. -> Tree의 subtree는 Tree이다.

Optimal BST도 동일한 원리를 가지고 있다. Optimal BST의 subtree도 Optimal BST이다. 이 성질을 이용하여 recursive하게 작은 문제들로 쪼갤 수 있어서 DP를 사용할 수 있다. 

n개의 노드가 있다고 하자.

  • k1 ~ kn : 각 노드에 대한 key
  • pi : 각 노드의 탐색 확률
  • ci : 노드 탐색을 위한 횟수 (root = 1)
  • A[i][j] : ki 에서 kj 까지의 optimal value

A[1][n] 을 구해보자.
1<=k<=n에 대하여, A[1][k-1] (왼쪽 subtree) + A[k+1][n] (오른쪽 subtree) 의 최솟값을 취하면 된다. 
이를 식을 적어보면 다음과 같다.
A[1][n] = min( (k = 1 ~ n) A[1][k-1] + A[k+1][n] ) + pm (m = i ~ j )

이를 일반화 시키면 아래와 같은 식이 나온다. 

  • i < j
    마지막에 pi ~ pj를 더하는 이유는 각 노드를 한번씩 거쳐가야 하기 때문이다.

  • i == j
    ci ( == cj) 가 1이기 때문에 pi의 값과 같아진다.
    A[i][i] = pi

  • i > j
    범위를 벗어나기 때문에 0으로 한다.

 

3. 예제 풀이

Optimal Binary Search는 A배열과 R배열을 완성시키는 것을 목적으로 한다.

A 배열은 위에서 이미 소개했기 때문에 넘어가고 R 배열만 말하자면, R[i][j]는 A[i][j] 선택할 때 minimum이었던 k 값을 의미한다. 예제 풀이를 차근차근 따라오면 쉽게 이해할 수 있다. 

내가 예제를 만드는 것은 귀찮으니 교과서에 있는 예제로 설명을 하겠다. 

4개의 노드 ( 사람 ) 이 있고, 각 사람을 찾을 확률이 위와 같을 때 A배열과 R배열을 완성시켜보자.

 


p = [3, 3, 1, 1]  (실제 p는 [ 3/8, 3/8, 1/8, 1/8 ] 인데 계산이 귀찮으니까 그냥 자연수로 생각하고 계산하겠다.)
c = [1, 2, 3, 4]

위의 순서로 채울 것이다. 

1. i == j일 때 ( 대각성분)
A[i][j] = pi이다. R[i][j] = i가 된다.

2. 첫번째 대각성분 채우기

이 식을 이해한 상태로 따라와야 수월할 것이다.

A[1][2] ( => i = 1, j = 2 )
(k = 1) A[1][0] + A[2][2] = 0 + 3 = 3
(k = 2) A[1][1] + A[3][2] = 3 + 0 = 3
=> mininum = 3

p1 + p2 = 3 + 3 = 6

따라서 A[1][2] = 3 + 6 = 9

R[1][2]는 minimum에 해당하는 k값인데 이번 경우는 minimum 값이 3에 해당하는 수가 k = 1, k = 2일 때가 나왔다. 여러개가 나온 경우 더 작은 것을 취한다. 

나머지 성분을 구하면 다음과 같다.

A[2][3] = min( A[2][1] + A[3][3],  A[2][2] + A[4][3] ) + p2 + p3 = 1 + 3 + 1 = 5
R[2][3] = 2 (k = 2 -> A[2][1] + A[3][3] = 1 -> 이게 최소 값이기 때문에 R값이 된다)

A[3][4] = min( A[3][2] + A[4][4], A[3][3] + A[5][4] ) + p3 + p4 = 1 + 1 + 1 = 3
R[3][4] = 3

 

3. 세번째 대각성분 채우기

A[1][3] = min( A[1][0] + A[2][3] , A[1][1] + A[3][3], A[1][2] + A[4][3] ) + p1 + p2 + p3 = 4 + 3 + 3 + 1 = 11
R[1][3] = 2 

A[2][4] = min( A[2][1] + A[3][4], A[2][2] + A[4][4], A[2][3] + A[5][4]) + p2 + p3 + p4 = 3 + 3 + 1 + 1 = 8
R[2][4] = 2

 

 

4. 네 번째 대각성분 채우기

A[1][4] = min( A[1][0] + A[2][4] , A[1][1] + A[3][4], A[1][2] + A[4][4], A[1][3] + A[4][5] ) + p1 + p2 + p3 + p4 = 6 + 3 + 3 + 1 + 1 = 14
R[1][4] = 2

 

4. 코드 구현

책에 표현된 슈도코드는 다음과 같다.

 

python 으로 구현한 코드는 다음과 같다. ( 위의 슈도 코드보다는 더 직관적인 코드 일 것이다 )

def optimal_bst(p, q, n):
	# 데이터 정의
    A = [[-1 for _ in range(n+2)] for _ in  range(n+2)] # n+1 by n+1 matrix
    R = [[-1 for _ in range(n+2)] for _ in  range(n+2)]
    
    # 데이터 초기화 및 기본 대각 성분 초기화
    for i in range(1, n+1):
    	A[i][i-1] = 0
        R[i][i-1] = 0
        A[i][i] = p[i]
        R[i][i] = i
   A[n+1][n] = 0
   R[n+1][n] = 0
   
   # 최적의 root 찾기
   for d in range(1, n): # d = 몇번째 대각성분인지
   		for i in range(1, n-d + 1):
        	j = i + d
            A[i][j] , R[i][j] = minimum(A, p , i, j)
   return A, R
   
   
def minimum(A, p, i, j):
	minValue = INF # A[i][j]
    minK = 0 # R[i][j]
	for k in range(i, j+1):
    	v = A[i][k-1] + A[k+1][j]
    	for x in range(i, j+1)
        	v += p[x]
        if minValue < v:
        	minValue = v
            minK = k
    
    return minValue, minK

 

BELATED ARTICLES

more