dataset.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #!/usr/bin/python
  2. # encoding: utf-8
  3. import random
  4. import sys
  5. import lmdb
  6. import numpy as np
  7. import six
  8. import torch
  9. import torchvision.transforms as transforms
  10. from PIL import Image
  11. from torch.utils.data import Dataset
  12. from torch.utils.data import sampler
  13. class lmdbDataset(Dataset):
  14. def __init__(self, root=None, transform=None, target_transform=None):
  15. self.env = lmdb.open(
  16. root,
  17. max_readers=1,
  18. readonly=True,
  19. lock=False,
  20. readahead=False,
  21. meminit=False)
  22. if not self.env:
  23. print('cannot creat lmdb from %s' % (root))
  24. sys.exit(0)
  25. with self.env.begin(write=False) as txn:
  26. nSamples = int(txn.get('num-samples'))
  27. self.nSamples = nSamples
  28. self.transform = transform
  29. self.target_transform = target_transform
  30. def __len__(self):
  31. return self.nSamples
  32. def __getitem__(self, index):
  33. assert index <= len(self), 'index range error'
  34. index += 1
  35. with self.env.begin(write=False) as txn:
  36. img_key = 'image-%09d' % index
  37. imgbuf = txn.get(img_key)
  38. buf = six.BytesIO()
  39. buf.write(imgbuf)
  40. buf.seek(0)
  41. try:
  42. img = Image.open(buf).convert('L')
  43. except IOError:
  44. print('Corrupted image for %d' % index)
  45. return self[index + 1]
  46. if self.transform is not None:
  47. img = self.transform(img)
  48. label_key = 'label-%09d' % index
  49. label = str(txn.get(label_key))
  50. if self.target_transform is not None:
  51. label = self.target_transform(label)
  52. return (img, label)
  53. class resizeNormalize(object):
  54. def __init__(self, size, interpolation=Image.BILINEAR):
  55. self.size = size
  56. self.interpolation = interpolation
  57. self.toTensor = transforms.ToTensor()
  58. def __call__(self, img):
  59. img = img.resize(self.size, self.interpolation)
  60. img = self.toTensor(img)
  61. img.sub_(0.5).div_(0.5)
  62. return img
  63. class randomSequentialSampler(sampler.Sampler):
  64. def __init__(self, data_source, batch_size):
  65. self.num_samples = len(data_source)
  66. self.batch_size = batch_size
  67. def __iter__(self):
  68. n_batch = len(self) // self.batch_size
  69. tail = len(self) % self.batch_size
  70. index = torch.LongTensor(len(self)).fill_(0)
  71. for i in range(n_batch):
  72. random_start = random.randint(0, len(self) - self.batch_size)
  73. batch_index = random_start + torch.range(0, self.batch_size - 1)
  74. index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
  75. # deal with tail
  76. if tail:
  77. random_start = random.randint(0, len(self) - self.batch_size)
  78. tail_index = random_start + torch.range(0, tail - 1)
  79. index[(i + 1) * self.batch_size:] = tail_index
  80. return iter(index)
  81. def __len__(self):
  82. return self.num_samples
  83. class alignCollate(object):
  84. def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1):
  85. self.imgH = imgH
  86. self.imgW = imgW
  87. self.keep_ratio = keep_ratio
  88. self.min_ratio = min_ratio
  89. def __call__(self, batch):
  90. images, labels = zip(*batch)
  91. imgH = self.imgH
  92. imgW = self.imgW
  93. if self.keep_ratio:
  94. ratios = []
  95. for image in images:
  96. w, h = image.size
  97. ratios.append(w / float(h))
  98. ratios.sort()
  99. max_ratio = ratios[-1]
  100. imgW = int(np.floor(max_ratio * imgH))
  101. imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
  102. transform = resizeNormalize((imgW, imgH))
  103. images = [transform(image) for image in images]
  104. images = torch.cat([t.unsqueeze(0) for t in images], 0)
  105. return images, labels