123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- #!/usr/bin/python
- # encoding: utf-8
- import random
- import sys
- import lmdb
- import numpy as np
- import six
- import torch
- import torchvision.transforms as transforms
- from PIL import Image
- from torch.utils.data import Dataset
- from torch.utils.data import sampler
- class lmdbDataset(Dataset):
- def __init__(self, root=None, transform=None, target_transform=None):
- self.env = lmdb.open(
- root,
- max_readers=1,
- readonly=True,
- lock=False,
- readahead=False,
- meminit=False)
- if not self.env:
- print('cannot creat lmdb from %s' % (root))
- sys.exit(0)
- with self.env.begin(write=False) as txn:
- nSamples = int(txn.get('num-samples'))
- self.nSamples = nSamples
- self.transform = transform
- self.target_transform = target_transform
- def __len__(self):
- return self.nSamples
- def __getitem__(self, index):
- assert index <= len(self), 'index range error'
- index += 1
- with self.env.begin(write=False) as txn:
- img_key = 'image-%09d' % index
- imgbuf = txn.get(img_key)
- buf = six.BytesIO()
- buf.write(imgbuf)
- buf.seek(0)
- try:
- img = Image.open(buf).convert('L')
- except IOError:
- print('Corrupted image for %d' % index)
- return self[index + 1]
- if self.transform is not None:
- img = self.transform(img)
- label_key = 'label-%09d' % index
- label = str(txn.get(label_key))
- if self.target_transform is not None:
- label = self.target_transform(label)
- return (img, label)
- class resizeNormalize(object):
- def __init__(self, size, interpolation=Image.BILINEAR):
- self.size = size
- self.interpolation = interpolation
- self.toTensor = transforms.ToTensor()
- def __call__(self, img):
- img = img.resize(self.size, self.interpolation)
- img = self.toTensor(img)
- img.sub_(0.5).div_(0.5)
- return img
- class randomSequentialSampler(sampler.Sampler):
- def __init__(self, data_source, batch_size):
- self.num_samples = len(data_source)
- self.batch_size = batch_size
- def __iter__(self):
- n_batch = len(self) // self.batch_size
- tail = len(self) % self.batch_size
- index = torch.LongTensor(len(self)).fill_(0)
- for i in range(n_batch):
- random_start = random.randint(0, len(self) - self.batch_size)
- batch_index = random_start + torch.range(0, self.batch_size - 1)
- index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
- # deal with tail
- if tail:
- random_start = random.randint(0, len(self) - self.batch_size)
- tail_index = random_start + torch.range(0, tail - 1)
- index[(i + 1) * self.batch_size:] = tail_index
- return iter(index)
- def __len__(self):
- return self.num_samples
- class alignCollate(object):
- def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1):
- self.imgH = imgH
- self.imgW = imgW
- self.keep_ratio = keep_ratio
- self.min_ratio = min_ratio
- def __call__(self, batch):
- images, labels = zip(*batch)
- imgH = self.imgH
- imgW = self.imgW
- if self.keep_ratio:
- ratios = []
- for image in images:
- w, h = image.size
- ratios.append(w / float(h))
- ratios.sort()
- max_ratio = ratios[-1]
- imgW = int(np.floor(max_ratio * imgH))
- imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
- transform = resizeNormalize((imgW, imgH))
- images = [transform(image) for image in images]
- images = torch.cat([t.unsqueeze(0) for t in images], 0)
- return images, labels
|