crnn.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import sys
  2. sys.path.insert(1, "./crnn")
  3. import torch.nn as nn
  4. import models.utils as utils
  5. class BidirectionalLSTM(nn.Module):
  6. def __init__(self, nIn, nHidden, nOut, ngpu):
  7. super(BidirectionalLSTM, self).__init__()
  8. self.ngpu = ngpu
  9. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  10. self.embedding = nn.Linear(nHidden * 2, nOut)
  11. def forward(self, input):
  12. recurrent, _ = utils.data_parallel(self.rnn, input,
  13. self.ngpu) # [T, b, h * 2]
  14. T, b, h = recurrent.size()
  15. t_rec = recurrent.view(T * b, h)
  16. output = utils.data_parallel(self.embedding, t_rec,
  17. self.ngpu) # [T * b, nOut]
  18. output = output.view(T, b, -1)
  19. return output
  20. class CRNN(nn.Module):
  21. def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
  22. super(CRNN, self).__init__()
  23. self.ngpu = ngpu
  24. assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
  25. ks = [3, 3, 3, 3, 3, 3, 2]
  26. ps = [1, 1, 1, 1, 1, 1, 0]
  27. ss = [1, 1, 1, 1, 1, 1, 1]
  28. nm = [64, 128, 256, 256, 512, 512, 512]
  29. cnn = nn.Sequential()
  30. def convRelu(i, batchNormalization=False):
  31. nIn = nc if i == 0 else nm[i - 1]
  32. nOut = nm[i]
  33. cnn.add_module('conv{0}'.format(i),
  34. nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  35. if batchNormalization:
  36. cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  37. if leakyRelu:
  38. cnn.add_module('relu{0}'.format(i),
  39. nn.LeakyReLU(0.2, inplace=True))
  40. else:
  41. cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  42. convRelu(0)
  43. cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
  44. convRelu(1)
  45. cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
  46. convRelu(2, True)
  47. convRelu(3)
  48. cnn.add_module('pooling{0}'.format(2),
  49. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
  50. convRelu(4, True)
  51. convRelu(5)
  52. cnn.add_module('pooling{0}'.format(3),
  53. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
  54. convRelu(6, True) # 512x1x16
  55. self.cnn = cnn
  56. self.rnn = nn.Sequential(
  57. BidirectionalLSTM(512, nh, nh, ngpu),
  58. BidirectionalLSTM(nh, nh, nclass, ngpu))
  59. def forward(self, input):
  60. # conv features
  61. conv = utils.data_parallel(self.cnn, input, self.ngpu)
  62. b, c, h, w = conv.size()
  63. assert h == 1, "the height of conv must be 1"
  64. conv = conv.squeeze(2)
  65. conv = conv.permute(2, 0, 1) # [w, b, c]
  66. # rnn features
  67. output = utils.data_parallel(self.rnn, conv, self.ngpu)
  68. return output