Python 제곱근 연산 **(1/2)를 함부로 하면 안되는 이유
10216번: Count Circle Groups (acmicpc.net)
10216번: Count Circle Groups
백준이는 국방의 의무를 수행하기 위해 떠났다. 혹독한 훈련을 무사히 마치고 나서, 정말 잘 생겼고 코딩도 잘하는 백준은 그 특기를 살려 적군의 진영을 수학적으로 분석하는 일을 맡게 되었
www.acmicpc.net
각 위치 A[i]가 (x,y)를 중심으로 가지고 반지름은 r인 원으로 주어지는데...
두 원이 서로 닿거나 겹친다면 두 원은 통신이 가능하다.
두 원이 서로 닿거나 겹칠려면 어떻게 해야하나?
두 원의 중심사이 거리 AB랑 r1+r2를 비교하면 된다.
AB = (r1+r2)이면 두 원이 서로 닿는것이고 AB < (r1+r2)이면 두 원이 서로 겹치는 것이다.
원의 개수가 3000개이기 때문에 두 원 i,j를 $O(N^{2})$으로 순회한 다음,
두 원 사이 거리와 두 원의 반지름 합을 비교해서 AB <= (r1+r2)이면 두 원을 하나의 그룹으로 합친다.
하나의 그룹으로 합치는 방법은? union find로 두 점을 합칠 수 있다.
이렇게 조건을 만족하면 계속 union연산으로 두 점을 합친 다음, 마지막에 0부터 n-1 점을 순회해서...
이들의 대표자를 find_parent로 모두 찾아 set에 넣은 다음 set의 크기를 구하면 그룹의 개수가 된다.
알고 있겠지만 parent[i]는 i의 최종 대표자가 아니다.
최종 대표자는 find_parent(i)로 찾아야한다
from sys import stdin
def find_parent(x,parent):
if x != parent[x]:
parent[x] = find_parent(parent[x],parent)
return parent[x]
def union(a,b):
a = find_parent(a,parent)
b = find_parent(b,parent)
if a != b:
if parent[a] < parent[b]:
parent[a] = b
else:
parent[b] = a
def distance(x,y,z,w):
return ((x-z)**2 + (y-w)**2)
T = int(stdin.readline())
for _ in range(T):
n = int(stdin.readline())
points = []
for _ in range(n):
x,y,r = map(int,stdin.readline().split())
points.append((x,y,r))
parent = [i for i in range(n)]
for i in range(n-1):
for j in range(i+1,n):
x,y,r1 = points[i]
z,w,r2 = points[j]
d = distance(x,y,z,w)
if d <= (r1+r2)*(r1+r2):
union(i,j)
s = set()
for i in range(n):
s.add(find_parent(i,parent))
print(len(s))
마지막에 순회하면서 set에 넣은 다음 set의 크기를 구하는것보다 효과적으로 그룹의 개수를 찾을 수 있는데
처음에 union 하기 전에는 n개의 점이 있다.
초기에는 n개의 그룹이 있는 것이다.
순회하면서 조건을 만족하는 두 점 i,j를 union하면 2개의 그룹을 서로 합쳐 1개의 그룹으로 만드므로,
현재 그룹의 개수에서 1씩 감소한다.
이렇게 하면 마지막에 별도로 순회해서 set에 넣는 과정을 안하더라도 바로 그룹의 개수를 찾을 수 있다.
parent = [i for i in range(n)]
answer = n
for i in range(n-1):
for j in range(i+1,n):
x,y,r1 = points[i]
z,w,r2 = points[j]
d = distance(x,y,z,w)
if d <= (r1+r2)*(r1+r2):
I = find_parent(i,parent)
J = find_parent(j,parent)
if I != J:
union(I,J)
answer -= 1
print(answer)
union find를 하지 않고 다른 방법으로도 할 수 있는데,
두 점을 하나로 합치는 것은 두 점 사이를 이동할 수 있다는 의미로, 초기에 n개의 점이 있고 0개의 간선이 있는 그래프에서
i,j가 서로 겹치거나 닿는 원이라면 두 점 사이를 이동할 수 있다는 의미에서 i,j번 노드 사이에 간선을 추가해준다.
이 과정이 끝나면 하나의 그래프가 만들어지고 이 그래프를 BFS하면 그룹의 개수를 찾을 수 있다.
그룹의 개수를 찾는 방법은?
새로 만들어진 그래프에서 0번부터 n-1번까지 순회를 하는데, 방문하지 않은 노드 i번 노드라면 i번에서 BFS를 시작해서,
방문할 수 있는 모든 노드를 방문한다.
방문이 끝나면 그룹을 하나 찾은 것이다.
그리고 다시 0번부터 N-1번까지 순회하면서 아직 방문하지 않은 노드를 찾고,,.... bfs를 반복하면서...
모든 노드를 방문할 때까지 위 과정을 반복하면 BFS를 수행한 횟수가 그룹의 개수가 되는 이미 배운 기본 테크닉
from sys import stdin
from collections import deque
def bfs(i):
queue = deque([i])
while queue:
i = queue.popleft()
for v in graph[i]:
if visited[v] == 0:
visited[v] = 1
queue.append(v)
return 1
def distance(x,y,z,w):
return ((x-z)**2 + (y-w)**2)
T = int(stdin.readline())
for _ in range(T):
n = int(stdin.readline())
points = []
for _ in range(n):
x,y,r = map(int,stdin.readline().split())
points.append((x,y,r))
graph = [[] for _ in range(n)]
for i in range(n-1):
for j in range(i+1,n):
x,y,r1 = points[i]
z,w,r2 = points[j]
d = distance(x,y,z,w)
if d <= (r1+r2)*(r1+r2):
graph[i].append(j)
graph[j].append(i)
visited = [0]*n
count = 0
for i in range(n):
if visited[i] == 0:
visited[i] = 1
count += bfs(i)
print(count)
근데 사실 이 글을 쓴 이유는... 문제를 푸는 것보다 다른 이유가 있다
처음에는 다음과 같이 제출했는데 시간초과를 받았다.
from sys import stdin
def find_parent(x,parent):
if x != parent[x]:
parent[x] = find_parent(parent[x],parent)
return parent[x]
def union(a,b):
a = find_parent(a,parent)
b = find_parent(b,parent)
if parent[a] < parent[b]:
parent[a] = b
else:
parent[b] = a
def distance(x,y,z,w):
return ((x-z)**2 + (y-w)**2)**(1/2)
T = int(stdin.readline())
for _ in range(T):
n = int(stdin.readline())
points = []
for _ in range(n):
x,y,r = map(int,stdin.readline().split())
points.append((x,y,r))
parent = [i for i in range(n)]
for i in range(n-1):
for j in range(i+1,n):
x,y,r1 = points[i]
z,w,r2 = points[j]
d = distance(x,y,z,w)
if d <= (r1+r2):
union(i,j)
s = set()
for i in range(n):
s.add(find_parent(i,parent))
print(len(s))
통과한 코드랑 차이는 거리를 계산하고 비교하는 과정에 차이가 있다.
이 코드는 거리를 계산할때, 제곱근을 반환했고..
def distance(x,y,z,w):
return ((x-z)**2 + (y-w)**2)**(1/2)
통과한 코드는 제곱한 값을 반환했다.
def distance(x,y,z,w):
return ((x-z)**2 + (y-w)**2)
당연히 반환한 거리에 따라 비교하는 부등식이 달라진다.
제곱근을 반환하면 그 값이 거리가 되므로 AB <= (r1+r2)
제곱을 반환하면 양변을 제곱한 것과 같으므로 AB**2 <= (r1+r2)*(r1+r2)
d = distance(x,y,z,w)
if d <= (r1+r2)*(r1+r2):
I = find_parent(i,parent)
J = find_parent(j,parent)
이게 도대체 무슨 차이라고 그러는건지... 생각하는데
추측할 수 있는건 거듭제곱이 O(logN)이니까 제곱근하면 로그시간이 들고 제곱근하지 않으면 시간이 안드니까
이게 최악의 경우 3000*3000번하면 시간 차이가 생기는 것 같다라고 생각을 했다..
여러가지 자료를 찾아보면 chatgpt 피셜 제곱근 연산은 계산비용이 크다
https://www.geeksforgeeks.org/square-root-of-an-integer/
Square root of an integer - GeeksforGeeks
A Computer Science portal for geeks. It contains well written, well thought and well explained computer science and programming articles, quizzes and practice/competitive programming/company interview Questions.
www.geeksforgeeks.org
여기서 **(1/2) 연산이 O(logX)라고 함
https://stackoverflow.com/questions/327002/which-is-faster-in-python-x-5-or-math-sqrtx
Which is faster in Python: x**.5 or math.sqrt(x)?
I've been wondering this for some time. As the title say, which is faster, the actual function or simply raising to the half power? UPDATE This is not a matter of premature optimization. This is ...
stackoverflow.com
math 라이브러리의 math.sqrt()와 **(1/2)을 비교했을때 **(1/2)가 압도적으로 느리다고 한다
실제로 math.sqrt()로 바꾸면 시간초과없이 통과할 수 있다
from sys import stdin
import math
def find_parent(x,parent):
if x != parent[x]:
parent[x] = find_parent(parent[x],parent)
return parent[x]
def union(a,b):
if parent[a] < parent[b]:
parent[a] = b
else:
parent[b] = a
def distance(x,y,z,w):
return math.sqrt((x-z)**2 + (y-w)**2)
T = int(stdin.readline())
for _ in range(T):
n = int(stdin.readline())
points = []
for _ in range(n):
x,y,r = map(int,stdin.readline().split())
points.append((x,y,r))
parent = [i for i in range(n)]
answer = n
for i in range(n-1):
for j in range(i+1,n):
x,y,r1 = points[i]
z,w,r2 = points[j]
d = distance(x,y,z,w)
if d <= (r1+r2):
I = find_parent(i,parent)
J = find_parent(j,parent)
if I != J:
union(I,J)
answer -= 1
print(answer)
'알고리즘 > 알고리즘 일반' 카테고리의 다른 글
Python dict indexing보다는 list indexing을 사용해야하는 이유 (0) | 2024.04.12 |
---|---|
알고리즘 테크닉 - LR 테크닉 (0) | 2023.04.16 |
자바 알고리즘 기본 -입력을 받는 방법- (0) | 2023.02.15 |
파이썬 알고리즘 기본기 EOF(End of File) 배우기 (0) | 2022.12.09 |
알고리즘 문제를 풀기위해 2차원 배열에서 이해해야할 테크닉들 (0) | 2022.09.11 |