util.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/python
  2. # encoding: utf-8
  3. import torch
  4. import torch.nn as nn
  5. unicode = str
  6. class strLabelConverter(object):
  7. def __init__(self, alphabet):
  8. self.alphabet = alphabet + u'-' # for `-1` index
  9. self.dict = {}
  10. for i, char in enumerate(alphabet):
  11. # NOTE: 0 is reserved for 'blank' required by wrap_ctc
  12. self.dict[char] = i + 1
  13. def encode(self, text, depth=0):
  14. """Support batch or single str."""
  15. length = []
  16. result = []
  17. for str in text:
  18. str = unicode(str, "utf8")
  19. length.append(len(str))
  20. for char in str:
  21. # print(char)
  22. index = self.dict[char]
  23. result.append(index)
  24. text = result
  25. return (torch.IntTensor(text), torch.IntTensor(length))
  26. def decode(self, t, length, raw=False):
  27. if length.numel() == 1:
  28. length = length[0]
  29. t = t[:length]
  30. if raw:
  31. return ''.join([self.alphabet[i - 1] for i in t])
  32. else:
  33. char_list = []
  34. for i in range(length):
  35. if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
  36. char_list.append(self.alphabet[t[i] - 1])
  37. return ''.join(char_list)
  38. else:
  39. texts = []
  40. index = 0
  41. for i in range(length.numel()):
  42. l = length[i]
  43. texts.append(self.decode(
  44. t[index:index + l], torch.IntTensor([l]), raw=raw))
  45. index += l
  46. return texts
  47. class averager(object):
  48. def __init__(self):
  49. self.reset()
  50. def add(self, v):
  51. self.n_count += v.data.numel()
  52. # NOTE: not `+= v.sum()`, which will add a node in the compute graph,
  53. # which lead to memory leak
  54. self.sum += v.data.sum()
  55. def reset(self):
  56. self.n_count = 0
  57. self.sum = 0
  58. def val(self):
  59. res = 0
  60. if self.n_count != 0:
  61. res = self.sum / float(self.n_count)
  62. return res
  63. def oneHot(v, v_length, nc):
  64. batchSize = v_length.size(0)
  65. maxLength = v_length.max()
  66. v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
  67. acc = 0
  68. for i in range(batchSize):
  69. length = v_length[i]
  70. label = v[acc:acc + length].view(-1, 1).long()
  71. v_onehot[i, :length].scatter_(1, label, 1.0)
  72. acc += length
  73. return v_onehot
  74. def loadData(v, data):
  75. v.data.resize_(data.size()).copy_(data)
  76. def prettyPrint(v):
  77. print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
  78. print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], v.mean().data[0]))
  79. def assureRatio(img):
  80. """Ensure imgH <= imgW."""
  81. b, c, h, w = img.size()
  82. if h > w:
  83. main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
  84. img = main(img)
  85. return img