반복문으로 구현하는 세그먼트 트리(iterative segment tree) 배우기

1. 반복문으로 구현하는 세그먼트 트리

 

세그먼트 트리 기본 버전은 재귀함수로 구현되어 있다

 

import math

def create_segment(A,tree,tree_index,start,end):
    
    if start == end:
        
        tree[tree_index] = A[start]
    
    else:
        
        mid = (start+end)//2

        create_segment(A,tree,2*tree_index,start,mid)
        create_segment(A,tree,2*tree_index+1,mid+1,end)

        tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]

def query_sum(tree,tree_index,start,end,left,right):
    
    if left > end or right < start:
        
        return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_sum = query_sum(tree,2*tree_index,start,mid,left,right)
    right_sum = query_sum(tree,2*tree_index+1,mid+1,end,left,right)

    return left_sum + right_sum

def update(A,tree,tree_index,start,end,index,value):
    
    if index < start or end < index:
        
        return
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = value
        return
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]


N = int(input())

A = list(map(int,input().split()))

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [0]*(n+1)

create_segment(A,tree,1,0,N-1)

 

 

근데 재귀함수는 반복문보다 느리다는 사실은 잘 알려져있다..

 

문제를 푸는데 아무리 해도 맞는데.. 시간초과가 나니까

 

설마 재귀함수라서 느린건가.. 생각이 들어 반복문으로 구현할 수 있는지 찾아보았는데..

 

실제로 있더라고

 

Iterative Segment Tree (Range Minimum Query) - GeeksforGeeks

 

Iterative Segment Tree (Range Minimum Query) - GeeksforGeeks

A Computer Science portal for geeks. It contains well written, well thought and well explained computer science and programming articles, quizzes and practice/competitive programming/company interview Questions.

www.geeksforgeeks.org

 

 

대충 코드만 있긴하던데..

 

코드라도 외우고 필요하면 써먹어야지

 

일단 특징이 배열 A의 크기가 n일때, 세그먼트 트리의 크기는 2n으로 초기화하는데

 

이게 왜 그런지 대충 분석해본 결과

 

https://deepdata.tistory.com/553

 

고난이도 자료구조 세그먼트 트리 개념 이해하기 1편

1. 특정한 구간에 존재하는 모든 수의 합 어떤 수열이 주어질때, 만약 특정 구간 [a,b]에 존재하는 모든 수의 합을 구하라고 한다면 어떻게 구할 수 있을까? 가장 쉬운 방법은, 그냥 반복문으로 [a,b

deepdata.tistory.com

 

"배열의 크기가 N일때, 세그먼트 트리에서 필요한 노드의 개수가 2N-1이다"에 답이 있는듯?

 

#반복문으로 구현하는 세그먼트 트리
#배열 A의 크기가 n이면, 세그먼트 트리에서 필요한 노드 수는 2n-1개

def create_segment(A,tree,n):
    
    #세그먼트 트리의 리프 노드를 결정
    for i in range(n):
        
        tree[n+i] = A[i]
    
    #나머지 노드의 값을 아래에서부터 올라가면서 결정
    #부모 노드가 i이면, 왼쪽 자식은 2i, 오른쪽 자식은 2i+1
    for i in range(n-1,0,-1):
        
        tree[i] = tree[2*i] + tree[2*i+1]

 

 

근데... update랑 range query를 구하는 함수가 사실 이해가 안됨 .. ㅜㅜ

 

언젠가 이해할 수 있지 않을까.. 허허

 

일단 외우기라도 해

 

def query_sum(tree,left,right,n):
    
    left += n
    right += n

    """ Basically the left and right indices
    will move towards right and left respectively
    and with every each next higher level and
    compute the minimum at each height change
    the index to leaf node first """
    
    result = 0 #합의 초기값

    while left < right:
        
        if left & 1: #left index가 홀수라면...
            
            result += tree[left]
            left += 1
        
        if right & 1: #right index가 홀수라면...
            
            right -= 1
            result += tree[right]
        
        #move to the next higher level
        left //= 2
        right //= 2
    
    return result

def update(tree,index,value,n):
    
    #change the index to leaf node first
    index += n

    #update the value at the leaf node, at the exact index
    tree[index] = value

    while index > 1:
        
        #move up one level at a time in the segment tree
        index >>= 1

        #update the value in the node in the next higher level
        tree[index] = tree[2*index] + tree[2*index+1]

 

update는 보니까 대충 알겠네...

 

leaf 부분부터 업데이트 시키고, 부모 노드가 현재 노드의 1/2이니까, 올라가면서 업데이트 시키면 된다는 소리인것 같은데

 

range query는 ... 나중에 각잡고 분석해봐야 알듯..

 

 

2. 연습문제

 

25778번: House Prices Going Up (acmicpc.net)

 

25778번: House Prices Going Up

The first input line contains an integer, n (1 ≤ n ≤ 5 ×105), indicating the number of houses in VC. Each of the next n input lines contains an integer (between 1 and 109, inclusive) indicating the initial price for a house; first integer is the price

www.acmicpc.net

 

 

3. 풀이

 

재귀 버전으로 세그먼트 트리 구현하면 시간초과나더라

 

반복문 버전으로 구현해야 통과가능

 

그대로 하면 되는데..

 

중요한게 geeksforgeeks에도 나와있긴 하지만 range query의 값을 구할때..

 

[left_index, right_index]의 합을 구하고 싶다면...

 

query_sum 함수에 left_index, right_index를 넣는게 아니라

 

left_index, right_index+1을 넣어줘야한다.

 

왜 인지는 모름.... 이러면 안되는데.. 언젠가 분석해볼 기회가 있으려나

 

import math
from sys import stdin

def create_segment(A,tree,n):
    
    for i in range(n):
        
        tree[n + i] = A[i]
    
    for i in range(n-1,0,-1):
        
        tree[i] = tree[2*i] + tree[2*i+1]

def query_sum(tree,left,right,n):
    
    left += n
    right += n

    m = 0

    while (left < right):
        
        if (left & 1):
            
            m += tree[left]
            left += 1
        
        if (right & 1):
            
            right -= 1
            m += tree[right]
        
        left = left//2
        right = right//2
    
    return m

def update(tree,index,value,n):
    
    index += n

    tree[index] += value

    while (index > 1):
        
        index >>= 1

        tree[index] = tree[2*index] + tree[2*index+1]


N = int(stdin.readline())

A = []

for _ in range(N):
    
    h = int(stdin.readline())

    A.append(h)

n = 2*N - 1

tree = [0]*(n+1)

create_segment(A,tree,N)

t = int(stdin.readline())

for _ in range(t):
    
    q,a,b = stdin.readline().rstrip().split()

    a = int(a)
    b = int(b)

    if q[0] == 'U':

        update(tree,a-1,b,N)
    
    else:
        
        print(query_sum(tree,a-1,b,N))
TAGS.

Comments