HwangHub

자바로 알고리즘 시작하기 6 - 세그먼트 트리 본문

CS-STUDY/자료구조 & 알고리즘

자바로 알고리즘 시작하기 6 - 세그먼트 트리

HwangJerry 2024. 2. 8. 18:22

기왕 배운 김에 정리합니다. 코테 빈출 유형은 아니라서, 제 블로그 아무도 안보긴 하겠지만... 혹시나 누군가 본다면 그냥 얘 이런것도 배웠구나 하고 넘어가시면 됩니다.

세그먼트 트리란

세그먼트(Segment)는 '부분'을 의미합니다. 세그먼트 트리는 말 그대로 각 노드에 전체 배열의 부분 정보를 저장시켜두고, 이를 완전 이진 트리로 구성하여 빠르게 해당 값을 탐색할 수 있도록 구현한 자료구조입니다. 각 노드에 담기는 데이터는 대표적으로 "구간합", 그리고 특정 구간에서의 "최대/최소값"이 있습니다.

 

이 데이터들을 완전 이진 트리로 관리한다는 게 장점인데, 이를 큰 범위의 배열로 관리하기 때문에 메모리를 많이 사용한다는 점을 감안해야 합니다. 공간을 내어주고 시간 활용도를 많이 높이는 전형적인 자료구조입니다. 세그먼트 트리는 특정 인덱스의 데이터보다 "구간" 상의 특정 데이터 값을 활용하여 연산하고, 값이 빈번하게 수정될 경우 구간 값이 어떻게 변하는지를 체크할 때 용이합니다.

 

구현 방법

배열을 활용하여 트리를 구성하려면 우선 트리의 크기부터 설정해야 합니다. 세그먼트 트리는 완전 이진 트리를 응용한 구조이므로 각 노드가 최대 2개의 자식 노드를 가지며, 트리의 높이가 h일 때 트리의 각 층은 다음과 같이 구성됩니다.

  • level 0 (root) : 1개의 노드
  • level 1 : 2개의 노드
  • level 2 : 4개의 노드
  • level 3 : 8개의 노드
  • ...
  • level h : 2^h개의 노드

따라서 이를 바탕으로 등비수열의 합 공식에 따라 2^(h + 1) - 1로 계산되지만 설정하되, 간단하게 활용하기 위해 트리 사이즈는 일반적으로 다음과 같이 설정합니다.

int h = (int) Math.ceil(Math.log(arr.length)/Math.log(2));
int treeSize = (int) Math.pow(2, h + 1);
tree = new long[treeSize];

 

 

배열 상에서 트리를 구현하기 위해, 왼쪽 노드와 오른쪽 노드의 인덱스는 다음과 같이 구성합니다.

  • left node idx = parent node idx * 2;
  • right node idx = parent node idx * 2 + 1;

만약 위 사이즈가 갑자기 기억이 안날 때에는 간단하게 arr.length * 4로 대체하여 표현할 수도 있습니다. 이를 바탕으로 구간에서 필요한 값을 저장할 수 있도록 초기화 메서드, 업데이트 메서드, 값 조회 메서드를 구현하여 사용하면 됩니다. internal node(중간 노드)에 우리가 원하는 구간에 대한 관리 값이 들어가고, 리프 노드에는 관리하려는 원본 배열의 데이터들이 들어갑니다. 이 리프 노드는 start == end인 노드를 찾아가면 데이터를 얻을 수 있습니다.

중요한 점

세그먼트 트리를 활용하기 위해서는 각 internal node에, 쉽게 말해서 root 노드가 어떤 데이터를 담고 있을지를 명확하게 정의해야 합니다.

  • 리프노드에 무엇을 저장할지
  • 트리의 중간값에 무엇을 저장할지
  • 중간 값에 어떻게 찾아갈건지

 

이를 좀 더 명확하게 느끼기 위해 문제를 풀어봅시다.

 

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

구간 합 구하기는 세그먼트 트리를 연습하는 가장 대표적인 문제입니다. 풀이 코드는 다음과 같습니다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    private static long[] tree; // 트리 배치
    private static long[] arr; // 문자 배치

    // 구간합 세그먼트 트리 초기화
    private static long init(int node, int start, int end) {
        if (start == end) { // leaf node인 경우
            return tree[node] = arr[start];
        }

        int mid = (start + end) >> 1;
        return tree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end);
    }

    private static long sum(int node, int start, int end, int left, int right) {
        // 범위를 아예 벗어난 경우에는 0을 더함
        if (left > end || right < start) {
            return 0;
        }

        // 범위 내에 있는 internal node에 대하여는 그 값을 반환
        if (left <= start && end <= right) {
            return tree[node];
        }

		// 구간합 연산
        int mid = (start + end) >> 1;
        return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
    }

    private static void update(int node, int start, int end, int idx, long diff) {
        // 범위 외 노드에 대하여는 pass
        if (idx < start || idx > end) {
            return;
        }

		// 지나가는 노드의 값을 업데이트 (리프노드까지)
        tree[node] += diff;

		// 리프 노드에 도착하면 업데이트 종료
        if (start == end) {
            return;
        }

		// 이분 탐색을 이용하여 다음 노드로 진행
        int mid = (start + end) >> 1;
        update(node * 2, start, mid, idx, diff);
        update(node * 2 + 1, mid + 1, end, idx, diff);
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken()); // update 횟수
        int k = Integer.parseInt(st.nextToken()); // 구간합 횟수

        arr = new long[n+1];
        int h = (int) Math.ceil(Math.log(arr.length) / Math.log(2));
        int treeSize = (int) Math.pow(2, h + 1);
        tree = new long[treeSize];
//        tree = new long[n << 2];


        for (int i = 1; i <= n; i++) {
            long num = Long.parseLong(br.readLine().trim());
            arr[i] = num;
        }
        init(1, 1, n);

        StringBuffer sb = new StringBuffer();

        for (int i = 0; i < m + k; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());

            if (a == 1) { // update
                long diff = c - arr[b];
                arr[b] = c; // 배열 값도 업데이트
                update(1, 1, n, b, diff);
            } else { // sum
                long res = sum(1, 1, n, b, (int) c);
                sb.append(res).append("\n");
            }
        }
        System.out.println(sb);
    }
}

 

유사하게 연습해볼 문제는 다음 구간곱 구하기 입니다.

 

 

 

11505번: 구간 곱 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 곱을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

풀이 코드는 다음과 같습니다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    static long[] tree;
    static long[] arr;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken()); // 숫자 입력량
        int m = Integer.parseInt(st.nextToken()); // 업데이트 획수
        int k = Integer.parseInt(st.nextToken()); // 구간 곱 횟수
        arr = new long[n];

        for (int i = 0; i < n; i++) {
            long num = Long.parseLong(br.readLine());
            arr[i] = num;
        }

        int h = (int) Math.ceil(Math.log(arr.length) / Math.log(2));
        int treeSize = (int) Math.pow(2, h + 1);
        tree = new long[treeSize];

        init(1, 0, n - 1);

        StringBuffer sb = new StringBuffer();

        for (int i = 0; i < m + k; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken()) - 1;
            long c = Long.parseLong(st.nextToken());

            if (a == 1) {
                long diff = mod(c - arr[b]);
                arr[b] = c;
                update(1, 0, n - 1, b, diff);

            } else {
                long res = mul(1, 0, n - 1, b, (int) c - 1);
                sb.append(res).append("\n");
            }
        }
        System.out.println(sb);
    }

    public static long init(int node, int start, int end) {
        if (start == end) {
            return tree[node] = arr[start];
        }

        int mid = (start + end) >> 1;
        return tree[node] = mod(init(node * 2, start, mid) * init(node * 2 + 1, mid + 1, end));
    }

    public static void update(int node, int start, int end, int idx, long diff) {
        if (idx < start || idx > end) {
            return;
        }

        if (start == end) {
            tree[node] += diff;
            return;
        }

        int mid = (start + end) >> 1;
        update(node * 2, start, mid, idx, diff);
        update(node * 2 + 1, mid + 1, end, idx, diff);

		// mul() 메서드에서 범위 내 return tree[node] 하기 위해선 업데이트 코드가 돌 때 internal node를 업데이트
        tree[node] = mod(tree[node * 2] * tree[node * 2 + 1]);
    }

    public static long mul(int node, int start, int end, int left, int right) {
        // 범위 나가는 곱은 무효 처리
        if (left > end || right < start) {
            return 1;
        }

        // 범위 내 값은 이미 있는 값으로 처리
        if (left <= start && end <= right) {
            return tree[node];
        }

        int mid = (start + end) >> 1;
        return mod(mul(node * 2, start, mid, left, right) * mul(node * 2 + 1, mid + 1, end, left, right));
    }

    public static long mod(long v) {
        return v % ((long) 1e9 + 7);
    }
}

 

만약 세그먼트 트리를 더 잘 이해하고 싶다면, 다음 문제를 풀어보면 됩니다. (이 문제는 제가 잘 풀게 되면 정답 코드를 추후 추가하겠습니다 🥲)

 

1168번: 요세푸스 문제 2

첫째 줄에 N과 K가 빈 칸을 사이에 두고 순서대로 주어진다. (1 ≤ K ≤ N ≤ 100,000)

www.acmicpc.net

 

Comments