본문 바로가기
추천 시스템 이론

카카오 아레나 Melon Playlist Continuation baseline github 분석 - (1)

by 블쭌 2020. 10. 19.
728x90

split_data.py 실행순서에 따른 코드 분석

 

(1) module

import copy
import random
import fire
import numpy as np

# arena_util.py에서 함수 호출
from arena_util import load_json, write_json

 

 

(2) main

if __name__ == "__main__":
    fire.Fire(ArenaSplitter)

fire패키지에 대한 이해는 다음을 참고해주세요.

bladejun.tistory.com/21

 

Python fire package

fire 패키지는 Python에서의 모든 객체를 command line interface로 만들어 준다. python 객체(함수, 클래스, dictionary, list, tuple 모두다 호출이 가능하다) 함수 예시 import fire def hello(name="World"):..

bladejun.tistory.com

 

 

(3) run method

def run(self, fname):
        # random shuffle때문에 seed설정
        random.seed(777)

        print("Reading data...\n")
        
        # json파일 불러오기
        playlists = load_json(fname)  # -> (4)
        
        # 불러온 playlist파일 순서 섞기
        random.shuffle(playlists)
        print(f"Total playlists: {len(playlists)}")

        # train, test split
        print("Splitting data...")
        train, val = self._split_data(playlists) # -> (5)
        
        # train, val나눈 데이터 json파일 작성
        print("Original train...")
        write_json(train, "orig/train.json") # -> (6)
        print("Original val...")
        write_json(val, "orig/val.json") # -> (6)

        # masking 작업
        print("Masked val...")
        val_q, val_a = self._mask_data(val) # -> (7)
        write_json(val_q, "questions/val.json") # -> (9)
        write_json(val_a, "answers/val.json") # -> (9)

 

 

(4) arena_utils.py -> load_json

def load_json(fname):
    with open(fname, encoding="utf-8") as f:
        json_obj = json.load(f)

    return json_obj

 

 

(5) class ArenaSplitter -> split_data 

def _split_data(playlists):
        tot = len(playlists) # 전체 길이
        train = playlists[:int(tot*0.80)] # 전체 길이의 80% train
        val = playlists[int(tot*0.80):]  # 전체 길이의 20% test
        
        return train, val

 

 

(6) arena_utils.py -> write_json

def write_json(data, fname):
    def _conv(o):
        # isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
        # o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
        if isinstance(o, (np.int64, np.int32)):
            return int(o)
        
        # 속해있지 않다면 type error발생
        raise TypeError
    
    # 부모 directory 경로 설정
    parent = os.path.dirname(fname)
    
    # 새로운 경로 만들기
    distutils.dir_util.mkpath("./arena_data/" + parent)
    
    with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
        json_str = json.dumps(data, ensure_ascii=False, default=_conv)
        f.write(json_str)

 

 

(7) class ArenaSplitter -> mask_data

def _mask_data(self, playlists):
    # deep copy
    playlists = copy.deepcopy(playlists)
    tot = len(playlists)
    
    # song과 tag를 예측하는 문제이기 때문에 일부 값을 mask처리
    song_only = playlists[:int(tot * 0.3)]                   # 곡만 존재
    song_and_tags = playlists[int(tot * 0.3):int(tot * 0.8)] # 곡,태그 둘다 존재
    tags_only = playlists[int(tot * 0.8):int(tot * 0.95)]    # 태그만 존재
    title_only = playlists[int(tot * 0.95):]                 # 제목만 존재
	
    # 길이 확인
    print(f"Total: {len(playlists)}, "
          f"Song only: {len(song_only)}, "
          f"Song & Tags: {len(song_and_tags)}, "
          f"Tags only: {len(tags_only)}, "
          f"Title only: {len(title_only)}")

    song_q, song_a = self._mask(song_only, ['songs'], ['tags'])             # -> (8)
    songtag_q, songtag_a = self._mask(song_and_tags, ['songs', 'tags'], []) # -> (8)
    tag_q, tag_a = self._mask(tags_only, ['tags'], ['songs'])               # -> (8)
    title_q, title_a = self._mask(title_only, [], ['songs', 'tags'])        # -> (8)
	
    # 합치기
    q = song_q + songtag_q + tag_q + title_q
    a = song_a + songtag_a + tag_a + title_a
	
    # random하게 섞기
    shuffle_indices = np.arange(len(q))
    np.random.shuffle(shuffle_indices)
	
    q = list(np.array(q)[shuffle_indices])
    a = list(np.array(a)[shuffle_indices])

 

 

(8) class ArenaSplitter -> mask

def _mask(playlists, mask_cols, del_cols):
    
    # deep copy
    q_pl = copy.deepcopy(playlists) # list다 비워놓기
    a_pl = copy.deepcopy(playlists) # 정답 
    
    for i in range(len(playlists)):
        # 삭제할 컬럼
        for del_col in del_cols:
            # question deep copy는 맞춰야하니까 빈 list생성
            q_pl[i][del_col] = []
            if del_col == 'songs':
                # 상위 100개의 곡 추출
                a_pl[i][del_col] = a_pl[i][del_col][:100]
            elif del_col == 'tags':
                # 상위 10개 태그 추출
                a_pl[i][del_col] = a_pl[i][del_col][:10]
        
        # masking
        for col in mask_cols:
            mask_len = len(playlists[i][col])
            mask = np.full(mask_len, False)
            
            # 절반은 True, 절반은 False
            mask[:mask_len//2] = True
            np.random.shuffle(mask)

            q_pl[i][col] = list(np.array(q_pl[i][col])[mask])
            a_pl[i][col] = list(np.array(a_pl[i][col])[np.invert(mask)])

    return q_pl, a_pl

 

 

(9) arena_utils.py -> write_json

def write_json(data, fname):
    def _conv(o):
        # isinstance는 첫번째 객체가 뒤에 타입에 속해있는지 확인
        # o라는 객체가 numpy배열 int64, int32타입에 속해있는지 확인
        if isinstance(o, (np.int64, np.int32)):
            return int(o)
        
        # 속해있지 않다면 type error발생
        raise TypeError
    
    # 부모 directory 경로 설정
    parent = os.path.dirname(fname)
    
    # 새로운 경로 만들기
    distutils.dir_util.mkpath("./arena_data/" + parent)
    
    with io.open("./arena_data/" + fname, "w", encoding="utf-8") as f:
        json_str = json.dumps(data, ensure_ascii=False, default=_conv)
        f.write(json_str)

출처

github.com/kakao-arena/melon-playlist-continuation

728x90

댓글