| import os |
| from torch.utils.data import DataLoader, random_split |
| from torchvision import datasets, transforms |
|
|
| def get_dataloaders(data_dir="path/to/data/dir", batch_size=512, train_split=0.8, img_size=224, num_workers=4): |
| """ |
| Returns training and validation dataloaders for an image classification dataset. |
| |
| Parameters: |
| - data_dir (str): Path to the directory containing image data in a folder structure compatible with ImageFolder. |
| - batch_size (int): Number of samples per batch. |
| - train_split (float): Fraction of data to use for training. Remaining is for validation. |
| - img_size (int): Target size to which all images are resized after validation. |
| - num_workers (int): Number of worker processes for data loading. |
| |
| Image Size Validation: |
| - Minimum allowed image size: 49x49 pixels. |
| - If an image has either width or height less than 49 pixels, a ValueError is raised. |
| |
| Returns: |
| - train_dataloader (DataLoader): DataLoader for the training split. |
| - val_dataloader (DataLoader): DataLoader for the validation split. |
| """ |
|
|
| |
| if img_size < 49: |
| raise ValueError(f"Image size must be at least 49x49 pixels, but got {img_size}x{img_size}.") |
| |
| transform = transforms.Compose([ |
| transforms.Resize((img_size, img_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| full_dataset = datasets.ImageFolder(root=data_dir, transform=transform) |
|
|
| |
| train_size = int(train_split * len(full_dataset)) |
| val_size = len(full_dataset) - train_size |
| train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) |
|
|
| |
| train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
| val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
|
| return train_dataloader, val_dataloader |
|
|