Vengineerの妄想

人生を妄想しています。

PyTorch を Cloud TPU & Colab で動かす

@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそすべては、SystemC v0.9公開から始まった 

PyTorchがTPUで動くようになっています。この記事では、Colabにて、PyTorch + Cloud TPUを使うというもの。

medium.com

ドキュメント、PYTORCH ON XLA DEVICES は、こちら。

pytorch.org

ポイントは、

import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())

のように、torch_xla.core.xla_model を import するというもの。

xla_model は、ここ で定義されている。

def xla_device(n=None, devkind=None):
  """Returns a given instance of an XLA device.
  Args:
    n (int, optional): The specific instance (ordinal) to be returned. If
      specified, the specific XLA device instance will be returned. Otherwise
      the first device of `devkind` will be returned.
      devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`
      (the 'GPU' XLA device is currently not implemented).

  Returns:
    A `torch.device` with the requested instance.
  """
  if n is None:
    devices = get_xla_supported_devices(
         devkind=[devkind] if devkind is not None else None)
    assert devices, 'No devices of {} kind'.format(devkind or 'ANY')
    # This is a utility API mainly called from tests or simple code which wants
    # to just have a single device to run on. Set the default device so that
    # the tensor barrier can work correctly and avoid growing graphs surprises.
    device = devices[0]
  else:
    device = 'xla:{}'.format(n)
  torch_xla._XLAC._xla_set_default_device(device)
  return torch.device(device)

 引数 n が None なら、get_xla_supported_devices が呼ばれる。

def get_xla_supported_devices(devkind=None, max_devices=None):
  """Returns a list of supported devices of a given kind.

  Args:
    devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`
      (the 'GPU' XLA device is currently not implemented).
    max_devices (int, optional): The maximum number of devices to be returned of
    that kind.

  Returns:
    The list of device strings.
  """
  xla_devices = torch_xla._XLAC._xla_get_devices()
  devkind = devkind or ['TPU', 'GPU', 'CPU']
  for kind in devkind:
    kind_devices =
    for i, device in enumerate(xla_devices):
      if re.match(kind + r':\d+$', device):
        kind_devices.append('xla:{}'.format(i))
      if kind_devices:
        return kind_devices[:max_devices] if max_devices else kind_devices

 torch_xla._XLAC._xla_get_devices() は、ここ

m.def("_xla_get_devices",
         () { return xla::ComputationClient::Get()->GetLocalDevices(); });

 結局、XLAの GetLocalDevices を呼んでいる。。。

 

PyTorchのモデル(Net)を作って、

# Places network on the default TPU core
net = Net().to(dev)

 のように、モデルを TPU に送っちゃうんですね。ここで、モデルをTPU用にコンパイルするんですかね。