vgg16_bn.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from collections import namedtuple
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.init as init
  5. from torchvision import models
  6. from torchvision.models.vgg import model_urls
  7. def init_weights(modules):
  8. for m in modules:
  9. if isinstance(m, nn.Conv2d):
  10. init.xavier_uniform_(m.weight.data)
  11. if m.bias is not None:
  12. m.bias.data.zero_()
  13. elif isinstance(m, nn.BatchNorm2d):
  14. m.weight.data.fill_(1)
  15. m.bias.data.zero_()
  16. elif isinstance(m, nn.Linear):
  17. m.weight.data.normal_(0, 0.01)
  18. m.bias.data.zero_()
  19. class vgg16_bn(torch.nn.Module):
  20. def __init__(self, pretrained=True, freeze=True):
  21. super(vgg16_bn, self).__init__()
  22. model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
  23. vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
  24. self.slice1 = torch.nn.Sequential()
  25. self.slice2 = torch.nn.Sequential()
  26. self.slice3 = torch.nn.Sequential()
  27. self.slice4 = torch.nn.Sequential()
  28. self.slice5 = torch.nn.Sequential()
  29. for x in range(12): # conv2_2
  30. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  31. for x in range(12, 19): # conv3_3
  32. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  33. for x in range(19, 29): # conv4_3
  34. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  35. for x in range(29, 39): # conv5_3
  36. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  37. # fc6, fc7 without atrous conv
  38. self.slice5 = torch.nn.Sequential(
  39. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  40. nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
  41. nn.Conv2d(1024, 1024, kernel_size=1)
  42. )
  43. if not pretrained:
  44. init_weights(self.slice1.modules())
  45. init_weights(self.slice2.modules())
  46. init_weights(self.slice3.modules())
  47. init_weights(self.slice4.modules())
  48. init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
  49. if freeze:
  50. for param in self.slice1.parameters(): # only first conv
  51. param.requires_grad= False
  52. def forward(self, X):
  53. h = self.slice1(X)
  54. h_relu2_2 = h
  55. h = self.slice2(h)
  56. h_relu3_2 = h
  57. h = self.slice3(h)
  58. h_relu4_3 = h
  59. h = self.slice4(h)
  60. h_relu5_3 = h
  61. h = self.slice5(h)
  62. h_fc7 = h
  63. vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
  64. out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
  65. return out