세그먼트 트리 응용2 -최솟값과 최댓값을 구하는 세그먼트 트리-

1. 문제

 

2357번: 최솟값과 최댓값 (acmicpc.net)

 

2357번: 최솟값과 최댓값

N(1 ≤ N ≤ 100,000)개의 정수들이 있을 때, a번째 정수부터 b번째 정수까지 중에서 제일 작은 정수, 또는 제일 큰 정수를 찾는 것은 어려운 일이 아니다. 하지만 이와 같은 a, b의 쌍이 M(1 ≤ M ≤ 100

www.acmicpc.net

 

이번엔 임의의 구간의 최솟값과 최댓값을 구하는 문제

 

 

2. 풀이

 

구간합이나 구간곱과 큰 차이 없긴한데 query 함수 구할때 조금 신경써야할듯

 

create_segment함수에는 왼쪽 자식과 오른쪽 자식의 최솟값과 최댓값을 저장해야함

 

최솟값 구하는 트리와 최댓값 구하는 트리 함수를 따로 만들수도 있지만...

 

그냥 인자 method를 받아서 method가 0이면 최솟값 트리를 만들고 1이면 최댓값 트리를 만들도록 함수를 조정했음

 

def create_segment(A,tree,tree_index,A_start,A_end,method):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]
    
    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid,method)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end,method)
        
        #method가 0이면 왼쪽 자식과 오른쪽 자식의 최솟값을 저장
        
        if method == 0:

            tree[tree_index] = min(tree[2*tree_index],tree[2*tree_index+1])
        
        #method가 1이면 왼쪽 자식과 오른쪽 자식의 최댓값을 저장
        
        elif method == 1:
            
            tree[tree_index] = max(tree[2*tree_index],tree[2*tree_index+1])

 

 

query함수에는 조금 신경을 써야할 것 같은데

 

먼저 [left,right]와 [start,end]가 완전히 벗어날때 어떤 값을 return해야하는가?

 

의미없는 값을 return해야하잖아

 

그런데 입력으로 주어지는 수가 1이상 1000000000이하이므로 

 

최솟값을 구할려면 1000000001을 return하고 최댓값을 구할려면 0을 return하도록

 

def query(tree,tree_index,start,end,left,right,method):
    
    if left > end or right < start:
        
        if method == 0:
            
            return 1000000001
        
        elif method == 1:
            
            return 0

 

 

그리고 [left,right] 안에 완전히 [start,end]가 포함된다면... 노드에 저장된 값을 그대로 return해준다

 

구간의 최솟값 아니면 최댓값이 저장되어 있을거임

 

    if left <= start and end <= right:
        
        return tree[tree_index]

 

그렇지 않다면 왼쪽, 오른쪽 아래로 내려가서 탐색을 진행할건데

 

    mid = (start+end)//2
    
    left_value = query(tree,2*tree_index,start,mid,left,right,method)
    right_value = query(tree,2*tree_index+1,mid+1,end,left,right,method)

 

그리고 method가 0이면 left와 right의 min을 return하고 1이면 max를 return하도록

 

    mid = (start+end)//2
    
    left_value = query(tree,2*tree_index,start,mid,left,right,method)
    right_value = query(tree,2*tree_index+1,mid+1,end,left,right,method)

    if method == 0:
        
        return min(left_value,right_value)
    
    elif method == 1:
        
        return max(left_value,right_value)

 

 

근데 뭔가 처음에는 left_value랑 right_value에서 계속 최소,최대 갱신하도록 만들어야한다고 생각을 했는데

 

    mid = (start+end)//2

    if method == 0:
        
        left_value = min(left_value,query(tree,2*tree_index,start,mid,left,right,method,left_value,right_value))
        right_value = min(right_value,query(tree,2*tree_index+1,mid+1,end,left,right,method,left_value,right_value))
        
        return min(left_value,right_value)
    
    elif method == 1:
        
        left_value = max(left_value,query(tree,2*tree_index,start,mid,left,right,method,left_value,right_value))
        right_value = max(right_value,query(tree,2*tree_index+1,mid+1,end,left,right,method,left_value,right_value))
        
        return max(left_value,right_value)

 

왜 이럴 필요가 없을까??

 

 

 

최댓값 트리를 만들어서 한번 생각해보자고

 

[2,8]에 속하는 수의 최댓값은 어떻게 구할까?

 

left_value = query(tree,2*tree_index,start,mid,left,right,method) 얘가 먼저 수행될텐데

 

 

다시 query(tree,2*tree_index,start,mid,left,right,method) 들어가서...

 

start = 0, end = 4이고 left = 2, right = 8인데

 

다시 왼쪽 오른쪽으로 나뉘면서

 

left_value = query(tree,2*tree_index,start,mid,left,right,method) 수행될텐데

 

 

다시 start = 0, end = 2이고 left = 2, right = 8에서 

 

left_value = query(tree,2*tree_index,start,mid,left,right,method)  수행될텐데

 

 

다시 left_value = query(tree,2*tree_index,start,mid,left,right,method) 수행되면서

 

start = 0, end = 1, left = 2, right = 8인데

 

    if left > end or right < start:
        
        if method == 0:
            
            return 1000000001
        
        elif method == 1:
            
            return 0

 

에 걸리면서 0을 return하고 left_value = 0이 저장되고..

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

 

>>>재귀 들어가서

 

left_value = query(tree,2*tree_index,0,2,left,right,method)

 

>>>재귀 들어가서

 

left_value = query(tree,2*tree_index,0,1,left,right,method)

right_value = query(tree,2*tree_index+1,2,2,left,right,method)

 

부분에서..

 

left_value = 0

right_value = query(tree,2*tree_index+1,2,2,left,right,method)가 수행될건데

 

 

 

 

    if left <= start and end <= right:
        
        return tree[tree_index]

 

에 걸리므로, 3를 return할거임

 

left_value = 0

right_value = 3이 될거고 이 중에서 최댓값? 3을 return하도록 할건데

 

3은 어디에 return될까?

 

현재 수행되는 함수

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

 

>>재귀 들어가서

 

left_value = query(tree,2*tree_index,0,2,left,right,method)

 

>>재귀 들어가서

 

left_value = 0

right_value = 3

 

에서 left_value = query(tree,2*tree_index,0,2,left,right,method)가 수행중이니까

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

 

left_value = 3

right_value = query(tree,2*tree_index+1,3,4,left,right,method)

 

이 진행되고 right_value = query(tree,2*tree_index+1,3,4,left,right,method)가 수행될거임

 

 

여기서는 이제 [2,8]안에 [3,4]가 포함되므로 5를 바로 return하게 될거임

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

 

left_value = 3

right_value = 5로 될거고

 

여기서 left와 right의 최댓값인 5가 return될건데... 그러면 이 5는 어디에 return될까?

left_value = query(tree,2*tree_index,0,4,left,right,method)를 수행중이니까 여기에 return될거임

 

따라서 원래 함수 진행을 다시 추적해보면..

 

query(tree,2*tree_index,0,9,left,right,method)

 

에서..

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

right_value = query(tree,2*tree_index+1,5,9,left,right,method)

 

이고 left_value = query(tree,2*tree_index,0,4,left,right,method)가 재귀로 수행되면서 마지막에서 가지고 올라와서

 

left_value = 5가 된다는 것을 알았고

 

최종 return되었으니까 이제는

 

right_value = query(tree,2*tree_index+1,5,9,left,right,method)가 수행될거다

 

 

다시 right_value = query(tree,2*tree_index+1,5,9,left,right,method)로 들어가면서

 

left_value = query(tree,2*tree_index,5,7,left,right,method)

right_value = query(tree,2*tree_index+1,8,9,left,right,method)가 수행될건데

 

다시 left_value = query(tree,2*tree_index,5,7,left,right,method)가 먼저 들어가게 될거임

 

 

그런데 [5,7]은 [2,8]에 완전히 포함되므로 8을 바로 return할거다

 

그러므로 

 

left_value = 8

right_value = query(tree,2*tree_index+1,8,9,left,right,method)이고

 

right_value = query(tree,2*tree_index+1,8,9,left,right,method)가 수행될거다

 

 

 

그러면 다시 왼쪽, 오른쪽으로 나누어 들어가면서

 

left_value = 8

right_value = query(tree,2*tree_index+1,8,9,left,right,method)

 

>>재귀로 들어가서

 

left_value = query(tree,2*tree_index,8,8,left,right,method)

right_value = query(tree,2*tree_index+1,9,9,left,right,method)가 수행될건데

 

왼쪽부터 [8,8]은 [2,8]에 완전히 포함되므로 left_value = query(tree,2*tree_index,8,8,left,right,method) = 9

 

[9,9]는 [2,8]에 완전히 벗어났으니 right_value = query(tree,2*tree_index+1,9,9,left,right,method) = 0이 될거다.

 

 

 

left_value = 8

right_value = query(tree,2*tree_index+1,8,9,left,right,method)

 

>>재귀로 들어가서

 

left_value = 9

right_value = 0

 

이 중에서 최댓값인 9가 return되는데 어디에 return될까?

 

right_value = query(tree,2*tree_index+1,8,9,left,right,method)을 수행중이므로... 여기에 9가 return된다

 

 

그러면 결국.. 

 

left_value = 8

right_value = query(tree,2*tree_index+1,8,9,left,right,method) = 9니까..

 

이들의 최댓값인 9가 return되는데 어디에 return될까?

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

right_value = query(tree,2*tree_index+1,5,9,left,right,method)에서 

 

 

left_value = query(tree,2*tree_index,0,4,left,right,method) = 5임을 구했고

 

 

right_value = query(tree,2*tree_index+1,5,9,left,right,method)가 재귀로 호출되어 return되면서 올라와가지고

 

right_value = query(tree,2*tree_index+1,5,9,left,right,method) = 9임을 구하게 된다

 

따라서...

 

query(tree,2*tree_index,0,9,left,right,method)가 수행되면서...

 

left_value = query(tree,2*tree_index,0,4,left,right,method)

right_value = query(tree,2*tree_index+1,5,9,left,right,method)

 

나누어지던게... left_value = 5, right_value = 9로 구해지고..

 

이들의 최댓값을 return하게 되는데.. 그러므로 최종 9를 return하게 된다

 

그러니까 [2,8]의 최댓값은 9가 된다

 

 

 

그러니까.. 굳이 매번 최소,최대를 갱신하지 않더라도.. 알아서 최소 최대가 갱신되면서 올라오게 됨

 

def query(tree,tree_index,start,end,left,right,method):
    
    if left > end or right < start:
        
        if method == 0:
            
            return 1000000001
        
        elif method == 1:
            
            return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2
    
    left_value = query(tree,2*tree_index,start,mid,left,right,method)
    right_value = query(tree,2*tree_index+1,mid+1,end,left,right,method)

    if method == 0:
        
        return min(left_value,right_value)
    
    elif method == 1:
        
        return max(left_value,right_value)

 

 

재귀함수를 아직 제대로 이해하지 못했다는 뜻이겠지..?

 

근데 걍 직관적으로 생각하자.. 처음에 너무 어렵게 생각했어

 

왼쪽으로 계속 재귀로 파고들면서 마지막에 호출이 끝나면 min이나 max를 return하면서

 

가지고 올라올거잖아.. 그러니까 min이나 max가 계속 갱신되면서 올라오겠지  

 

 

 

아무튼 create_segment와 query함수를 만들면... 끝

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end,method):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]
    
    else:
        
        A_mid = (A_start+A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid,method)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end,method)

        if method == 0:

            tree[tree_index] = min(tree[2*tree_index],tree[2*tree_index+1])
        
        elif method == 1:
            
            tree[tree_index] = max(tree[2*tree_index],tree[2*tree_index+1])

def query(tree,tree_index,start,end,left,right,method):
    
    if left > end or right < start:
        
        if method == 0:
            
            return 1000000001
        
        elif method == 1:
            
            return 0
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2
    
    left_value = query(tree,2*tree_index,start,mid,left,right,method)
    right_value = query(tree,2*tree_index+1,mid+1,end,left,right,method)

    if method == 0:
        
        return min(left_value,right_value)
    
    elif method == 1:
        
        return max(left_value,right_value)

    


N,m = map(int,stdin.readline().split())

A = []

for _ in range(N):
    
    a = int(stdin.readline())

    A.append(a)

n = 2**(math.ceil(math.log2(N))+1)-1

min_tree = [0]*(n+1)
max_tree = [0]*(n+1)

create_segment(A,min_tree,1,0,N-1,0)
create_segment(A,max_tree,1,0,N-1,1)
    
for _ in range(m):
    
    a,b = map(int,stdin.readline().split())
    
    print(query(min_tree,1,0,N-1,a-1,b-1,0),query(max_tree,1,0,N-1,a-1,b-1,1))

 

 

 

3. 문제2

 

14438번: 수열과 쿼리 17 (acmicpc.net)

 

14438번: 수열과 쿼리 17

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. 1 i v : Ai를 v로 바꾼다. (1 ≤ i ≤ N, 1 ≤ v ≤ 109) 2 i j : Ai, Ai+1, ..., Aj에서 크기가 가장 작은 값을

www.acmicpc.net

 

이번엔 값을 업데이트 하는 기능도 추가된 최솟값을 구하는 세그먼트 트리 구현하기

 

4. 풀이2

 

업데이트는 리프 노드 찾을때까지 재귀호출 하다가, 리프노드 찾으면 A배열 바꾸고 tree의 tree_index값 바꾼 다음에

 

return하면서 변경된 자식을 이용해서 min값을 갱신해나가는 방식으로

 

다른거랑 별 차이 없다

 

    if index > end or index < start:
        
        return
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = value
        return ##return을 하지 않으면 에러난다

 

return을 까먹어서... 에러난거 주의하고

 

import math
from sys import stdin

def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = A[A_start]
    
    else:
        
        A_mid = (A_start + A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)

        tree[tree_index] = min(tree[2*tree_index],tree[2*tree_index+1])

def query_min(tree,tree_index,start,end,left,right):
    
    if left > end or right < start:
        
        return 10000000000
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_min = query_min(tree,2*tree_index,start,mid,left,right)

    right_min = query_min(tree,2*tree_index+1,mid+1,end,left,right)

    return min(left_min,right_min)

def update(A,tree,tree_index,start,end,index,value):
    
    if index > end or index < start:
        
        return
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = value
        return ##return을 하지 않으면 에러난다
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)

    tree[tree_index] = min(tree[2*tree_index],tree[2*tree_index+1])
        
N = int(stdin.readline())

A = list(map(int,stdin.readline().split()))

m = int(stdin.readline())

n = 2**(math.ceil(math.log2(N))+1)-1

tree = [0]*(n+1)

create_segment(A,tree,1,0,N-1)

for _ in range(m):
    
    a,b,c = map(int,stdin.readline().split())

    if a == 1:
        
        update(A,tree,1,0,N-1,b-1,c)
    
    elif a == 2:
        
        print(query_min(tree,1,0,N-1,b-1,c-1))
TAGS.

Comments