트리 위의 모든 입자 이동시켜서 소멸시키기(트리에서의 다이나믹 프로그래밍)

https://atcoder.jp/contests/abc409/tasks/abc409_e

 

E - Pair Annihilation

AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

atcoder.jp

 

트리 위에 정점 i에서 xi개의 양전자가 놓여있고, 혹은 -xi개의 전자가 놓여있다.

 

이때 모든 입자의 합은 0임이 보장된다.

 

따라서 입자들을 적절히 이동시키면 모든 입자를 소멸시킬 수 있다.

 

한 입자를 간선 j를 따라 이동시키면 에너지 wj가 소모된다.

 

양전자와 전자가 같은 정점에 속하면 입자가 소멸된다.

 

모든 입자를 완전히 소멸하는데 필요한 최소 에너지량을 구한다면?

 

------------------------------------------------------------------------------------------------------------------------------------------------

 

문제만 읽어보면 굉장히 어렵다...

 

입자들을 적절히 이동시켜서 모든 정점에 있는 입자들을 없애라고?

 

무슨 입자를 이동시켜야하고? 어디로 이동시켜야하고??...

 

 

어떤 간선 j가 정답인 에너지량에 얼마나 기여할 수 있을까?

 

uj가 vj의 부모라고 하자.

 

Xj가 vj를 루트로 하는 서브트리에 있는 모든 정점의 x값의 합이라고 하자.

 

따라서 이 서브트리 내에서는 어떤 연산을 하더라도 반드시 |Xj|개의 입자가 남게 된다.

 

그런데 이 입자들을 모두 소멸시키기 위해서는, 간선 j를 통해서 서브트리 외부로 모두 옮기거나, 

 

외부에서 입자들을 끌고와 간선 j를 통해 서브트리 내로 옮겨서 쌍소멸시키는 방법밖에 없다.

 

어떠한 방식으로 하더라도, 반드시 최소 |Xj|개의 입자이동이 간선 j를 통해 필요하다.

 

그러므로 간선 j를 통해서 |Xj|*wj만큼의 에너지량이 발생한다.

 

그러므로 전체 최소 비용은 j = 1,2,..,n에 대해 |Xj|*wj의 합이다.

 

 

|Xj|를 구하는 방법은?

 

정의 상 노드 vj의 서브트리내 모든 정점들의 정수값 합이 Xj이므로, DFS를 이용한 트리에서의 다이나믹 프로그래밍의 전형적인 문제가 된다.

 

from sys import setrecursionlimit
setrecursionlimit(10**6)

n = int(input())

X = [0] + list(map(int,input().split()))

graph = [[] for _ in range(n+1)]
edge = []

for _ in range(n-1):
    
    u,v,w = map(int,input().split())
    graph[u].append((v,w))
    graph[v].append((u,w))
    edge.append((u,v,w))

dp = [0]*(n+1)
visited = [0]*(n+1)
visited[1] = 1

answer = 0

def dfs(u):
    
    global answer
    
    dp[u] += X[u]

    for v,w in graph[u]:
        
        if visited[v] == 0:
            
            visited[v] = 1
            dfs(v)
            answer += abs(dp[v])*w
            dp[u] += dp[v]

dfs(1)

print(answer)

 

 

 

dfs로 1번부터 순회하면서 자식 정점 v로 계속 타고타고 들어가는데

 

dfs(1) > dfs(2) > dfs(3) > dfs(4) > dfs(5)

 

각 정점에서 먼저 dp[u]에 X[u]를 더해주면서 현재 정점 u의 입자 수 합을 구해주고

 

마지막 자식 정점 5번에 왔을때, 더 이상 들어갈 정점이 없으니까

 

dfs(5)가 return되면서 dfs(4) 호출 중에 dfs(5)가 return 되면서 4의 자식 정점인 dp[5]는 계산이 끝나있고 그 다음 줄

 

answer += abs(dp[5])*w로 더해주고, 5의 부모 정점인 dp[4] += dp[5]로 더해주고...

 

이러면 dfs(4)도 끝났으니까 return 되면서 dfs(3) 호출 중에 dfs(4)가 return되면서 3의 자식 정점인 dp[4]는 계산 끝났고

 

그 다음 줄인 answer  += abs(dp[4])*w, 부모 정점인 dp[3] += dp[4]

 

... 이러면 dfs(3)도 끝나서 return되면서 dfs(2) 호출 중에 dfs(3)이 return되면서 2의 자식 정점인 dp[3]이 계산 끝났고...

 

그 다음 줄인 answer += abs(dp[3])*w, 부모 정점인 dp[2] += dp[3]

 

이러면 dfs(2)도 끝나서 return되면서 dfs(1) 호출 중에 dfs(2)가 return되면서 1의 자식 정점인 dp[2]가 끝났고

 

그 다음 줄인 answer += abs(dp[2])*w 부모 정점인 dp[1] += dp[2]로 끝..

 

아쉽긴한데... 문제를 이렇게 바꿔서 생각해야하구나

 

728x90