test.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # coding:utf-8
  2. import dataset
  3. import keys
  4. import models.crnn as crnn
  5. import torch.utils.data
  6. import util
  7. from PIL import Image
  8. from torch.autograd import Variable
  9. alphabet = keys.alphabet
  10. print(len(alphabet))
  11. raw_input('\ninput:')
  12. converter = util.strLabelConverter(alphabet)
  13. model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda()
  14. path = './samples/netCRNN63.pth'
  15. model.load_state_dict(torch.load(path))
  16. print(model)
  17. while 1:
  18. im_name = raw_input("\nplease input file name:")
  19. im_path = "./img/" + im_name
  20. image = Image.open(im_path).convert('L')
  21. scale = image.size[1] * 1.0 / 32
  22. w = image.size[0] / scale
  23. w = int(w)
  24. print(w)
  25. transformer = dataset.resizeNormalize((w, 32))
  26. image = transformer(image).cuda()
  27. image = image.view(1, *image.size())
  28. image = Variable(image)
  29. model.eval()
  30. preds = model(image)
  31. _, preds = preds.max(2)
  32. preds = preds.squeeze(2)
  33. preds = preds.transpose(1, 0).contiguous().view(-1)
  34. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  35. raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
  36. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  37. print('%-20s => %-20s' % (raw_pred, sim_pred))