모든 순서쌍의 합의 나머지를 합해야하는데 매 항마다 나머지를 더하면 안되는 문제

C - Sigma Problem (atcoder.jp)

 

C - Sigma Problem

AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

atcoder.jp

 

 

f(x,y) = x + y를 $10^{8}$로 나눈 나머지라고 정의

 

모든 i = 1,2,3,...,n-1, i < j에 대하여 f(A[i],A[j])의 합을 구하는 문제

 

예를 들어 3 50000001 50000002이면...

 

(3, 50000001), (3, 50000002), ( 50000001, 50000002)가 있고...

 

50000004, 50000005, (100000003 % 100000000 = 3)이 된다.

 

이들을 합하면 100000012

 

모든 순서쌍 (i,j)를 찾는 것은 기본적으로 $O(N^{2})$이지만 N이 30만이라 시간초과

 

하지만 다음과 같이 수식을 먼저 조작해보면 어떨까

 

예를 들어 A = [a1,a2,a3,a4]라고 하자.

 

단순히 f(a,b) = a+b라고 해보자.

 

a1 + a2 + a1 + a3 + a1 + a4

 

+ a2 + a3 + a2 + a4

 

+ a3 + a4

 

 = 3*a1 + 3*a2 + 3*a3 + 3*a4

 

그러면 구하고자 하는 합은 (n-1)*(a1+a2+a3+...+an)

 

-----------------------------------------------------------------------------------------------------

 

그런데 정수론에서 (x+y) % mod = (x%mod + y%mod)니까... 그냥 a1,a2,a3,...,an을 전부 mod = 10^8로 나눈 다음

 

다 합해버리고 n-1 곱해도 되는거 아니야?

 

라고 생각했는데..

 

1) (x+y) % mod = (x%mod + y%mod)가 아니라 (x+y) % mod = (x%mod + y%mod)%mod이다.

 

2) 두번째로 이 문제가 더 어려운 점은 모든 f(A[i],A[j])의 합을 mod로 나눈 나머지를 구하는 것이 아니다.

 

정수론 테크닉 중 하나가 이렇게 매 항마다 더하고, mod로 나눠서 나머지를 더하는거잖아

 

근데 이렇게 하면 안된다는 거지..

 

for i in range(n):
    
    v += i
    v %= mod

print(v % mod)

 

 

조금 다르게 생각해보면 f(a,b)의 정의에 대해 다시 한번 생각해보자.

 

f(a,b)는 a+b를 10^8로 나눈 나머지이다.

 

여기서 배열의 수가 10^8보다 작기 때문에 두 수를 합해도 2*10^8보다는 작다.

 

따라서 a+b가 10^8보다 크거나 같다면 10^8으로 나눠도 최대 몫이 1이므로 a+b - 10^8과 같다.

 

a+b가 10^8보다 작다면 10^8로 나눠도 그대로이므로 a+b와 같다.

 

 

그래서 먼저 모든 순서쌍에 대해 a+b가 10^8보다 크거나 같은 개수를 찾는다면?

 

모든 순서쌍의 합은 (n-1)*(a1+a2+...+an)임을 위에서 보였다.

 

여기에 10^8 * (10^8보다 크거나 같은 순서쌍의 개수)를 빼주기만 하면 된다.

 

 

그러면 모든 순서쌍을 다 검사해야돼?

 

배열 A를 정렬하더라도 순서쌍의 합은 변하지 않는다.

 

따라서 A를 먼저 오름차순으로 정렬한 다음... 작은 수부터 큰 수대로 차례대로 순회하자.

 

현재 A[i]를 보고 있고 오른쪽 포인터가 r = n이라고 초기화하자.

 

A[i] + A[r-1] > 10^8이면, r을 1 감소시킨다.

 

현재 i번 수에서 A[i] + A[r-1] > 10^8인 가장 작은 오른쪽 포인터를 찾는 것이다.

 

그 포인터를 찾았다면 r ~ n까지 길이가 합이 10^8보다 큰 (A[i],A[r])의 개수가 된다.

 

 

그런데 여기서 또 하나 핵심은..

 

A가 정렬되어 있기 때문에 왼쪽 포인터 i가 증가할 수록 왼쪽 수는 점점 커지고...

 

그렇기 때문에 왼쪽 수 + 오른쪽 수가 10^8보다 커질려면 오른쪽 수는 점점 작아져도 된다.

 

그래서 오른쪽 포인터를 초기화하지 않고 그대로 사용해도 된다.

 

 

물론 여기서 주의해야할 것은 i < j인 경우에만 해당되는데... 오른쪽 포인터는 계속 감소하다보니...

 

어느 순간 i > r이 될수도 있다.

 

그래서 r은 항상 i+1보다 커지도록 유지해야한다.

 

n = int(input())

A = list(map(int,input().split()))

A.sort()

mod = 10**8

count = 0

j = n

for i in range(n-1):
    
    if j < i+1:
        
        j = i+1
        
    while j-i > 1 and A[j-1]+A[i] >= mod:
        
        j -= 1

    count += (n-j)


v = 0

for i in range(n):
    
    v += ((n-1)*A[i])

print(v - count*mod)

 

 

A를 정렬하고, 이분탐색으로 A[i] + A[j] > 10^8인 가장 작은 j를 찾아도 된다...

 

def binary_search(x,target,array,start,end):
    
    while start < end:
        
        mid = (start + end)//2

        if array[mid]+x >= target:
            
            end = mid
        
        else:
            
            start = mid + 1
    
    return end

n = int(input())

A = list(map(int,input().split()))

mod = 10**8

A.sort()

count = 0

for i in range(n-1):
    
    loc = binary_search(A[i],mod,A,i+1,n)

    count += (n-loc)

v = 0

for i in range(n):
    
    v += ((n-1)*A[i])

print(v - mod*count)

 

TAGS.

Comments