#!/usr/bin/env python3
import argparse
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio.v3 as iio

from MI_image import LoadImage

def read_list_file(list_path: str):
    with open(list_path, 'r', encoding='utf-8') as f:
        tokens = []
        for line in f:
            line = line.strip()
            if not line:
                continue
            tokens.extend(line.split())
        if not tokens:
            raise RuntimeError('Empty list file')
        nb = int(tokens[0])
        paths = tokens[1:1+nb]
        if len(paths) < nb:
            raise RuntimeError('List file has fewer paths than declared')
        return nb, paths


def temporal_median_pipeline(list_path: str, fps: int = 20):
    # ---- Read list of image paths ----
    nb, paths = read_list_file(list_path)
    print(f"Found {nb} image paths")

    # ---- Load frames with your own PNM reader ----
    frames = []
    for p in paths:
        img = LoadImage(p)               # your function
        arr = img["data"]                # ndarray (H, W, C)

        # Convert to 2D grayscale if needed
        if arr.ndim == 3:
            if arr.shape[2] == 1:
                arr = arr[:, :, 0]                       # drop singleton
            elif arr.shape[2] == 3:
                # standard luminance conversion
                arr = (0.299*arr[:,:,0] + 
                       0.587*arr[:,:,1] + 
                       0.114*arr[:,:,2])
            else:
                raise ValueError(f"Unsupported channel count: {arr.shape}")

        frames.append(arr.astype(np.float32))

    # ---- Stack into (T, H, W) ----
    frames = np.stack(frames, axis=0)
    T, H, W = frames.shape
    print(f"Loaded frames shape = {frames.shape}")

    # ---- Fix intensity window for stable brightness ----
    vmin, vmax = np.percentile(frames, (1, 99))

    # ---- Matplotlib viewer setup ----
    fig, ax = plt.subplots()
    im = ax.imshow(frames[0], cmap="gray", vmin=vmin, vmax=vmax, animated=True)
    ax.set_axis_off()

    # ---- Animation update ----
    def update(i):
        im.set_array(frames[i])
        return (im,)

    ani = animation.FuncAnimation(
        fig,
        update,
        frames=T,
        interval=int(1000/fps),
        blit=True
    )

    plt.show()
    return ani
    
# -------- CLI --------
def make_parser():
    p = argparse.ArgumentParser(description='TP2 using MI_image_numpy.py')
    sub = p.add_subparsers(dest='cmd', required=True)
    # median
    pt = sub.add_parser('median', help='Temporal median pipeline')
    pt.add_argument('--list', required=True, help='Text file: first token = nbImages, then nbImages paths')
    pt.add_argument('--median-out', default='outTemporalMedianImage.ppm', help='Output PPM for median image')
    pt.add_argument('--frame-prefix', default='outVideoFile_', help='Prefix for per-frame output PGM files')
    pt.add_argument('--threshold', type=float, default=25.0, help='Threshold value (default: 25)')
    # warp
    pb = sub.add_parser('warp', help='Wave warping (nearest and bilinear)')
    pb.add_argument('--input', required=True, help='Input image (PGM/PPM)')
    pb.add_argument('--output0', default='outWarping0.ppm', help='Output PPM for nearest neighbor warping')
    pb.add_argument('--output1', default='outWarping1.ppm', help='Output PPM for bilinear warping')
    # SNR
    ps = sub.add_parser('snr', help='Signal to noise ratio experiment')
    ps.add_argument('--input', required=True, help='Input image (PGM/PPM)')
    ps.add_argument('-N', type=int, default=64, help='Number of noisy images to average (default: 64)')
    ps.add_argument('--noise-type', choices=['gaussian', 'uniform', 'saltpepper'], default='gaussian', help='Type of noise to add')
    return p

def main():
    args = make_parser().parse_args()
    if args.cmd == 'warp':
        warp_both(args.input, args.output0, args.output1)
    elif args.cmd == 'median':
        temporal_median_pipeline(args.list)
    elif args.cmd == 'snr':
        snr(args.input, args.N, args.noise_type)

if __name__ == '__main__':
    main()
