k fold cross validation을 구하는 알고리즘 문제 복기하기
1. 알고리즘
주어진 데이터에 대한 k fold cross validation은...
1) 데이터를 k개의 크기가 같은 부분집합으로 분할한다.
여기서 전체 데이터 개수가 k의 배수가 아니면 마지막 집합은 나머지 데이터로만 채워넣는다
2) k-1개의 부분집합을 training set, 나머지 1개의 부분집합을 test set으로 한다
3) test set을 고르는 방법은 총 k가지가 있다.
4) (training set, test set) 순서로 총 k개의 세트를 하나의 list에 넣어 return
from collections import deque
def kfold(indices,k):
answer = []
deque_data = deque(indices)
data_num = len(indices)
if int(data_num/k) == data_num/k:
num_per_k = int(data_num/k)
else:
num_per_k = int(data_num/k)+1
last_num = data_num-(k-1)*num_per_k
split_total = deque()
split = []
for ind in range(k):
if ind <= k-2:
for _ in range(num_per_k):
data = deque_data.popleft()
split.append(data)
split_total.append(split)
split = []
else:
for _ in range(last_num):
data = deque_data.popleft()
split.append(data)
split_total.append(split)
split = []
for ind in range(k):
test = split_total[-1]
train = sum(list(split_total)[:-1],[])
answer.append(train)
answer.append(test)
split_total.rotate()
return answer
먼저 초기 필요한 변수들을 초기화함
def kfold(indices,k):
answer = []
deque_data = deque(indices)
data_num = len(indices)
if int(data_num/k) == data_num/k:
num_per_k = int(data_num/k)
else:
num_per_k = int(data_num/k)+1
last_num = data_num-(k-1)*num_per_k
split_total = deque()
split = []
전체 데이터 수가 k의 배수이면 int(data_num/k) == data_num/k여서 각 부분집합이 int(data_num/k)만큼 sample을 가져감
k의 배수가 아니면 부분집합에 속하는 데이터 수를 int(data_num/k)+1로 한다..
data_num/k가 소수이면 int(data_num/k)하면 소수를 버림하여 정수로 바꿔줌
예를 들어 data_num/k = 2.7이면 int(data_num/k)=2
k의 배수가 아닐때 마지막 부분집합의 수가 중요해서 last_num도 전체 데이터수에 각 부분집합 당 데이터 수 * (k-1)을 빼면 구할 수 있을 것
for ind in range(k):
if ind <= k-2:
for _ in range(num_per_k):
data = deque_data.popleft()
split.append(data)
split_total.append(split)
split = []
else:
for _ in range(last_num):
data = deque_data.popleft()
split.append(data)
split_total.append(split)
split = []
이제 k개의 부분집합으로 전체 데이터를 split함
1번부터 k-1번까지는 각 부분집합에 num_per_k만큼 들어간다
deque_data에서 왼쪽부터 하나씩 빼서 split에 넣는데 num_per_k만큼 반복함
일반 리스트가 아니라 deque로 바꿔야 시간을 절약할 수 있음( O(n) >>> O(1) )
split_total에 split된 set을 넣고 split을 다시 초기화하여 k-1번까지 반복
마지막 부분집합에는 last_num만큼 들어가기 때문에 num_per_k 대신에 last_num만큼 반복
for ind in range(k):
test = split_total[-1]
train = sum(list(split_total)[:-1],[])
answer.append(train)
answer.append(test)
split_total.rotate()
return answer
이제 train과 test로 split을 해야함
kfold니까 k개의 fold만큼 구성되어서 range(k)로 k번 반복함
train과 test를 선택하기 쉽게 split_total[-1]로 마지막 split set을 test로 하고 나머지는 train으로 구성함
list(split_total)[:-1]하면 [ [1,2,3],[4,5,6] ] 이런식으로 이중리스트로 구성되어 있어서 이중리스트를 풀어줄거임
sum(list(split_total)[:-1],[])을 이용해서 이중리스트를 풀어줄 수 있음
sum([[1,2,3],[4,5,6]],[])
[1,2,3,4,5,6]
train과 test 순으로 answer에 넣어주고.. fold마다 train test set이 다르게 구성되어야 하니까
split_total.rotate()를 이용해 한칸씩 deque를 회전시킴
이러면 split_total[-1]을 하더라도 매 반복마다 다른 test set을 고를거임
answer = kfold([1,2,3,4,5,6,7,8,9,10],5)
print(answer)
[[1, 2, 3, 4, 5, 6, 7, 8], [9, 10], [9, 10, 1, 2, 3, 4, 5, 6], [7, 8], [7, 8, 9, 10, 1, 2, 3, 4], [5, 6], [5, 6, 7, 8, 9, 10, 1, 2], [3, 4], [3, 4, 5, 6, 7, 8, 9, 10], [1, 2]]
2. 되돌아보기
정답인지는 모르겠는데 여기서 크게 다른 부분은 없을듯
시간있었으면 충분히 풀만한 문제였을듯??.. 다른 시험이랑 겹쳐서 못했던거지.. 자신감 가져도 좋을듯
마지막 로직은 기억하면 좋을듯
for ind in range(k):
test = split_total[-1]
train = sum(list(split_total)[:-1],[])
answer.append(train)
answer.append(test)
split_total.rotate()
return answer
sum(A,[])로 A의 이중리스트를 단일리스트로 풀어낼 수 있고
마지막 index로 고정해서 rotate를 통해 매 반복마다 다른거 선택하는거
'알고리즘 > 알고리즘 일반' 카테고리의 다른 글
stack 활용법 - 올바른 괄호 문자열 판별 (0) | 2021.11.17 |
---|---|
합의 차이가 최소가 되는 분할 1편 (0) | 2021.11.16 |
명일방주 픽업을 위한 평균 가챠횟수 4편(일반 헤드헌팅) (0) | 2021.11.09 |
명일방주 픽업을 위한 평균 가챠횟수 3편(한정 헤드헌팅) (0) | 2021.11.09 |
명일방주 픽업을 위한 평균 가챠횟수 2편(위기협약) (0) | 2021.11.09 |