import torch
import torch.utils.data
from torchvision import transforms,datasets
# 定义transforms的一些操作
data_transform = transforms.Compose([
# Resize后数据的大小为224 * 224
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# 数据标准化,采用的图片标准化参数
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 使用ImageFolder去读取,返回后的数据路径和标签对应起来
all_dataset = datasets.ImageFolder('../data/amazon/images', transform=data_transform)
# 使用random_split实现数据集的划分,lengths是一个list,按照对应的数量返回数据个数。
# 这儿需要注意的是,lengths的数据量总和等于all_dataset中的数据个数,这儿不是按比例划分的
train, test, valid = torch.utils.data.random_split(dataset= all_dataset, lengths=[2000, 417, 400])
# 接着按照正常方式使用DataLoader读取数据,返回的是DataLoader对象
train = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True, num_workers=4)
test = torch.utils.data.DataLoader(test, batch_size=4, shuffle=True, num_workers=4)
valid = torch.utils.data.DataLoader(valid, batch_size=4, shuffle=True, num_workers=4)