k fold cross validation을 구하는 알고리즘 문제 복기하기

1. 알고리즘

 

주어진 데이터에 대한 k fold cross validation은...

 

1) 데이터를 k개의 크기가 같은 부분집합으로 분할한다.

 

여기서 전체 데이터 개수가 k의 배수가 아니면 마지막 집합은 나머지 데이터로만 채워넣는다

 

2) k-1개의 부분집합을 training set, 나머지 1개의 부분집합을 test set으로 한다

 

3) test set을 고르는 방법은 총 k가지가 있다.

 

4) (training set, test set) 순서로 총 k개의 세트를 하나의 list에 넣어 return

 

from collections import deque

def kfold(indices,k):

    answer = []

    deque_data = deque(indices)

    data_num = len(indices)

    if int(data_num/k) == data_num/k:
        
        num_per_k = int(data_num/k)
    
    else:
        
        num_per_k = int(data_num/k)+1

    last_num = data_num-(k-1)*num_per_k

    split_total = deque()

    split = []

    for ind in range(k):
        
        if ind <= k-2:
            for _ in range(num_per_k):
                
                data = deque_data.popleft()

                split.append(data)
            
            split_total.append(split)

            split = []
        
        else:
            
            for _ in range(last_num):
                data = deque_data.popleft()

                split.append(data)
            
            split_total.append(split)

            split = []
    
    for ind in range(k):

        test = split_total[-1]
        train = sum(list(split_total)[:-1],[])

        answer.append(train)
        answer.append(test)

        split_total.rotate()
    
    return answer

 

먼저 초기 필요한 변수들을 초기화함

 

def kfold(indices,k):

    answer = []

    deque_data = deque(indices)

    data_num = len(indices)

    if int(data_num/k) == data_num/k:
        
        num_per_k = int(data_num/k)
    
    else:
        
        num_per_k = int(data_num/k)+1

    last_num = data_num-(k-1)*num_per_k

    split_total = deque()

    split = []

 

전체 데이터 수가 k의 배수이면 int(data_num/k) == data_num/k여서 각 부분집합이 int(data_num/k)만큼 sample을 가져감

 

k의 배수가 아니면 부분집합에 속하는 데이터 수를 int(data_num/k)+1로 한다..

 

data_num/k가 소수이면 int(data_num/k)하면 소수를 버림하여 정수로 바꿔줌

 

예를 들어 data_num/k = 2.7이면 int(data_num/k)=2

 

k의 배수가 아닐때 마지막 부분집합의 수가 중요해서 last_num도 전체 데이터수에 각 부분집합 당 데이터 수 * (k-1)을 빼면 구할 수 있을 것

 

for ind in range(k):

    if ind <= k-2:
        for _ in range(num_per_k):

            data = deque_data.popleft()

            split.append(data)

        split_total.append(split)

        split = []

    else:

        for _ in range(last_num):
            data = deque_data.popleft()

            split.append(data)

        split_total.append(split)

        split = []

 

이제 k개의 부분집합으로 전체 데이터를 split함

 

1번부터 k-1번까지는 각 부분집합에 num_per_k만큼 들어간다

 

deque_data에서 왼쪽부터 하나씩 빼서 split에 넣는데 num_per_k만큼 반복함

 

일반 리스트가 아니라 deque로 바꿔야 시간을 절약할 수 있음( O(n) >>> O(1) )

 

split_total에 split된 set을 넣고 split을 다시 초기화하여 k-1번까지 반복

 

마지막 부분집합에는 last_num만큼 들어가기 때문에 num_per_k 대신에 last_num만큼 반복

 

for ind in range(k):

    test = split_total[-1]
    train = sum(list(split_total)[:-1],[])

    answer.append(train)
    answer.append(test)

    split_total.rotate()

return answer

 

이제 train과 test로 split을 해야함

 

kfold니까 k개의 fold만큼 구성되어서 range(k)로 k번 반복함

 

train과 test를 선택하기 쉽게 split_total[-1]로 마지막 split set을 test로 하고 나머지는 train으로 구성함

 

list(split_total)[:-1]하면 [ [1,2,3],[4,5,6] ] 이런식으로 이중리스트로 구성되어 있어서 이중리스트를 풀어줄거임

 

sum(list(split_total)[:-1],[])을 이용해서 이중리스트를 풀어줄 수 있음

 

sum([[1,2,3],[4,5,6]],[])

[1,2,3,4,5,6]

 

train과 test 순으로 answer에 넣어주고.. fold마다 train test set이 다르게 구성되어야 하니까

 

split_total.rotate()를 이용해 한칸씩 deque를 회전시킴

 

이러면 split_total[-1]을 하더라도 매 반복마다 다른 test set을 고를거임

 

answer = kfold([1,2,3,4,5,6,7,8,9,10],5)
print(answer)

[[1, 2, 3, 4, 5, 6, 7, 8], [9, 10], [9, 10, 1, 2, 3, 4, 5, 6], [7, 8], [7, 8, 9, 10, 1, 2, 3, 4], [5, 6], [5, 6, 7, 8, 9, 10, 1, 2], [3, 4], [3, 4, 5, 6, 7, 8, 9, 10], [1, 2]]

 

2. 되돌아보기

 

정답인지는 모르겠는데 여기서 크게 다른 부분은 없을듯

 

시간있었으면 충분히 풀만한 문제였을듯??.. 다른 시험이랑 겹쳐서 못했던거지.. 자신감 가져도 좋을듯

 

마지막 로직은 기억하면 좋을듯

 

for ind in range(k):

    test = split_total[-1]
    train = sum(list(split_total)[:-1],[])

    answer.append(train)
    answer.append(test)

    split_total.rotate()

return answer

 

sum(A,[])로 A의 이중리스트를 단일리스트로 풀어낼 수 있고

 

마지막 index로 고정해서 rotate를 통해 매 반복마다 다른거 선택하는거

TAGS.

Comments