728x90
split_data.py 실행순서에 따른 코드 분석
(1) module
import copy
import random
import fire
import numpy as np
# arena_util.py에서 함수 호출
from arena_util import load_json, write_json
(2) main
if __name__ == "__main__":
fire.Fire(ArenaSplitter)
fire패키지에 대한 이해는 다음을 참고해주세요.
Python fire package
fire 패키지는 Python에서의 모든 객체를 command line interface로 만들어 준다. python 객체(함수, 클래스, dictionary, list, tuple 모두다 호출이 가능하다) 함수 예시 import fire def hello(name="World"):..
bladejun.tistory.com
(3) run method
def run(self, fname):
# random shuffle때문에 seed설정
random.seed(777)
print("Reading data...\n")
# json파일 불러오기
playlists = load_json(fname) # -> (4)
# 불러온 playlist파일 순서 섞기
random.shuffle(playlists)
print(f"Total playlists: {len(playlists)}")
# train, test split
print("Splitting data...")
train, val = self._split_data(playlists) # -> (5)
# train, val나눈 데이터 json파일 작성
print("Original train...")
write_json(train, "orig/train.json") # -> (6)
print("Original val...")
write_json(val, "orig/val.json") # -> (6)
# masking 작업
print("Masked val...")
val_q, val_a = self._mask_data(val) # -> (7)
write_json(val_q, "questions/val.json") # -> (9)
write_json(val_a, "answers/val.json") # -> (9)
(4) arena_utils.py -> load_json
def load_json(fname):
with open(fname, encoding="utf-8") as f:
json_obj = json.load(f)
return json_obj
(5) class ArenaSplitter -> split_data
def _split_data(playlists):
tot = len(playlists) # 전체 길이
train = playlists[:int(tot*0.80)] # 전체 길이의 80% train
val = playlists[int(tot*0.80):] # 전체 길이의 20% test
return train, val
(6) arena_utils.py -> write_json
def write_json(data, fname):
def _conv(o):
# isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
# o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
if isinstance(o, (np.int64, np.int32)):
return int(o)
# 속해있지 않다면 type error발생
raise TypeError
# 부모 directory 경로 설정
parent = os.path.dirname(fname)
# 새로운 경로 만들기
distutils.dir_util.mkpath("./arena_data/" + parent)
with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
json_str = json.dumps(data, ensure_ascii=False, default=_conv)
f.write(json_str)
(7) class ArenaSplitter -> mask_data
def _mask_data(self, playlists):
# deep copy
playlists = copy.deepcopy(playlists)
tot = len(playlists)
# song과 tag를 예측하는 문제이기 때문에 일부 값을 mask처리
song_only = playlists[:int(tot * 0.3)] # 곡만 존재
song_and_tags = playlists[int(tot * 0.3):int(tot * 0.8)] # 곡,태그 둘다 존재
tags_only = playlists[int(tot * 0.8):int(tot * 0.95)] # 태그만 존재
title_only = playlists[int(tot * 0.95):] # 제목만 존재
# 길이 확인
print(f"Total: {len(playlists)}, "
f"Song only: {len(song_only)}, "
f"Song & Tags: {len(song_and_tags)}, "
f"Tags only: {len(tags_only)}, "
f"Title only: {len(title_only)}")
song_q, song_a = self._mask(song_only, ['songs'], ['tags']) # -> (8)
songtag_q, songtag_a = self._mask(song_and_tags, ['songs', 'tags'], []) # -> (8)
tag_q, tag_a = self._mask(tags_only, ['tags'], ['songs']) # -> (8)
title_q, title_a = self._mask(title_only, [], ['songs', 'tags']) # -> (8)
# 합치기
q = song_q + songtag_q + tag_q + title_q
a = song_a + songtag_a + tag_a + title_a
# random하게 섞기
shuffle_indices = np.arange(len(q))
np.random.shuffle(shuffle_indices)
q = list(np.array(q)[shuffle_indices])
a = list(np.array(a)[shuffle_indices])
(8) class ArenaSplitter -> mask
def _mask(playlists, mask_cols, del_cols):
# deep copy
q_pl = copy.deepcopy(playlists) # list다 비워놓기
a_pl = copy.deepcopy(playlists) # 정답
for i in range(len(playlists)):
# 삭제할 컬럼
for del_col in del_cols:
# question deep copy는 맞춰야하니까 빈 list생성
q_pl[i][del_col] = []
if del_col == 'songs':
# 상위 100개의 곡 추출
a_pl[i][del_col] = a_pl[i][del_col][:100]
elif del_col == 'tags':
# 상위 10개 태그 추출
a_pl[i][del_col] = a_pl[i][del_col][:10]
# masking
for col in mask_cols:
mask_len = len(playlists[i][col])
mask = np.full(mask_len, False)
# 절반은 True, 절반은 False
mask[:mask_len//2] = True
np.random.shuffle(mask)
q_pl[i][col] = list(np.array(q_pl[i][col])[mask])
a_pl[i][col] = list(np.array(a_pl[i][col])[np.invert(mask)])
return q_pl, a_pl
(9) arena_utils.py -> write_json
def write_json(data, fname):
def _conv(o):
# isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
# o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
if isinstance(o, (np.int64, np.int32)):
return int(o)
# 속해있지 않다면 type error발생
raise TypeError
# 부모 directory 경로 설정
parent = os.path.dirname(fname)
# 새로운 경로 만들기
distutils.dir_util.mkpath("./arena_data/" + parent)
with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
json_str = json.dumps(data, ensure_ascii=False, default=_conv)
f.write(json_str)
출처
728x90
'추천 시스템 이론' 카테고리의 다른 글
Pytorch Recommend system github 작동 순서 (0) | 2021.04.23 |
---|---|
ALS 알고리즘 (1) | 2021.03.07 |
SGD를 사용한 Matrix Factorization 알고리즘 (1) | 2021.01.13 |
카카오 아레나 Melon Playlist Continuation baseline github 분석 - (3) (0) | 2020.10.26 |
카카오 아레나 Melon Playlist Continuation baseline github 분석 - (2) (0) | 2020.10.20 |
댓글