실수 구간에서 이분 탐색 방법 제대로 배우기

1. 문제

 

21627번: Ice Cream (acmicpc.net)

 

21627번: Ice Cream

The first line contains three integers $n$, $v$, $u$ --- the number of ice cream cones, the melting rate and the rate of eating ice cream ($1\le n\le 3\cdot10^5$, $1\le v,u \le 10^9$). The second lines contains $n$ integers $a_i$ ($1\le a_i \le 10^9$) ---

www.acmicpc.net

 

 

2. 풀이

 

n개의 아이스크림이 초당 v만큼 동시에 녹으며, 나는 초당 u만큼 아이스크림을 1개씩 먹을 수 있다.

 

모든 아이스크림이 없어질때까지 아이스크림을 먹는데, 최소한의 양만을 먹고 싶다면... 얼마나 먹을 수 있을까?

 

고정된 시간 t에 대하여, i번째 아이스크림 A[i]는 t*v만큼 녹아있다.

 

그러면 나는 i번째 아이스크림은 A[i] - t*v만큼의 아이스크림을 먹게된다.

 

 

물론 A[i] > t*v일때만 먹고 A[i] <= t*v이면 t 시간에 이미 녹아있으니 먹을게 없다.

 

 

주어진 시간 t동안 내가 먹을 수 있는 전체 양은 얼마인가?

 

초당 u만큼 먹으니까 당연히 t*u이다.

 

 

또 다른 표현으로 모든 아이스크림에 대하여 내가 먹을 수 있는 양은 $s = \sum_{i = 0}^{n} max(0,A[i]-t*v)$

 

 

만약에 s > t*u라면? t가 작아서 먹을 수 있는 양보다 더 먹었다.

 

s <= t*u라면? t가 커서 먹을 수 있는 양보다 덜 먹었다.

 

 

그래서 이분탐색으로 t를 찾으면 O(NlogN)에 답을 찾을 수 있다.

 

s > t*u라면 start = mid + a로 start를 키우고,

 

s <= t*u라면 end = mid로 end를 줄여서 먹는 양의 최솟값을 찾는다.

 

이분탐색이 끝나면 end가 구하고자 하는 시간이고, 이 end에 대하여 내가 먹을 수 있는 아이스크림의 양은?

 

 $s = \sum_{i = 0}^{n} max(0,A[i]-t*v)$

 

이건 O(N)으로 구해야하잖아...

 

end * u로 바로 구할수 있잖아

 

 

 

그런데 문제는 t가 실수라는 점

 

오차 $10^{-6}$이내로 답을 내야하는데.. 

 

https://deepdata.tistory.com/779

 

특별한 이분탐색 -실수 구간에서도 이분탐색이 가능할까-

1. 문제 14609번: 구분구적법 (Large) (acmicpc.net) 14609번: 구분구적법 (Large) 첫 번째 줄에는 다항함수의 차수를 나타내는 양의 정수 K(1 ≤ K ≤ 10) 가 주어진다. 두 번째 줄에는 최고차항부터 내림차순

deepdata.tistory.com

 

예전에 공부한 실수 구간에서 이분탐색 방법을 적용해서

 

def binary_search(array,v,u,start,end):
    
    while end - start > 10**(-6):
        
        mid = (start+end)/2

        t = 0

        for i in range(len(array)):
            
            if array[i] > mid*v:
                
                t += ((array[i]-mid*v)/u)
        
        if t <= mid:
            
            end = mid
        
        else:
            
            start = mid + 0.0000001
    
    answer = 0

    for i in range(len(array)):
        
        answer += ((array[i] - mid*v))

    return answer

n,v,u = map(int,input().split())

A = list(map(int,input().split()))

print(binary_search(A,v,u,0,10**12+1))

 

 

이런식으로 해봤는데 틀렸다네?

 

실제 공식 데이터로 해봤는데.. 오차가 $10^{-6}$보다 크더라고

 

end - start > 10^-18까지도 해보고

 

start = mid + 10^-36까지도 해보고...

 

시작 구간을 0~10^12+1이 아니라 10^-18~10^18+1로도 해보고...

 

Decimal도 써보고 그랬지만 안되더라

 

공식 솔루션에서 실수 구간에서 이분탐색을 어떤 식으로 했는지 잘 보여주고 있다..

 

앞으로 이렇게 하면 좋을 것 같다

 

#실수 이분탐색
def binary_search(array,v,u,start,end):
    
    #실수 이분탐색은 핵심이 start < end 조건 반복이 아니라 그냥 100번 정도만 돌린다
    for _ in range(100):
        
        mid = (start+end)/2 #고정된 시간 mid

        t = 0

        for i in range(len(array)):
            
            if array[i] > mid*v: #아이스크림은 mid시간동안 mid*v만큼 녹는다.
                
                t += ((array[i]-mid*v)) #내가 i번째 아이스크림은 array[i]-mid*v만큼 먹을 수 있다
        
        #mid동안 이론상 내가 먹을 수 있는 양은 mid*u
        #실제 먹은 양 t와 mid*u를 비교해서
        
        #t가 작으면 mid가 너무 커서 덜 먹은 것이니 end를 줄인다
        if t <= mid*u:
            
            end = mid
        
        #t가 크면 mid가 작아서 많이 먹은거니 start를 키운다
        else:
            
            start = mid
    
    #이분탐색이 끝나면 end시간이 최적해
    #end 시간동안 먹을 수 있는 양은 end*u
    return end*u

n,v,u = map(int,input().split())

A = list(map(int,input().split()))

print(binary_search(A,v,u,0,10**9+1))

 

 

어차피 오차가 무조건 나기 때문에, 오차 이내로 내놓으면 된다는 생각

 

그리고 원래 start = mid + 0.0000000001로 아주 작은 수를 더했지만... 사실 이건 더하나 마나잖아

 

그냥 start = mid, end = mid로 둔 다음,

 

핵심 포인트는 while문으로 end - start > error 쓰지말고 그냥 for문으로 100번정도만 돌리자 이거임

 

예전에 했던 구분구적법 문제도 이렇게 바꾸니까 통과하더라

 

from sys import stdin

def binary_search(a,b,n,real,k,c):
    
    delta = (b-a)/n

    start = 0
    end = delta

    for _ in range(100):
        
        approx = n*(c[-1])
        
        mid = start + (end - start)/2

        for j in range(1,k+1):
            
            for h in range(n):
            
                approx += (c[-(j+1)]*((a+h*delta+mid)**j))
        
        approx *= delta
            
        if real > approx:

            start = mid

        else:

            end = mid

    return end
        
k = int(stdin.readline())

c = list(map(int,stdin.readline().split()))

a,b,n = map(int,stdin.readline().split())

real = 0

for i in range(k+1):
    
    real += c[i]*(((b**(k+1-i))- (a**(k+1-i)))/(k+1-i))

print(binary_search(a,b,n,real,k,c))

 

TAGS.

Comments