코딩테스트 복기 - 구간합이 전부 똑같도록 3구간으로 나누는 방법(잘 모를때는 조건식을 써봐라)

1. 문제

 

구간 A를 1번부터 x번까지, 구간 B를 x+1번부터 y번까지, 구간 C를 y+1번부터 n번까지 나눈다.

 

각 구간의 모든 원소의 합을 각각 a,b,c라고 하자.

 

a,b,c가 전부 같도록 x,y를 정하자. 여기서 1 <= x < y < n이다.

 

그러한 방법의 수가 몇가지일까?

 

n은 최대 50만

 

배열의 각 원소는 -100만부터 100만까지로 음수일수도 있다.

 

예를 들어 [1,2,3,0,3]이면..

 

A가 1번 2번 = 3

 

B가 3번 = 3

 

C가 4번 5번 = 3

 

---------------------------

 

A가 1번 2번  = 3

 

B가 3번 4번 = 3

 

C가 5번 = 3

 

2가지 있다.

 

 

2. 풀이

 

구간합이니까, prefix sum으로 누적합을 만들어야하는 것은 명확하다

 

[1,3,6,6,9]

 

n = int(input())

A = list(map(int,input().split()))

for i in range(1,n):
    
    A[i] += A[i-1]

print(A)
[1, 3, 6, 6, 9]

 

 

가장 쉬운건 x번을 고정하고 y번을 고정한 다음 a,b,c를 구해서 비교해서 count하는 $O(N^{2})$풀이

 

1 <= x < y < n이므로, x는 0번부터 n-3번까지 가능하고, y는 x+1번부터 n-2번까지 가능하다.

 

또한 1번부터 x번까지 합은 누적합 배열 A에서 A[x]

 

x+1번부터 y번까지 합은 A[y] - A[x]

 

y+1번부터 n번까지 합은 A[n-1] - A[y]

 

count = 0

for x in range(n-2):
    
    for y in range(x+1,n-1):
        
        a = A[x]
        b = A[y] - A[x]
        c = A[n-1] - A[y]

        if a == b and b == c:
            
            count += 1

print(count)
2

 

 

n이 최대 50만이라 시간초과날게 뻔하다.

 

x를 고정하고 y 위치를 이분탐색으로 찾는 방법?

 

x,y를 투포인터로 찾는 방법?? 여러가지 생각해봤는데 이 방법들은 문제가 있다

 

A의 원소가 음수일 수도 있기 때문에, 누적합 배열 A가 오름차순이 아닐 수 있다는 것

 

-----------------------------------------------------------------------------------------------------------------------------------------------

 

고민끝에 엄청난 방법을 떠올렸다..

 

a = A[x]

 

b = A[y] - A[x]

 

c = A[n-1] - A[y]인데 이 식이 뭔가 서로 비슷해보인다..

 

그런데 a = b = c여야 하므로, A[x] = A[y] - A[x]이고 A[y] - A[x] = A[n-1] - A[y]이고 A[x] = A[n-1] - A[y]이다.

 

여기서 x,y가 미지수이므로 이 3개의 방정식을 연립한다면?

 

첫번째 방정식에서 2A[x] = A[y]인데 A[x] = A[n-1] - A[y]이므로, 3A[x] = A[n-1]이다.

 

반대로 A[x] = A[y]/2이므로, 두번째 방정식에서 A[y] - A[x] = A[n-1] - A[y]이므로 3A[y]/2 = A[n-1]

 

그런데 분수가 있으니 정수 연산을 위해 3A[y] = 2A[n-1]

 

A[n-1]은 고정된 값이므로, 3A[x] = A[n-1], 3A[y] = 2A[n-1]을 만족하는 x,y를 찾을 수 있다.

 

즉, a = b = c일려면 일단 x,y는 3A[X] = A[n-1] , 3A[y] = 2A[n-1]을 만족한다.

 

1 <= x < y < n이므로 x는 0에서 n-3, y는 1에서 n-2 범위를 가진다. 

 

x_list = []

for x in range(n-2):

    if 3*A[x] == A[n-1]:

        x_list.append(x)

y_list = []

for y in range(1,n-1):
    
    if 3*A[y] == 2*A[n-1]:
        
        y_list.append(y)

print(x_list)
print(y_list)
[1]
[2,3]

 

 

이제 x_list를 순회해서 고정된 x에 대하여, x < y를 만족하는 y_list에서의 y의 개수를 찾는다.

 

이분탐색으로 찾을 수도 있겠고... 이건 O(NlogN)일거고

 

x_list, y_list는 오름차순으로 정렬되어 있기 때문에.. 어떤 x에 대해 x < y인 y를 i번에서 찾았다면,

 

i번 부터 y_list의 끝까지는 모두 조건을 만족하기 때문에 i ~ len(y)-1까지 counting해주고,

 

다음 x에 대해서는 i번 이후에 나타난다는 점을 이용한다면 ... 이건 O(N)일듯?

 

count = 0

s = 0

for x in x_list:
    
    for i in range(s,len(y_list)):
        
        y = y_list[i]

        if y > x:
            
            break
    
    count += (len(y_list) - i)
    s = i

print(count)

 

 

아니 근데 다시 보니까 잘못푼것 같다 ㅁㅊ;

 

반례가 있겠는걸?

 

counting하는 부분을 조건문 안에 넣어야겠다..

 

y_list에 있는 원소들이 x_list에 있는 원소들보다 작을 수도 있다보니... 애초에 조건문에 안걸리면 0으로 해야하니까

 

count = 0

s = 0

for x in x_list:
    
    for i in range(s,len(y_list)):
        
        y = y_list[i]

        if y > x:
            
            count += (len(y_list) - i)
            s = i
            break
            
print(count)

 

 

대충 50만개짜리 해봐도 시간초과는 안남... 접근방식은 맞긴한듯?

 

아 근데 접근법은 맞았는데 사소한 실수로 또 틀리네 ㅡㅡ

 

이게 테스트를 못해보니 억울하다 ㄹㅇ

TAGS.

Comments