세그먼트 트리 응용 - 2개의 쿼리를 동시에 수행할 수 있는 2차원 세그먼트 트리

1. 문제

 

18436번: 수열과 쿼리 37 (acmicpc.net)

 

18436번: 수열과 쿼리 37

길이가 N인 수열 A1, A2, ..., AN이 있다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. 1 i x: Ai를 x로 바꾼다. 2 l r: l ≤ i ≤ r에 속하는 모든 Ai중에서 짝수의 개수를 출력한다. 3 l r: l ≤ i ≤

www.acmicpc.net

 

구간에서 짝수의 개수를 출력하거나, 홀수의 개수를 출력하거나, 값을 바꾸는 쿼리를 수행하는 세그먼트 트리를 만드는 문제

 

 

2. 풀이

 

조금 응용해서, 세그먼트 트리를 2차원으로 구성해보자

 

[[0,0] for _ in range(n+1)] 형태로 tree[tree_index][0]은 구간의 짝수의 개수, tree[tree_index][1]은 구간의 홀수의 개수로

 

그러면, 부모 노드의 짝수의 개수는? 당연히 두 자식노드의 짝수의 개수 합이고, 홀수의 개수는 두 자식노드의 홀수의 개수 합이다.

 

import math
from sys import stdin

#[[0,0] for _ in range(n+1)] 형태로 트리 생성

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        #0번에 짝수의 개수를 저장하고
        
        if A[A_start] % 2 == 0:
            
            tree[tree_index][0] += 1
        
        #1번에 홀수의 개수를 저장
        elif A[A_start] % 2 == 1:
            
            tree[tree_index][1] += 1
    
    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #부모 노드의 짝수 개수는, 자식 노드의 짝수 개수의 합
        #홀수 개수는, 자식 노드의 홀수 개수의 합

        tree[tree_index][0] = tree[2*tree_index][0] + tree[2*tree_index+1][0]
        tree[tree_index][1] = tree[2*tree_index][1] + tree[2*tree_index+1][1]

 

 

그러면, 결국 구간합 트리를 구하는거나 마찬가지다.

 

query함수도 number라는 인자를 받아서 2,i,j 형태는 짝수이고 3,i,j 형태는 홀수개수를 출력하라고 했으니까,

 

number-2에 대응하는 개수가 저장되어 있다(number=2이면, 0이므로 짝수개수, number=3이면 1이므로 홀수개수)

 

def query(tree,tree_index,start,end,left,right,number):
    
    if left > end or right < start:
        
        return 0
    
    if left <= start and end <= right:
        
        #number = 2이면, 짝수개수
        #number = 3이면, 홀수개수
        return tree[tree_index][number-2]
    
    mid = (start+end)//2

    left_num = query(tree,2*tree_index,start,mid,left,right,number)
    right_num = query(tree,2*tree_index+1,mid+1,end,left,right,number)

    return left_num+right_num

 

 

업데이트 함수도 조금 머리써서, 리프노드에 도달했을때, A[index]에 해당 노드의 값이 저장되어 있으니까, 먼저 이것이 짝수인지 홀수인지 판단한다

 

어차피 바꿀거니까 짝수면 tree[tree_index][0]에 1을 빼고 홀수면 tree[tree_index][1]에 1을 빼고 

 

그 후에 A[index] = value로 바꿔주고, 이 value가 짝수면 tree[tree_index][0]에 1을 더해주고, 홀수면 tree[tree_index][1]에 1을 더해주고 

 

def update(A,tree,tree_index,start,end,index,value):
    
    if index < start or index > end:
        
        return
    
    if start == end:
        
        #기존의 값 A[index]가 짝수이면
        if A[index] % 2 == 0:
            
            tree[tree_index][0] -= 1
        #홀수이면..
        elif A[index] % 2 == 1:
            
            tree[tree_index][1] -= 1
        
        #기존 값 제거후에 업데이트하고
        A[index] = value
        
        #업데이트한 값이 짝수인지 홀수인지에 따라
        if value % 2 == 0:
            
            tree[tree_index][0] += 1
        
        elif value % 2 == 1:

            tree[tree_index][1] += 1
        
        return
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index][0] = tree[2*tree_index][0]+tree[2*tree_index+1][0]
    tree[tree_index][1] = tree[2*tree_index][1]+tree[2*tree_index+1][1]

 

따라서, 구하고자하는 세그먼트 트리는...

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        if A[A_start] % 2 == 0:
            
            tree[tree_index][0] += 1
        
        elif A[A_start] % 2 == 1:
            
            tree[tree_index][1] += 1
    
    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)

        tree[tree_index][0] = tree[2*tree_index][0] + tree[2*tree_index+1][0]
        tree[tree_index][1] = tree[2*tree_index][1] + tree[2*tree_index+1][1]

def query(tree,tree_index,start,end,left,right,number):
    
    if left > end or right < start:
        
        return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index][number-2]
    
    mid = (start+end)//2

    left_num = query(tree,2*tree_index,start,mid,left,right,number)
    right_num = query(tree,2*tree_index+1,mid+1,end,left,right,number)

    return left_num+right_num

def update(A,tree,tree_index,start,end,index,value):
    
    if index < start or index > end:
        
        return
    
    if start == end:
        
        if A[index] % 2 == 0:
            
            tree[tree_index][0] -= 1
        
        elif A[index] % 2 == 1:
            
            tree[tree_index][1] -= 1
        
        A[index] = value

        if value % 2 == 0:
            
            tree[tree_index][0] += 1
        
        elif value % 2 == 1:

            tree[tree_index][1] += 1
        
        return
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index][0] = tree[2*tree_index][0]+tree[2*tree_index+1][0]
    tree[tree_index][1] = tree[2*tree_index][1]+tree[2*tree_index+1][1]


N = int(stdin.readline())

A = list(map(int,stdin.readline().split()))

m = int(stdin.readline())

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [[0,0] for _ in range(n+1)]

create_segment(A,tree,1,0,N-1)

for _ in range(m):
    
    a,b,c = map(int,stdin.readline().split())

    if a == 1:
        
        update(A,tree,1,0,N-1,b-1,c)
    
    else:
        
        print(query(tree,1,0,N-1,b-1,c-1,a))
TAGS.

Comments