Algorithm

[Algo] 백준 2042 구간 합 구하기

조핑구 2024. 8. 1. 10:14

문제 바로가기 : https://www.acmicpc.net/problem/2042

메모리 : 103268KB

시간 : 616 ms

언어 : Java 11


코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;


public class Main {
    static class SegmentTree {
        long[] tree;
        int treeSize;

        /**
         * 세그먼트트리 만들기
         */
        public SegmentTree(int arrSize) {
						//size N*4해도됨
            int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2)); //트리의 높이
            this.treeSize = (int) Math.pow(2, h + 1); //트리의 사이즈
            tree = new long[treeSize]; //트리 사이즈로 배열 생성
        }

        /**
         * 트리에 값 넣기
         */
        //node: 현재노드/ start: 배열의 시작/ end: 배열의 끝/
        public long init(long[] arr, int node, int start, int end) {
            if (start == end) { //leaf 노드다!
                return tree[node] = arr[start];
            }
            //leaf 노드가 아니면 자식 노드 값의 합을 담는다.
            return tree[node] = init(arr, node * 2, start, (start + end) / 2) //왼쪽노드
                    + init(arr, node * 2 + 1, (start + end) / 2 + 1, end); //오른쪽노드
        }

        /**
         * 트리에 숫자 고치기
         */
        public void update(int node, int start, int end, int idx, long diff) {
            //만약 변경할 index값이 범위 바깥이면 확인 불필요
            if (idx < start || end < idx) return;
            tree[node] += diff;
            if (start != end) { //leaf 노드가 아니면 아래 자식들도 확인
                update(node * 2, start, (start + end) / 2, idx, diff);
                update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
            }
        }

        /**
         * 구간합 구하기
         */
				// left: 구간합을 구할 범위1 right: 범위2
        public long sum(int node, int start, int end, int left, int right) {
            if (left > end || right < start) { //범위를 벗어나면 return
                return 0;
            }
            if (left <= start && end <= right) { //노드가 해당 범위의 구간합을 가지고있다면 값을 리턴
                return tree[node];
            }
            return sum(node * 2, start, (start + end) / 2, left, right) +
                    sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        StringBuilder sb = new StringBuilder();
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());
        long[] arr = new long[N + 1];
        for (int i = 1; i <= N; i++) {
            arr[i] = Long.parseLong(br.readLine());
        }
        SegmentTree tree = new SegmentTree(N);
        tree.init(arr, 1, 1, N);

        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            if (a == 1) { //바꾸기
                int b = Integer.parseInt(st.nextToken());
                long c = Long.parseLong(st.nextToken());
                tree.update(1, 1, N, b, c - arr[b]);
                arr[b] = c; //기본배열도 바꿔주기
            } else { //구간합
                int b = Integer.parseInt(st.nextToken());
                int c = Integer.parseInt(st.nextToken());
                sb.append(tree.sum(1, 1, N, b, (int) c) + "\n");
            }
        }
        System.out.println(sb);
    }
}

 

 

정석적인 세그먼트 트리 문제. 

트리 그래프를 그려보면서 코드를 따라가보면 이해가 쉽다.