# src/dataset.py
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

def get_dataloaders(data_dir, image_size=224, batch_size=8):
    # Define transforms: resize, random augment, convert to tensor, normalize
    train_transforms = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    
    val_transforms = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    # ImageFolder assumes subfolders named by class
    full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    print("Class to index mapping:", full_dataset.class_to_idx)
    
    # Split dataset: 80% train, 10% val, 10% test
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size
    
    train_set, val_set, test_set = random_split(full_dataset, [train_size, val_size, test_size])
    
    # For validation and test, override transform for no augmentation
    val_set.dataset.transform = val_transforms
    test_set.dataset.transform = val_transforms

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader

if __name__ == "__main__":
    # Quick test: update path as needed.
    train_loader, val_loader, test_loader = get_dataloaders(os.path.join('..', 'data'))
    print("Training set size:", len(train_loader.dataset))
