Python 제곱근 연산 **(1/2)를 함부로 하면 안되는 이유

10216번: Count Circle Groups (acmicpc.net)

 

10216번: Count Circle Groups

백준이는 국방의 의무를 수행하기 위해 떠났다. 혹독한 훈련을 무사히 마치고 나서, 정말 잘 생겼고 코딩도 잘하는 백준은 그 특기를 살려 적군의 진영을 수학적으로 분석하는 일을 맡게 되었

www.acmicpc.net

 

 

각 위치 A[i]가 (x,y)를 중심으로 가지고 반지름은 r인 원으로 주어지는데...

 

두 원이 서로 닿거나 겹친다면 두 원은 통신이 가능하다.

 

두 원이 서로 닿거나 겹칠려면 어떻게 해야하나?

 

 

 

두 원의 중심사이 거리 AB랑 r1+r2를 비교하면 된다.

 

AB = (r1+r2)이면 두 원이 서로 닿는것이고 AB < (r1+r2)이면 두 원이 서로 겹치는 것이다.

 

원의 개수가 3000개이기 때문에 두 원 i,j를 $O(N^{2})$으로 순회한 다음,

 

두 원 사이 거리와 두 원의 반지름 합을 비교해서 AB <= (r1+r2)이면 두 원을 하나의 그룹으로 합친다.

 

하나의 그룹으로 합치는 방법은? union find로 두 점을 합칠 수 있다.

 

이렇게 조건을 만족하면 계속 union연산으로 두 점을 합친 다음, 마지막에 0부터 n-1 점을 순회해서...

 

이들의 대표자를 find_parent로 모두 찾아 set에 넣은 다음 set의 크기를 구하면 그룹의 개수가 된다.

 

알고 있겠지만 parent[i]는 i의 최종 대표자가 아니다.

 

최종 대표자는 find_parent(i)로 찾아야한다

 

from sys import stdin

def find_parent(x,parent):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent[x],parent)
    
    return parent[x]

def union(a,b):
    
    a = find_parent(a,parent)
    b = find_parent(b,parent)
    
    if a != b:

        if parent[a] < parent[b]:

            parent[a] = b

        else:

            parent[b] = a

def distance(x,y,z,w):
    
    return ((x-z)**2 + (y-w)**2)

T = int(stdin.readline())

for _ in range(T):
    
    n = int(stdin.readline())

    points = []

    for _ in range(n):
        
        x,y,r = map(int,stdin.readline().split())

        points.append((x,y,r))
    
    parent = [i for i in range(n)]

    for i in range(n-1):
        
        for j in range(i+1,n):
            
            x,y,r1 = points[i]
            z,w,r2 = points[j]

            d = distance(x,y,z,w)

            if d <= (r1+r2)*(r1+r2):
                
                union(i,j)
    
    s = set()

    for i in range(n):
        
        s.add(find_parent(i,parent))
    
    print(len(s))

 

 

마지막에 순회하면서 set에 넣은 다음 set의 크기를 구하는것보다 효과적으로 그룹의 개수를 찾을 수 있는데

 

처음에 union 하기 전에는 n개의 점이 있다.

 

초기에는 n개의 그룹이 있는 것이다.

 

순회하면서 조건을 만족하는 두 점 i,j를 union하면  2개의 그룹을 서로 합쳐 1개의 그룹으로 만드므로,

 

현재 그룹의 개수에서 1씩 감소한다.

 

이렇게 하면 마지막에 별도로 순회해서 set에 넣는 과정을 안하더라도 바로 그룹의 개수를 찾을 수 있다.

 

parent = [i for i in range(n)]
answer = n

for i in range(n-1):

    for j in range(i+1,n):

        x,y,r1 = points[i]
        z,w,r2 = points[j]

        d = distance(x,y,z,w)

        if d <= (r1+r2)*(r1+r2):

            I = find_parent(i,parent)
            J = find_parent(j,parent)

            if I != J:

                union(I,J)
                answer -= 1
print(answer)

 

 

union find를 하지 않고 다른 방법으로도 할 수 있는데,

 

두 점을 하나로 합치는 것은 두 점 사이를 이동할 수 있다는 의미로, 초기에 n개의 점이 있고 0개의 간선이 있는 그래프에서

 

i,j가 서로 겹치거나 닿는 원이라면 두 점 사이를 이동할 수 있다는 의미에서 i,j번 노드 사이에 간선을 추가해준다.

 

이 과정이 끝나면 하나의 그래프가 만들어지고 이 그래프를 BFS하면 그룹의 개수를 찾을 수 있다.

 

그룹의 개수를 찾는 방법은?

 

새로 만들어진 그래프에서 0번부터 n-1번까지 순회를 하는데, 방문하지 않은 노드 i번 노드라면 i번에서 BFS를 시작해서,

 

방문할 수 있는 모든 노드를 방문한다.

 

방문이 끝나면 그룹을 하나 찾은 것이다.

 

그리고 다시 0번부터 N-1번까지 순회하면서 아직 방문하지 않은 노드를 찾고,,.... bfs를 반복하면서...

 

모든 노드를 방문할 때까지 위 과정을 반복하면 BFS를 수행한 횟수가 그룹의 개수가 되는 이미 배운 기본 테크닉

 

from sys import stdin
from collections import deque

def bfs(i):
    
    queue = deque([i])
    
    while queue:
        
        i = queue.popleft()
        
        for v in graph[i]:
            
            if visited[v] == 0:
                
                visited[v] = 1
                queue.append(v)
    
    return 1
        
def distance(x,y,z,w):
    
    return ((x-z)**2 + (y-w)**2)

T = int(stdin.readline())

for _ in range(T):
    
    n = int(stdin.readline())

    points = []

    for _ in range(n):
        
        x,y,r = map(int,stdin.readline().split())

        points.append((x,y,r))
    
    graph = [[] for _ in range(n)]
    
    for i in range(n-1):
        
        for j in range(i+1,n):
            
            x,y,r1 = points[i]
            z,w,r2 = points[j]

            d = distance(x,y,z,w)

            if d <= (r1+r2)*(r1+r2):
                
                graph[i].append(j)
                graph[j].append(i)
    
    visited = [0]*n
    count = 0
    
    for i in range(n):
        
        if visited[i] == 0:
            
            visited[i] = 1
            count += bfs(i)
    
    print(count)

 

 

 

근데 사실 이 글을 쓴 이유는... 문제를 푸는 것보다 다른 이유가 있다

 

처음에는 다음과 같이 제출했는데 시간초과를 받았다.

 

from sys import stdin

def find_parent(x,parent):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent[x],parent)
    
    return parent[x]

def union(a,b):
    
    a = find_parent(a,parent)
    b = find_parent(b,parent)
    
    if parent[a] < parent[b]:

        parent[a] = b

    else:

        parent[b] = a

def distance(x,y,z,w):
    
    return ((x-z)**2 + (y-w)**2)**(1/2)

T = int(stdin.readline())

for _ in range(T):
    
    n = int(stdin.readline())

    points = []

    for _ in range(n):
        
        x,y,r = map(int,stdin.readline().split())

        points.append((x,y,r))
    
    parent = [i for i in range(n)]

    for i in range(n-1):
        
        for j in range(i+1,n):
            
            x,y,r1 = points[i]
            z,w,r2 = points[j]

            d = distance(x,y,z,w)

            if d <= (r1+r2):
                
                union(i,j)
    
    s = set()

    for i in range(n):
        
        s.add(find_parent(i,parent))
    
    print(len(s))

 

 

통과한 코드랑 차이는 거리를 계산하고 비교하는 과정에 차이가 있다.

 

이 코드는 거리를 계산할때, 제곱근을 반환했고..

 

def distance(x,y,z,w):
    
    return ((x-z)**2 + (y-w)**2)**(1/2)

 

 

통과한 코드는 제곱한 값을 반환했다.

 

def distance(x,y,z,w):
    
    return ((x-z)**2 + (y-w)**2)

 

 

당연히 반환한 거리에 따라 비교하는 부등식이 달라진다.

 

제곱근을 반환하면 그 값이 거리가 되므로 AB <= (r1+r2)

 

제곱을 반환하면 양변을 제곱한 것과 같으므로 AB**2 <= (r1+r2)*(r1+r2)

 

d = distance(x,y,z,w)

if d <= (r1+r2)*(r1+r2):

    I = find_parent(i,parent)
    J = find_parent(j,parent)

 

 

이게 도대체 무슨 차이라고 그러는건지... 생각하는데

 

추측할 수 있는건 거듭제곱이 O(logN)이니까 제곱근하면 로그시간이 들고 제곱근하지 않으면 시간이 안드니까

 

이게 최악의 경우 3000*3000번하면 시간 차이가 생기는 것 같다라고 생각을 했다..

 

여러가지 자료를 찾아보면 chatgpt 피셜 제곱근 연산은 계산비용이 크다

 

 

 

 

https://www.geeksforgeeks.org/square-root-of-an-integer/

 

Square root of an integer - GeeksforGeeks

A Computer Science portal for geeks. It contains well written, well thought and well explained computer science and programming articles, quizzes and practice/competitive programming/company interview Questions.

www.geeksforgeeks.org

 

여기서 **(1/2) 연산이 O(logX)라고 함

 

https://stackoverflow.com/questions/327002/which-is-faster-in-python-x-5-or-math-sqrtx

 

Which is faster in Python: x**.5 or math.sqrt(x)?

I've been wondering this for some time. As the title say, which is faster, the actual function or simply raising to the half power? UPDATE This is not a matter of premature optimization. This is ...

stackoverflow.com

 

math 라이브러리의 math.sqrt()와 **(1/2)을 비교했을때 **(1/2)가 압도적으로 느리다고 한다

 

실제로 math.sqrt()로 바꾸면 시간초과없이 통과할 수 있다

 

from sys import stdin
import math

def find_parent(x,parent):
    
    if x != parent[x]:
        
        parent[x] = find_parent(parent[x],parent)
    
    return parent[x]

def union(a,b):

    if parent[a] < parent[b]:

        parent[a] = b

    else:

        parent[b] = a
        
def distance(x,y,z,w):
    
    return math.sqrt((x-z)**2 + (y-w)**2)

T = int(stdin.readline())

for _ in range(T):
    
    n = int(stdin.readline())

    points = []

    for _ in range(n):
        
        x,y,r = map(int,stdin.readline().split())

        points.append((x,y,r))
    
    parent = [i for i in range(n)]
    answer = n
    
    for i in range(n-1):
        
        for j in range(i+1,n):
            
            x,y,r1 = points[i]
            z,w,r2 = points[j]

            d = distance(x,y,z,w)

            if d <= (r1+r2):
                
                I = find_parent(i,parent)
                J = find_parent(j,parent)
                
                if I != J:
                    
                    union(I,J)
                    answer -= 1
    print(answer)

 

TAGS.

Comments