from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
class FLDataset(Dataset):
def __init__(self, dataset_name_or_path="ylecun/mnist", split="train", transform=transform):
self.dataset = load_dataset(dataset_name_or_path)[split]
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# Get the image and label
image = self.dataset[idx]['image']
label = self.dataset[idx]['label']
# Apply transformations to the image
if self.transform:
image = self.transform(image)
return image, label