utils.py 321 B

12345678910111213
  1. #!/usr/bin/python
  2. # encoding: utf-8
  3. import torch.nn as nn
  4. import torch.nn.parallel
  5. def data_parallel(model, input, ngpu):
  6. if isinstance(input.data, torch.cuda.FloatTensor) and ngpu > 1:
  7. output = nn.parallel.data_parallel(model, input, range(ngpu))
  8. else:
  9. output = model(input)
  10. return output