torch.vmap

사용법

def my_dot(a, b):
    return torch.dot(a, b)


a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = torch.tensor([1,2,3])
vfunc = torch.vmap(my_dot, in_dims=(0, None), out_dims=0)
print(vfunc(a, b))
# tensor([14, 32, 50])

vfunc = torch.vmap(my_dot, in_dims=(1, None), out_dims=0)
print(vfunc(a, b))
# tensor([30, 36, 42])

in_dims 는 np와 dataframe의 axis와 같다고 볼 수 있다.
다음과 같이 행렬 a와,

1 2 3
4 5 6
7 8 9

행렬 b가 주어졌을 때

1 2 3

in_dims = (1, None)

in_dims = (1, None)