고난이도 자료구조 세그먼트 트리 구현하면서 이해하기 3편(값을 업데이트 하는 방법)

1. 배열의 일부 원소를 바꾸는 경우

 

배열의 index번째 수를 value로 바꾸고자 한다면 세그먼트 트리를 어떻게 구성해야할까

 

1-1) 당연히 해당 index번째 수를 포함하고 있는 구간의 합을 저장하고 있는 모든 노드의 값을 바꿔주면 된다

 

원래 index번째 수가 A[index]였다고 하자. 해당 구간의 합은 s+A[index]인데

 

A[index]가 value로 바뀌었다고 한다면? 해당 구간의 합은 s+value이다.

 

그러면 해당 구간의 합 s+A[index]는 s+value로 바꿀려면 어떻게 해야할까?

 

당연히 value-A[index]를 s+A[index]에 더해주면 된다

 

1-2) 노드가 저장하는 구간이 [start,end]라고 한다면, 당연히 index가 [start,end]에 포함되는 경우와

 

포함되지 않는 경우로 나눌 수 있다.

 

index가 [start,end]에 포함되지 않으면 당연히 그 아래 자식 노드들도 index를 포함하지 않으므로 재귀호출로 탐색할 필요는 없다

 

index가 [start,end]에 포함된다면 당연히 재귀호출로 계속 탐색하면서 해당 노드의 값을 바꿔준다

 

 

2. 구현 예시1

 

배열 A의 index번째 수를 value로 바꾼다면.. A[index] = value로 바꿔주고

 

tree_index에 저장된 합은 value-A[index]를 더해주면 된다는 것을 위에서 보였다

 

index가 [start,end]에 포함되지 않는다면, 탐색을 하지 않고 return하면 되는데

 

index가 [start,end]에 포함된다면, value-A[index]값을 더해주면서 값을 바꿔나간다

 

만약 리프노드가 아니라면, 왼쪽, 오른쪽으로 나누어서 계속 탐색을 수행해나간다

 

def update_tree(tree,tree_index,start,end,index,difference):
    
    #index가 [start,end]에 포함되지 않는 경우
    
    if index < start or index > end:
        
        return
    
    #index가 포함되는 구간이라면.. 해당 노드의 값을 바꿔줌
    tree[tree_index] = tree[tree_index] + difference
    
    #리프노드가 아닌 경우, 왼쪽,오른쪽 자식으로 탐색함
    
    mid = (start + end) // 2
    
    if start != end:
        
        update_tree(tree,2*tree_index,start,mid, index, difference)
        update_tree(tree,2*tree_index+1,mid+1,end,index,difference)

def update(A,tree,N,index,value):
    
    difference = value - A[index] #합을 변경시킬 값
    
    A[index] = value #index값이 value로 바뀜
    
    update_tree(tree,1,0,n-1,index,difference) #루트부터 탐색해서 값이 바뀐다

 

3. 구현 예시2

 

다른 방법으로 리프 노드를 찾을 때까지 재귀 호출을 계속 수행하고, 리프 노드를 찾으면 해당 노드의 값을 바꿔준다.

 

그리고 재귀 함수가 return될때마다 바뀌어 있는 자식 노드의 값을 이용해서 상위 노드의 값을 바꿔나가는 방법이 있다

 

def update(A,tree, tree_index, start, end, index, value):
    
    #index가 [start,end]에 포함되어 있지 않다
    
    if index < start or index > end:
        
        return
    
    #리프노드를 찾았다면... 값을 바꿔준다
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = value
        return
    
    #리프노드가 아니라면, 왼쪽, 오른쪽으로 나누어서 탐색
    
    mid = (start + end) // 2
    
    update(A,tree,2*tree_index, start, mid, index, value)
    update(A,tree,2*tree_index+1,mid+1,end, index, value)
    
    #update 함수가 return되면 자식 노드의 값들은 변경되어 있다
    tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]

 

4. 시간복잡도

 

트리의 높이 $h= \left \lceil log_{2} N \right \rceil $로 구해지므로, 값을 바꾸는 시간복잡도는 O(logN)이다.

 

logN개를 탐색하면서 값을 바꾼다는 뜻이다.

 

 

5. 연습문제

 

2042번: 구간 합 구하기 (acmicpc.net)

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

 

중간에 수가 바뀔 수 있을때, 주어진 구간의 합을 계속해서 구하는 문제

 

인덱스 같은거 주의하면서 외워질때까지 자주 반복하기

 

import math
from sys import stdin

#세그먼트 트리 생성

def create_segment(A,tree,tree_index,A_start,A_end):
    
    #리프 노드라면

    if A_start == A_end:
        
        tree[tree_index] = A[A_start]
    
    else: #리프노드가 아니라면
        
        A_mid = (A_start + A_end) // 2

        create_segment(A,tree,2*tree_index, A_start, A_mid)
        create_segment(A,tree,2*tree_index+1, A_mid+1, A_end)

        tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]

#구간의 합을 구하는 함수
def query_sum(tree,tree_index,start,end,left,right):
    
    #[left,right]가 [start,end]를 벗어났다

    if left > end or right < start:
        
        return 0
    
    #[left,right]가 [start,end]를 포함한다

    if left <= start and end <= right:
        
        return tree[tree_index]
    
    #그 이외의 경우 왼쪽, 오른쪽 탐색

    mid = (start+end)//2

    left_sum = query_sum(tree,2*tree_index, start,mid,left,right)
    right_sum = query_sum(tree,2*tree_index+1,mid+1,end,left,right)

    return left_sum + right_sum

#배열의 값을 바꾸는 함수

def update_tree(tree,tree_index,start,end,index,value,difference):
    
    #index가 [start,end]에 포함되지 않는다
    
    if index < start or index > end:
        
        return
    
    #index가 [start,end]에 포함된다
    
    tree[tree_index] = tree[tree_index] + difference

    #리프노드가 아니라면 왼쪽, 오른쪽으로 나누어 탐색

    mid = (start + end) // 2

    if start != end: #이 부분 까먹었는데... 주의

        update_tree(tree,2*tree_index, start, mid, index, value, difference)
        update_tree(tree,2*tree_index+1,mid+1,end, index, value, difference)

def update(A,tree,N,index,value):
    
    difference = value - A[index]

    A[index] = value

    update_tree(tree,1,0,N-1,index,value,difference)

N,m,k = map(int,stdin.readline().split())

#A배열 만들기

A = []

for _ in range(N):
    
    a = int(stdin.readline())

    A.append(a)

#세그먼트 트리 초기화

n = 2**(math.ceil(math.log2(N))+1) -1

tree = [0]*(n+1)

create_segment(A,tree,1,0,N-1)

for _ in range(m+k):
    
    a,b,c = map(int,stdin.readline().split())

    if a == 1:
        
        update(A,tree,N,b-1,c)
    
    elif a == 2:
        
        print(query_sum(tree,1,0,N-1,b-1,c-1))

 

 

참조

 

41. 세그먼트 트리(Segment Tree) : 네이버 블로그 (naver.com)

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

 

 

세그먼트 트리 (acmicpc.net)

 

세그먼트 트리

누적 합을 사용하면, 1번 연산의 시간 복잡도를 $O(1)$로 줄일 수 있습니다. 하지만, 2번 연산으로 수가 변경될 때마다 누적 합을 다시 구해야 하기 때문에, 2번 연산의 시간 복잡도는 $O(N)$입니다.

book.acmicpc.net

 

 

TAGS.

Comments