union find 재활훈련 - union by size 복기하기
1. 문제1
4143번: Bridges and Tunnels (acmicpc.net)
2. 풀이1
서로 연결되는 간선이 주어질때마다 각 노드의 집합 크기를 구하는 문제
find 함수는 쉽게 기억하는데
from sys import stdin
def find_parent(x,parent):
if x != parent[x]:
parent[x] = find_parent(parent[x],parent)
return parent[x]
union by size 함수가 하도 안쓰다보니 기억이 잘 안나더라고
이번 기회에 복기하자고
def union(a,b,parent):
a = find_parent(a,parent)
b = find_parent(b,parent)
if a != b:
if rank[a] > rank[b]:
parent[b] = a
else:
parent[a] = b
rank[a] += rank[b]
rank[b] = rank[a]
두 정점 a,b의 부모를 찾는다
부모가 서로 다르다면, union을 수행한다
이때, a,b의 rank를 비교해서 rank가 작은 노드의 부모를 rank가 큰 노드로 한다
그리고 rank[a]에 rank[b]를 더해주고 rank[b] = rank[a]로 바꿔준다
rank[a]는 a 노드의 집합 크기를 의미한다
그러니까 rank 배열은 최초 [1]*n 같이 모든 원소가 1인 배열로 초기화해줘야한다.
t = int(stdin.readline())
for _ in range(t):
n = int(stdin.readline())
map_dict = {}
i = 0
parent = [j for j in range(2*n+1)]
rank = [1] * (2*n+1)
for _ in range(n):
a,b = stdin.readline().rstrip().split()
if map_dict.get(a,-1) == -1:
map_dict[a] = i
i += 1
if map_dict.get(b,-1) == -1:
map_dict[b] = i
i += 1
union(map_dict[a],map_dict[b],parent)
print(rank[find_parent(map_dict[a],parent)])
이 문제는 먼저 노드가 숫자로 안주어지는데 그거는 dictionary로 mapping시켜주면 된다.
그리고 노드 수가 주어지지 않고 간선의 수만 주어지는데..
간선의 수가 n개이면 가능한 최대 노드는 2n개이다.
그러므로 parent와 rank배열 초기화 할때 1~2n번까지 접근할 수 있도록 2n+1 크기의 배열로 초기화하면 된다
그리고 rank[a]를 노드 a가 나타내는 집합의 크기로 정의했잖아
union할때마다 각 노드의 집합 크기를 출력해야하는데...
print(rank[a])하면 될까??
a는 그냥 노드일 뿐이고... a가 속한 집합의 rank를 구해야지
a가 속한 집합은 어떻게 구한다고 했지?? find함수를 이용해서 a가 속하는 집합의 대표자를 찾아야겠지
대표자 노드가 a가 속하는 집합을 나타낸다고 했었거든
그래서 a의 부모를 찾아서 rank[find_parent(a,parent)]를 print해줘야겠다
3. 문제2
4. 풀이2
위 문제랑 똑같은 문제다
간선을 연결할때는 union으로 연결하고, 해당 노드가 속하는 집합의 크기를 구해야할 때는 크기를 구해서 출력해주고
중요한 점은 union by size
a,b의 부모를 찾고 a,b가 다르다면 union을 수행
rank가 적은 노드의 부모를 rank가 큰 노드로 갱신한다
union하고 나서 rank를 갱신해야하는데 rank[a] += rank[b]하고 rank[b] = rank[a]로 해준다
a와 b가 합쳐지는거니까
그리고 rank는 집합의 크기를 나타내니 최초 모든 원소가 1인 배열로 초기화
++ 두번째로는 노드 i가 속하는 집합의 크기를 rank[i]로 하면 안된다
i는 그냥 i일 뿐이고.. i가 속하는 집합은 i의 부모, i가 속한 집합의 대표자인 find_parent(i,parent)로 구해줘야한다
from sys import stdin
def find_parent(x,parent):
if x != parent[x]:
parent[x] = find_parent(parent[x],parent)
return parent[x]
def union(a,b,parent):
a = find_parent(a,parent)
b = find_parent(b,parent)
if a != b:
if rank[a] > rank[b]:
parent[b] = a
else:
parent[a] = b
rank[a] += rank[b]
rank[b] = rank[a]
N = int(stdin.readline())
n = 10**6
parent = [0]*(n+1)
rank = [1]*(n+1)
for i in range(1,n+1):
parent[i] = i
for _ in range(N):
c = stdin.readline().rstrip().split()
if c[0] == 'I':
union(int(c[1]),int(c[2]),parent)
else:
p = find_parent(int(c[1]),parent)
print(rank[p])
'알고리즘 > 그래프 이론 정복기' 카테고리의 다른 글
최소 공통 조상(Lowest Common Ancestor,LCA)문제 기본 해법 배우기1 (0) | 2023.08.07 |
---|---|
컴퓨터로 한붓그리기 하는 방법 - 오일러 경로를 찾는 알고리즘 (0) | 2023.02.19 |
union find 응용문제 풀어보면서 분리집합 개념 재활하기 (0) | 2023.01.09 |
union find 최적화 - union by size 배우기 (0) | 2022.10.05 |
최대힙, 최소힙 직접 구현하기 (0) | 2022.10.02 |