import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from model import ColonCancerModel  # Ensure this file defines your model architecture

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_gradcam(model, img_tensor, target_layer):
    """
    Generate a Grad-CAM heatmap for an input image and a specified model target layer.
    This function uses a forward hook to capture the activation and registers a hook
    on that activation tensor to capture its gradient during the backward pass.
    """
    model.eval()
    
    # Variables to store the activation and its gradient
    activation = None
    gradient = None

    def activation_gradient_hook(grad):
        nonlocal gradient
        gradient = grad

    def forward_hook(module, input, output):
        nonlocal activation
        activation = output.clone()  # Clone the activation to keep a copy
        # Register a hook on the activation tensor so its gradient is captured during backpropagation.
        if output.requires_grad:
            output.register_hook(activation_gradient_hook)
        else:
            raise RuntimeError("Output of target layer does not require grad. "
                               "Ensure that the input tensor has requires_grad=True.")

    # Register the forward hook on the target layer.
    handle = target_layer.register_forward_hook(forward_hook)

    # Move the image to the device and ensure it requires gradient.
    img_tensor = img_tensor.to(device)
    if not img_tensor.requires_grad:
        img_tensor.requires_grad_()

    # Forward pass through the model.
    output = model(img_tensor)

    # For binary classification, apply sigmoid.
    pred = torch.sigmoid(output)

    # Zero out any existing gradients.
    model.zero_grad()
    # Backward pass with a gradient signal of ones so that gradients propagate.
    pred.backward(torch.ones_like(pred))

    # Remove the hook after use.
    handle.remove()

    # Check that both activation and gradient were captured.
    if activation is None or gradient is None:
        raise ValueError("No activation or gradient captured. Check that the target layer is correct.")

    # Convert activation and gradient to numpy arrays (assuming batch size 1).
    # Expected shapes: activation [N, C, H, W], gradient [N, C, H, W]
    activation_np = activation.detach().cpu().numpy()[0]  # shape: [C, H, W]
    gradient_np = gradient.detach().cpu().numpy()[0]        # shape: [C, H, W]

    # Compute channel-wise weights via global-average pooling of the gradients (averaging over H and W).
    weights = np.mean(gradient_np, axis=(1, 2))  # shape: [C]

    # Create the Grad-CAM map: Weighted sum of the activation maps.
    cam = np.zeros(activation_np.shape[1:], dtype=np.float32)  # shape: [H, W]
    for i, w in enumerate(weights):
        cam += w * activation_np[i, :, :]

    # Apply ReLU to keep only positive contributions.
    cam = np.maximum(cam, 0)

    # Resize the heatmap to the same size as the input image.
    cam = cv2.resize(cam, (img_tensor.shape[3], img_tensor.shape[2]))

    # Normalize the heatmap to the range [0, 1].
    if np.max(cam) != 0:
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))
    
    return cam

def overlay_heatmap(original_img, heatmap, alpha=0.4, colormap=cv2.COLORMAP_JET):
    """
    Overlays the heatmap on the original image.

    Parameters:
      - original_img: Original image in OpenCV BGR format.
      - heatmap: Heatmap normalized to [0, 1].
      - alpha: Opacity for the heatmap overlay.
      - colormap: OpenCV colormap to use.

    Returns:
      - overlay: The resulting image after overlay.
    """
    heatmap_uint8 = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, colormap)
    overlay = cv2.addWeighted(original_img, 1 - alpha, heatmap_color, alpha, 0)
    return overlay

if __name__ == "__main__":
    # ---------------------- MODEL LOADING AND PREDICTION ----------------------
    # Load your trained model and its state dictionary.
    model = ColonCancerModel()
    model.load_state_dict(torch.load('../models/mobilenet_model.pth', map_location=device))
    model.to(device)

    # Prepare the input image.
    img_path = "../data/0_colon_benign_tissue/colonn8.jpeg"
    img = Image.open(img_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0)  # Shape: [1, 3, 224, 224]

    # Ensure the model is in evaluation mode and the image tensor requires gradients.
    model.eval()
    if not img_tensor.requires_grad:
        img_tensor.requires_grad_()

    # Forward pass for prediction.
    output = model(img_tensor.to(device))
    pred = torch.sigmoid(output)
    probability = pred.item()

    # Print the prediction results.
    print(f"Predicted Probability of Cancer: {probability:.4f}")
    if probability > 0.5:
        print("Diagnosis: Cancer Detected")
    else:
        print("Diagnosis: No Cancer Detected")
    # ---------------------- END OF PREDICTION SECTION ----------------------

    # ---------------------- GRAD-CAM GENERATION ----------------------
    # Specify the target layer. For MobileNetV2, model.mobilenet.features[-1] is typically
    # the last convolutional layer. Adjust it if your model's architecture differs.
    target_layer = model.mobilenet.features[-1]

    # Generate the Grad-CAM heatmap.
    cam = generate_gradcam(model, img_tensor, target_layer)

    # Convert the original image to OpenCV format (BGR) for overlaying.
    original_img_cv = cv2.cvtColor(np.array(img.resize((224, 224))), cv2.COLOR_RGB2BGR)

    # Overlay the heatmap on the original image.
    overlay = overlay_heatmap(original_img_cv, cam)

    # Save the overlay image (ensure the destination directory exists).
    cv2.imwrite("../results/cam_images/gradcam_result.jpg", overlay)

    # Display the result using matplotlib.
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.title("Grad-CAM")
    plt.show()
