대표자가 같으면 union하지 말아야한다 (union by rank 주의점)

https://atcoder.jp/contests/abc420/tasks/abc420_e

 

E - Reachability Query

AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

atcoder.jp

 

n개의 정점과 0개의 간선이 주어진다.

 

각 정점은 1번부터 n번까지 번호가 있고, 처음에 모든 정점은 흰색이다.

 

q개의 쿼리가 주어지는데 각 쿼리는 3종류중 하나이다.

 

1) u와 v를 무방향 간선으로 연결

 

2) v가 흰색이면 검은색으로 바꾸고 검은색이면 흰색으로 바꿈

 

3) 정점 v에서 검은색 정점에 0개 이상의 간선을 타고 도달할 수 있는지 검사한다. 

 

3번 쿼리가 주어질때, 가능하면 Yes, 불가능하면 No를 출력

 

--------------------------------------------------------------------------------------------------------------------------------

 

1번 쿼리가 주어질때, u와 v를 union하고...

 

2번 쿼리가 주어지면 각 정점의 색을 바꾼다.

 

3번 쿼리가 주어질때, 어떤 정점 v가 속하는 그룹을 찾고, 그 그룹에 검은색 정점이 존재하는지만 체크하면 된다.

 

그룹내에서는 서로 도달가능이기 때문이다.

 

따라서 각 그룹의 검은색 정점의 개수를 알고 있다면, 3번 쿼리가 들어올때 v의 그룹을 find로 찾고

 

해당 그룹의 검은색 정점의 개수가 1이상이면 Yes, 아니면 No를 출력

 

검은색 정점의 개수는 rank로 관리하면 된다.

 

https://deepdata.tistory.com/433

 

union find 알고리즘 최적화 -경로압축과 rank union-

1. 기본 union find 알고리즘의 문제점 기본적인 union find 알고리즘을 사용할 경우 union을 수행하면.. 예를 들어 최악의 경우 위와 같은 union이 형성될 수 있다. 1의 대표자를 찾을 때, 2번, 3번, 4번을

deepdata.tistory.com

 

 

https://deepdata.tistory.com/676

 

union find 재활훈련 - union by size 복기하기

1. 문제1 4143번: Bridges and Tunnels (acmicpc.net) 4143번: Bridges and Tunnels The first line of input contains one integer specifying the number of test cases to follow. Each test case begins with a line containing an integer n, the total number of br

deepdata.tistory.com

 

 

 

여기까지 생각했으면 맞은거나 다름없긴한데...

 

처음에 이렇게 하다가 틀렸다

 

def find(x):
    
    if parent[x] != x:
        
        parent[x] = find(parent[x])
    
    return parent[x]

def union(a,b):
    
    a = find(a)
    b = find(b)

    if parent[a] > parent[b]:
        
        parent[a] = b

        rank[b] += rank[a]

        rank[a] = rank[b]
    
    else:
        
        parent[b] = a

        rank[a] += rank[b]

        rank[b] = rank[a]

n,q = map(int,input().split())

A = [0]*(n+1)

parent = [i for i in range(n+1)]
rank = [0]*(n+1)

for _ in range(q):
    
    query = list(map(int,input().split()))

    if query[0] == 1:
        
        u,v = query[1],query[2]
        
        union(u,v)

    elif query[0] == 2:
        
        v = query[1]

        if A[v] == 0:
            
            A[v] = 1

            rank[find(v)] += 1
        
        else:
            
            A[v] = 0

            rank[find(v)] -= 1

    else:
        
        v = query[1]

        v = find(v)

        if rank[v] >= 1:
            
            print('Yes')
        
        else:
            
            print('No')

 

 

 

자세히 보니 union이 이상하더라고

 

def union(a,b):
    
    a = find(a)
    b = find(b)

    if parent[a] > parent[b]:
        
        parent[a] = b

        rank[b] += rank[a]

        rank[a] = rank[b]
    
    else:
        
        parent[b] = a

        rank[a] += rank[b]

        rank[b] = rank[a]

 

 

왜 union by rank할때 a != b이면 하는지 드디어 깨달았다

 

def union(a,b,parent):
    
    a = find_parent(a,parent)
    b = find_parent(b,parent)

    if a != b:
        
        if rank[a] > rank[b]:
            
            parent[b] = a
        
        else:
            
            parent[a] = b
        
        rank[a] += rank[b]

        rank[b] = rank[a]

 

 

a와 b의 대표자가 같은데, rank[a] += rank[b], rank[b] = rank[a]해버리면...

 

rank가 중복이 되잖아!

 

아래와 같이 해버리면 find(a)와 find(b)가 같을때 else문으로 들어가고.. 이거는 상관이 없는데

 

rank[a] += rank[b], rank[b] = rank[a]해버리면..

 

find(a), find(b)가 같으면 두 집단을 합치면 안되는데 rank가 중복이 되어 2배로 커져버리니까..

 

여기서 문제가 생긴다

 

def union(a,b):
    
    a = find(a)
    b = find(b)

    if parent[a] > parent[b]:
        
        parent[a] = b

        rank[b] += rank[a]

        rank[a] = rank[b]
    
    else:
        
        parent[b] = a

        rank[a] += rank[b]

        rank[b] = rank[a]

 

 

그래서 이것만 추가하면 정답

 

def find(x):
    
    if parent[x] != x:
        
        parent[x] = find(parent[x])
    
    return parent[x]

def union(a,b):
    
    a = find(a)
    b = find(b)

    if a != b:

        if parent[a] > parent[b]:
            
            parent[a] = b

            rank[b] += rank[a]

            rank[a] = rank[b]
        
        else:
            
            parent[b] = a

            rank[a] += rank[b]

            rank[b] = rank[a]

n,q = map(int,input().split())

A = [0]*(n+1)

parent = [i for i in range(n+1)]
rank = [0]*(n+1)

for _ in range(q):
    
    query = list(map(int,input().split()))

    if query[0] == 1:
        
        u,v = query[1],query[2]
        
        union(u,v)

    elif query[0] == 2:
        
        v = query[1]

        if A[v] == 0:
            
            A[v] = 1

            rank[find(v)] += 1
        
        else:
            
            A[v] = 0

            rank[find(v)] -= 1

    else:
        
        v = query[1]

        v = find(v)

        if rank[v] >= 1:
            
            print('Yes')
        
        else:
            
            print('No')

 

 

union by rank할때는 find(a), find(b)가 같을때 rank를 합치면 안된다는 것을 기억

728x90