알고리즘/주요 개념

세그먼트 트리(Segment tree)

Chavo Kim 2020. 9. 1. 20:44

{1, 3, 100, 5, -2, 10, 4, 20, 1, 5}

 

다음 배열의 구간 합을 m개의 각 요청마다 구한다고 생각해보자

 

그냥 구하게 된다면

 

O(mn)의 시간복잡도를 가지게 될 것이다.

 

이를 줄이고 싶다면, 누적합(prefix sum) 기법을 사용하면 된다.

 

미리 index = 0부터 더한 값들을 배열에 저장한다.

 

pre_sum = {1, 4, 104, 109, 107, 117, 121, 141, 142, 147}

 

그리고 만약 i번째에서 j번째의 구간 합을 구하라는 쿼리가 들어온다면

 

prefix[j - 1] - prefix[i - 1]의 값을 바로 구하면 되기 때문에 O(1)의 시간복잡도 내에 각각의 쿼리를 처리할 수 있다.

 

이렇게 되면 처음에 누적합을 구할 때만 배열 전체를 순회하면 되기 때문에 O(N)의 시간 복잡도를 가지게 된다.

 

 

하지만 구간 내에 있는 숫자가 계속 바뀐다면?

 

 

구간 내에 숫자가 m번 바뀐다고 할 때 그때마다 누적합을 갱신해줘야하기 때문에, O(nm)의 시간이 걸리게 될 것이다.

 

그렇기 때문에 세그먼트 트리(Segment tree)가 필요하다.

 

 

 

세그먼트 트리란?

 

 

주어진 쿼리에 빠르게 대응하기 위해 만들어진 트리이다.

 

세그먼트 트리에서 수를 바꾸는 과정과 수를 더하는 과정은 각각 O(logn)에 수행 가능하다.

 

그림을 통해 알아보자. (출처 www.crocus.co.kr/648 )

 

세그먼트 트리

아래는 n이 12일 때의 세그먼트 트리이다.

 

노드는 n의 가장 가까운 2제곱수에서 2를 곱한 만큼 생기기 때문에 위의 예에서는 총 32개(16 * 2)의 노드가 생기게 된다.

 

그래서 보통 n의 4배만큼의 공간을 할당하게 된다.

 

세그먼트 트리는 full binary tree의 형태를 하기 때문에 루트 노드의 좌우 공간에 거의 균등하게 데이터가 들어오게 된다.

 

그렇기 때문에 세그먼트 트리를 구현하기 위해 배열을 사용하게 된다.

 

세그먼트 트리를 배열에 구현하는 규칙

현재 노드 번호를 node이라고 했을 때 왼쪽 자식 노드를 2*node, 오른쪽 자식 노드를 2*node+1에 위치한다.

 

 

세그먼트 트리 구현 방법

 

아래는 배열 arr[n]이 들어왔을 때 세그먼트 트리를 만드는 함수 init이다.

 

int init(int node, int start, int end)
{
	if(start == end)
		return tree[node] = arr[start]

	int mid = (start+end) / 2;

	return tree[node] = init(node * 2, start, mid) + init(node * 2, mid + 1, end);	
}

 만약 최하단 if (start == end)으로 내려가게 되면 tree[node]에 arr[start]를 넣고 끝내고

 

나머지에는 각각 왼쪽 노드의 합, 오른쪽 노드의 합을 더한 뒤 init 함수를 재귀 호출해준다.

 

 

다음은 배열이 바뀌었을 때 이를 update 해주는 함수이다,

 

void update(int node, int start, int end, int index, int diff)
{
	if(!(start <= index && index <= end))
		return;

	tree[node] += diff;

	if(start != end)
	{
		int mid = (start + end) / 2
		update(node * 2, start, mid, index, diff);
		update(node * 2 + 1, mid + 1, end, index, diff);
	}
}

diff은 새로 바뀐 수와 기존의 수의 차이다. 범위에 해당하지 않을 때는 아무것도 하지 않고 return하고,

 

범위에 들어온다면 해당 node에 diff을 더해준다.

 

 

 

아래는 세그먼트 트리에서 구간 합을 구하는 코드이다.

 

int sum(int node, int start, int end, int left, int right)
{
	if(left > end || right > start)
		return 0;

	if(left <= start && end <= right)
		return tree[node];

	int mid = (start + end) / 2
	return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
}

 

이를 통해서 구간 합 또한 logN에 구해줄 수 있다.