세그먼트 트리 기본문제로 연습하며 재활(반복문, 재귀 연습)
1. 문제1
2. 풀이
문제를 요약하자면 구간의 합을 구하는 세그먼트 트리
create_segment라는 함수는 만들필요가 없다
P라는 query가 index번째 박스에 a개의 구슬을 넣어주는 query라서, update 함수를 수행하면서, tree가 생성된다
실수하기 쉬운 부분이 update 함수에서.. A배열과 tree배열에서 index번째 값에 누적합을 하지 않고.. 단순히 값만 변경하는 경우가 있는데
if start == end:
A[index] = value
tree[tree_index] = value
return
문제를 보면.. index번째 박스에 구슬을 계속 놓아주는 의미라서, 누적합으로
if start == end:
A[index] += value
tree[tree_index] += value
return
기본형태인 재귀함수 버전...
import math
from sys import stdin
def update(A,tree,tree_index,start,end,index,value):
if index < start or end < index:
return
if start == end:
A[index] += value
tree[tree_index] += value
return
mid = (start + end)//2
update(A,tree,2*tree_index,start,mid,index,value)
update(A,tree,2*tree_index+1,mid+1,end,index,value)
tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]
def query_sum(tree,tree_index,start,end,left,right):
if right < start or end < left:
return 0
if left <= start and end <= right:
return tree[tree_index]
mid = (start+end)//2
left_sum = query_sum(tree,2*tree_index,start,mid,left,right)
right_sum = query_sum(tree,2*tree_index+1,mid+1,end,left,right)
return left_sum + right_sum
T = int(stdin.readline())
for _ in range(T):
b,p,q = map(int,stdin.readline().split())
A = [0]*(b)
n = 2**(math.ceil(math.log2(b))+1)-1
tree = [0]*(n+1)
for _ in range(p+q):
q,x,y = stdin.readline().rstrip().split()
x = int(x)
y = int(y)
if q == 'P':
update(A,tree,1,0,b-1,x-1,y)
else:
print(query_sum(tree,1,0,b-1,x-1,y-1))
다음은 최근에 배운 반복문 버전 연습..
중요한 점은 역시 update에 tree[index] = value로 값만 바꾸지말고.. 누적합해주는거
특히 반복문 버전에서는 구간 쿼리 구할때, [left,right]를 구한다면.. [left,right+1]을 구해줘야한다는 점
from sys import stdin
def update(tree,index,value,b):
index += b
tree[index] += value
while index > 1:
index >>= 1
tree[index] = tree[2*index] + tree[2*index+1]
def query(tree,b,left,right):
left += b
right += b
m = 0
while left < right:
if left & 1:
m += tree[left]
left += 1
if right & 1:
right -= 1
m += tree[right]
left //= 2
right //= 2
return m
T = int(stdin.readline())
for _ in range(T):
b,p,q = map(int,stdin.readline().split())
tree = [0]*(2*b)
for _ in range(p+q):
x,y,z = stdin.readline().rstrip().split()
y = int(y)
z = int(z)
if x == 'P':
update(tree,y-1,z,b)
else:
print(query(tree,b,y-1,z))
3. 문제2
12837번: 가계부 (Hard) (acmicpc.net)
4. 풀이
역시 구간의 합을 구하는 세그먼트 트리
위 문제랑 사실상 똑같다
create segment 함수를 만들 필요 없고, update함수에서 값을 "추가"하는 방식
가계부에 수입/지출을 "추가"한다는건... 가계부에서 수입/지출을 바꾼다는 의미가 아니고, 해당 수입/지출을 누적합시키는것
대충 읽으면 "생후 p일에 추가한다", "생후 p일부터 q일에 변화량을 출력한다"이게 뭔소린가 하는데..
p가 1이상 N이하이고.. "살아온 날이 N일"
생후라는건 "0일부터"라는 의미니까...
쉽게 생각하면 배열 A의 크기가 N이고.. p는 배열 A의 index
다음은 재귀 버전
이번엔 이전 풀이와는 다르게 A배열을 그냥 없애버렸다..
애초에 필요가 없으니까
import math
from sys import stdin
def update(tree,tree_index,start,end,index,value):
if index < start or index > end:
return
if start == end:
tree[tree_index] += value
return
mid = (start+end)//2
update(tree,2*tree_index,start,mid,index,value)
update(tree,2*tree_index+1,mid+1,end,index,value)
tree[tree_index] = tree[2*tree_index] + tree[2*tree_index+1]
def query(tree,tree_index,start,end,left,right):
if left > end or right < start:
return 0
if left <= start and end <= right:
return tree[tree_index]
mid = (start+end)//2
left_sum = query(tree,2*tree_index,start,mid,left,right)
right_sum = query(tree,2*tree_index+1,mid+1,end,left,right)
return left_sum + right_sum
N,q = map(int,stdin.readline().split())
n = 2**(math.ceil(math.log2(N))+1)-1
tree = [0]*(n+1)
for _ in range(q):
a,b,c = map(int,stdin.readline().split())
if a == 1:
update(tree,1,0,N-1,b-1,c)
else:
print(query(tree,1,0,N-1,b-1,c-1))
다음은 반복문 버전
역시 구간 쿼리 구할때 [left,right]를 구하고 싶으면 [left,right+1]을 넣어줘야한다는 점 기억해야
from sys import stdin
def update(tree,index,value,n):
index += n
tree[index] += value
while index > 1:
index >>= 1
tree[index] = tree[2*index] + tree[2*index+1]
def query(tree,n,left,right):
left += n
right += n
m = 0
while left < right:
if left & 1:
m += tree[left]
left += 1
if right & 1:
right -= 1
m += tree[right]
left //= 2
right //= 2
return m
n,q = map(int,stdin.readline().split())
tree = [0]*(2*n)
for _ in range(q):
a,b,c = map(int,stdin.readline().split())
if a == 1:
update(tree,b-1,c,n)
else:
print(query(tree,n,b-1,c))
5. 문제3
6213번: Balanced Lineup (acmicpc.net)
6. 풀이
조금 난이도가 높다면.. 최댓값과 최솟값을 모두 관리하는 세그먼트 트리를 만들어야한다.
재귀 버전은 최댓값을 구하는 세그먼트 트리, 최솟값을 구하는 세그먼트 트리 2개를 만들었다
1개로도 가능하긴한데..
이 문제는 update 함수가 필요없다.. 업데이트 한다는 말이 없으니까
import math
from sys import stdin
def max_create_segment(A,max_tree,tree_index,start,end):
if start == end:
max_tree[tree_index] = A[start]
else:
mid = (start+end)//2
max_create_segment(A,max_tree,2*tree_index,start,mid)
max_create_segment(A,max_tree,2*tree_index+1,mid+1,end)
max_tree[tree_index] = max(max_tree[2*tree_index],max_tree[2*tree_index+1])
def min_create_segment(A,min_tree,tree_index,start,end):
if start == end:
min_tree[tree_index] = A[start]
else:
mid = (start+end)//2
min_create_segment(A,min_tree,2*tree_index,start,mid)
min_create_segment(A,min_tree,2*tree_index+1,mid+1,end)
min_tree[tree_index] = min(min_tree[2*tree_index],min_tree[2*tree_index+1])
def query(tree,tree_index,start,end,left,right,ind):
if left > end or right < start:
return
if left <= start and end <= right:
return tree[tree_index]
mid = (start + end)//2
left_value = query(tree,2*tree_index,start,mid,left,right,ind)
right_value = query(tree,2*tree_index+1,mid+1,end,left,right,ind)
if ind == 0:
if left_value == None:
return right_value
elif right_value == None:
return left_value
else:
return min(left_value,right_value)
else:
if left_value == None:
return right_value
elif right_value == None:
return left_value
else:
return max(left_value,right_value)
N,q = map(int,stdin.readline().split())
A = []
for _ in range(N):
a = int(stdin.readline())
A.append(a)
n = 2**(math.ceil(math.log2(N))+1)-1
max_tree = [0]*(n+1)
min_tree = [0]*(n+1)
max_create_segment(A,max_tree,1,0,N-1)
min_create_segment(A,min_tree,1,0,N-1)
for _ in range(q):
a,b = map(int,stdin.readline().split())
print(query(max_tree,1,0,N-1,a-1,b-1,1) - query(min_tree,1,0,N-1,a-1,b-1,0))
다음은 반복문 버전
이번엔 세그먼트 트리 1개로 최댓값, 최솟값을 모두 관리했다
세그먼트 트리 한 원소가 [최솟값, 최댓값]을 가지도록 만든다면.. 가능하겠지
반복문 버전에서 주의할 점은 3번째 말하는데 [left,right] 구간 쿼리 구할때, [left,right+1]을 넣어줘야한다는 점
from sys import stdin
def create_segment(A,tree,n):
for i in range(n):
tree[n+i][0] = A[i]
tree[n+i][1] = A[i]
for i in range(n-1,0,-1):
tree[i][0] = min(tree[2*i][0],tree[2*i+1][0])
tree[i][1] = max(tree[2*i][1],tree[2*i+1][1])
def query(tree,n,left,right):
left += n
right += n
m1 = 1000001
m2 = 0
while left < right:
if left & 1:
m1 = min(m1,tree[left][0])
m2 = max(m2,tree[left][1])
left += 1
if right & 1:
right -= 1
m1 = min(m1,tree[right][0])
m2 = max(m2,tree[right][1])
left //= 2
right //= 2
return m2-m1
n,q = map(int,stdin.readline().split())
A = []
for _ in range(n):
a = int(stdin.readline())
A.append(a)
tree = [[0]*2 for _ in range(2*n)]
create_segment(A,tree,n)
for _ in range(q):
a,b = map(int,stdin.readline().split())
print(query(tree,n,a-1,b))
'알고리즘 > 세그먼트 트리' 카테고리의 다른 글
binary indexed tree(BIT, fenwick tree) 간단하게 배우기 (0) | 2024.02.13 |
---|---|
세그먼트 트리 응용단계 - 이분탐색과 콜라보1 (0) | 2023.05.09 |
반복문으로 구현하는 세그먼트 트리(iterative segment tree) 배우기 (0) | 2023.05.04 |
세그먼트 트리 응용 - 2개의 쿼리를 동시에 수행할 수 있는 2차원 세그먼트 트리 (0) | 2022.12.15 |
세그먼트 트리 응용 - XOR 연산을 수행하는 트리 (0) | 2022.12.15 |