crnn.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # coding:utf-8
  2. import sys
  3. from PIL import Image
  4. sys.path.insert(1, "./crnn")
  5. import torch
  6. import torch.utils.data
  7. from torch.autograd import Variable
  8. import numpy as np
  9. import util
  10. import dataset
  11. import models.crnn as crnn
  12. import keys_crnn
  13. from math import *
  14. import cv2
  15. GPU = True
  16. def dumpRotateImage_(img, degree, pt1, pt2, pt3, pt4):
  17. height, width = img.shape[:2]
  18. heightNew = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree))))
  19. widthNew = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree))))
  20. matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
  21. matRotation[0, 2] += (widthNew - width) / 2
  22. matRotation[1, 2] += (heightNew - height) / 2
  23. imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255))
  24. pt1 = list(pt1)
  25. pt3 = list(pt3)
  26. [[pt1[0]], [pt1[1]]] = np.dot(matRotation, np.array([[pt1[0]], [pt1[1]], [1]]))
  27. [[pt3[0]], [pt3[1]]] = np.dot(matRotation, np.array([[pt3[0]], [pt3[1]], [1]]))
  28. imgOut = imgRotation[int(pt1[1]):int(pt3[1]), int(pt1[0]):int(pt3[0])]
  29. height, width = imgOut.shape[:2]
  30. return imgOut
  31. def crnnSource():
  32. alphabet = keys_crnn.alphabet
  33. converter = util.strLabelConverter(alphabet)
  34. if torch.cuda.is_available() and GPU:
  35. model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
  36. else:
  37. model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cpu()
  38. # path = '../crnn/samples/netCRNN_61_134500.pth'
  39. path = './crnn/samples/model_acc97.pth'
  40. model.eval()
  41. # w = torch.load(path)
  42. # ww = {}
  43. # for i in w:
  44. # ww[i.replace('module.', '')] = w[i]
  45. #
  46. # model.load_state_dict(ww)
  47. model.load_state_dict(torch.load(path))
  48. return model, converter
  49. ##加载模型
  50. model, converter = crnnSource()
  51. def crnnOcr(image):
  52. """
  53. crnn模型,ocr识别
  54. @@model,
  55. @@converter,
  56. @@im
  57. @@text_recs:text box
  58. """
  59. if isinstance(image,str):
  60. image = Image.open(image).convert("L")
  61. else:
  62. image = Image.fromarray(image).convert("L")
  63. scale = image.size[1] * 1.0 / 32
  64. w = image.size[0] / scale
  65. w = int(w)
  66. # print "im size:{},{}".format(image.size,w)
  67. transformer = dataset.resizeNormalize((w, 32))
  68. if torch.cuda.is_available() and GPU:
  69. image = transformer(image).cuda()
  70. else:
  71. image = transformer(image).cpu()
  72. image = image.view(1, *image.size())
  73. image = Variable(image)
  74. model.eval()
  75. preds = model(image)
  76. _, preds = preds.max(2)
  77. preds = preds.transpose(1, 0).contiguous().view(-1)
  78. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  79. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  80. if len(sim_pred) > 0:
  81. if sim_pred[0] == u'-':
  82. sim_pred = sim_pred[1:]
  83. return quchong(sim_pred)
  84. def quchong(s):
  85. ls = list(s)
  86. for i in range(len(ls)-1):
  87. if ls[i]==ls[i+1]:
  88. ls[i+1]=''
  89. return ''.join(ls)
  90. if __name__ == '__main__':
  91. #
  92. print(crnnOcr(Image.open(r'D:\试卷切割\result\text_img\132-145-46-182.png').convert("L")))
  93. # print(quchong('abcdefghiijjklmn'))