728x90
split_data.py 실행순서에 따른 코드 분석
(1) module
import copy
import random
import fire
import numpy as np
# arena_util.py에서 함수 호출
from arena_util import load_json, write_json
(2) main
if __name__ == "__main__":
fire.Fire(ArenaSplitter)
fire패키지에 대한 이해는 다음을 참고해주세요.
(3) run method
def run(self, fname):
# random shuffle때문에 seed설정
random.seed(777)
print("Reading data...\n")
# json파일 불러오기
playlists = load_json(fname) # -> (4)
# 불러온 playlist파일 순서 섞기
random.shuffle(playlists)
print(f"Total playlists: {len(playlists)}")
# train, test split
print("Splitting data...")
train, val = self._split_data(playlists) # -> (5)
# train, val나눈 데이터 json파일 작성
print("Original train...")
write_json(train, "orig/train.json") # -> (6)
print("Original val...")
write_json(val, "orig/val.json") # -> (6)
# masking 작업
print("Masked val...")
val_q, val_a = self._mask_data(val) # -> (7)
write_json(val_q, "questions/val.json") # -> (9)
write_json(val_a, "answers/val.json") # -> (9)
(4) arena_utils.py -> load_json
def load_json(fname):
with open(fname, encoding="utf-8") as f:
json_obj = json.load(f)
return json_obj
(5) class ArenaSplitter -> split_data
def _split_data(playlists):
tot = len(playlists) # 전체 길이
train = playlists[:int(tot*0.80)] # 전체 길이의 80% train
val = playlists[int(tot*0.80):] # 전체 길이의 20% test
return train, val
(6) arena_utils.py -> write_json
def write_json(data, fname):
def _conv(o):
# isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
# o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
if isinstance(o, (np.int64, np.int32)):
return int(o)
# 속해있지 않다면 type error발생
raise TypeError
# 부모 directory 경로 설정
parent = os.path.dirname(fname)
# 새로운 경로 만들기
distutils.dir_util.mkpath("./arena_data/" + parent)
with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
json_str = json.dumps(data, ensure_ascii=False, default=_conv)
f.write(json_str)
(7) class ArenaSplitter -> mask_data
def _mask_data(self, playlists):
# deep copy
playlists = copy.deepcopy(playlists)
tot = len(playlists)
# song과 tag를 예측하는 문제이기 때문에 일부 값을 mask처리
song_only = playlists[:int(tot * 0.3)] # 곡만 존재
song_and_tags = playlists[int(tot * 0.3):int(tot * 0.8)] # 곡,태그 둘다 존재
tags_only = playlists[int(tot * 0.8):int(tot * 0.95)] # 태그만 존재
title_only = playlists[int(tot * 0.95):] # 제목만 존재
# 길이 확인
print(f"Total: {len(playlists)}, "
f"Song only: {len(song_only)}, "
f"Song & Tags: {len(song_and_tags)}, "
f"Tags only: {len(tags_only)}, "
f"Title only: {len(title_only)}")
song_q, song_a = self._mask(song_only, ['songs'], ['tags']) # -> (8)
songtag_q, songtag_a = self._mask(song_and_tags, ['songs', 'tags'], []) # -> (8)
tag_q, tag_a = self._mask(tags_only, ['tags'], ['songs']) # -> (8)
title_q, title_a = self._mask(title_only, [], ['songs', 'tags']) # -> (8)
# 합치기
q = song_q + songtag_q + tag_q + title_q
a = song_a + songtag_a + tag_a + title_a
# random하게 섞기
shuffle_indices = np.arange(len(q))
np.random.shuffle(shuffle_indices)
q = list(np.array(q)[shuffle_indices])
a = list(np.array(a)[shuffle_indices])
(8) class ArenaSplitter -> mask
def _mask(playlists, mask_cols, del_cols):
# deep copy
q_pl = copy.deepcopy(playlists) # list다 비워놓기
a_pl = copy.deepcopy(playlists) # 정답
for i in range(len(playlists)):
# 삭제할 컬럼
for del_col in del_cols:
# question deep copy는 맞춰야하니까 빈 list생성
q_pl[i][del_col] = []
if del_col == 'songs':
# 상위 100개의 곡 추출
a_pl[i][del_col] = a_pl[i][del_col][:100]
elif del_col == 'tags':
# 상위 10개 태그 추출
a_pl[i][del_col] = a_pl[i][del_col][:10]
# masking
for col in mask_cols:
mask_len = len(playlists[i][col])
mask = np.full(mask_len, False)
# 절반은 True, 절반은 False
mask[:mask_len//2] = True
np.random.shuffle(mask)
q_pl[i][col] = list(np.array(q_pl[i][col])[mask])
a_pl[i][col] = list(np.array(a_pl[i][col])[np.invert(mask)])
return q_pl, a_pl
(9) arena_utils.py -> write_json
def write_json(data, fname):
def _conv(o):
# isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
# o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
if isinstance(o, (np.int64, np.int32)):
return int(o)
# 속해있지 않다면 type error발생
raise TypeError
# 부모 directory 경로 설정
parent = os.path.dirname(fname)
# 새로운 경로 만들기
distutils.dir_util.mkpath("./arena_data/" + parent)
with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
json_str = json.dumps(data, ensure_ascii=False, default=_conv)
f.write(json_str)
출처
728x90
'추천 시스템 이론' 카테고리의 다른 글
Pytorch Recommend system github 작동 순서 (0) | 2021.04.23 |
---|---|
ALS 알고리즘 (1) | 2021.03.07 |
SGD를 사용한 Matrix Factorization 알고리즘 (1) | 2021.01.13 |
카카오 아레나 Melon Playlist Continuation baseline github 분석 - (3) (0) | 2020.10.26 |
카카오 아레나 Melon Playlist Continuation baseline github 분석 - (2) (0) | 2020.10.20 |
댓글