123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- #!/usr/bin/python
- # encoding: utf-8
- import torch
- import torch.nn as nn
- unicode = str
- class strLabelConverter(object):
- def __init__(self, alphabet):
- self.alphabet = alphabet + u'-' # for `-1` index
- self.dict = {}
- for i, char in enumerate(alphabet):
- # NOTE: 0 is reserved for 'blank' required by wrap_ctc
- self.dict[char] = i + 1
- def encode(self, text, depth=0):
- """Support batch or single str."""
- length = []
- result = []
- for str in text:
- str = unicode(str, "utf8")
- length.append(len(str))
- for char in str:
- # print(char)
- index = self.dict[char]
- result.append(index)
- text = result
- return (torch.IntTensor(text), torch.IntTensor(length))
- def decode(self, t, length, raw=False):
- if length.numel() == 1:
- length = length[0]
- t = t[:length]
- if raw:
- return ''.join([self.alphabet[i - 1] for i in t])
- else:
- char_list = []
- for i in range(length):
- if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
- char_list.append(self.alphabet[t[i] - 1])
- return ''.join(char_list)
- else:
- texts = []
- index = 0
- for i in range(length.numel()):
- l = length[i]
- texts.append(self.decode(
- t[index:index + l], torch.IntTensor([l]), raw=raw))
- index += l
- return texts
- class averager(object):
- def __init__(self):
- self.reset()
- def add(self, v):
- self.n_count += v.data.numel()
- # NOTE: not `+= v.sum()`, which will add a node in the compute graph,
- # which lead to memory leak
- self.sum += v.data.sum()
- def reset(self):
- self.n_count = 0
- self.sum = 0
- def val(self):
- res = 0
- if self.n_count != 0:
- res = self.sum / float(self.n_count)
- return res
- def oneHot(v, v_length, nc):
- batchSize = v_length.size(0)
- maxLength = v_length.max()
- v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
- acc = 0
- for i in range(batchSize):
- length = v_length[i]
- label = v[acc:acc + length].view(-1, 1).long()
- v_onehot[i, :length].scatter_(1, label, 1.0)
- acc += length
- return v_onehot
- def loadData(v, data):
- v.data.resize_(data.size()).copy_(data)
- def prettyPrint(v):
- print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
- print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], v.mean().data[0]))
- def assureRatio(img):
- """Ensure imgH <= imgW."""
- b, c, h, w = img.size()
- if h > w:
- main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
- img = main(img)
- return img
|