
import numpy as np
from scipy.ndimage import gaussian_filter


def _to_float(x):
    return x.astype(np.float64)


def _ensure_same_shape(x, y):
    if x.shape != y.shape:
        raise ValueError("Input images must have the same shape.")


# ======================================================================
#                           BASIC METRICS
# ======================================================================

def mse(x, y):
    """Mean Squared Error (supports multi-channel images)."""
    x = _to_float(x)
    y = _to_float(y)
    _ensure_same_shape(x, y)

    return np.mean((x - y) ** 2)


def psnr(x, y, data_range=255.0):
    """Peak Signal-to-Noise Ratio."""
    m = mse(x, y)
    if m == 0:
        return np.inf
    return 20 * np.log10(data_range / np.sqrt(m))


def snr(x, y):
    """
    Signal-to-Noise Ratio.
    Same definition as ton fichier snr.py :
        snr = 20 * log10( ||x|| / ||x - y|| )
    Compatible multi-canaux.
    """
    x = _to_float(x)
    y = _to_float(y)
    _ensure_same_shape(x, y)

    noise = x - y
    num = np.linalg.norm(x.ravel())
    den = np.linalg.norm(noise.ravel())

    if den == 0:
        return np.inf

    return 20 * np.log10(num / den)


# ======================================================================
#                           UQI
# ======================================================================

def uqi_channel(x, y):
    """UQI for a single channel."""
    x = _to_float(x)
    y = _to_float(y)

    x_mean = x.mean()
    y_mean = y.mean()

    x_var = np.var(x)
    y_var = np.var(y)

    cov = np.mean((x - x_mean) * (y - y_mean))

    numerator = 4 * cov * x_mean * y_mean
    denominator = (x_mean**2 + y_mean**2) * (x_var + y_var)

    if denominator == 0:
        return 1.0 if numerator == 0 else 0.0

    return numerator / denominator


def uqi(x, y):
    """Universal Quality Index — multi-channel support."""
    _ensure_same_shape(x, y)

    if x.ndim == 2:
        return uqi_channel(x, y)

    return np.mean([uqi_channel(x[..., c], y[..., c]) for c in range(x.shape[-1])])


# ======================================================================
#                           SSIM
# ======================================================================

def ssim_channel(x, y, data_range=255.0, K1=0.01, K2=0.03, sigma=1.5):
    """SSIM for a single channel."""
    x = _to_float(x)
    y = _to_float(y)

    ux  = gaussian_filter(x, sigma)
    uy  = gaussian_filter(y, sigma)

    uxx = gaussian_filter(x * x, sigma)
    uyy = gaussian_filter(y * y, sigma)
    uxy = gaussian_filter(x * y, sigma)

    vx  = uxx - ux * ux
    vy  = uyy - uy * uy
    vxy = uxy - ux * uy

    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2

    numerator = (2 * ux * uy + C1) * (2 * vxy + C2)
    denominator = (ux**2 + uy**2 + C1) * (vx + vy + C2)

    ssim_map = numerator / denominator
    return np.mean(ssim_map)


def ssim(x, y, data_range=255.0, K1=0.01, K2=0.03, sigma=1.5):
    """Full SSIM with multi-channel support."""
    _ensure_same_shape(x, y)

    if x.ndim == 2:
        return ssim_channel(x, y, data_range, K1, K2, sigma)

    return np.mean([
        ssim_channel(x[..., c], y[..., c], data_range, K1, K2, sigma)
        for c in range(x.shape[-1])
    ])


# ======================================================================
#                           MSSIM
# ======================================================================

def mssim(x, y, levels=5, data_range=255.0, K1=0.01, K2=0.03, sigma=1.5):
    """Multi-Scale SSIM with multi-channel support."""
    _ensure_same_shape(x, y)

    mssim_vals = []

    for _ in range(levels):
        mssim_vals.append(
            ssim(x, y, data_range=data_range, K1=K1, K2=K2, sigma=sigma)
        )

        # Downsample
        x = x[::2, ::2, ...] if x.ndim == 3 else x[::2, ::2]
        y = y[::2, ::2, ...] if y.ndim == 3 else y[::2, ::2]

        if min(x.shape[0], x.shape[1]) < 2:
            break

    return np.mean(mssim_vals)
