라그랑주의 네 제곱수 정리를 이용한 알고리즘(다이나믹 프로그래밍, 브루트포스 연습)

1. 라그랑주의 네 제곱수 정리

 

모든 자연수 n은 많아야 음이 아닌 4개의 정수의 제곱수 합으로 표현할 수 있다.

 

$$n = a^{2} + b^{2} + c^{2} + d^{2}$$을 만족하는 0 이상의 정수 a,b,c,d가 존재한다.

 

증명은 매우 까다롭다.. 너무 길어서 따라하기도 힘들다..

 

https://jjycjnmath.tistory.com/295

 

[퍼온글] 라그랑주의 네제곱수 정리(Four Square Theorem)와 그 증명

※ 출처 - http://kevin0960.tistory.com/ 디오판토스의 저서 '산학'에는 '모든 양의 정수는 네 제곱수의 합으로 표현될 수 있다.' 라는 내용이 담겨 있다. 예를 들어, \[ \begin{aligned} 3 &= 1^2 + 1^2 + 1^2 + 0^2 \\ 3

jjycjnmath.tistory.com

 

https://ko.wikipedia.org/wiki/%EB%9D%BC%EA%B7%B8%EB%9E%91%EC%A3%BC_%EB%84%A4_%EC%A0%9C%EA%B3%B1%EC%88%98_%EC%A0%95%EB%A6%AC

 

라그랑주 네 제곱수 정리 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 수론에서 라그랑주 네 제곱수 정리(-數定理, 영어: Lagrange's four-square theorem)는 모든 양의 정수가 많아야 4개의 제곱수의 합이라는 정리이다.[1] 양의 정수 n ∈ Z +

ko.wikipedia.org

 

2. 문제1

 

https://www.acmicpc.net/problem/17626

 

17626번: Four Squares

라그랑주는 1770년에 모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다고 증명하였다. 어떤 자연수는 복수의 방법으로 표현된다. 예를 들면, 26은 52과 12의 합이다; 또한 42 + 32 + 1

www.acmicpc.net

 

3. 풀이

 

자연수 n이 최소 몇개의 제곱수 합으로 표현할 수 있는지 구하는 문제

 

문제에서 모든 자연수가 4개 이하의 제곱수 합으로 표현할 수 있다고 명시하고 있다

 

이걸 명시해주냐 아니냐에 따라 사실 난이도가 천지차이다

 

 

이 사실을 모르면 상한이 몇인지를 모르는데.. 알고 있다면 많아야 4개니까 

 

for문 1번, for문 2번, for문 3번까지 모두 확인해봐서 답을 찾을 수 있다는 얘기

 

N의 최대가 5만인데 $\sqrt{N} = 223$정도라 223 + 223*223 + 223*223*223 = 1100만정도

 

1초에 1억번 정도 연산하던가?? 2천만번이라는 사람도 있고 1억번으로 알고 있는데

 

아무튼 0.5초안에 들어올만해

 

from sys import stdin

n = int(stdin.readline())

x = 1

m = int(n**(1/2))

num_list = [y**2 for y in range(1,m+1)]

find = False

#1개의 제곱수 합으로 나타내는 경우
for i in range(m):

    if num_list[i] == n:
        
        find = True

        break

if find:

    print(x)

#제곱수 1개로 찾지 못한 경우
else:

    x += 1
    
    #제곱수 2개로 찾는 경우
    for i in range(m):

        for j in range(m):

            if num_list[i] + num_list[j] == n:

                find = True
                break

        if find:

            break

    if find:

        print(x)
    
    #제곱수 2개로 찾지 못한 경우
    else:

        x += 1
        
        #제곱수 3개로 찾는 경우
        for i in range(m):

            for j in range(m):

                for k in range(m):

                    if num_list[i] + num_list[j] + num_list[k] == n:

                        find = True
                        break

                if find:

                    break

            if find:

                break

        if find:

            print(x)
        
        #제곱수 3개로도 답을 찾지 못한다면, 정답은 4이다.
        else:

            print(x+1)

 

4. 문제2

 

https://www.acmicpc.net/problem/1699

 

1699번: 제곱수의 합

어떤 자연수 N은 그보다 작거나 같은 제곱수들의 합으로 나타낼 수 있다. 예를 들어 11=32+12+12(3개 항)이다. 이런 표현방법은 여러 가지가 될 수 있는데, 11의 경우 11=22+22+12+12+12(5개 항)도 가능하다

www.acmicpc.net

 

5. 풀이1

 

이 문제는 라그랑주의 네 제곱수 정리를 명시하지 않는다.

 

그러니까 네 제곱수 정리를 모르면 상한이 있는지를 모르니까 위와 같은 브루트 포스로 풀 생각을 못하는거

 

그래서 한동안 못풀었다

 

하지만 알고있다면.. 정확히 똑같이 제출해도 정답임

 

from sys import stdin

n = int(stdin.readline())

x = 1

m = int(n**(1/2))

num_list = [y**2 for y in range(1,m+1)]

find = False

for i in range(m):

    if num_list[i] == n:
        
        find = True

        break

if find:

    print(x)

else:

    x += 1

    for i in range(m):

        for j in range(m):

            if num_list[i] + num_list[j] == n:

                find = True
                break

        if find:

            break

    if find:

        print(x)

    else:

        x += 1

        for i in range(m):

            for j in range(m):

                for k in range(m):

                    if num_list[i] + num_list[j] + num_list[k] == n:

                        find = True
                        break

                if find:

                    break

            if find:

                break

        if find:

            print(x)

        else:

            print(x+1)

 

6. 풀이2

 

태그에 다이나믹 프로그래밍이 있는데.. 다이나믹 프로그래밍으로는 어떻게 풀 수 있을까

 

dp배열을 정확히 정의하고 시작한다

 

dp[i] = i를 제곱수 합으로 나타낼 수 있는, 제곱수의 개수의 최솟값

 

dp[i] = i가 되도록 초기화

 

1부터 n까지 dp배열을 채워넣는다.

 

j*j <= i인 j에 대하여, i-(j*j)에 j*j를 더하면, i가 되므로 i-(j*j)를 제곱수 합으로 나타낼 수 있는 개수의 최솟값에 1을 더한 개수로 i를 나타낼 수 있다.

 

즉 dp[i]는 dp[i-(j*j)]에 1을 더한 값과 dp[i]중 최솟값이 된다.

 

from sys import stdin

n = int(stdin.readline())

dp = [i for i in range(n+1)]

for i in range(1,n+1):
    
    for j in range(1,i+1):
        
        val = j*j

        if val > i:
            
            break
        
        dp[i] = min(dp[i],dp[i-val]+1)

print(dp[n])

 

그리고 조금 디테일이라면.. 시간복잡도는 위 코드는 $O(N\sqrt{N})$인데.. python3로는 통과를 못해

 

두번째 반복문 안에서 매번 조건문에 의한 비교연산을 하기 때문에 연산량이 많다

 

애초에 j*j <= i인 j만 돌도록 반복문을 설정하고 조건문을 수행하지 않게 고친다면..

 

시간복잡도는 $O(N\sqrt{N})$인데 연산량이 조금 줄어서 python3로도 통과 가능함

 

from sys import stdin

n = int(stdin.readline())

dp = [i for i in range(n+1)]

dp[0] = 0

for i in range(1,n+1):
    
    for j in range(1,int(i**(1/2))+1):
          
        dp[i] = min(dp[i],dp[i-(j*j)]+1)

print(dp[n])

 

 

7. 문제3

 

https://www.acmicpc.net/problem/3933

 

3933번: 라그랑주의 네 제곱수 정리

입력은 최대 255줄이다. 각 줄에는 215보다 작은 양의 정수가 하나씩 주어진다. 마지막 줄에는 0이 하나 있고, 입력 데이터가 아니다.

www.acmicpc.net

 

8. 풀이1

 

이번에는 최소 개수가 아니라, n을 제곱수 합으로 나타낼 수 있는 경우의 수를 묻는 문제

 

다이나믹 프로그래밍으로도 풀 수 있다는데... 전혀 생각도 못하겠다

 

근데 2**15  = 33000정도라서 $\sqrt{33000} = 14$라서 브루트포스로도 할 수 있을 것 같다.

 

브루트포스를 잘 해야하는게 $3^{2} + 4^{2}$와 $4^{2} + 3^{2}$는 같은 경우라서 제외해야한다.

 

아무튼..

 

1부터 n**(1/2) = m까지 for문으로 순회해서 그 수를 i라고 한다면...

 

i*i == n이라면.. 경우의 수 +1

 

i*i == n이 아니라면? r = n - (i*i)이라 하고, i부터 m까지 순회해서 그 수를 j라고 하고...

 

j*j == r이라면 i*i + j*j == n이라는 뜻이므로.. 또 경우의 수 +1

 

----------------------------------------------------------------------------------------------------------------

 

여기서 j를 구할 때 i부터 순회하는 이유는?

 

1부터 순회하면 $3^{2} + 4^{2}$와 $4^{2} + 3^{2}$처럼 같은 경우를 제외하지 못해서 그렇다

 

i부터 순회한다면.. 자연스럽게 오름차순 정렬되어 $i*i + j*j + ... == n$로 구해지니까

 

m = int(n**(1/2))

count = 0

for i in range(1,m+1):

    if i*i == n:

        count += 1

    else:

        remain1 = n-i*i

        for j in range(i,m+1):

            if j*j == remain1:

                count += 1

 

근데 j를 m까지 순회해야하나?

 

j의 최댓값이 m인지는 한번 생각해볼필요가 있는게... 결국 remain1의 제곱근까지가 최댓값이므로...

 

더욱 범위를 압축할 수 있다(하지 않으면 시간초과임)

 

m = int(n**(1/2))

count = 0

for i in range(1,m+1):

    if i*i == n:

        count += 1

    else:

        remain1 = n-i*i

        for j in range(i,int(remain1**(1/2))+1):

            if j*j == remain1:

                count += 1

 

마찬가지로.. j*j == remain1도 안된다면...

 

remain2 = remain1 - j*j = n- i*i - j*j로 설정하고

 

k를 j부터 remain2**(1/2)까지 순회해서..

 

k*k == remain2라면 경우의 수 +1

 

그렇지 않다면...

 

remain3 = remain2 - k*k로 설정하자

 

여기서 for문을 또 설정해야하나? 그렇지 않다.. remain3가 제곱수인지 아닌지만 판정하면 된다

 

근데 remain3가 음수일 수 있어서 remain3**(1/2)하면 복소수 나올 수 있으니 int(remain3**(1/2))를 바로 하면..

 

런타임에러남

 

remain3 > 0인지도 판단해줘야하고..

 

그리고 또 중요한 점은 remain3**(1/2)가 k이상이어야한다.

 

왜냐하면? 그래야 $3^{2} + 4^{2}$와 $4^{2} + 3^{2}$처럼 같은 경우를 제외할 수 있으니까..

 

자연스럽게 i,j,k,remain4가 오름차순으로 정렬되도록 브루트포스를 할 수 있다는 뜻

 

from sys import stdin

while 1:
    
    n = int(stdin.readline())

    if n == 0:
        
        break
    
    else:
        
        m = int(n**(1/2))

        count = 0

        for i in range(1,m+1):
            
            if i*i == n:
                
                count += 1
            
            else:
                
                remain1 = n-i*i

                for j in range(i,int(remain1**(1/2))+1):
                    
                    if j*j == remain1:
                        
                        count += 1
                    
                    else:
                        
                        remain2 = remain1-j*j

                        for k in range(j,int(remain2**(1/2))+1):
                            
                            if k*k == remain2:
                                
                                count += 1
                            
                            else:
                                
                                remain3 = remain2-k*k
                                
                                remain4 = remain3**(1/2)

                                if remain3 > 0 and int(remain4) == remain4 and remain4 >= k:
                                    
                                    count += 1
      
        print(count)

 

9. 풀이2

 

다이나믹 프로그래밍으로는 어떻게 풀 수 있을까...

 

dp배열을 정확히 정의한다.

 

dp[i][j]를 i를 j개의 제곱수 합으로 나타낼 수 있는 경우의 수로 정의한다.

 

그러면 j는 1부터 4까지 가능하니까.. dp = [[0]*5 for _ in range(n+1)]처럼 초기화 가능

 

그리고 m = int(n**(1/2))로 압축시킨 다음, 1부터 m까지 순회해서..

 

최초 dp[i*i][1] = 1로 초기화

 

왜냐하면 i*i는 i 1개로 나타낼 수 있으니까 최소한 1가지는 확보할 수 있다

 

그리고 j를 i*i부터 n까지 순회해서...

 

j를 k개의 제곱수 합으로 나타낼 수 있는 경우의 수는 어떻게 구할 수 있을까?

 

j - (i*i) + (i*i) = j이므로...

 

j - (i*i)를 나타낼 수 있는 모든 경우의 수 각각에  그냥 i*i만 더해주면 j가 된다..

 

예를 들어..

 

$2003 = 3^{2} + 21^{2} + 23^{2} + 32^{2}$

 

여기서 $3^{2} + 21^{2} + 23^{2} = 979$인데..

 

979를 3가지 제곱수 합으로 나타내는 경우의 수가..

 

(3,3,31)

 

(3,21,23)

 

(5,15,27)

 

(9,13,27)

 

(15,15,23)

 

이렇게 5가지가 있다.

 

이 5가지에 32*32만 더해준다면..

 

(3,3,31,32)

 

(3,21,23,32)

 

(5,15,27,32)

 

(9,13,27,32)

 

(15,15,23,32)

 

이들 제곱수 합은 2003이 된다..

 

즉 2003을 4개의 제곱수 합으로 나타내는 방법은.. 최소한 979를 3개의 제곱수 합으로 나타내는 방법을 포함한다.

 

그래서 일반적으로.. j를 k개의 제곱수 합으로 나타내는 방법의 수는..

 

j를 k개의 제곱수 합으로 나타내는 방법의 수  + (j-(i*i))를 k-1개의 제곱수 합으로 나타내는 방법의 수가 된다.

 

여기서 k는 2,3,4

 

from sys import stdin

N = 2**15

#dp[i][j] = i를 j개의 제곱수 합으로 나타낼 수 있는 경우의 수
dp = [[0]*5 for _ in range(N+1)]

m = int(N**(1/2))

for i in range(1,m+1):
    
    #초기화    
    dp[i*i][1] = 1 #i*i는 i 1개로 나타낼 수 있다

    for j in range(i*i,N+1):
        
        for k in range(2,5):
            
            #j는 j-(i*i) + (i*i)로 나타낼 수 있다
            #따라서, j-(i*i)를 나타낼 수 있는 경우의 수들 각각에 (i*i)만 더해주면 j가 되므로..
            #k-1개의 제곱수로 j-(i*i)를 나타낼 수 있는 경우의 수들 각각에 1개의 제곱수 i*i만 더해준다
            #따라서, j를 k개의 제곱수 합으로 나타낼 수 있는 경우의 수는, k-1개의 제곱수로 j-(i*i)를 나타낼 수 있는 경우의 수와 같다.
            dp[j][k] = dp[j][k] + dp[j-i*i][k-1]

while 1:
    
    n = int(stdin.readline())

    answer = 0
    
    if n == 0:
        
        break
    
    else:
        
        for i in range(1,5):
            
            answer += dp[n][i]
        
        print(answer)
TAGS.

Comments