@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそ、すべては、SystemC v0.9公開から始まった
PyTorchがTPUで動くようになっています。この記事では、Colabにて、PyTorch + Cloud TPUを使うというもの。
ドキュメント、PYTORCH ON XLA DEVICES は、こちら。
ポイントは、
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
のように、torch_xla.core.xla_model を import するというもの。
xla_model は、ここ で定義されている。
def xla_device(n=None, devkind=None):
"""Returns a given instance of an XLA device.
Args:
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`
(the 'GPU' XLA device is currently not implemented).Returns:
A `torch.device` with the requested instance.
"""
if n is None:
devices = get_xla_supported_devices(
devkind=[devkind] if devkind is not None else None)
assert devices, 'No devices of {} kind'.format(devkind or 'ANY')
# This is a utility API mainly called from tests or simple code which wants
# to just have a single device to run on. Set the default device so that
# the tensor barrier can work correctly and avoid growing graphs surprises.
device = devices[0]
else:
device = 'xla:{}'.format(n)
torch_xla._XLAC._xla_set_default_device(device)
return torch.device(device)
引数 n が None なら、get_xla_supported_devices が呼ばれる。
def get_xla_supported_devices(devkind=None, max_devices=None):
"""Returns a list of supported devices of a given kind.Args:
devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`
(the 'GPU' XLA device is currently not implemented).
max_devices (int, optional): The maximum number of devices to be returned of
that kind.Returns:
The list of device strings.
"""
xla_devices = torch_xla._XLAC._xla_get_devices()
devkind = devkind or ['TPU', 'GPU', 'CPU']
for kind in devkind:
kind_devices =
for i, device in enumerate(xla_devices):
if re.match(kind + r':\d+$', device):
kind_devices.append('xla:{}'.format(i))
if kind_devices:
return kind_devices[:max_devices] if max_devices else kind_devices
torch_xla._XLAC._xla_get_devices() は、ここ
m.def("_xla_get_devices",
() { return xla::ComputationClient::Get()->GetLocalDevices(); });
結局、XLAの GetLocalDevices を呼んでいる。。。
PyTorchのモデル(Net)を作って、
# Places network on the default TPU core
net = Net().to(dev)
のように、モデルを TPU に送っちゃうんですね。ここで、モデルをTPU用にコンパイルするんですかね。