@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそ、すべては、SystemC v0.9公開から始まった
1年ぐらい前に、「PyTorch + XLA」のソースコード解析を3回に分けて行ったけど、どうもそれからコードがかなり変わったみたい。
下記のコードは、API_GUIDE.md にあるコード。
凄くきれいになっている。
import torch_xla.core.xla_model as xm device = xm.xla_device() model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) for data, target in train_loader: optimizer.zero_grad() data = data.to(device) target = target.to(device) output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer, barrier=True)
torch_xla.core.xla_model がポイント。。。
xla_device() で device を獲得するだけで、あとは基本的に同じなのか。。。
PyTorch公式の例題のコード。。。
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
ほぼ同じ。
違いは、optimizer_step が xm.optimizer_step になっているだけ。。
そして、マルチプロセッシングもできる。
import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp def _mp_fn(index): device = xm.xla_device() para_loader = pl.ParallelLoader(train_loader, [device]) model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) for data, target in para_loader.per_device_loader(device): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if __name__ == '__main__': xmp.spawn(_mp_fn, args=())
PyTorchの例題では、こちら。
ngpus_per_node = torch.cuda.device_count()
if args.multiprocessing_distributed:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
args.world_size = ngpus_per_node * args.world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
# Simply call main_worker function
main_worker(args.gpu, ngpus_per_node, args)
mp.spawn が xmp.spawn になる。。引数もだいたい同じ。