본문 바로가기
Tensorflow

Tensorflow vs Pytorch 명령어 비교 -(5)

by 블쭌 2021. 6. 2.
728x90
  • tf.gather(params, indices, validate_indices=None, name=None, axis=None, batch_dims=0)
v1 = tf.constant([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])
v2 = tf.constant([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])

with tf.Session() as sess:
    print(sess.run(tf.gather(v1, [2, 5, 2, 5], axis=0)))
    print(sess.run(tf.gather(v2, [0, 1], axis=0)))
    print(sess.run(tf.gather(v2, [0, 1], axis=1)))
    
[5 0 5 0]

[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]]

[[1 2]
 [7 8]]

axis 축을 기준으로 원하는 index의 값을 뽑아온다.

axis=0이면 행을 기준으로 0번째, 1번째 행을 불러오기 때문에 2번째 결과와 같다.

axis=1이면 열을 기준으로 0번째, 1번째 열을 불러오기 때문에 3번째 결과와 같다.


  • torch.gather(tensor, dim, indices)
v1 = torch.tensor([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])
v2 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])

print(torch.gather(v1, 0, torch.tensor([2, 5, 2, 5])))
print(torch.gather(v2, 1, torch.tensor([[0, 1, 2, 3, 4, 5],
                                        [0, 1, 2, 3, 4, 5]])))
print(torch.gather(v2, 0, torch.tensor([[0, 0],
                                        [1, 1]])))
print(torch.gather(v2, 1, torch.tensor([[0, 0],
                                        [1, 1]])))
                                        
                                        
tensor([5, 0, 5, 0])

tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12]])
        
tensor([[1, 2],
        [7, 8]])
        
tensor([[1, 1],
        [8, 8]])                                        

torch.gather은 조금 더 헷갈려서 잘 이해해야한다.

dim=0은 열을 기준으로 dim=1은 행을 기준으로 이해하면 편하다.

그리고 indices는 입력 tensor와 shape가 동일해야한다.

3번째를 보면 dim=0이기 때문에 열이 기준이다. 0행 [0, 0]은 0번째 열에서 0번째 index를 불러오고, 1번째 열에서 0번째 index를 불러오라는 뜻이다. 마찬가지로 1행 [1, 1]은 0번째 열에서 1번째 index를 불러오고 1번째 열에서 1번째 index를 불러오라는 뜻이다.

 

4번째를 보면 dim=1이기 때문에 행이 기준이다. 0행 [0, 0]은 0번째 행에서 0번째 index를 불러오고 0번째 행에서 0번째 index를 불러오라는 뜻이다. 1행 [1, 1]은 1번째 행에서 1번째 index를 불러오고 1번째 행에서 1번째 index를 불러오라는 뜻이다.

 

 

추가적인 코드 몇개 더 첨부할테니 이해해보시길 바라겠습니다!

v3 = tf.constant([x for x in range(1, 101)])
v3 = tf.reshape(v3, (10, 10))

with tf.Session() as sess:
    print(sess.run(v3))
    print(sess.run(tf.gather(v3, [x for x in range(10) if x%2 == 1], axis=1)))
    print(sess.run(tf.gather(v3, [x for x in range(10) if x%2 == 1], axis=0)))
    
[[  2   4   6   8  10]
 [ 12  14  16  18  20]
 [ 22  24  26  28  30]
 [ 32  34  36  38  40]
 [ 42  44  46  48  50]
 [ 52  54  56  58  60]
 [ 62  64  66  68  70]
 [ 72  74  76  78  80]
 [ 82  84  86  88  90]
 [ 92  94  96  98 100]]
 
[[ 11  12  13  14  15  16  17  18  19  20]
 [ 31  32  33  34  35  36  37  38  39  40]
 [ 51  52  53  54  55  56  57  58  59  60]
 [ 71  72  73  74  75  76  77  78  79  80]
 [ 91  92  93  94  95  96  97  98  99 100]]
v3 = torch.tensor([x for x in range(1, 101)]).reshape(10, 10)

# 각 행의 홀수번째 불러오기
indices = torch.tensor([x for x in range(v3.size(1)) if x%2 == 1])
indices = indices.repeat(v3.size(0)).reshape(10, 5)
print(torch.gather(v3, 1, indices))

# 홀수번째 행 불러오기
indices = torch.tensor([[i]*10 for i in range(10) if i%2 == 1])
print(torch.gather(v3, 0, indices))


tensor([[  2,   4,   6,   8,  10],
        [ 12,  14,  16,  18,  20],
        [ 22,  24,  26,  28,  30],
        [ 32,  34,  36,  38,  40],
        [ 42,  44,  46,  48,  50],
        [ 52,  54,  56,  58,  60],
        [ 62,  64,  66,  68,  70],
        [ 72,  74,  76,  78,  80],
        [ 82,  84,  86,  88,  90],
        [ 92,  94,  96,  98, 100]])
        
tensor([[ 11,  12,  13,  14,  15,  16,  17,  18,  19,  20],
        [ 31,  32,  33,  34,  35,  36,  37,  38,  39,  40],
        [ 51,  52,  53,  54,  55,  56,  57,  58,  59,  60],
        [ 71,  72,  73,  74,  75,  76,  77,  78,  79,  80],
        [ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100]])

 제가 이해한바로는 tensorflow에서 작성이 조금 더 쉽네요!

728x90

댓글