세그먼트 트리 기본문제로 연습하며 재활(반복문, 재귀 연습)

1. 문제1

 

11143번: Beads (acmicpc.net)

 

11143번: Beads

The first line of the input consists of a single number T, the number of games played. Each game start with a line describing B, P and Q, the number of boxes, put requests and query requests, respectively. Then follows P + Q lines with either P i a, saying

www.acmicpc.net

 

2. 풀이

 

문제를 요약하자면 구간의 합을 구하는 세그먼트 트리

 

create_segment라는 함수는 만들필요가 없다

 

P라는 query가 index번째 박스에 a개의 구슬을 넣어주는 query라서, update 함수를 수행하면서, tree가 생성된다

 

실수하기 쉬운 부분이 update 함수에서.. A배열과 tree배열에서 index번째 값에 누적합을 하지 않고.. 단순히 값만 변경하는 경우가 있는데

 

    if start == end:
        
        A[index] = value
        tree[tree_index] = value
        return

 

문제를 보면.. index번째 박스에 구슬을 계속 놓아주는 의미라서, 누적합으로 

 

    if start == end:
        
        A[index] += value
        tree[tree_index] += value
        return

 

기본형태인 재귀함수 버전...

 

import math
from sys import stdin

def update(A,tree,tree_index,start,end,index,value):
    
    if index < start or end < index:
        
        return
    
    if start == end:
        
        A[index] += value
        tree[tree_index] += value
        return
    
    mid = (start + end)//2
    
    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]

def query_sum(tree,tree_index,start,end,left,right):
    
    if right < start or end < left:
        
        return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_sum = query_sum(tree,2*tree_index,start,mid,left,right)
    right_sum = query_sum(tree,2*tree_index+1,mid+1,end,left,right)

    return left_sum + right_sum
    

T = int(stdin.readline())

for _ in range(T):
    
    b,p,q = map(int,stdin.readline().split())

    A = [0]*(b)

    n = 2**(math.ceil(math.log2(b))+1)-1

    tree = [0]*(n+1)

    for _ in range(p+q):
        
        q,x,y = stdin.readline().rstrip().split()

        x = int(x)
        y = int(y)

        if q == 'P':
            
            update(A,tree,1,0,b-1,x-1,y)
        
        else:
            
            print(query_sum(tree,1,0,b-1,x-1,y-1))

 

다음은 최근에 배운 반복문 버전 연습..

 

중요한 점은 역시 update에 tree[index] = value로 값만 바꾸지말고.. 누적합해주는거

 

특히 반복문 버전에서는 구간 쿼리 구할때, [left,right]를 구한다면.. [left,right+1]을 구해줘야한다는 점

 

from sys import stdin

def update(tree,index,value,b):
    
    index += b

    tree[index] += value

    while index > 1:
        
        index >>= 1

        tree[index] = tree[2*index] + tree[2*index+1]

def query(tree,b,left,right):
    
    left += b
    right += b

    m = 0

    while left < right:
        
        if left & 1:
            
            m += tree[left]
            left += 1
        
        if right & 1:
            
            right -= 1
            m += tree[right]
        
        left //= 2
        right //= 2
    
    return m


T = int(stdin.readline())

for _ in range(T):
    
    b,p,q = map(int,stdin.readline().split())

    tree = [0]*(2*b)

    for _ in range(p+q):
        
        x,y,z = stdin.readline().rstrip().split()

        y = int(y)
        z = int(z)

        if x == 'P':
            
            update(tree,y-1,z,b)
        
        else:
            
            print(query(tree,b,y-1,z))

 

 

3. 문제2

 

12837번: 가계부 (Hard) (acmicpc.net)

 

12837번: 가계부 (Hard)

살아있는 화석이라고 불리는 월곡이는 돈에 찌들려 살아가고 있다. 그에게 있어 수입과 지출을 관리하는 것은 굉장히 중요한 문제이다. 스마트폰에 가계부 어플리케이션을 설치해서 사용하려

www.acmicpc.net

 

 

4. 풀이

 

역시 구간의 합을 구하는 세그먼트 트리

 

위 문제랑 사실상 똑같다

 

create segment 함수를 만들 필요 없고, update함수에서 값을 "추가"하는 방식

 

가계부에 수입/지출을 "추가"한다는건... 가계부에서 수입/지출을 바꾼다는 의미가 아니고, 해당 수입/지출을 누적합시키는것

 

대충 읽으면 "생후 p일에 추가한다", "생후 p일부터 q일에 변화량을 출력한다"이게 뭔소린가 하는데..

 

p가 1이상 N이하이고.. "살아온 날이 N일"

 

생후라는건 "0일부터"라는 의미니까...

 

쉽게 생각하면 배열 A의 크기가 N이고.. p는 배열 A의 index

 

다음은 재귀 버전

 

이번엔 이전 풀이와는 다르게 A배열을 그냥 없애버렸다..

 

애초에 필요가 없으니까

 

import math
from sys import stdin

def update(tree,tree_index,start,end,index,value):
    
    if index < start or index > end:
        
        return
    
    if start == end:
        
        tree[tree_index] += value
        return
    
    mid = (start+end)//2

    update(tree,2*tree_index,start,mid,index,value)
    update(tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]

def query(tree,tree_index,start,end,left,right):
    
    if left > end or right < start:
        
        return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_sum = query(tree,2*tree_index,start,mid,left,right)
    right_sum = query(tree,2*tree_index+1,mid+1,end,left,right)

    return left_sum + right_sum

N,q = map(int,stdin.readline().split())

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [0]*(n+1)

for _ in range(q):
    
    a,b,c = map(int,stdin.readline().split())

    if a == 1:
        
        update(tree,1,0,N-1,b-1,c)
    
    else:
        
        print(query(tree,1,0,N-1,b-1,c-1))

 

 

다음은 반복문 버전

 

역시 구간 쿼리 구할때 [left,right]를 구하고 싶으면 [left,right+1]을 넣어줘야한다는 점 기억해야

 

from sys import stdin

def update(tree,index,value,n):
    
    index += n

    tree[index] += value

    while index > 1:
        
        index >>= 1

        tree[index] = tree[2*index] + tree[2*index+1]

def query(tree,n,left,right):
    
    left += n
    right += n

    m = 0

    while left < right:
        
        if left & 1:
            
            m += tree[left]
            left += 1
        
        if right & 1:
            
            right -= 1
            m += tree[right]
        
        left //= 2
        right //= 2
    
    return m

n,q = map(int,stdin.readline().split())

tree = [0]*(2*n)

for _ in range(q):
    
    a,b,c = map(int,stdin.readline().split())

    if a == 1:
        
        update(tree,b-1,c,n)
    
    else:
        
        print(query(tree,n,b-1,c))

 

 

5. 문제3

 

6213번: Balanced Lineup (acmicpc.net)

 

6213번: Balanced Lineup

For the daily milking, Farmer John's N cows (1 <= N <= 50,000) always line up in the same order. One day Farmer John decides to organize a game of Ultimate Frisbee with some of the cows. To keep things simple, he will take a contiguous range of cows from t

www.acmicpc.net

 

6. 풀이

 

조금 난이도가 높다면.. 최댓값과 최솟값을 모두 관리하는 세그먼트 트리를 만들어야한다.

 

재귀 버전은 최댓값을 구하는 세그먼트 트리, 최솟값을 구하는 세그먼트 트리 2개를 만들었다

 

1개로도 가능하긴한데..

 

이 문제는 update 함수가 필요없다.. 업데이트 한다는 말이 없으니까

 

import math
from sys import stdin

def max_create_segment(A,max_tree,tree_index,start,end):
    
    if start == end:
        
        max_tree[tree_index] = A[start]
    
    else:
        
        mid = (start+end)//2

        max_create_segment(A,max_tree,2*tree_index,start,mid)
        max_create_segment(A,max_tree,2*tree_index+1,mid+1,end)

        max_tree[tree_index] = max(max_tree[2*tree_index],max_tree[2*tree_index+1])

def min_create_segment(A,min_tree,tree_index,start,end):
    
    if start == end:
        
        min_tree[tree_index] = A[start]
    
    else:
        
        mid = (start+end)//2

        min_create_segment(A,min_tree,2*tree_index,start,mid)
        min_create_segment(A,min_tree,2*tree_index+1,mid+1,end)

        min_tree[tree_index] = min(min_tree[2*tree_index],min_tree[2*tree_index+1])


def query(tree,tree_index,start,end,left,right,ind):
    
    if left > end or right < start:
        
        return
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start + end)//2

    left_value = query(tree,2*tree_index,start,mid,left,right,ind)
    right_value = query(tree,2*tree_index+1,mid+1,end,left,right,ind)

    if ind == 0:
        
        if left_value == None:
            
            return right_value
        
        elif right_value == None:
            
            return left_value
        
        else:

            return min(left_value,right_value)
    
    else:
        
        if left_value == None:
            
            return right_value
        
        elif right_value == None:
            
            return left_value
        
        else:

            return max(left_value,right_value)



        
N,q = map(int,stdin.readline().split())

A = []

for _ in range(N):
    
    a = int(stdin.readline())
    A.append(a)

n = 2**(math.ceil(math.log2(N))+1)-1

max_tree = [0]*(n+1)
min_tree = [0]*(n+1)

max_create_segment(A,max_tree,1,0,N-1)
min_create_segment(A,min_tree,1,0,N-1)

for _ in range(q):
    
    a,b = map(int,stdin.readline().split())

    print(query(max_tree,1,0,N-1,a-1,b-1,1) - query(min_tree,1,0,N-1,a-1,b-1,0))

 

다음은 반복문 버전

 

이번엔 세그먼트 트리 1개로 최댓값, 최솟값을 모두 관리했다

 

세그먼트 트리 한 원소가 [최솟값, 최댓값]을 가지도록 만든다면.. 가능하겠지

 

반복문 버전에서 주의할 점은 3번째 말하는데 [left,right] 구간 쿼리 구할때, [left,right+1]을 넣어줘야한다는 점

 

from sys import stdin

def create_segment(A,tree,n):
    
    for i in range(n):
        
        tree[n+i][0] = A[i]
        tree[n+i][1] = A[i]
    
    for i in range(n-1,0,-1):
        
        tree[i][0] = min(tree[2*i][0],tree[2*i+1][0])
        tree[i][1] = max(tree[2*i][1],tree[2*i+1][1])

def query(tree,n,left,right):
    
    left += n
    right += n

    m1 = 1000001
    m2 = 0

    while left < right:
        
        if left & 1:
            
            m1 = min(m1,tree[left][0])
            m2 = max(m2,tree[left][1])
            left += 1
        
        if right & 1:
            
            right -= 1
            m1 = min(m1,tree[right][0])
            m2 = max(m2,tree[right][1])
        
        left //= 2
        right //= 2
    
    return m2-m1

n,q = map(int,stdin.readline().split())

A = []

for _ in range(n):
    
    a = int(stdin.readline())
    A.append(a)

tree = [[0]*2 for _ in range(2*n)]

create_segment(A,tree,n)

for _ in range(q):
    
    a,b = map(int,stdin.readline().split())

    print(query(tree,n,a-1,b))
TAGS.

Comments