MNIST pytorch TPU demoを見つけた。
devices = xm.get_xla_supported_devices() # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices)
の部分でデバイス選んで、DataParalell でモデルを実行するのね。
accuracy = 0.0 for epoch in range(1, num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) if metrics_debug: print(torch_xla._XLAC._xla_metrics_report())
の部分で、学習+テストの繰り返しを行っている。model_parallel に実行する関数とDataLoaderを渡している。。
最初の model_paralell で学習して、2番目の model_paralell でテスト。。。
最初の model_paralell で学習して、2番目の model_paralell でテスト。。。
DataParallelでは、
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._batchdim = batchdim self._drop_last = drop_last self._device_ids = list(device_ids) self._replication = ( xm.Replication(self._device_ids) if self._device_ids else None) self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) self._models.append(module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
XLA対応デバイスのときは、network().to(device=torch.device(device)) をそうでないときは、network()
network()は、MNISTですね。
class MNIST(nn.Module): def __init__(self): super(MNIST, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.bn1 = nn.BatchNorm2d(10) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.bn2 = nn.BatchNorm2d(20) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = self.bn1(x) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = self.bn2(x) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)
ResNetとかでもできるのかな。。。