절댓값을 풀어내는 필수 테크닉 - 모든 i,j에 대해 (i-j)|ai-bj|의 합을 빠르게 구하는 방법

28867번: Портальная пушка

 

최대 100000개의 원소를 가지는 배열 A,B에 대하여 $\sum_{i,j}^{}(i-j)|a_{i} - b_{j}|$를 구하는 문제

 

당연히 $O(N^{2})$은 안될거고 O(N)에는 풀어야하는데

 

$a_{i} >= b_{j}$이고 $a_{i} < b_{j}$인 경우에 따라 $|a_{i} - b_{j}| = a_{i} - b_{j}$이거나 $b_{j}-a_{i}$이다.

 

따라서 $(i-j)|a_{i}-b_{j}| = i(a_{i}-b_{j})-j(a_{i}-b_{j})+i(b_{j}-a_{i})-j(b_{j}-a_{i})$

 

그래서 ai>=bj인 경우 i(ai-bj)와 j(ai-bj), ai < bj인 경우 i(bj-ai), j(bj-ai) 4가지 부분으로 나눠서 계산하면 좋을것 같다.

 

여기서 핵심은 앞에 붙은 i,j인데 얘를 고정시킨다면 투포인터를 활용해서 O(N)에 계산할 수 있다.

 

예를 들어 i = 0인 경우 j = 0,1,2,...증가시켜서 ai >= bj인 경우의 j를 찾아서 i(ai-bj) + i(bj-ai)를 구하면 된다.

 

마찬가지로 j를 고정시킨다면, j = 0인 경우 i = 0,1,2,... 증가시켜서 ai >= bj인 경우의 i를 찾아서 j(ai-bj)+j(bj-ai)를 구하면 된다.

 

그러면 i,j에 따라 ai >= bj인 경우를 찾기 위해 a,b를 정렬해둔다.

 

이때 인덱스는 보존해야하므로 인덱스까지 같이 넣어서

 

n = int(input())

a = list(map(int,input().split()))

m = int(input())

b = list(map(int,input().split()))

A = []

for i in range(len(a)):

    A.append((a[i],i+1))

B = []

for j in range(len(b)):

    B.append((b[j],j+1))

A.sort()
B.sort()

 

 

 

그 다음 먼저 i = 0,1,2..를 고정시키고 j = 0,1,2,... 증가시키면서 ai >= bj인 가장 작은 j를 찾는다.

 

0~j-1까지는 ai < bj인 경우이므로 i(bj-ai)에 영향을 미치고

 

i가 고정되어 있으므로 ai에는 0~j-1개수만큼 곱해서 j*ai이고 bj에는 0~j-1까지 b의 합을 구하면 될 것이고

 

j~m-1까지는 ai >= bj인 경우이므로 i(ai-bj)에 영향을 미친다.

 

i가 고정되어 있으므로 ai에는 j~m-1 개수 m-j만큼 곱해서 (m-j)ai이고 bj에는 j~m-1까지 b의 합을 구하면 된다.

 

따라서 A,B의 누적합이 필요하다.

 

p_a = [A[0][0]]

for i in range(1,n):
    
    p_a.append(p_a[-1]+A[i][0])

p_b = [B[0][0]]

for i in range(1,m):
    
    p_b.append(p_b[-1]+B[i][0])

 

 

누적합을 구한다면 먼저 i를 고정시켜서 j를 증가시킨다음... 위에서 말한대로 구한다면

 

i = 0,1,2.. 각각에 대하여 j = 0부터 시작해서 A[i][0] > B[j][0]이면 j를 증가시키다가..

 

A[i][0] <= B[j][0]가 될때 j부터는 A,B가 정렬되어있으므로 무조건 A[i][0] <= B[j][0]이므로 j를 그만 증가시키고

 

0~j-1까지는 A[i][0] > B[j][0]이고 j~m-1까지는 A[i][0] <= B[j][0]이다.

 

-----------------------------------------------------------------------------------------------------------------------------------------------------

 

여기서 A[i][0] < B[j][0]이면 j를 증가시키다가 A[i][0] >= B[j][0]이면 그만 증가시킬려 했는데

 

이러면 안되는게 i가 고정되어 있기 때문에 A[i][0]는 일정하고 B가 정렬되어 있으므로 B[j][0]는 계속해서 증가하니까

 

-----------------------------------------------------------------------------------------------------------------------------------------------------------

 

따라서 0~j-1에서는 i(ai-bj)의 합인 A[i][1](A[i][0]*j - p_b[-1])

 

j~m-1에는 i(bj-ai)의 합인 A[i][1](p_b[-1]-p_b[j-1] - A[i][0]*(m-j))

 

여기서 j = 0인 경우는 0~m-1까지 모두 A[i][0] <= B[j][0]라는 뜻이므로 i(bj-ai)의 합 A[i][1](p_b[-1]-A[i][0]*m)만 구하면 된다.

 

v1 = 0
v2 = 0
j = 0

for i in range(n):
    
    while j < m and A[i][0] > B[j][0]:

        j += 1
    
    if j == 0:

        v2 += (A[i][1]*(p_b[-1] - A[i][0]*m))

    else:
        
        v1 += (A[i][1]*(A[i][0]*j - p_b[j-1]))
        v2 += (A[i][1]*(p_b[-1] - p_b[j-1] - A[i][0]*(m-j)))

 

 

비슷하게 j = 0,1,2..로 고정시켜서 i = 0부터 시작해서 증가시키다가

 

A[i][0] < B[j][0]이면 i를 증가시키다가 A[i][0] >= B[j][0]이면 반복문을 빠져나오고

 

0~i-1에는 A[i][0] < B[j][0]이므로 j(bj-ai)의 합인 B[j][1](B[j][0]*i-p_a[i-1])

 

i~n-1에는 A[i][0] >= B[j][0]이므로 j(ai-bj)의 합인 B[j][1](p_a[-1]-p_a[i-1] - B[j][0]*(n-i))

 

v3 = 0
v4 = 0
i = 0

for j in range(m):
    
    while i < n and A[i][0] <= B[j][0]:
        
        i += 1
    
    if i == 0:
        
        v3 += B[j][1]*(p_a[-1] - B[j][0]*n)
    
    else:
        
        v4 += B[j][1]*(B[j][0]*i - p_a[i-1] )
        v3 += B[j][1]*(p_a[-1] - p_a[i-1] - B[j][0]*(n-i))

print(v1 + v2 - v3 - v4)

 

 

이제 4부분 v1 = i(ai-bj),v2 = i(bj-ai),v3 = j(ai-bj),v4 = j(bj-ai)를 구했으니 v1+v2-v3-v4로 답을 구하면 끝

 
TAGS.

Comments