union find 최적화 - union by size 배우기

1. 문제

 

4195번: 친구 네트워크 (acmicpc.net)

 

4195번: 친구 네트워크

첫째 줄에 테스트 케이스의 개수가 주어진다. 각 테스트 케이스의 첫째 줄에는 친구 관계의 수 F가 주어지며, 이 값은 100,000을 넘지 않는다. 다음 F개의 줄에는 친구 관계가 생긴 순서대로 주어진

www.acmicpc.net

 

친구관계인 간선이 주어질때마다 각 사람(노드)에 속한 집합의 크기를 구하는 문제

 

 

2. 풀이

 

이 문제가 어렵게 느껴지는 이유는 일단 노드 수가 주어지지 않는다..

 

간선의 수 f만 주어지는데.. 이럴 경우 최대 노드 수는 2f개 일 것이다.

 

사실 최대 노드 수만 알아도 문제를 푸는데 상관이 없다.

 

왜냐하면 union이 일어날때마다 각 노드에 속한 집합의 크기만 구하면 되니까

 

parent = [i for i in range(2*f+1)]
rank = [1]*(2*f+1)

 

다음으로 각 노드가 번호가 아니고 영어이름으로 주어지니까, 영어이름을 노드 번호로 바꿔줄 필요가 있다

 

사전에 담아둬서 사전의 key를 영어이름으로 하고 번호를 value로 해서 바꿔줄 수 있다

 

for _ in range(f):
        
    a,b = stdin.readline().rstrip().split()

    if ind_dict.get(a,0) == 0:

        ind_dict[a] = ind
        ind += 1

    else:

        pass

    if ind_dict.get(b,0) == 0:

        ind_dict[b] = ind
        ind += 1

    else:

        pass

 

사전에 이미 key가 존재한다면, 대응하는 value를 가지고 오면 번호로 바꿀 수 있지만

 

아직 key가 존재하지 않는다면 새로 나온 노드니까 새로 key를 추가하고 추가할때마다 번호는 1씩 늘려줘야겠지

 

다음으로 이제 union할때마다 각 대표자가 속한 집합의 크기를 구해야하는데

 

어떻게 가능하냐면, union by size를 이용하면 가능하다

 

내가 알고 있는 union by rank는 rank배열에 트리의 높이를 저장했어가지고 애초에 초기 단계가

 

rank = [0]*(v+1)이었다.

 

하지만 union by size는 rank 배열에 트리를 이루는 집합의 크기를 저장한다.

 

대표자가 i번인 집합의 크기를 rank[i]에 저장하는 것이다.

 

그러면 초기에는 각각 1개씩 있으니까 rank = [1]*(v+1)로 시작을 한다.

 

그러면 이제 집합의 크기가 작은 노드의 부모를 집합의 크기가 큰 노드로 하는 것은 동일하다.

 

rank[a] > rank[b]이면, parent[b] = a로 하는 것은 동일하다.

 

그러면 이제 a와 b가 union되어 합쳐지는 것이니까, 집합 a안에 집합 b가 들어간다는 소리

 

그러니까 a집합의 크기는 b집합의 크기를 더해준 rank[a] = rank[a] + rank[b]가 된다.

 

하지만 집합 a와 집합 b가 합쳐지면, 대표자 a와 b가 가리키는 집합은 동일한 집합이고, 그러므로 집합의 크기는 동일하다.

 

따라서 rank[b] = rank[a]

 

 

그리고 마지막으로 중요한 점은 union by size는 a와 b가 다를때만 수행해야한다는 점이다.

 

근데 다른 것도 사실 a와 b가 같으면 수행안해도 되는데, 수행해도 상관이 없어서 고려하지 않았음

 

하지만 union by size는 a와 b가 같을때도 수행해버리면 union이 안되는거니까 집합 크기는 그대로인데

 

집합 크기는 2배로 늘어나버리는 잘못된 로직을 짤 가능성이 높다

 

그래서 union by size의 코드는 다음과 같다.

 

##union by size

##rank 배열을 집합의 크기로 저장

def union_parent(parent,a,b):

    
    ##a,b의 대표자를 찾고
    
    a = find_parent(parent,a)
    b = find_parent(parent,b)
    
    ##a,b가 서로 다르다면 union을 수행
    
    if a != b:
        
    ##크기가 작은 집합의 부모를 크기가 큰 집합으로
        
        if rank[a] > rank[b]:
            
            parent[b] = a
        
        else:
            
            parent[a] = b
        
        ##union 되어서 집합의 크기를 변경
        
        rank[a] += rank[b]
        
        rank[b] = rank[a]
    
    

parent = [i for i in range(v+1)] ##최초 부모는 자기 자신으로 설정

rank = [1]*(v+1) ##각 집합의 크기를 저장

 

최종 코드..

 

from sys import stdin

def find_parent(parent,x):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent,parent[x])

    return parent[x]

def union_parent(parent,a,b):
    
    a = find_parent(parent,a)
    b = find_parent(parent,b)

    if a != b:

        if rank[a] > rank[b]:
            
            parent[b] = a
        
        else:
            
            parent[a] = b

        
        rank[a] += rank[b]
        rank[b] = rank[a]

    return rank[find_parent(parent,a)]
        

    
T = int(stdin.readline())

for _ in range(1,T+1):
    
    f = int(stdin.readline())

    parent = [i for i in range(2*f+1)]
    rank = [1]*(2*f+1)

    ind_dict = {}

    ind = 1

    for _ in range(f):
        
        a,b = stdin.readline().rstrip().split()

        if ind_dict.get(a,0) == 0:
            
            ind_dict[a] = ind
            ind += 1
        
        else:
            
            pass
        
        if ind_dict.get(b,0) == 0:
            
            ind_dict[b] = ind
            ind += 1
        
        else:
            
            pass
        
        print(union_parent(parent,ind_dict[a],ind_dict[b]))
TAGS.

Comments