절댓값을 풀어내는 필수 테크닉 - 모든 i,j에 대해 (i-j)|ai-bj|의 합을 빠르게 구하는 방법
최대 100000개의 원소를 가지는 배열 A,B에 대하여 $\sum_{i,j}^{}(i-j)|a_{i} - b_{j}|$를 구하는 문제
당연히 $O(N^{2})$은 안될거고 O(N)에는 풀어야하는데
$a_{i} >= b_{j}$이고 $a_{i} < b_{j}$인 경우에 따라 $|a_{i} - b_{j}| = a_{i} - b_{j}$이거나 $b_{j}-a_{i}$이다.
따라서 $(i-j)|a_{i}-b_{j}| = i(a_{i}-b_{j})-j(a_{i}-b_{j})+i(b_{j}-a_{i})-j(b_{j}-a_{i})$
그래서 ai>=bj인 경우 i(ai-bj)와 j(ai-bj), ai < bj인 경우 i(bj-ai), j(bj-ai) 4가지 부분으로 나눠서 계산하면 좋을것 같다.
여기서 핵심은 앞에 붙은 i,j인데 얘를 고정시킨다면 투포인터를 활용해서 O(N)에 계산할 수 있다.
예를 들어 i = 0인 경우 j = 0,1,2,...증가시켜서 ai >= bj인 경우의 j를 찾아서 i(ai-bj) + i(bj-ai)를 구하면 된다.
마찬가지로 j를 고정시킨다면, j = 0인 경우 i = 0,1,2,... 증가시켜서 ai >= bj인 경우의 i를 찾아서 j(ai-bj)+j(bj-ai)를 구하면 된다.
그러면 i,j에 따라 ai >= bj인 경우를 찾기 위해 a,b를 정렬해둔다.
이때 인덱스는 보존해야하므로 인덱스까지 같이 넣어서
n = int(input())
a = list(map(int,input().split()))
m = int(input())
b = list(map(int,input().split()))
A = []
for i in range(len(a)):
A.append((a[i],i+1))
B = []
for j in range(len(b)):
B.append((b[j],j+1))
A.sort()
B.sort()
그 다음 먼저 i = 0,1,2..를 고정시키고 j = 0,1,2,... 증가시키면서 ai >= bj인 가장 작은 j를 찾는다.
0~j-1까지는 ai < bj인 경우이므로 i(bj-ai)에 영향을 미치고
i가 고정되어 있으므로 ai에는 0~j-1개수만큼 곱해서 j*ai이고 bj에는 0~j-1까지 b의 합을 구하면 될 것이고
j~m-1까지는 ai >= bj인 경우이므로 i(ai-bj)에 영향을 미친다.
i가 고정되어 있으므로 ai에는 j~m-1 개수 m-j만큼 곱해서 (m-j)ai이고 bj에는 j~m-1까지 b의 합을 구하면 된다.
따라서 A,B의 누적합이 필요하다.
p_a = [A[0][0]]
for i in range(1,n):
p_a.append(p_a[-1]+A[i][0])
p_b = [B[0][0]]
for i in range(1,m):
p_b.append(p_b[-1]+B[i][0])
누적합을 구한다면 먼저 i를 고정시켜서 j를 증가시킨다음... 위에서 말한대로 구한다면
i = 0,1,2.. 각각에 대하여 j = 0부터 시작해서 A[i][0] > B[j][0]이면 j를 증가시키다가..
A[i][0] <= B[j][0]가 될때 j부터는 A,B가 정렬되어있으므로 무조건 A[i][0] <= B[j][0]이므로 j를 그만 증가시키고
0~j-1까지는 A[i][0] > B[j][0]이고 j~m-1까지는 A[i][0] <= B[j][0]이다.
-----------------------------------------------------------------------------------------------------------------------------------------------------
여기서 A[i][0] < B[j][0]이면 j를 증가시키다가 A[i][0] >= B[j][0]이면 그만 증가시킬려 했는데
이러면 안되는게 i가 고정되어 있기 때문에 A[i][0]는 일정하고 B가 정렬되어 있으므로 B[j][0]는 계속해서 증가하니까
-----------------------------------------------------------------------------------------------------------------------------------------------------------
따라서 0~j-1에서는 i(ai-bj)의 합인 A[i][1](A[i][0]*j - p_b[-1])
j~m-1에는 i(bj-ai)의 합인 A[i][1](p_b[-1]-p_b[j-1] - A[i][0]*(m-j))
여기서 j = 0인 경우는 0~m-1까지 모두 A[i][0] <= B[j][0]라는 뜻이므로 i(bj-ai)의 합 A[i][1](p_b[-1]-A[i][0]*m)만 구하면 된다.
v1 = 0
v2 = 0
j = 0
for i in range(n):
while j < m and A[i][0] > B[j][0]:
j += 1
if j == 0:
v2 += (A[i][1]*(p_b[-1] - A[i][0]*m))
else:
v1 += (A[i][1]*(A[i][0]*j - p_b[j-1]))
v2 += (A[i][1]*(p_b[-1] - p_b[j-1] - A[i][0]*(m-j)))
비슷하게 j = 0,1,2..로 고정시켜서 i = 0부터 시작해서 증가시키다가
A[i][0] < B[j][0]이면 i를 증가시키다가 A[i][0] >= B[j][0]이면 반복문을 빠져나오고
0~i-1에는 A[i][0] < B[j][0]이므로 j(bj-ai)의 합인 B[j][1](B[j][0]*i-p_a[i-1])
i~n-1에는 A[i][0] >= B[j][0]이므로 j(ai-bj)의 합인 B[j][1](p_a[-1]-p_a[i-1] - B[j][0]*(n-i))
v3 = 0
v4 = 0
i = 0
for j in range(m):
while i < n and A[i][0] <= B[j][0]:
i += 1
if i == 0:
v3 += B[j][1]*(p_a[-1] - B[j][0]*n)
else:
v4 += B[j][1]*(B[j][0]*i - p_a[i-1] )
v3 += B[j][1]*(p_a[-1] - p_a[i-1] - B[j][0]*(n-i))
print(v1 + v2 - v3 - v4)
이제 4부분 v1 = i(ai-bj),v2 = i(bj-ai),v3 = j(ai-bj),v4 = j(bj-ai)를 구했으니 v1+v2-v3-v4로 답을 구하면 끝
'알고리즘 > 투 포인터 알고리즘' 카테고리의 다른 글
배열에서 두 수의 합이 s가 되는 경우의 수(서로 다른 방향으로 움직이는 투 포인터) (0) | 2024.09.02 |
---|---|
겹치는 직선 구간쌍의 개수 빠르게 세기 (0) | 2024.08.02 |
투 포인터로 원소를 삭제하면서 가장 긴 수열을 찾을 수 있을까? (0) | 2023.06.08 |
투 포인터 올바르게 생각하기 기본문제로 연습2 (0) | 2023.04.12 |
투 포인터 기억해야할 점 - 끝 포인터를 처음으로 옮기지 않기 (0) | 2023.04.11 |