MNIST pytorch TPU demoを見つけた。
devices = xm.get_xla_supported_devices() # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices)
の部分でデバイス選んで、DataParalell でモデルを実行するのね。
accuracy = 0.0
for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
accuracies = model_parallel(test_loop_fn, test_loader)
accuracy = sum(accuracies) / len(devices)
if metrics_debug:
print(torch_xla._XLAC._xla_metrics_report())
の部分で、学習+テストの繰り返しを行っている。model_parallel に実行する関数とDataLoaderを渡している。。
最初の model_paralell で学習して、2番目の model_paralell でテスト。。。
最初の model_paralell で学習して、2番目の model_paralell でテスト。。。
DataParallelでは、
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._batchdim = batchdim self._drop_last = drop_last self._device_ids = list(device_ids) self._replication = ( xm.Replication(self._device_ids) if self._device_ids else None) self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) self._models.append(module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
XLA対応デバイスのときは、network().to(device=torch.device(device)) をそうでないときは、network()
network()は、MNISTですね。
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
ResNetとかでもできるのかな。。。