第12章 PyTorch像分割代码框架1
墨初 知识笔记 563阅读
A PyTorch Dataset class for the VOC Segmentation dataset. def __init__(self, root, image_settrain, transformNone): Initialize the dataset. Args: root (str): Path to the dataset root directory. image_set (str): The image set to use (train or val). transform (callable, optional): A function/transform to apply to the images and masks. self.root Path(root) self.transform transform self.image_set image_set base_dir VOCdevkit/VOC2012 voc_root self.root / base_dir image_dir voc_root / JPEGImages if not voc_root.is_dir(): raise RuntimeError(Dataset not found.) mask_dir voc_root / SegmentationClass splits_dir voc_root / ImageSets/Segmentation split_f splits_dir / f{image_set.rstrip()}.txt with open(split_f, r) as f: file_names [x.strip() for x in f.readlines()] self.images [image_dir / f{x}.jpg for x in file_names] self.masks [mask_dir / f{x}.png for x in file_names] assert (len(self.images) len(self.masks)) def __getitem__(self, index): Get an item from the dataset. Args: index (int): Index of the item to get. Returns: tuple: (image, target) where target is the image segmentation. img Image.open(self.images[index]).convert(RGB) target Image.open(self.masks[index]) if self.transform is not None: img, target self.transform(img, target) return img, target def __len__(self): return len(self.images)
如代码11-2所示我们通过pathlib库来定义数据路径通过pillow的Image.open函数来读取图像并进行同步转换。除此之外VOC数据集掩码还需要单独进行颜色编码的接码所以实际操作时要单独定义voc_map函数以及在VOCSegmentation类中补充一个掩码图像的解码方法decode_target。完整代码可参考本书配套代码对应章节。

需要特别说明的是torchvision库中提供的transform模块提供了各种图像变换方法也就是我们通常所说的在线数据增强Online Data Augmentation。在线数据增强是指在训练过程中每次读取一个样本时都会进行数据增强操作。也就是说数据增强是在每个小批量batch的数据上实时进行的。在线数据增强可以通过数据转换的方式在每个训练迭代中生成多个不同的数据样本以增加训练集的多样性但不实际增加训练数据的数量。常见的在线数据增强操作包括随机裁剪random cropping、翻转flipping、旋转rotation、缩放scaling等。与在线数据增强对应的是离线数据增强Offline Data Augmentation。离线数据增强是指在训练开始之前将原始数据集进行增强并将增强后的数据保存为新的训练集。然后在训练过程中使用增强后的训练集进行模型训练。离线数据增强的好处是可以节省训练时间因为数据增强只需在训练开始之前完成一次。常见的离线数据增强操作包括扩充数据集例如通过旋转、平移、缩放等方式生成新的图像。常用的离线数据增强库包括imgaug、albumentations和Augmentor等。在代码11-2中我们在初始化方法里面提供数据transform方式同步接收img和target作为输入进行在线数据增强以增强训练样本的多样性。当然了这需要我们对torchvision中的transform方法稍微进行改动。
后续全书内容和代码将在github上开源请关注仓库

未完待续