파이썬 알고리즘 기본기 곱셈 연산에서 주의할 점 배우기(세그먼트 트리 문제)

1. 문제

 

5676번: 음주 코딩 (acmicpc.net)

 

5676번: 음주 코딩

각 테스트 케이스마다 곱셈 명령의 결과를 한 줄에 모두 출력하면 된다. 출력하는 i번째 문자는 i번째 곱셈 명령의 결과이다. 양수인 경우에는 +, 음수인 경우에는 -, 영인 경우에는 0을 출력한다.

www.acmicpc.net

 

특정 구간의 수열의 곱이 양수인지 음수인지 0인지 판단하거나 특정 인덱스의 원소를 바꾸는 쿼리를 수행하는 문제

 

 

2. 곱셈의 시간복잡도

 

매우 큰 수를 곱할수록 프로그램의 시간복잡도가 높아진다.

 

이게 몇번만 연산하면 별로 차이 없어보이지만 10만번이나 반복연산해야한다면, 눈에띄게 느려진다

 

 

매우 큰 4개의 수를 곱하는 과정을 10만번 반복했는데 평균적으로 0.03초 걸린다면...

 

 

단순히 1,2,3,4 4개를 곱하는 과정을 10만 반복하면.. 평균 0.01초로 약 3배차이남

 

별거 없어보이긴 한디

 

이 문제의 수열의 숫자 범위가 -100~100이라 최대 100까지이고..

 

수열의 크기는 10만까지이고 쿼리 수도 10만까지라는 점을 생각해본다면..

 

구간 곱의 최댓값은 100을 10만번 곱하는 경우인데

 

쿼리 수가 10만개이고 곱셈을 구하기를 원하는 구간이 약 10만까지이며, 수열의 모든 값이 100이라면?

 

 

예상 시간이 36분임...

 

따라서 이걸 어떻게 피할 수 있을까?

 

문제에서 구간 곱셈의 값이 양수인지, 음수인지, 0인지만 판단하라고 했다.

 

괜히 이런 이유가 있는게 아니다.

 

구간 곱셈의 값이 중요한게 아니고 곱셈의 값이 양수이냐, 음수이냐, 0이냐만 중요하다.

 

따라서 구간의 곱을 구하는 세그먼트 트리를 만드는데, 트리의 노드에 구간 곱셈의 부호만 저장해두면 된다.

 

 

1을 10만번 곱하는걸 10만번 반복하면 0.08초밖에 안걸림..

 

 

근데 이게 2를 10만번 곱하는걸 10만번 반복하면 35초나 걸림

 

 

3. 풀이

 

세그먼트 트리를 만드는데, 구간의 곱을 트리의 노드에 저장하는 세그먼트 트리를 만든다

 

그런데 단순히 구간의 곱을 저장하면 수가 매우 커져서 시간 복잡도가 매우 증가하므로,

 

구간의 곱을 저장하지 않고, 구간의 곱의 부호를 저장해둔다

 

리프노드에만 값의 부호를 저장해두면.. 올라오면서 계산하는 결과는 값의 부호가 된다

 

import math
from sys import stdin

#값의 부호를 구하는 함수
def sign(x):
    
    if x > 0:
        return 1
    
    elif x < 0:
        return -1
    
    else:
        
        return 0
 
#세그먼트 트리를 만들때, 값의 부호를 저장한다
def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = sign(A[A_start]) #값의 부호 저장
    
    else:
        
        A_mid = (A_start + A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #리프노드에 모두 값의 부호가 저장되어 있으니, 올라오면서 알아서 값의 부호끼리 곱셈이 된다
        tree[tree_index] = tree[2*tree_index]*tree[2*tree_index+1]

 

 

구간의 곱을 구할때는, 어차피 트리에 값의 부호가 저장되어 있으니... 계산해도 값의 부호가 나온다

 

sign함수를 쓸 필요가 없다는 말이야

 

#트리의 노드에 값의 부호가 저장되어 있으니, 계산 결과는 무조건 값의 부호가 된다.
def query_product(tree,tree_index,start,end,left,right):
    
    if left > end or right < start:
        
        return 1
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_product = query_product(tree,2*tree_index,start,mid,left,right)
    right_product = query_product(tree,2*tree_index+1,mid+1,end,left,right)

    return left_product * right_product

 

 

update 함수에는 리프노드를 만날때, 새로운 값을 넣어줄때만 값의 부호를 넣어주고...

 

그러면, 모든 노드에 값의 부호가 들어가있으니, 올라오면서 값을 업데이트해줄때, 알아서 값의 부호가 계산될 것이다

 

#리프노드까지 탐색하면서 새로운 값으로 바꿔줄때, 값의 부호를 넣어준다. 
def update(A,tree,tree_index,start,end,index,value):
    
    if index > end or index < start:
        
        return
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = sign(value) #여기만 값의 부호를 넣어서 바꿔주면..
        return
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)
    
    #자식에 모두 값의 부호로 되어 있으니까, 계산하면 값의 부호로 나온다.
    tree[tree_index] = tree[2*tree_index]*tree[2*tree_index+1]

 

 

그리고 최근에 배운 EOF처리가 필요하다..

 

마지막을 읽으면 빈문자열을 받아오는데... 그것을 split해서 int로 바꿀려하면 에러가 나니까

 

try ~ except로 에러 처리를 해준다.

 

그래서 모두 작성하면... 최종 세그먼트 트리는..

 

import math
from sys import stdin

#값의 부호를 구하는 함수
def sign(x):
    
    if x > 0:
        return 1
    
    elif x < 0:
        return -1
    
    else:
        
        return 0
 
#세그먼트 트리를 만들때, 값의 부호를 저장한다
def create_segment(A,tree,tree_index,A_start,A_end):
    
    if A_start == A_end:
        
        tree[tree_index] = sign(A[A_start]) #값의 부호 저장
    
    else:
        
        A_mid = (A_start + A_end)//2

        create_segment(A,tree,2*tree_index,A_start,A_mid)
        create_segment(A,tree,2*tree_index+1,A_mid+1,A_end)
        
        #리프노드에 모두 값의 부호가 저장되어 있으니, 올라오면서 알아서 값의 부호끼리 곱셈이 된다
        tree[tree_index] = tree[2*tree_index]*tree[2*tree_index+1]

#트리의 노드에 값의 부호가 저장되어 있으니, 계산 결과는 무조건 값의 부호가 된다.
def query_product(tree,tree_index,start,end,left,right):
    
    if left > end or right < start:
        
        return 1
    
    if left <= start and end <= right:
        
        return tree[tree_index]
    
    mid = (start+end)//2

    left_product = query_product(tree,2*tree_index,start,mid,left,right)
    right_product = query_product(tree,2*tree_index+1,mid+1,end,left,right)

    return left_product * right_product

#리프노드까지 탐색하면서 새로운 값으로 바꿔줄때, 값의 부호를 넣어준다. 
def update(A,tree,tree_index,start,end,index,value):
    
    if index > end or index < start:
        
        return
    
    if start == end:
        
        A[index] = value
        tree[tree_index] = sign(value) #여기만 값의 부호를 넣어서 바꿔주면..
        return
    
    mid = (start+end)//2

    update(A,tree,2*tree_index,start,mid,index,value)
    update(A,tree,2*tree_index+1,mid+1,end,index,value)
    
    #자식에 모두 값의 부호로 되어 있으니까, 계산하면 값의 부호로 나온다.
    tree[tree_index] = tree[2*tree_index]*tree[2*tree_index+1]

while 1:
    
    try:
        
        N,k = map(int,stdin.readline().split())

        A = list(map(int,stdin.readline().split()))

        n = 2**(math.ceil(math.log2(N)+1))-1

        tree = [0]*(n+1)

        create_segment(A,tree,1,0,N-1)

        answer = []

        for _ in range(k):

            command = stdin.readline().rstrip().split()

            if command[0] == 'C':

                update(A,tree,1,0,N-1,int(command[1])-1, int(command[2]))

            elif command[0] == 'P':

                value = query_product(tree,1,0,N-1,int(command[1])-1,int(command[2])-1)

                if value == 0:

                    answer.append('0')

                elif value == 1:

                    answer.append('+')

                elif value == -1:

                    answer.append('-')

        print(''.join(answer))
    
    except:
        
        break

 

 

4. 되돌아보기

 

직접 실험해보니까 이제 확 와닿겠지..?

 

매우 큰 수를 매우 많이 반복연산하면 시간이 상당히 오래걸린다는 점

 

그래서 맨날 알고리즘 문제 보면 곱셈 문제에서 10000009로 나눈 나머지를 구해라...

 

그래서 리스트에 저장할때 중간 계산 결과의 나머지를 저장해두면서 .. 시간을 줄이는게 이런 이유였다

 

이 문제도 마찬가지로 숫자를 1,0,-1로 치환해서 시간을 줄이는...

 

결국 문제 자체에 힌트가 있었네

 

 

TAGS.

Comments