import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
import sys
import os
import cv2
import numpy as np
from pathlib import Path

# Add src folder
BASE_DIR = Path(__file__).resolve().parent
SRC_DIR = BASE_DIR / "src"
MODEL_PATH = BASE_DIR / "models" / "mobilenet_model.pth"

sys.path.append(str(SRC_DIR))
from model import ColonCancerModel
from gradcam import generate_gradcam, overlay_heatmap

# ---------------- PAGE CONFIG ---------------- #
st.set_page_config(page_title="Diagnostic Dashboard", layout="wide")

# ---------------- GLOBAL STYLING ---------------- #
st.markdown("""
<style>

/* Background */
.stApp {
    background: linear-gradient(135deg, #6a85f1, #7f53ac);
}



/* Title */
.title {
    font-size: 32px;
    font-weight: 600;
    color: #2c2c2c;
}

/* Button */
.stButton>button {
    background-color: #5a8dee;
    color: white;
    border-radius: 10px;
    height: 45px;
    width: 100%;
}

/* Input fields */
input {
    border-radius: 8px !important;
}

</style>
""", unsafe_allow_html=True)

# ---------------- USERS ---------------- #
users = {
    "akhila": "1234",
    "maithri": "1234",
    "tejaswi": "1234",
    "kowshik": "1234"
}

if "logged_in" not in st.session_state:
    st.session_state.logged_in = False


@st.cache_resource
def load_model():
    model = ColonCancerModel()
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    model.eval()
    return model

# ---------------- LOGIN PAGE ---------------- #
def login_page():
    col1, col2 = st.columns([1.2, 1])

    with col1:
        st.markdown("""
        <div style='padding:60px; color:white;'>
            <h1>Colon Care</h1>
            <p style='font-size:18px;'>
            Colon Cancer Detection Using A Light-weight MobileNetV2 Model with Grad-CAM.
            </p>
        </div>
        """, unsafe_allow_html=True)

    with col2:
        st.markdown("<div class='card'>", unsafe_allow_html=True)

        st.markdown("<div class='title'>Login</div>", unsafe_allow_html=True)

        username = st.text_input("Username")
        password = st.text_input("Password", type="password")

        if st.button("Login"):
            if username in users and users[username] == password:
                st.session_state.logged_in = True
                st.success("Login successful")
                st.rerun()
            else:
                st.error("Invalid credentials")

        st.markdown("</div>", unsafe_allow_html=True)

# ---------------- DASHBOARD ---------------- #
def dashboard():

    # Top bar
    col1, col2 = st.columns([6,1])
    with col1:
        st.markdown("<h2 style='color:white;'>Colon Cancer Diagnostic Dashboard</h2>", unsafe_allow_html=True)
    with col2:
        if st.button("Logout"):
            st.session_state.logged_in = False
            st.rerun()

    st.markdown("<br>", unsafe_allow_html=True)

    # Main container
    col1, col2 = st.columns([1,1])

    # ---------------- LEFT: Upload ---------------- #
    with col1:
        st.markdown("<div class='card'>", unsafe_allow_html=True)

        st.subheader("Upload Image")

        uploaded_file = st.file_uploader("Upload histopathology image", type=["jpg","png","jpeg"])

        st.markdown("</div>", unsafe_allow_html=True)

    # ---------------- RIGHT: Results ---------------- #
    with col2:
        st.markdown("<div class='card'>", unsafe_allow_html=True)

        st.subheader("Prediction Result")

        if uploaded_file:
            img = Image.open(uploaded_file).convert("RGB")

            transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],
                                     [0.229,0.224,0.225])
            ])

            img_tensor = transform(img).unsqueeze(0)

            # Load model
            model = load_model()

            with torch.no_grad():
                output = model(img_tensor)
                prob = torch.sigmoid(output).item()

            st.write(f"Probability: {prob:.4f}")

            if prob > 0.5:
                st.error("Cancer Detected")
            else:
                st.success("No Cancer Detected")

        else:
            st.info("Upload an image to see results")

        st.markdown("</div>", unsafe_allow_html=True)

    # ---------------- FULL WIDTH: IMAGE + GRADCAM ---------------- #
    if uploaded_file:
        st.markdown("<br>", unsafe_allow_html=True)

        col1, col2 = st.columns(2)

        img = Image.open(uploaded_file).convert("RGB")

        with col1:
            st.markdown("<div class='card'>", unsafe_allow_html=True)
            st.subheader("Original Image")
            st.image(img, width=300)
            st.markdown("</div>", unsafe_allow_html=True)

        with col2:
            st.markdown("<div class='card'>", unsafe_allow_html=True)
            st.subheader("Grad-CAM Visualization")

            target_layer = model.mobilenet.features[-1]
            cam = generate_gradcam(model, img_tensor, target_layer)

            original = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR)
            overlay = overlay_heatmap(original, cam)

            st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB), width=300)

            st.markdown("</div>", unsafe_allow_html=True)

# ---------------- APP FLOW ---------------- #
if st.session_state.logged_in:
    dashboard()
else:
    login_page()
