Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

正式リリースした PyTorch + XLA を見てみたら、凄くきれいになっていた

@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそすべては、SystemC v0.9公開から始まった 

1年ぐらい前に、「PyTorch + XLA」のソースコード解析を3回に分けて行ったけど、どうもそれからコードがかなり変わったみたい。

vengineer.hatenablog.com

下記のコードは、API_GUIDE.md にあるコード。

凄くきれいになっている。

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  xm.optimizer_step(optimizer, barrier=True)

 

torch_xla.core.xla_model がポイント。。。

xla_device() で device を獲得するだけで、あとは基本的に同じなのか。。。

PyTorch公式の例題のコード。。。

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()

ほぼ同じ。

違いは、optimizer_step が xm.optimizer_step になっているだけ。。

 

そして、マルチプロセッシングもできる。

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  para_loader = pl.ParallelLoader(train_loader, [device])

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in para_loader.per_device_loader(device):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())


PyTorchの例題では、こちら

  ngpus_per_node = torch.cuda.device_count()
  if args.multiprocessing_distributed:
    # Since we have ngpus_per_node processes per node, the total world_size
    # needs to be adjusted accordingly
    args.world_size = ngpus_per_node * args.world_size
    # Use torch.multiprocessing.spawn to launch distributed processes: the
    # main_worker process function
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
  else:
    # Simply call main_worker function
    main_worker(args.gpu, ngpus_per_node, args)

 

mp.spawn が xmp.spawn になる。。引数もだいたい同じ。