전공/알고리즘

백준 2042(구간 합 구하기)

xkdlaldfjtnl 2020. 6. 6. 17:40

이 문제는 전형적인 세그먼트 트리 문제이다. 

 

세그먼트 트리는 어떤 연속된 구간의 합 차 등 어떤 계산의 반복을 미리 쪼개어 놔서 시간을 월등히 줄여주는 자료구조이다.

특징은 완전 이진 트리로 구성되어 배열 자료구조를 사용하며 맨 위 노드의 index를 1로 잡아서 좌측 노드는 2*index, 우측 노드는 2*index+1이다.

 

위 특징을 이용해서 구현을 진행한다.

 

단순 계산으로는 O(쿼리의 횟수 * 배열의 크기) O(mn)의 시간복잡도가

 

세그먼트 트리로는 O(mlogn)로 줄어든다.

 

 

코드의 구성은

 

세그먼트 트리 구현,

쿼리 함수 구현

 

이렇게 나뉜다고 볼 수 있다.

 

세그먼트 트리의 구현은 n의 값이 2의 거듭제곱꼴이 아닌 어떤 자연수이므로,

top - down 형식의 재귀적 방법을 사용한다. 

 

앞선 문제에서 언급했듯이 재귀적 구성은 

leaf 노드까지 도달한 뒤 다시 재귀적 특성을 이용해서 올라와서 나머지 노드들을 채우는 꼴로 진행된다. 

 

ll init(int node, int left, int right) {
	if (left == right) return tree[node] = arr[left]; // leaf 노드의 값 

	int mid = (left + right) / 2;
	ll m1 = init(node * 2, left, mid); // 우측 자식
	ll m2 = init(node * 2 + 1, mid + 1, right); // 좌측 자식
	return tree[node] = m1+m2; // 우측과 좌측 자식의 합 (문제 조건)
}

 쿼리 함수는 두가지로 int query와 void change 로 두 함수를 구성하였다.

 

int query 함수

ll query(int node, int start, int end, int left, int right) {
	if (left <= start && right >= end) return tree[node]; // 찾고자 하는 구간과 노드의 구간 비교
	if (left > end || right < start) return 0; // 구간을 벗어난 경우
	int mid = (start + end) / 2;
	ll m1 = query(node * 2, start, mid, left, right); // 좌측 자식
	ll m2 = query(node * 2 + 1, mid + 1, end, left, right); // 우측 자식
	return m1 + m2; 
} 

void change 함수

 

void change(int node, int start, int end, int index, int after) {
	if (start == end && start == index) { //찾는 index인 경우
		tmp = after-tree[node]; //차이를 저장 
		tree[node] = after;
		return;
	}
	int mid = (start + end) / 2; 
	if (index > mid) { 
		change(node * 2 + 1, mid + 1, end, index, after);
	}
	else {
		change(node * 2, start, mid, index, after);
	}
	tree[node] += tmp; //해당 index값이 쓰인 모든 부모노드의 값 + 차이 
}
#include<iostream>
#include<cmath>
#include<vector>
using namespace std;

typedef long long ll;

vector<ll> arr;
vector<ll> tree;
int N, M, K, a, b, c;
ll tmp;

ll init(int node, int left, int right) {
	if (left == right) return tree[node] = arr[left];

	int mid = (left + right) / 2;
	ll m1 = init(node * 2, left, mid);
	ll m2 = init(node * 2 + 1, mid + 1, right);
	return tree[node] = m1+m2;
}

ll query(int node, int start, int end, int left, int right) {
	if (left <= start && right >= end) return tree[node];
	if (left > end || right < start) return 0;
	int mid = (start + end) / 2;
	ll m1 = query(node * 2, start, mid, left, right);
	ll m2 = query(node * 2 + 1, mid + 1, end, left, right);
	return m1 + m2;
}

void change(int node, int start, int end, int index, int after) {
	if (start == end && start == index) {
		tmp = after-tree[node];
		tree[node] = after;
		return;
	}
	int mid = (start + end) / 2;
	if (index > mid) {
		change(node * 2 + 1, mid + 1, end, index, after);
	}
	else {
		change(node * 2, start, mid, index, after);
	}
	tree[node] += tmp;
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);

	cin >> N >> M >> K;
	arr.resize(N);
	int h = (int)ceil(log2(N));
	int tree_size = (1 << (h + 1));
	tree.resize(tree_size);
	for (int i = 0; i < N; i++)
		cin >> arr[i];
	init(1, 0, N - 1);
	for (int i = 0; i < M + K; i++) {
		tmp = 0;
		cin >> a >> b >> c;
		if (a == 1) {
			change(1, 0, N - 1, b-1, c);
		}
		else if (a == 2) {
			cout << query(1, 0, N - 1, b-1, c-1) << "\n";
		}
	}
}

전형적인 세그먼트 구현 문제이고 개념도 쉽게 터득할 수 있는 것 같다. 재귀적 성질에 대한 이해만 제대로 된다면.