느리게 갱신되는 세그먼트 트리 응용 -리프노드의 값을 출력하는 트리?-

1. 문제

 

16975번: 수열과 쿼리 21 (acmicpc.net)

 

16975번: 수열과 쿼리 21

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. 1 i j k: Ai, Ai+1, ..., Aj에 k를 더한다. 2 x: Ax 를 출력한다.

www.acmicpc.net

 

주어진 구간에 특정 값을 더해주고, 배열 A의 x번째 수를 출력하는 문제

 

 

2. 풀이

 

lazy propagation을 쓰지 않고 팬윅 트리를 이용하는 방법도 있다는데..

 

아직 팬윅 트리는 공부하지 않았으니까 넘어가고 정직하게 lazy propagation으로 풀어보자

 

먼저 목표는 A 배열의 값을 바꾸고, 구간의 합이나 곱이나 이런게 아니라 결국에 A[x]를 출력하는 것이다.

 

그러니까, 리프노드의 값을 출력하는게 결국 목표다.

 

그러니까 구간의 값을 저장하는 tree에 합을 저장하거나 그럴 필요는 없다

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]

    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #아무 연산도 하지 않는다

 

왼쪽, 오른쪽으로 나누어 들어간 다음에, 굳이 tree[tree_index]에 구간합 등을 저장하지 않는다

 

lazy propagation으로 리프노드도 업데이트가 되니까, lazy 배열 업데이트 하는 것은 그대로

 

def update_lazy(tree,lazy,tree_index,start,end):
    
    if lazy[tree_index] != 0:
        
        tree[tree_index] += (end-start+1)*lazy[tree_index]

        if start != end:
            
            lazy[2*tree_index] += lazy[tree_index]
            lazy[2*tree_index+1] += lazy[tree_index]
        
        lazy[tree_index] = 0

 

 

구간의 값을 변경하는 함수도.. 그대로 쓰면 되는데 

 

[left,right]와 [start,end]가 서로 일부만 겹치는 경우에 굳이 구간합이나 구간곱 등을 저장하는 연산은 수행하지 않는다

 

리프노드를 변경시키는게 목표라서 [left,right]가 [start,end]를 완전히 포함하는 경우만 변경시키자

 

def update_range(tree,lazy,tree_index,start,end,left,right,value):
    
    update_lazy(tree,lazy,tree_index,start,end)
    
    if left > end or right < start:
        
        return
    
    if left <= start and end <= right:
        
        tree[tree_index] += (end-start+1)*value

        if start != end:
            
            lazy[2*tree_index] += value
            lazy[2*tree_index+1] += value
        
        return
    
    mid = (start+end)//2

    update_range(tree,lazy,2*tree_index,start,mid,left,right,value)
    update_range(tree,lazy,2*tree_index+1,mid+1,end,left,right,value)
    
    #아무 연산도 수행하지 않는다

 

 

그러면 리프노드의 값을 출력하는 함수는?

 

index를 받으면 A[index]를 출력하고 싶다.

 

index가 [start,end]를 벗어나면 return해서 더 이상 탐색하지 않고,

 

start == end로 리프노드에 도달했는데, index == end이면, 해당 index에 맞는 값이므로 tree[tree_index] 값을 return

 

그 이외의 경우에는 왼쪽 오른쪽으로 나누어 탐색하는데...

 

def query_print(tree,lazy,tree_index,start,end,index):
    
    update_lazy(tree,lazy,tree_index,start,end)
    
    if index > end or index < start:
        
        return
    
    if start == end and index == end:
        
        return tree[tree_index]
    
    mid = (start+end)//2

   query_print(tree,lazy,2*tree_index,start,mid,index)
   query_print(tree,lazy,2*tree_index+1,mid+1,end,index)

 

이러면 왜 안됨??

 

 그냥 전체 query_print()함수가 반드시 None을 return하기 때문임

 

 

하나의 쿼리 query_print(tree,lazy,1,0,N-1,index)를 수행하면...

 

왼쪽, 오른쪽으로 재귀가 들어가면서

 

query_print(tree,lazy,2,0,N-1//2,index)

query_print(tree,lazy,3,N-1//2+1,N-1,index)가 수행되겠지..

 

그러면 각각이 수행되면서 return1과 return2가 수행될거임...

 

그리고 나서  query_print(tree,lazy,1,0,N-1,index)의 return3이 수행되는거다.

 

그런데 return1과 return2는 어디 변수에 저장이 안되니까.. 그냥 날라가고..

 

return3은 현재 return문이 없으니 None이 return됨

 

어쨌든 재귀 호출을 수행하면서 둘중 하나를 반드시 return한다

 

None을 return하거나, index에 도달하면 A[index]를 return하거나

 

왼쪽으로 들어간 재귀경로나 오른쪽으로 들어간 재귀경로 둘중 하나는 None이고 다른것은 A[index]라는 말

 

def query_print(tree,lazy,tree_index,start,end,index):
    
    update_lazy(tree,lazy,tree_index,start,end)
    
    if index > end or index < start:
        
        return
    
    if start == end and index == end:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_value = query_print(tree,lazy,2*tree_index,start,mid,index)
    right_value = query_print(tree,lazy,2*tree_index+1,mid+1,end,index)

    if left_value == None:
        
        return right_value
    
    elif right_value == None:
        
        return left_value

 

 

이렇게 해야 전체 query_print(tree,lazy,1,0,N-1,index)의 return값이 있겠지?

 

그러므로.. 리프노드를 출력하는 세그먼트 트리는..

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]

    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #아무 연산도 하지 않는다

def update_lazy(tree,lazy,tree_index,start,end):
    
    if lazy[tree_index] != 0:
        
        tree[tree_index] += (end-start+1)*lazy[tree_index]

        if start != end:
            
            lazy[2*tree_index] += lazy[tree_index]
            lazy[2*tree_index+1] += lazy[tree_index]
        
        lazy[tree_index] = 0


def query_print(tree,lazy,tree_index,start,end,index):
    
    update_lazy(tree,lazy,tree_index,start,end)
    
    if index > end or index < start:
        
        return
    
    #원하는 리프노드 index에 도달하면... 해당 값을 return
    if start == end and index == end:
        
        return tree[tree_index]
    
    mid = (start+end)//2
    
    #왼쪽으로 들어간 경로와 오른쪽으로 들어간 경로는 둘 중 하나
    #None이거나, A[index]

    left_value = query_print(tree,lazy,2*tree_index,start,mid,index)
    right_value = query_print(tree,lazy,2*tree_index+1,mid+1,end,index)

    if left_value == None:
        
        return right_value
    
    elif right_value == None:
        
        return left_value

def update_range(tree,lazy,tree_index,start,end,left,right,value):
    
    update_lazy(tree,lazy,tree_index,start,end)
    
    if left > end or right < start:
        
        return
    
    if left <= start and end <= right:
        
        tree[tree_index] += (end-start+1)*value

        if start != end:
            
            lazy[2*tree_index] += value
            lazy[2*tree_index+1] += value
        
        return
    
    mid = (start+end)//2

    update_range(tree,lazy,2*tree_index,start,mid,left,right,value)
    update_range(tree,lazy,2*tree_index+1,mid+1,end,left,right,value)
    
    #아무 연산도 하지 않는다
    

N = int(stdin.readline())

A = list(map(int,stdin.readline().split()))

m = int(stdin.readline())

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [0]*(n+1)
lazy = [0]*(n+1)

create_segment(A,tree,1,0,N-1)

for _ in range(m):
    
    query = list(map(int,stdin.readline().split()))

    if query[0] == 1:
        
        update_range(tree,lazy,1,0,N-1,query[1]-1,query[2]-1,query[3])
    
    elif query[0] == 2:
        
        print(query_print(tree,lazy,1,0,N-1,query[1]-1))

 

TAGS.

Comments