#!/usr/bin/env python3
"""
"""
from __future__ import annotations
import numpy as np
from typing import Tuple

# Format constants (P2, P3, P5, P6)
PGM_RAW   = 0  # P5
PGM_ASCII = 1  # P2
PPM_RAW   = 2  # P6
PPM_ASCII = 3  # P3

# ------------------------------
# Header parsing helpers
# ------------------------------
def _read_header_tokens(f, n_needed: int):
    """Read header tokens while skipping full-line and inline '#' comments."""
    tokens = []
    while len(tokens) < n_needed:
        line = f.readline()
        if not line:
            break
        # Remove inline comments
        if b'#' in line:
            line = line.split(b'#', 1)[0]
        parts = line.split()
        if parts:
            tokens.extend(parts)
    return tokens

# ------------------------------
# I/O
# ------------------------------
def load_image(file_name: str) -> Tuple[np.ndarray, int, int, int]:
    """Load PGM/PPM (P2,P5,P3,P6). Returns (img, width, height, channels).
       img is float32 in [0,255], shape (H,W) for gray or (H,W,3) for color.
    """
    if not file_name:
        raise ValueError("Error: Please specify a filename")

    with open(file_name, 'rb') as f:
        header = _read_header_tokens(f, 4)
        if len(header) < 4:
            raise ValueError("LoadImage Error: Incomplete PNM header")
        magic  = header[0].decode('ascii')
        width  = int(header[1])
        height = int(header[2])
        maxval = int(header[3])
        if maxval <= 0:
            raise ValueError("LoadImage Error: invalid maxval")

        # Determine format and channels
        if   magic == 'P2': fmt, channels = PGM_ASCII, 1
        elif magic == 'P5': fmt, channels = PGM_RAW,   1
        elif magic == 'P3': fmt, channels = PPM_ASCII, 3
        elif magic == 'P6': fmt, channels = PPM_RAW,   3
        else:
            raise ValueError("LoadImage Error: Unsupported file type")

        nvals = width * height * channels

        if fmt in (PGM_ASCII, PPM_ASCII):
            # Read remaining ASCII numbers (skip inline comments)
            ascii_bytes = bytearray()
            while True:
                line = f.readline()
                if not line:
                    break
                if b'#' in line:
                    line = line.split(b'#', 1)[0]
                ascii_bytes.extend(line)
            arr = np.fromstring(ascii_bytes.decode('ascii'), sep=' ', dtype=np.uint32)
            if arr.size < nvals:
                raise ValueError("LoadImage Error: Not enough pixel data in ASCII file")
            arr = arr[:nvals].astype(np.float32)
        else:
            # RAW binary data: support 8-bit and 16-bit per sample
            if maxval < 256:
                raw = f.read(nvals)
                if len(raw) < nvals:
                    raise ValueError("LoadImage Error: Not enough pixel data in RAW file (8-bit)")
                arr = np.frombuffer(raw, dtype=np.uint8).astype(np.float32)
            else:
                # 16-bit samples, big-endian per PNM spec
                raw = f.read(nvals * 2)
                if len(raw) < nvals * 2:
                    raise ValueError("LoadImage Error: Not enough pixel data in RAW file (16-bit)")
                arr = np.frombuffer(raw, dtype='>u2').astype(np.float32)

        # Scale to [0,255] like the C++ code
        arr *= (255.0 / float(maxval))

        # Reshape
        if channels == 1:
            img = arr.reshape(height, width)
        else:
            img = arr.reshape(height, width, channels)

    return img.astype(np.uint8), width, height, channels


def save_image(file_name: str, img, fmt: int):
    """Save img as PGM/PPM in ASCII or RAW (P2/P5/P3/P6), 8-bit output (maxval=255).
       Gray (H,W) -> PGM; Color (H,W,3) -> PPM. If saving gray to PPM, duplicate gray to RGB.
    """
    if not file_name:
        raise ValueError("Error!! save_image: empty file name")
    if img is None:
        raise ValueError("SaveImage Error: image is empty")

    img = np.asarray(img)
    if img.size == 0:
        raise ValueError("SaveImage Error: image is empty")

    # Squeeze possible singleton channel (H,W,1) -> (H,W)
    if img.ndim == 3 and img.shape[2] == 1:
        img = img.squeeze(axis=2)

    # Determine geometry
    if img.ndim == 2:
        height, width = img.shape
        channels = 1
    elif img.ndim == 3 and img.shape[2] == 3:
        height, width, channels = img.shape
    else:
        raise ValueError(f"SaveImage Error: img must be (H,W) or (H,W,3), got shape {img.shape}")

    if fmt == PGM_ASCII:
        with open(file_name, 'wb') as f:
            f.write(b"P2\n")
            f.write(f"{width} {height}\n".encode('ascii'))
            f.write(b"255\n")
            if channels == 1:
                ints = np.clip(img, 0, 255).astype(np.uint16)
            else:
                gray = np.mean(img, axis=2)
                ints = np.clip(gray, 0, 255).astype(np.uint16)
            for y in range(height):
                row = " ".join(map(str, ints[y])) + "\n"
                f.write(row.encode('ascii'))

    elif fmt == PGM_RAW:
        with open(file_name, 'wb') as f:
            f.write(b"P5\n")
            f.write(f"{width} {height}\n".encode('ascii'))
            f.write(b"255\n")
            if channels == 1:
                buf = np.clip(img, 0, 255).astype(np.uint8)
            else:
                buf = np.clip(np.mean(img, axis=2), 0, 255).astype(np.uint8)
            f.write(buf.tobytes())

    elif fmt == PPM_ASCII:
        with open(file_name, 'wb') as f:
            f.write(b"P3\n")
            f.write(f"{width} {height}\n".encode('ascii'))
            f.write(b"255\n")
            if channels == 1:
                v = np.clip(img, 0, 255).astype(np.uint16)
                rgb = np.stack([v, v, v], axis=2)
            else:
                rgb = np.clip(img, 0, 255).astype(np.uint16)
            for y in range(height):
                row = " ".join(map(str, rgb[y].reshape(-1))) + "\n"
                f.write(row.encode('ascii'))

    elif fmt == PPM_RAW:
        with open(file_name, 'wb') as f:
            f.write(b"P6\n")
            f.write(f"{width} {height}\n".encode('ascii'))
            f.write(b"255\n")
            if channels == 1:
                v = np.clip(img, 0, 255).astype(np.uint8)
                rgb = np.stack([v, v, v], axis=2)
            else:
                rgb = np.clip(img, 0, 255).astype(np.uint8)
            f.write(rgb.reshape(-1).tobytes())
    else:
        raise ValueError("SaveImage Error: Unsupported file format")


# ------------------------------
# Point operations
# ------------------------------
def invert(img: np.ndarray) -> np.ndarray:
    return (255 - img)

def threshold(img: np.ndarray, tvalue: float) -> np.ndarray:
    #TODO
    return img

def rescale(img: np.ndarray) -> np.ndarray:
    #TODO
    return img

def average(img: np.ndarray, channel: int = 0) -> float:
    #TODO
    return 0

def contrast(img: np.ndarray, channel: int = 0) -> float:
    #TODO
    return 0

def GammaCorrect(img: np.ndarray, gamma: float) -> np.ndarray:
    return img**gamma

# ------------------------------
# Histogramme
# ------------------------------
def ComputeHistogram(img: np.ndarray,
                      channel: int,
                      norm: bool = False) -> np.ndarray:
    #TODO
    return hist

def HistogramEqualization(img: np.ndarray,
                           channel: int = 0) -> np.ndarray:

    #TODO
    return img
