세그먼트 트리 응용 - XOR 연산을 수행하는 트리

1. 문제

 

14245번: XOR (acmicpc.net)

 

14245번: XOR

첫 번째 줄에 수열의 크기 n (0 < n ≤ 500,000)이 주어진다. 두 번째 줄에 수열의 원소가 0번부터 n - 1번까지 차례대로 주어진다. 수열의 원소는 100,000보다 크지 않은 음이 아닌 정수이다. 세 번째 줄

www.acmicpc.net

 

구간에 어떤 수를 XOR하는 쿼리와 어떤 index에 대응하는 원소를 출력하는 쿼리를 수행하는 문제

 

 

2. 풀이

 

구간에 수를 연산하니까 느리게 갱신되는 세그먼트 트리가 필요할 것 같고

 

index 하나 원소만 출력하는 트리는 이미 배웠으니까

 

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]

    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #아무 연산도 수행하지 않는다

 

원소 하나를 출력하는 세그먼트 트리는 리프 노드에 원소를 저장하면 되니까... 재귀호출 이후에 연산은 수행하지 않아

 

lazy를 업데이트 하는 함수는 조금 다른데

 

현재 노드에서 lazy값이 0이 아니면, 두 자식 노드에 lazy값을 전달한다

 

구간 합을 구하는 트리에서는

 

lazy[2*tree_index] += lazy[tree_index]
lazy[2*tree_index+1] += lazy[tree_index]

 

이렇게 lazy를 누적합 시켜왔는데 당연히 구간에 xor연산을 하는 트리에서는 누적합을 시키는게 아니라

 

누적 xor연산을 수행해야겠지

 

def update_lazy(tree,lazy,tree_index,start,end):
    
    if lazy[tree_index] != 0:
        
        #리프노드가 아니라면, 두 자식 노드에 xor연산을 하면서 lazy값을 누적 전달
        if start != end:
            
            lazy[2*tree_index] ^= lazy[tree_index]
            lazy[2*tree_index+1] ^= lazy[tree_index]
        
        #리프노드라면, 해당 노드의 값에 lazy값을 xor연산해서 갱신
        elif start == end:
            
            tree[tree_index] ^= lazy[tree_index]
        
        lazy[tree_index] = 0

 

그리고 해당 노드인 tree_index에서 자식으로 lazy값을 전달하거나 리프노드에서 lazy값을 갱신시켰다면

 

lazy를 썼으니까 lazy[tree_index] = 0으로 바꿔줘야함

 

원소를 출력하는 함수는 이전에 배운거랑 똑같다

 

def query_print(tree,lazy,tree_index,start,end,index):
    
    #lazy배열 업데이트
    update_lazy(tree,lazy,tree_index,start,end)

    if index > end or index < start:
        
        return
    
    if start == end and index == end:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_value = query_print(tree,lazy,2*tree_index,start,mid,index)
    right_value = query_print(tree,lazy,2*tree_index+1,mid+1,end,index)
    
    #둘중 하나는 None이고, 하나는 리프노드값
    if left_value == None:
        
        return right_value
    
    elif right_value == None:
        
        return left_value

 

 

구간을 업데이트하는 함수도 약간 다르다

 

리프노드가 아니라면, 두 자식노드의 lazy배열에 업데이트 value를 xor연산으로 넘겨주고

 

리프노드에 도달했다면, 해당 tree_index 노드에 value를 xor연산해서 계산

 

def update_range(tree,lazy,tree_index,start,end,left,right,value):
    
    update_lazy(tree,lazy,tree_index,start,end)

    if left > end or right < start:
        
        return
    
    if left <= start and end <= right:
        
        #리프노드가 아니라면, 두 자식노드의 lazy배열에 value값을 전달
        if start != end:
            
            lazy[2*tree_index] ^= value
            lazy[2*tree_index+1] ^= value
        
        #리프노드에 도달했다면 해당 노드에 value를 연산
        elif start == end:
            
            tree[tree_index] ^= value

        return
    
    mid = (start+end)//2

    update_range(tree,lazy,2*tree_index,start,mid,left,right,value)
    update_range(tree,lazy,2*tree_index+1,mid+1,end,left,right,value)

 

따라서 최종 세그먼트 트리는...

 

이 문제는 제시된 쿼리의 인덱스가 0번부터라서 -1을 안해도 되는.. 그런것도 잘 확인해야지

 

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]

    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #아무 연산도 수행하지 않는다

def update_lazy(tree,lazy,tree_index,start,end):
    
    if lazy[tree_index] != 0:

        if start != end:
            
            lazy[2*tree_index] ^= lazy[tree_index]
            lazy[2*tree_index+1] ^= lazy[tree_index]
        
        elif start == end:
            
            tree[tree_index] ^= lazy[tree_index]
        
        lazy[tree_index] = 0

def query_print(tree,lazy,tree_index,start,end,index):
    
    update_lazy(tree,lazy,tree_index,start,end)

    if index > end or index < start:
        
        return
    
    if start == end and index == end:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_value = query_print(tree,lazy,2*tree_index,start,mid,index)
    right_value = query_print(tree,lazy,2*tree_index+1,mid+1,end,index)

    if left_value == None:
        
        return right_value
    
    elif right_value == None:
        
        return left_value

def update_range(tree,lazy,tree_index,start,end,left,right,value):
    
    update_lazy(tree,lazy,tree_index,start,end)

    if left > end or right < start:
        
        return
    
    if left <= start and end <= right:
        
        if start != end:
            
            lazy[2*tree_index] ^= value
            lazy[2*tree_index+1] ^= value
        
        elif start == end:
            
            tree[tree_index] ^= value

        return
    
    mid = (start+end)//2

    update_range(tree,lazy,2*tree_index,start,mid,left,right,value)
    update_range(tree,lazy,2*tree_index+1,mid+1,end,left,right,value)


N = int(stdin.readline())

A = list(map(int,stdin.readline().split()))

m = int(stdin.readline())

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [0]*(n+1)
lazy = [0]*(n+1)

create_segment(A,tree,1,0,N-1)

for _ in range(m):
    
    q = list(map(int,stdin.readline().split()))

    if q[0] == 1:
        
        update_range(tree,lazy,1,0,N-1,q[1],q[2],q[3])
    
    elif q[0] == 2:
        
        print(query_print(tree,lazy,1,0,N-1,q[1]))
TAGS.

Comments