[Java]우선순위 큐 응용 - MN개의 조합에서 조건을 만족하는 k번째 수를 찾는 빠르게 찾는 방법 1편

1. 문제

 

n개의 숫자로 이루어진 수열과 m개의 숫자로 이루어진 수열이 주어졌을 때, 각 수열에서 정확히 원소를 하나씩만 뽑아 나올 수 있는 모든 쌍들을 모두 구하고, 그 값들을 오름차순이 되도록 나열했을 때의 k번째 쌍의 두 수의 합을 구하는 프로그램을 작성해보세요.

 

<제한>

 

n,m은 1이상 10만 이하의 자연수

 

k는 1이상 min(nm, 10만)이하

 

 

2. 풀이

 

가장 쉽게 생각할 수 있는 방법은 mn개의 모든 조합을 만든 다음에 우선순위 큐에 모두 집어 넣고, k번째 빠지는 수를 출력하면 된다

 

근데 뭐 당연히 시간초과  + 메모리 초과임

 

m과 n이 10만인데 시간복잡도가 얼마냐 이거 O(M + N + MNlogMN + KlogMN)인가..?

 

아무튼 O(MN)으로 생각하는 순간 10만*10만 = 100억으로 오바다

 

import java.util.Scanner;
import java.util.PriorityQueue;

class Pair implements Comparable<Pair> {
    int x,y;

    public Pair(int x, int y){
        this.x = x;
        this.y = y;
    }

    @Override
    public int compareTo(Pair p){
        return (this.x+this.y) - (p.x + p.y);
    }
}

public class Main {
    public static void main(String[] args) {
        // 여기에 코드를 작성해주세요.

        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        int[] arr1 = new int[n];
        int[] arr2 = new int[m];
        PriorityQueue<Pair> pq = new PriorityQueue<>();

        for(int i = 0; i < n; i++){
            arr1[i] = sc.nextInt();
        }

        for(int j = 0; j < m; j++){
            arr2[j] = sc.nextInt();
        }

        for(int i = 0; i < n; i++){
            for(int j = 0; j < m; j++){
                pq.add(new Pair(arr1[i],arr2[j]));
            }
        }

        for(int i = 0; i < k-1; i++){
            pq.poll();
        }

        Pair p = pq.poll();

        System.out.println(p.x+p.y);

    }
}

 

결국 MN개의 모든 조합을 안넣더라도 k번째 수를 찾는 방법을 생각해야하는데

 

어떻게 가능할까?

 

정수쌍 (p,q)에 대하여 p+q가 작은 순서대로 오름차순 배열하여 k번째 수를 찾는 문제이다.

 

p+q가 작을려면 p,q가 작아야한다

 

따라서 p를 가지는 A와 q를 가지는 B에 대해 A,B를 오름차순으로 정렬한다.

 

그러면 A의 모든 수 $p_{1}, p_{2}, ... , p_{n}$과 B의 첫번째 수 $q_{1}$에 대하여,

 

$(p_{1}, q_{1}), (p_{2}, q_{1}), ... (p_{n}, q_{1})$중 합이 가장 작은 원소가 존재한다.

 

왜냐하면, $q_{1}$이 B에 든 원소중에서 가장 작은 원소니까

 

그러므로 $(p_{1}, q_{1}), (p_{2}, q_{1}), ... (p_{n}, q_{1})$을 모두 우선순위 큐에 넣고 하나를 빼면, 그것이 첫번째로 가장 작은 원소가 된다.

 

그것이 $(p_{a}, q_{1})$이라고 하자.

 

그러면 다음, $(p_{a}, q_{2})$는 2번째로 작은 원소가 될 수 있는 후보이다.

 

즉, 우선순위 큐에 남아있는 n-1개의 원소와 $(p_{a}, q_{2})$는 2번째로 작은 원소가 될 수 있는 후보이고, 여기서 하나를 빼면 그것이 2번째로 작은 원소가 된다.

 

이러한 과정을 k번 반복하면 k번째 나오는 원소가 k번째 작은 원소가 된다.

 

먼저 p값과 q값, 그리고 p값은 첫번째 array의 어떤 index인지, q값은 두번째 array의 어디 index인지를 모두 관리하는 클래스 Pair를 구현

 

여기서 정렬 기준은 p값과 q값의 합의 오름차순으로 정렬하도록

 

class Pair implements Comparable<Pair> {
    int x,y,a,b;

    public Pair(int x, int y, int a, int b){
        this.x = x;
        this.y = y;
        this.a = a;
        this.b = b;
    }

    @Override
    public int compareTo(Pair p){
        return (this.x+this.y) - (p.x + p.y);
    }
}

 

 

다음, n개의 수를 첫번째 array에 넣어주고 m개의 수를 두번째 array에 넣어준 다음, 

 

두 array를 오름차순으로 정렬한다.

 

public class Main {
    public static void main(String[] args) {
        // 여기에 코드를 작성해주세요.

        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        int[] arr1 = new int[n];
        int[] arr2 = new int[m];
        PriorityQueue<Pair> pq = new PriorityQueue<>();

        for(int i = 0; i < n; i++){
            arr1[i] = sc.nextInt();
        }

        for(int j = 0; j < m; j++){
            arr2[j] = sc.nextInt();
        }

        Arrays.sort(arr1);
        Arrays.sort(arr2);

 

 

다음 첫번째 array에 든 arr1[0], arr1[1], ..., arr1[n-1]과 두번째 array의 가장 작은 원소인 arr2[0]와 매칭한

 

(arr1[0], arr2[0]), (arr1[1], arr2[0]), ... , (arr1[n-1], arr2[0])를 모두 우선순위 큐에 넣어준다.

 

 

for(int i = 0; i < n; i++){
    pq.add(new Pair(arr1[i],arr2[0],i,0));
}

 

그리고 우선순위 큐에서 원소 하나를 빼면 그것이 i번째 작은 원소가 된다.

 

그 원소에 들어있는 p,q값과 p,q의 index값을 이용해서 p값은 그대로 두고, q값의 다음 index와 매칭시킨

 

(arr1[p.a], arr2[p.b+1])은 다음 i+1번째 작은 원소 후보가 될 수 있으므로 우선순위 큐에 넣어준다.

 

여기서 주의할 점은, 두번째 array의 길이가 작아서, 매칭시킬 원소가 남아있지 않을 수 있다.

 

즉 p.b+1이 m 이상이 된다면, 더 이상 매칭 시킬 것이 없으니 우선순위 큐에 남아있는 원소들만으로도 i+1번째 작은 원소들의 후보로 충분하다

 

for(int i = 0; i < k-1; i++){

    Pair p = pq.poll();

    if (p.b+1 < m){
        pq.add(new Pair(p.x,arr2[p.b+1],p.a,p.b+1));
    }
}

 

 

k-1번째 작은 원소까지 뽑았으므로, 1번 더 우선순위 큐에서 빼주면 그것이 k번째 작은 원소가 된다.

 

import java.util.Scanner;
import java.util.PriorityQueue;
import java.util.Arrays;

class Pair implements Comparable<Pair> {
    int x,y,a,b;

    public Pair(int x, int y, int a, int b){
        this.x = x;
        this.y = y;
        this.a = a;
        this.b = b;
    }

    @Override
    public int compareTo(Pair p){
        return (this.x+this.y) - (p.x + p.y);
    }
}

public class Main {
    public static void main(String[] args) {
        // 여기에 코드를 작성해주세요.

        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        int[] arr1 = new int[n];
        int[] arr2 = new int[m];
        PriorityQueue<Pair> pq = new PriorityQueue<>();

        for(int i = 0; i < n; i++){
            arr1[i] = sc.nextInt();
        }

        for(int j = 0; j < m; j++){
            arr2[j] = sc.nextInt();
        }

        Arrays.sort(arr1);
        Arrays.sort(arr2);

        for(int i = 0; i < n; i++){
            pq.add(new Pair(arr1[i],arr2[0],i,0));
        }

        for(int i = 0; i < k-1; i++){

            Pair p = pq.poll();

            if (p.b+1 < m){
                pq.add(new Pair(p.x,arr2[p.b+1],p.a,p.b+1));
            }
        }

        Pair p =pq.poll();

        System.out.println(p.x+p.y);

    }
}

 

 

중요한 알고리즘 같으니까... 파이썬으로도 구현해보자

 

import heapq

n,m,k = map(int,input().split())

arr1 = list(map(int,input().split()))
arr2 = list(map(int,input().split()))

#두 정수 배열을 오름차순 정렬
arr1.sort()
arr2.sort()

q = []

# 처음에는 n개의 원소에 대해 각각 
# 두 번째 수열의 첫 번째 원소를 대응시켜줍니다.
# 두 수의 합이 작은 값이 더 먼저 나와야 하므로
# +를 붙여서 넣어줍니다. 

#두 정수쌍의 합이 작은 순서대로 k번째 작은 합을 구해야하므로,
#두 정수쌍의 합과 두 정수의 index를 tuple로 구성
for i in range(n):

    heapq.heappush(q,(arr1[i]+arr2[0], i, 0))

# 1번부터 k - 1번까지 합이 작은 쌍을 골라줍니다.

#우선순위 큐에서 하나를 빼면 그것이 i번째 작은 합
#뺀 튜플에 든 2번째 array의 index를 이용해서 다음 index와 매칭시켜
#그것도 i+1번째 작은 합의 후보가 될 수 있다.
for _ in range(k-1):

    _,ind1,ind2 = heapq.heappop(q)

    #두번째 array에서 매칭 시킬 원소가 남아있다면...
    if ind2+1 < m:
        heapq.heappush(q,(arr1[ind1]+arr2[ind2+1],ind1,ind2+1))


#k-1번 빼고 나서, 1번 더 빼면 이것이 k번째 작은 합
value,_,_ = heapq.heappop(q)

print(value)

 

 

 

atcoder에 비슷한 문제가 나왔는데... 이걸 응용하면 풀 수 있으려나

TAGS.

Comments