union find 재활훈련 - union by size 복기하기

1. 문제1

 

4143번: Bridges and Tunnels (acmicpc.net)

 

4143번: Bridges and Tunnels

The first line of input contains one integer specifying the number of test cases to follow. Each test case begins with a line containing an integer n, the total number of bridges or tunnels built. This number will be no more than 100,000. We assume that al

www.acmicpc.net

 

2. 풀이1

 

서로 연결되는 간선이 주어질때마다 각 노드의 집합 크기를 구하는 문제

 

find 함수는 쉽게 기억하는데

 

from sys import stdin

def find_parent(x,parent):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent[x],parent)
    
    return parent[x]

 

union by size 함수가 하도 안쓰다보니 기억이 잘 안나더라고

 

이번 기회에 복기하자고

 

def union(a,b,parent):
    
    a = find_parent(a,parent)
    b = find_parent(b,parent)

    if a != b:
        
        if rank[a] > rank[b]:
            
            parent[b] = a
        
        else:
            
            parent[a] = b
        
        rank[a] += rank[b]

        rank[b] = rank[a]

 

두 정점 a,b의 부모를 찾는다

 

부모가 서로 다르다면, union을 수행한다

 

이때, a,b의 rank를 비교해서 rank가 작은 노드의 부모를 rank가 큰 노드로 한다

 

그리고 rank[a]에 rank[b]를 더해주고 rank[b] = rank[a]로 바꿔준다

 

rank[a]는 a 노드의 집합 크기를 의미한다

 

그러니까 rank 배열은 최초 [1]*n 같이 모든 원소가 1인 배열로 초기화해줘야한다.

 

t = int(stdin.readline())

for _ in range(t):
    
    n = int(stdin.readline())

    map_dict = {}

    i = 0
    
    parent = [j for j in range(2*n+1)]

    rank = [1] * (2*n+1)

    for _ in range(n):
        
        a,b = stdin.readline().rstrip().split()

        if map_dict.get(a,-1) == -1:

            map_dict[a] = i

            i += 1
        
        if map_dict.get(b,-1) == -1:

            map_dict[b] = i

            i += 1
            
        union(map_dict[a],map_dict[b],parent)

        print(rank[find_parent(map_dict[a],parent)])

 

이 문제는 먼저 노드가 숫자로 안주어지는데 그거는 dictionary로 mapping시켜주면 된다.

 

그리고 노드 수가 주어지지 않고 간선의 수만 주어지는데..

 

간선의 수가 n개이면 가능한 최대 노드는 2n개이다.

 

그러므로 parent와 rank배열 초기화 할때 1~2n번까지 접근할 수 있도록 2n+1 크기의 배열로 초기화하면 된다

 

그리고 rank[a]를 노드 a가 나타내는 집합의 크기로 정의했잖아

 

union할때마다 각 노드의 집합 크기를 출력해야하는데...

 

print(rank[a])하면 될까??

 

a는 그냥 노드일 뿐이고... a가 속한 집합의 rank를 구해야지

 

a가 속한 집합은 어떻게 구한다고 했지?? find함수를 이용해서 a가 속하는 집합의 대표자를 찾아야겠지

 

대표자 노드가 a가 속하는 집합을 나타낸다고 했었거든

 

그래서 a의 부모를 찾아서 rank[find_parent(a,parent)]를 print해줘야겠다

 

 

3. 문제2

 

18116번: 로봇 조립 (acmicpc.net)

 

18116번: 로봇 조립

성규는 로봇을 조립해야 한다. 상자 안에는 여러 로봇의 부품들이 섞여 있다. 그런데 어떤 부품이 어느 로봇의 부품인지 표시가 되어있지 않다. 호재는 전자과라서 두 부품을 보면 같은 로봇의

www.acmicpc.net

 

4. 풀이2

 

위 문제랑 똑같은 문제다

 

간선을 연결할때는 union으로 연결하고, 해당 노드가 속하는 집합의 크기를 구해야할 때는 크기를 구해서 출력해주고

 

중요한 점은 union by size

 

a,b의 부모를 찾고 a,b가 다르다면 union을 수행

 

rank가 적은 노드의 부모를 rank가 큰 노드로 갱신한다

 

union하고 나서 rank를 갱신해야하는데 rank[a] += rank[b]하고 rank[b] = rank[a]로 해준다

 

a와 b가 합쳐지는거니까

 

그리고 rank는 집합의 크기를 나타내니 최초 모든 원소가 1인 배열로 초기화

 

++ 두번째로는 노드 i가 속하는 집합의 크기를 rank[i]로 하면 안된다

 

i는 그냥 i일 뿐이고.. i가 속하는 집합은 i의 부모, i가 속한 집합의 대표자인 find_parent(i,parent)로 구해줘야한다

 

from sys import stdin

def find_parent(x,parent):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent[x],parent)
    
    return parent[x]

def union(a,b,parent):
    
    a = find_parent(a,parent)
    b = find_parent(b,parent)

    if a != b:

        if rank[a] > rank[b]:
            
            parent[b] = a
        
        else:
            
            parent[a] = b
        
        rank[a] += rank[b]
        rank[b] = rank[a]

N = int(stdin.readline())

n = 10**6
parent = [0]*(n+1)
rank = [1]*(n+1)

for i in range(1,n+1):
    
    parent[i] = i

for _ in range(N):
    
    c = stdin.readline().rstrip().split()

    if c[0] == 'I':
        
        union(int(c[1]),int(c[2]),parent)
    
    else:
        
        p = find_parent(int(c[1]),parent)
        
        print(rank[p])
TAGS.

Comments