Vengineerの妄想

人生を妄想しています。

MNIST pytorch TPU demo


MNIST pytorch TPU demoを見つけた。

  devices = xm.get_xla_supported_devices()
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  model_parallel = dp.DataParallel(MNIST, device_ids=devices)

の部分でデバイス選んで、DataParalell でモデルを実行するのね。

  accuracy = 0.0
  for epoch in range(1, num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(devices)
    if metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

の部分で、学習+テストの繰り返しを行っている。model_parallel に実行する関数とDataLoaderを渡している。。
最初の model_paralell で学習して、2番目の model_paralell でテスト。。。

DataParallelでは、

  def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
    if device_ids is None:
      device_ids = xm.get_xla_supported_devices()
    self._batchdim = batchdim
    self._drop_last = drop_last
    self._device_ids = list(device_ids)
    self._replication = (
        xm.Replication(self._device_ids) if self._device_ids else None)
    self._models = []
    for device in device_ids:
      module = network().to(device=torch.device(device))
      self._models.append(module)
    if not self._models:
      # No XLA device, push a vanilla network in.
      self._models.append(network())

XLA対応デバイスのときは、network().to(device=torch.device(device)) をそうでないときは、network()

network()は、MNISTですね。

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

ResNetとかでもできるのかな。。。