こういうちょっとしたものも、エコシステムでは非常に大切なのよね。。。
特に、zero-copy conversionが大切。。。
フレームワーク間で頻繁にコピーが発生していちゃね。。。。
最後は、TVM の Packed Function に渡して、実行できるんだよね。
引用
n = tvm.convert(56)
X = tvm.placeholder((n,n), name='X')
Y = tvm.placeholder((n,n), name='Y')
k = tvm.reduce_axis((0, n), name='k')
Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k))
s = tvm.create_schedule(Z.op)
fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')
のように、TVM で Packed Function (fmm) を生成し、
引用
from tvm.contrib.dlpack import to_pytorch_func
# fmm is the previously built TVM function (Python function)
# fmm is the wrapped TVM function (Python function)
fmm_pytorch = to_pytorch_func(fmm)
z2 = torch.empty(56,56)
fmm_pytorch(x, y, z2)
np.testing.assert_allclose(z.numpy(), z2.numpy())
のように、to_pytorch_func にて、TVM の Packed Function を Pytorch の関数に変換するだけだって。。。MxNet でもできるって、
引用
import mxnet
from tvm.contrib.mxnet import to_mxnet_func
ctx = mxnet.cpu(0)
x = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
y = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
z = mxnet.nd.empty(shape=(56,56), ctx=ctx)
f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f')
f_mxnet = to_mxnet_func(f)
f_mxnet(x, y, z)
np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))
to_mxnet_func を使って、変換していますね。
