알고리즘/이론

[Algorithm] 세그먼트 트리(Segment Tree)

ssung.k 2019. 9. 27. 22:15

세그먼트 트리(Segment Tree)란?

배열에 부분 합을 구할 때 사용하는 개념입니다. 이 때 문제는 배열의 값이 지속적으로 바뀔 수 있기 때문에 매 순간 배열의 부분 길이 만큼, 즉 O(N) 만큼의 시간이 걸리기 때문에 이를 트리로 구현하여 O(logN) 의 시간으로 해결하는 방법입니다.

 

 

배열을 세그먼트 트리로

세그먼트 트리 를 사용하기 위해서는 주어진 배열을 이진 트리 구조로 만들어야 합니다. 이 때 트리를 구현하는 알고리즘은 다음과 같습니다.

  1. 부모노드의 값은 양 쪽 자식 노드 값의 합
  2. 배열의 요소들은 리프 노드에 위치

위 알고리즘을 통해 N = 10 일 때의 세그먼트 트리를 그리면 다음과 같습니다.

 

이 때 기존 데이터의 배열의 크기를 통해서 트리 배열의 최대 크기를 알 수 있습니다. 기존 데이터 배열의 크기를 N 이라 하면, 리프 노드의 개수가 N 이 되고, 트리의 높이 H 는 [ logN ] 이 되고, 배열의 크기는 2^(H+1) 이 됩니다.

 

트리는 단순히 1차원 배열과 인덱싱을 통해 구현할 수 있습니다. 트리의 각 노드 별 배열의 인덱스는 아래 그림과 같습니다.

 

이를 코드로 구현하면 다음과 같습니다.

#include <iostream>
#include <cmath>

using namespace std;

long long *tree;
long long A[10] = {1,2,3,4,5,6,7,8,9,10};

long long init(int index, int start, int end){
    if (start == end)
        tree[index] = A[start];
    else{
        int mid = (start+end)/2;
        tree[index] = init(index*2+1, start, mid) + init(index*2+2, mid+1, end);
    }
    return tree[index];
}

int main(int argc, const char * argv[]) {

    int N = 10;
    int h = ceil(log2(N));

    tree = new long long[1<<(h+1)];

    init(0,0,N-1);

    for (int i=0;i<1<<(h+1);i++){
        cout << tree[i] << "\n";
    }

    return 0;
}

위 그림과 약간의 차이점은 그림은 처음 인덱스를 1부터 시작한데 비해 구현할 때는 0부터 시작하였습니다. 현재 노드의 인덱스 index 에 대해 왼쪽 자식 노드의 인덱스는 index*2+1, 오른쪽 자식 노드의 인덱스는 index*2+2 로서 재귀적으로 구현하였습니다.

 

구간의 합 구하기

세그먼트 트리를 구현하였으니 이를 이용해서 본래의 목적, 구간의 합을 구해보도록 하겠습니다. 두 구간에 대해서 왼쪽 구간을left, 오른쪽 구간을 right 이라 할 때, 탐색 범위 [start, end] 와 합의 구간 [left, right] 의 관계는 다음과 같습니다.

  1. [left, right] 와 [start, end]가 전혀 겹치지 않는 경우

    탐색 범위 내에 구하는 범위가 존재하지 않습니다. 그렇다면 탐색 범위에 값들은 아무 의미없는 값이므로 0을 return 합니다.

  2. [start, end] 가 [left, right]에 속해 있는 경우

    탐색 범위 내에 값들이 전부 구하는 범위의 값들입니다. 하위 노드들을 탐색할 필요없이 이미 하위 노드들의 합을 저장하고 있는 tree[index] 를 반환합니다.

  3. [left, right] 가 [start, end]에 속해 있는 경우

  4. [left, right] 와 [start, end] 가 일부 겹치는 경우

    해당 경우에는 아직 결론을 내리기에 시기상조입니다. 재귀적으로 더 들어가서 어디까지의 값들이 필요한지에 대해 구해줍니다.

이를 코드로 구현하면 다음과 같습니다.

long long sum(int index, int start, int end, int left, int right){
    // 구간이 전혀 겹치지 않는 경우
    if (start > right || end < left)
        return 0;
    else if (left <= start && end <=right)
        return tree[index];
    else {
        int mid = (start+end) / 2;
        return sum(index*2+1, start, mid, left, right) + sum(index*2+2, mid+1, end, left, right);
    }
}

 

값 변경하기

마지막으로 배열의 값을 변경해보도록 하겠습니다. 이에 따라 세그먼트 트리의 값도 변경이 되야하죠.

void update(int changed_index, long long diff, int index, int start, int end){
    if (changed_index < start || changed_index > end)
        return;
    tree[index] += diff;
    
    if (start != end){
        int mid = (start+end) / 2;
        update(changed_index, diff, index*2+1, start, mid);
        update(changed_index, diff, index*2+2, mid+1, end);
    }
}

여기서 주의해야 할 점은 diff 는 바꿀 새로운 값이 아닙니다. diff = 새로 바꿀 값 - A[changed_index] (기존의 값) 으로서 새로 바꿀 값과 기존의 값의 차이입니다.

  • diff = 새로 바꿀 값 - A[changed_index] (기존의 값)

  • A[changed_index] = 새로 바꿀 값

  • 재귀적으로 양 쪽 자식노드로 나눠가며 start == end 가 될 때 까지, 즉 리프 노드가 될 때까지 탐색을 합니다. 탐색을 할 시에는 탐색 범위 안에 없다면 return 하고 탐색 범위 안에 있다면 변경된 노드의 증가값 diff 만큼 노드에 더해줍니다.