12345678910111213 |
- #!/usr/bin/python
- # encoding: utf-8
- import torch.nn as nn
- import torch.nn.parallel
- def data_parallel(model, input, ngpu):
- if isinstance(input.data, torch.cuda.FloatTensor) and ngpu > 1:
- output = nn.parallel.data_parallel(model, input, range(ngpu))
- else:
- output = model(input)
- return output
|