from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class FLDataset(Dataset):
    def __init__(self, dataset_name_or_path="ylecun/mnist", split="train", transform=transform):
        self.dataset = load_dataset(dataset_name_or_path)[split]
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the image and label
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['label']
        
        # Apply transformations to the image
        if self.transform:
            image = self.transform(image)
        
        return image, label