본문 바로가기
추천 시스템 논문

code로 이해하는 SR-GNN 논문

by 블쭌 2021. 5. 21.
728x90
  • As 행렬 만들기

v1 -> v2 -> v3 -> v2 -> v4의 session이 존재할때 그래프는 위의 노드와 간선 연결이 보인다.

해당 session의 그래프를 바탕으로 connectionm matrix A_s 생성

A_s는 두개의 인접행렬 A_s(out)과 A_s(in)의 연결로 정의된다.

이는 즉, 세션 그래프에서 각각 진입, 진출의 간선 연결로 생각하면 된다.

# batch가 10이라고 가정
i = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# input data, masking된 데이터, target 데이터 불러오기
inputs, mask, targets = train_data.inputs[i], train_data.mask[i], train_data.targets[i]

items, n_node, A, alias_inputs = [], [], [], []

# input 각 노드의 길이를 담아주고 최대 길이를 파악한다. 
for u_input in inputs:
    n_node.append(len(np.unique(u_input)))
max_n_node = np.max(n_node)


for u_input in inputs:
    # 유일한 값 찾기
    node = np.unique(u_input)
    
    # 최대길이에 도달하지 못하면 0으로 패딩
    items.append(node.tolist() + (max_n_node - len(node)) * [0]) 
    
    # A_s 0으로 세팅
    u_A = np.zeros((max_n_node, max_n_node))
    
    for i in np.arange(len(u_input) - 1):
        # target이 0이면 종료
        if u_input[i + 1] == 0:
            break
            
        # 해당 하는 index 찾기(u->v)
        u = np.where(node == u_input[i])[0][0]
        v = np.where(node == u_input[i + 1])[0][0]
        
        # u -> v
        u_A[u][v] = 1 
    
    # column 합
    u_sum_in = np.sum(u_A, 0)
    
    # 0인 부분을 1로 바꾸어준다.
    u_sum_in[np.where(u_sum_in == 0)] = 1
    
    # row(out) 합
    u_sum_out = np.sum(u_A, 1)
    
    # 0인 부분을 1로 바꾸어준다.
    u_sum_out[np.where(u_sum_out == 0)] = 1
    
    # 평균으로 나누어준다.
    u_A_out = np.divide(u_A.transpose(), u_sum_out)
    
    # 두 행렬을 concat진행
    u_A = np.concatenate([u_A_in, u_A_out], axis=1)

    A.append(u_A)
    
    # node의 index 저장
    alias_inputs.append([np.where(node == i)[0][0] for i in u_input])
    

  • gate 연산처리

 


  • A_s,i = A_s의 in과 out의 concat
input_in = torch.matmul(A[:, :, :A.shape[1]], linear_edge_in(hidden)) + b_iah # (10, 5, 20)
input_out = torch.matmul(A[:, :, A.shape[1]: 2 * A.shape[1]], linear_edge_out(hidden)) + b_oah # (10, 5, 20)
inputs = torch.cat([input_in, input_out], 2) # (10, 5, 40)

  • a_s,i
gi = F.linear(inputs, w_ih, b_ih) # (10, 5, 60)

  • v_i 
gh = F.linear(hidden, w_hh, b_hh) # (10, 5, 60)

  • r_s,i = reset_gate (논문에서는 z로 표시가 되어있는데 밑에서는 r로 다시적혀있고 코드를 보니 이게 맞는것같아요)
# 마지막 채널을 3개로 분리 (10, 5, 20), (10, 5, 20), (10, 5, 20)
i_r, i_i, i_n = gi.chunk(3, dim=2)

# 마지막 채널을 3개로 분리 (10, 5, 20), (10, 5, 20), (10, 5, 20)
h_r, h_i, h_n = gh.chunk(3, dim=2) 

reset_gate = torch.sigmoid(i_r + h_r)

  • z_s,i : update gate
# 마지막 채널을 3개로 분리 (10, 5, 20), (10, 5, 20), (10, 5, 20)
i_r, i_i, i_n = gi.chunk(3, dim=2)

# 마지막 채널을 3개로 분리 (10, 5, 20), (10, 5, 20), (10, 5, 20)
h_r, h_i, h_n = gh.chunk(3, dim=2) 

update_gate = torch.sigmoid(i_i + h_i)

  • v_i ~: new_gate
new_gate = torch.tanh(i_n + reset_gate * h_n)

  • 최종 node vector
hy = new_gate + update_gate * (hidden - new_gate)

  • score 계산
linear_one = nn.Linear(20, 20, bias=True) # w1
linear_two = nn.Linear(20, 20, bias=True) # w2
linear_three = nn.Linear(20, 1, bias=False) # q
linear_transform = nn.Linear(20 * 2, 20, bias=True) # w3

  • a_i

Vn = hidden[torch.arange(mask.shape[0]).long(), torch.sum(mask, 1) - 1]  # batch_size x latent_size
w1_Vn = linear_one(ht).view(ht.shape[0], 1, ht.shape[1])  # batch_size x 1 x latent_size
w2_Vi = linear_two(hidden)  # batch_size x seq_length x latent_size
alpha = linear_three(torch.sigmoid(w1_Vn + w2_Vi))

  • S_g

s_g = torch.sum(alpha * hidden * mask.view(mask.shape[0], -1, 1).float(), 1)

  • S_h

s_h = linear_transform(torch.cat([s_g, Vn], 1))

s = [v_s,1 v_s,2 v_s,3 ... v_s,n] 세션 s의 경우 local embedding을 마지막에 들어갔던 item v_s,n의 v_n으로 정의한다.

즉(s1 = vn)


  • z_i

v_i = embedding.weight[1:]  
scores = torch.matmul(s_h, v_i.transpose(1, 0))

  • 결과

scores.topk(20)[1]
728x90

댓글