728x90
- tf.gather(params, indices, validate_indices=None, name=None, axis=None, batch_dims=0)
v1 = tf.constant([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])
v2 = tf.constant([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
with tf.Session() as sess:
print(sess.run(tf.gather(v1, [2, 5, 2, 5], axis=0)))
print(sess.run(tf.gather(v2, [0, 1], axis=0)))
print(sess.run(tf.gather(v2, [0, 1], axis=1)))
[5 0 5 0]
[[ 1 2 3 4 5 6]
[ 7 8 9 10 11 12]]
[[1 2]
[7 8]]
axis 축을 기준으로 원하는 index의 값을 뽑아온다.
axis=0이면 행을 기준으로 0번째, 1번째 행을 불러오기 때문에 2번째 결과와 같다.
axis=1이면 열을 기준으로 0번째, 1번째 열을 불러오기 때문에 3번째 결과와 같다.
- torch.gather(tensor, dim, indices)
v1 = torch.tensor([1, 3, 5, 7, 9, 0, 2, 4, 6, 8])
v2 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
print(torch.gather(v1, 0, torch.tensor([2, 5, 2, 5])))
print(torch.gather(v2, 1, torch.tensor([[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]])))
print(torch.gather(v2, 0, torch.tensor([[0, 0],
[1, 1]])))
print(torch.gather(v2, 1, torch.tensor([[0, 0],
[1, 1]])))
tensor([5, 0, 5, 0])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[1, 2],
[7, 8]])
tensor([[1, 1],
[8, 8]])
torch.gather은 조금 더 헷갈려서 잘 이해해야한다.
dim=0은 열을 기준으로 dim=1은 행을 기준으로 이해하면 편하다.
그리고 indices는 입력 tensor와 shape가 동일해야한다.
3번째를 보면 dim=0이기 때문에 열이 기준이다. 0행 [0, 0]은 0번째 열에서 0번째 index를 불러오고, 1번째 열에서 0번째 index를 불러오라는 뜻이다. 마찬가지로 1행 [1, 1]은 0번째 열에서 1번째 index를 불러오고 1번째 열에서 1번째 index를 불러오라는 뜻이다.
4번째를 보면 dim=1이기 때문에 행이 기준이다. 0행 [0, 0]은 0번째 행에서 0번째 index를 불러오고 0번째 행에서 0번째 index를 불러오라는 뜻이다. 1행 [1, 1]은 1번째 행에서 1번째 index를 불러오고 1번째 행에서 1번째 index를 불러오라는 뜻이다.
추가적인 코드 몇개 더 첨부할테니 이해해보시길 바라겠습니다!
v3 = tf.constant([x for x in range(1, 101)])
v3 = tf.reshape(v3, (10, 10))
with tf.Session() as sess:
print(sess.run(v3))
print(sess.run(tf.gather(v3, [x for x in range(10) if x%2 == 1], axis=1)))
print(sess.run(tf.gather(v3, [x for x in range(10) if x%2 == 1], axis=0)))
[[ 2 4 6 8 10]
[ 12 14 16 18 20]
[ 22 24 26 28 30]
[ 32 34 36 38 40]
[ 42 44 46 48 50]
[ 52 54 56 58 60]
[ 62 64 66 68 70]
[ 72 74 76 78 80]
[ 82 84 86 88 90]
[ 92 94 96 98 100]]
[[ 11 12 13 14 15 16 17 18 19 20]
[ 31 32 33 34 35 36 37 38 39 40]
[ 51 52 53 54 55 56 57 58 59 60]
[ 71 72 73 74 75 76 77 78 79 80]
[ 91 92 93 94 95 96 97 98 99 100]]
v3 = torch.tensor([x for x in range(1, 101)]).reshape(10, 10)
# 각 행의 홀수번째 불러오기
indices = torch.tensor([x for x in range(v3.size(1)) if x%2 == 1])
indices = indices.repeat(v3.size(0)).reshape(10, 5)
print(torch.gather(v3, 1, indices))
# 홀수번째 행 불러오기
indices = torch.tensor([[i]*10 for i in range(10) if i%2 == 1])
print(torch.gather(v3, 0, indices))
tensor([[ 2, 4, 6, 8, 10],
[ 12, 14, 16, 18, 20],
[ 22, 24, 26, 28, 30],
[ 32, 34, 36, 38, 40],
[ 42, 44, 46, 48, 50],
[ 52, 54, 56, 58, 60],
[ 62, 64, 66, 68, 70],
[ 72, 74, 76, 78, 80],
[ 82, 84, 86, 88, 90],
[ 92, 94, 96, 98, 100]])
tensor([[ 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
[ 51, 52, 53, 54, 55, 56, 57, 58, 59, 60],
[ 71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
[ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]])
제가 이해한바로는 tensorflow에서 작성이 조금 더 쉽네요!
728x90
'Tensorflow' 카테고리의 다른 글
Tensorflow vs Pytorch 명령어 비교 -(6) (0) | 2021.06.03 |
---|---|
Tensorflow vs Pytorch -(4) (0) | 2021.05.19 |
Tensorflow vs Pytorch 명령어 비교 - (3) (0) | 2021.05.17 |
Tensorflow vs Pytorch 명령어 비교 - (2) (0) | 2021.05.15 |
Tensorflow vs Pytorch 명령어 비교 (0) | 2021.05.14 |
댓글