Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Pytorch で Tensor Comprehensions を使う



Versionが 0.1.1 になり、PyTorch でも使えるようになったようです。


githubに、tensor_comprehensionsが生成されました。

サンプルコード
引用
import tensor_comprehensions as tc
import torch
lang = """
def matmul(float(M,N) A, float(N,K) B) -> (output) {
  output(i, j) +=! A(i, kk) * B(kk, j)
}
"""
# The name should match the name of the "def" in "lang"
matmul = tc.define(lang, name="matmul")
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
out = matmul(mat1, mat2)

Tensor Comprehensions は、関数というかある処理を上記のように def で定義して、
TensorFlowのユーザー定義のカーネルを作るような感じです。