#!/usr/bin/env python3
"""
Sleep Staging with AnySleep – Automatic EEG-based sleep phase classification.

Uses the pre-trained AnySleep model to classify 30-second epochs of a full-night
EEG recording into Wake, N1, N2, N3, or REM sleep stages. Outputs a visualization
with probability panel, hypnogram, and EEG spectrogram.

Usage:
  python plot_sleep_scoring.py recording.EDF --channel 0
  python plot_sleep_scoring.py recording.EDF --channel 0 --start 3600 --duration 7200
"""

import argparse
import struct
import urllib.request
from pathlib import Path

import numpy as np
import scipy.signal
import scipy.special
import scipy.ndimage
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from anysleep_no_hydra import AnySleep


# Dark theme colors
DARK = '#0d1c2b'
WHITE = '#ffffff'
TEXT = '#aabbcc'
GRID = '#1a2f42'
ORANGE = '#ff9151'
GREEN = '#7be0ad'
BLUE = '#abd5ff'
CYAN = '#4dd4f4'

# Stage information (Wake=0, N1=1, N2=2, N3=3, REM=4)
STAGE_NAMES = ['Wake', 'N1', 'N2', 'N3', 'REM']
STAGE_COLORS = [
    '#abd5ff',  # Wake
    '#fff4a8',  # N1  
    '#ffc8a8',  # N2  
    '#ff9151',  # N3  
    '#d151ff',  # REM 
]

# Hypnogram Y-axis: traditional order — Wake top (4), deep sleep bottom (0)
STAGE_TO_HYPO_Y = {0: 4, 1: 2, 2: 1, 3: 0, 4: 3}
HYPO_Y_LABELS = ['N3', 'N2', 'N1', 'REM', 'WAKE']

# Model weights auto-download
MODEL_URL = 'https://media.githubusercontent.com/media/dslaborg/AnySleep/main/models/anysleep-run1.pth'


# ── EDF reader ────────────────────────────────────────────────────────────────

def read_edf_channel(filepath, channel_idx):
    """Read one ADS channel (0–7) from a Traumschreiber EDF file.
    Returns (signal: np.ndarray float64, sample_rate: float).
    """
    raw = Path(filepath).read_bytes()
    if len(raw) < 256:
        raise ValueError('File too small – not a valid EDF.')

    hdr = raw[:256].decode('latin-1')

    def _i(s):
        try: return int(s.strip())
        except ValueError: return 0

    def _f(s):
        try: return float(s.strip())
        except ValueError: return 1.0

    header_bytes    = _i(hdr[184:192]) or 3072
    num_records     = _i(hdr[236:244])      # -1 = unknown
    record_duration = _f(hdr[244:252])      # seconds per record
    num_signals     = _i(hdr[252:256]) or 11

    if not (0 <= channel_idx < min(num_signals, 8)):
        raise ValueError(
            f'Channel {channel_idx} invalid. Available EEG channels: '
            f'0–{min(num_signals, 8) - 1}.'
        )

    # EDF signal headers are stored as grouped fields across all signals:
    # [all labels (16 bytes each)] [all transducers (80 bytes each)] ...
    sig_hdr = raw[256:header_bytes].decode('latin-1')
    field_widths = [16, 80, 8, 8, 8, 8, 8, 80, 8, 32]
    field_names  = ['label', 'transducer', 'dim', 'phys_min', 'phys_max',
                    'dig_min', 'dig_max', 'prefilter', 'num_samples', 'reserved']
    sigs = [{} for _ in range(num_signals)]
    off = 0
    for fw, fn in zip(field_widths, field_names):
        for i in range(num_signals):
            sigs[i][fn] = sig_hdr[off:off + fw].strip()
            off += fw

    ns = []
    for s in sigs:
        try: ns.append(int(s['num_samples']))
        except ValueError: ns.append(0)

    record_bytes = sum(n * 2 for n in ns)
    if record_bytes == 0:
        raise ValueError('EDF header reports 0 samples/record – file corrupt?')

    sample_rate = ns[0] / record_duration if record_duration > 0 else 250.0

    # Collect raw digital samples for channel_idx across all data records
    samples = []
    offset = header_bytes
    rec = 0
    while offset + record_bytes <= len(raw):
        if num_records >= 0 and rec >= num_records:
            break
        pos = offset
        for i, n in enumerate(ns):
            if i == channel_idx:
                samples.extend(struct.unpack_from(f'<{n}h', raw, pos))
            pos += n * 2
        offset = pos
        rec += 1

    return np.array(samples, dtype=np.float64), float(sample_rate)


# ── Model helpers ─────────────────────────────────────────────────────────────

def ensure_model(model_path):
    """Download AnySleep weights if not present."""
    model_path = Path(model_path)
    if model_path.exists():
        return str(model_path)

    print(f'Model not found: {model_path}')
    print('Downloading AnySleep model (~12 MB)...')
    try:
        urllib.request.urlretrieve(MODEL_URL, model_path)
        print(f'✓ Model saved to {model_path}')
    except Exception as e:
        raise RuntimeError(f'Download failed: {e}\nManually download from:\n{MODEL_URL}')
    return str(model_path)


# ── Signal preprocessing ──────────────────────────────────────────────────────

def preprocess(signal, original_fs, verbose=False):
    """Filter, resample to 128 Hz, clip, robust scale — for AnySleep input.

    AnySleep was trained on clinical PSG signals, which are hardware-filtered
    by the amplifier. The Traumschreiber records broadband — without explicit
    filtering, DC drift and powerline noise dominate the RobustScaler statistics
    and the model collapses to WAKE. We apply a clinical-style bandpass first.

    Returns (sig_model, sig_spec):
      sig_model : notch + BP 0.3–35 Hz + resample + clip + scale → model input
      sig_spec  : notch + HP 0.3 Hz   + resample                  → spectrogram
    """
    sig = np.asarray(signal, dtype=np.float64)

    if verbose:
        print(f'  raw    : median={np.median(sig):+9.1f}  iqr={np.subtract(*np.percentile(sig, [75, 25])):.1f}  range=[{sig.min():.0f}, {sig.max():.0f}]')

    # 50 Hz notch (Europe powerline) — applied to both paths
    b_notch, a_notch = scipy.signal.iirnotch(50.0, 30.0, fs=original_fs)
    sos_notch = scipy.signal.tf2sos(b_notch, a_notch)
    sig_n = scipy.signal.sosfiltfilt(sos_notch, sig)

    # ── Spectrogram path: HP 0.3 Hz only (keeps content up to Nyquist) ──────
    sos_hp = scipy.signal.butter(4, 0.3, btype='highpass',
                                  fs=original_fs, output='sos')
    sig_hp = scipy.signal.sosfiltfilt(sos_hp, sig_n)
    sig_spec = scipy.signal.resample_poly(sig_hp, 128, int(original_fs))

    # ── Model path: full BP 0.3–35 Hz to match training distribution ────────
    sos_bp = scipy.signal.butter(4, [0.3, 35.0], btype='bandpass',
                                  fs=original_fs, output='sos')
    sig_bp = scipy.signal.sosfiltfilt(sos_bp, sig_n)

    if verbose:
        print(f'  filt   : median={np.median(sig_bp):+9.3f}  iqr={np.subtract(*np.percentile(sig_bp, [75, 25])):.3f}  range=[{sig_bp.min():.1f}, {sig_bp.max():.1f}]')

    sig_m = scipy.signal.resample_poly(sig_bp, 128, int(original_fs))

    iqr = float(np.subtract(*np.percentile(sig_m, [75, 25])))
    if iqr > 0:
        sig_m = np.clip(sig_m, -20 * iqr, 20 * iqr)

    median = float(np.median(sig_m))
    iqr = float(np.subtract(*np.percentile(sig_m, [75, 25])))
    sig_m = (sig_m - median) / max(iqr, 1e-8)

    if verbose:
        print(f'  scaled : median={np.median(sig_m):+9.3f}  iqr={np.subtract(*np.percentile(sig_m, [75, 25])):.3f}  range=[{sig_m.min():.2f}, {sig_m.max():.2f}]')

    return sig_m, sig_spec


def run_inference(signal, fs, model_path, device='cpu', verbose=True):
    """Preprocess, run model, return (preds, probs, n_epochs, sig_128hz)."""
    sig_scaled, sig_128 = preprocess(signal, fs, verbose=verbose)

    epoch_samples = 128 * 30  # 3840
    n_epochs = len(sig_scaled) // epoch_samples
    if n_epochs == 0:
        raise ValueError('Signal too short for even one 30-second epoch.')

    sig_trimmed = sig_scaled[:n_epochs * epoch_samples]

    model = AnySleep(path=model_path, sleep_stage_frequency=1)
    model = model.to(device).eval()

    # Shape: (1, total_samples, 1) — single channel
    x = torch.from_numpy(sig_trimmed.reshape(1, -1, 1)).float().to(device)
    with torch.no_grad():
        logits = model(x)  # (1, n_epochs, 5)

    logits_np = logits.cpu().numpy()[0]                        # (n_epochs, 5)
    probs = scipy.special.softmax(logits_np, axis=1)           # (n_epochs, 5)
    preds = np.argmax(probs, axis=1)

    # Trim sig_128 to match epoch count
    sig_128 = sig_128[:n_epochs * epoch_samples]

    return preds, probs, n_epochs, sig_128


# ── Plotting ──────────────────────────────────────────────────────────────────

def compute_spectrogram(signal, fs=128, fmax=30, median_t=31.0, median_f=3.0):
    """High-resolution spectrogram with 2D median-filter smoothing.

    Window 4 s, 75% overlap → 1 s per time bin, 0.25 Hz per freq bin.
    Median filter window: median_t seconds × median_f Hz.
    """
    nperseg  = int(fs * 4)           # 4-second window
    noverlap = int(nperseg * 0.75)   # 75% overlap → 1 s/bin
    f, t, Sxx = scipy.signal.spectrogram(
        signal, fs=fs, nperseg=nperseg, noverlap=noverlap, scaling='density'
    )
    mask = f <= fmax
    f = f[mask]
    logP = np.log10(Sxx[mask, :] + 1e-20)

    # 2D median filter (preserves edges in stage transitions)
    if median_t > 0 and median_f > 0:
        dt_spec = float(t[1] - t[0]) if len(t) > 1 else 1.0
        df_spec = float(f[1] - f[0]) if len(f) > 1 else 0.25

        def _odd(n): return n if n % 2 else n + 1
        med_t = max(3, _odd(int(round(median_t / dt_spec))))
        med_f = max(1, _odd(int(round(median_f / df_spec))))
        logP = scipy.ndimage.median_filter(logP, size=(med_f, med_t), mode='reflect')

    return f, t, logP


def _fmt_hm(seconds):
    """Format duration as 'H:MM' (e.g. '1:23')."""
    if seconds is None:
        return '—'
    h = int(seconds // 3600)
    m = int(round((seconds % 3600) / 60))
    if m == 60:
        h += 1
        m = 0
    return f'{h}:{m:02d}'


# Normative values for healthy adults (~25–45 y)
# Sources: Ohayon et al. 2004 (meta-analysis); Boulos et al. 2019 (Lancet
# Respir Med); AASM scoring manual. Means and SDs are pooled across studies.
NORMS = {
    'TIB':  {'mean': 8.0,  'sd': 1.0,  'unit': 'h',   'in_s': True,  'lo': 0},
    'TST':  {'mean': 7.0,  'sd': 1.0,  'unit': 'h',   'in_s': True,  'lo': 0},
    'SOL':  {'mean': 15.0, 'sd': 12.0, 'unit': 'min', 'in_s': True,  'lo': 0},
    'WASO': {'mean': 30.0, 'sd': 25.0, 'unit': 'min', 'in_s': True,  'lo': 0},
    'SE':   {'mean': 87.0, 'sd': 8.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
    'RL':   {'mean': 85.0, 'sd': 35.0, 'unit': 'min', 'in_s': True,  'lo': 0},
    'AW':   {'mean': 8.0,  'sd': 5.0,  'unit': '',    'in_s': False, 'lo': 0},
    # Stage proportions (% of TIB for Wake, % of TST for sleep stages)
    'Wake': {'mean': 10.0, 'sd': 6.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
    'N1':   {'mean': 6.0,  'sd': 3.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
    'N2':   {'mean': 50.0, 'sd': 7.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
    'N3':   {'mean': 17.0, 'sd': 8.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
    'REM':  {'mean': 22.0, 'sd': 5.0,  'unit': '%',   'in_s': False, 'lo': 0, 'hi': 100},
}


def _to_norm_unit(value, key):
    """Convert raw value (seconds, %, count) to the norm's display unit."""
    if value is None or key not in NORMS:
        return None
    n = NORMS[key]
    if n['unit'] == 'h':
        return value / 3600.0 if n['in_s'] else value
    if n['unit'] == 'min':
        return value / 60.0 if n['in_s'] else value
    return value  # %, count


def _fmt_norm(key):
    """Format the norm range as 'mean ± sd unit'."""
    n = NORMS[key]
    if n['unit'] == 'h':
        return f"{n['mean']:.0f}±{n['sd']:.0f} h"
    if n['unit'] == 'min':
        return f"{n['mean']:.0f}±{n['sd']:.0f} min"
    if n['unit'] == '%':
        return f"{n['mean']:.0f}±{n['sd']:.0f} %"
    return f"{n['mean']:.0f}±{n['sd']:.0f}"


def _draw_minibox(ax, x0, x1, y, value_raw, key, h_box=0.022):
    """Horizontal mini boxplot at row y showing population norm + user value.

    Box   : μ ± 1·SD   (typical range, ~68% of population)
    Whisker: μ ± 2.5·SD (extended normal, ~99%)
    Marker: user's measured value (red if outside whisker range)
    """
    if value_raw is None or key not in NORMS:
        return
    n = NORMS[key]
    value = _to_norm_unit(value_raw, key)
    mean, sd = n['mean'], n['sd']

    range_lo = max(n.get('lo', -1e9), mean - 2.5 * sd)
    range_hi = min(n.get('hi',  1e9), mean + 2.5 * sd)
    width = x1 - x0

    def nx(v):
        v_c = np.clip((v - range_lo) / (range_hi - range_lo), 0, 1)
        return x0 + v_c * width

    whisker_color = '#90a5b8'
    box_face      = '#7a9ab0'
    mean_color    = '#ffffff'
    user_good     = '#4ade80'   # green   — within ±1σ
    user_meh      = '#ffe851'   # yellow  — within ±2.5σ, outside ±1σ
    user_bad      = '#ff4466'   # red     — outside ±2.5σ

    # Whisker line
    ax.plot([x0, x1], [y, y], color=whisker_color, lw=1.1,
            transform=ax.transAxes, clip_on=False, solid_capstyle='butt')
    # Whisker caps
    cap = 0.017
    for xx in (x0, x1):
        ax.plot([xx, xx], [y - cap, y + cap], color=whisker_color, lw=1.1,
                transform=ax.transAxes, clip_on=False)

    # Box (μ ± σ)
    bx0, bx1 = nx(mean - sd), nx(mean + sd)
    box = plt.Rectangle((bx0, y - h_box), bx1 - bx0, 2 * h_box,
                        facecolor=box_face, alpha=0.55, edgecolor=whisker_color,
                        linewidth=0.8,
                        transform=ax.transAxes, clip_on=False)
    ax.add_patch(box)
    # Mean tick
    ax.plot([nx(mean), nx(mean)], [y - h_box * 1.35, y + h_box * 1.35],
            color=mean_color, lw=1.8, transform=ax.transAxes, clip_on=False)

    # User marker — three-tier color: green/yellow/red
    dev = abs(value - mean)
    if dev <= sd:
        user_color = user_good
    elif dev <= 2.5 * sd:
        user_color = user_meh
    else:
        user_color = user_bad

    is_out = (value < range_lo) or (value > range_hi)
    marker_x = nx(value)
    ax.plot(marker_x, y, 'o', markersize=9, color=user_color,
            markeredgecolor=DARK, markeredgewidth=1.4,
            transform=ax.transAxes, clip_on=False, zorder=10)
    if is_out:
        arrow_x = x1 + 0.008 if value > range_hi else x0 - 0.008
        arrow_marker = '▶' if value > range_hi else '◀'
        ax.text(arrow_x, y, arrow_marker, color=user_bad, fontsize=9,
                fontweight='800',
                transform=ax.transAxes, va='center',
                ha='left' if value > range_hi else 'right')


def compute_sleep_stats(preds, epoch_s=30):
    """Classical sleep architecture metrics from epoch-level predictions.

    Conventions:
      - Sleep onset  = first non-Wake epoch
      - TST          = non-Wake epochs from onset to end of recording
      - WASO         = Wake epochs after sleep onset
      - SE           = TST / TIB
      - REM latency  = onset → first REM epoch
      - Awakenings   = sleep→wake transitions after onset
      - Stage %      = % of TST for N1/N2/N3/REM; % of TIB for Wake
    """
    preds = np.asarray(preds)
    n = len(preds)
    tib_s = n * epoch_s

    sleep_mask = preds != 0          # 0 = Wake
    if not sleep_mask.any():
        return {
            'tib_s': tib_s, 'tst_s': 0, 'sol_s': tib_s,
            'waso_s': 0, 'se_pct': 0.0, 'rem_lat_s': None,
            'n_awakenings': 0,
            'stage_times': {n: 0 for n in STAGE_NAMES},
            'stage_pct':   {n: 0.0 for n in STAGE_NAMES},
        }

    onset_idx = int(np.argmax(sleep_mask))
    sol_s = onset_idx * epoch_s

    after = preds[onset_idx:]
    tst_epochs  = int(np.sum(after != 0))
    waso_epochs = int(np.sum(after == 0))
    tst_s  = tst_epochs * epoch_s
    waso_s = waso_epochs * epoch_s
    se_pct = 100.0 * tst_s / tib_s if tib_s else 0.0

    rem_mask_after = after == 4
    rem_lat_s = int(np.argmax(rem_mask_after)) * epoch_s if rem_mask_after.any() else None

    sleep_to_wake = (after[:-1] != 0) & (after[1:] == 0)
    n_awakenings  = int(np.sum(sleep_to_wake))

    stage_times, stage_pct = {}, {}
    for i, name in enumerate(STAGE_NAMES):
        count = int(np.sum(preds == i))
        stage_times[name] = count * epoch_s
        denom = n if name == 'Wake' else max(tst_epochs, 1)
        stage_pct[name] = 100.0 * count / denom

    return {
        'tib_s': tib_s, 'tst_s': tst_s, 'sol_s': sol_s,
        'waso_s': waso_s, 'se_pct': se_pct, 'rem_lat_s': rem_lat_s,
        'n_awakenings': n_awakenings,
        'stage_times': stage_times, 'stage_pct': stage_pct,
    }


def _add_stats_panel(fig, gs_pos, stats):
    """Sleep statistics panel: metrics + normative boxplots."""
    ax = fig.add_subplot(gs_pos)
    ax.set_facecolor(DARK)
    ax.set_xticks([]); ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_edgecolor('#2a3f52')
        spine.set_linewidth(1)

    # ── Section headers ──────────────────────────────────────────────────────
    ax.text(0.008, 0.93, 'SLEEP STATISTICS', color=WHITE,
            fontsize=11, fontweight='800', transform=ax.transAxes, va='top')
    ax.text(0.515, 0.93, 'STAGE BREAKDOWN', color=WHITE,
            fontsize=11, fontweight='800', transform=ax.transAxes, va='top')

    # Legend explainer — single line, right-aligned
    legend_y = 0.91
    # First the descriptive text on the left of the legend area
    ax.text(0.745, legend_y,
            'box ±1σ · whisker ±2.5σ ·',
            color='#7a9ab0', fontsize=8, fontstyle='italic',
            transform=ax.transAxes, va='center', ha='left')
    # Then three color dots with their meaning, inline
    for x_dot, x_lbl, label, col in (
            (0.870, 0.882, 'in ±1σ',    '#4ade80'),
            (0.916, 0.927, 'in ±2.5σ',  '#ffe851'),
            (0.974, 0.985, 'out',       '#ff4466'),
    ):
        ax.plot(x_dot, legend_y, 'o', markersize=7,
                color=col, markeredgecolor=DARK, markeredgewidth=1.0,
                transform=ax.transAxes, clip_on=False)
        ax.text(x_lbl, legend_y, label,
                color='#7a9ab0', fontsize=8,
                transform=ax.transAxes, va='center', ha='left')

    # ── LEFT: overall sleep metrics ──────────────────────────────────────────
    overall = [
        ('Time in bed (TIB)',          _fmt_hm(stats['tib_s']),      stats['tib_s'],     'TIB'),
        ('Total sleep time (TST)',     _fmt_hm(stats['tst_s']),      stats['tst_s'],     'TST'),
        ('Sleep onset latency (SOL)',  _fmt_hm(stats['sol_s']),      stats['sol_s'],     'SOL'),
        ('Wake after onset (WASO)',    _fmt_hm(stats['waso_s']),     stats['waso_s'],    'WASO'),
        ('Sleep efficiency (SE)',      f"{stats['se_pct']:.1f} %",   stats['se_pct'],    'SE'),
        ('REM latency (RL)',           _fmt_hm(stats['rem_lat_s']),  stats['rem_lat_s'], 'RL'),
        ('Awakenings (AW)',            str(stats['n_awakenings']),   stats['n_awakenings'], 'AW'),
    ]

    # Left columns
    y0_l   = 0.78
    dy_l   = 0.105
    x_lbl  = 0.008
    x_val  = 0.235     # right-aligned numeric value
    box_x0 = 0.26
    box_x1 = 0.40
    x_norm = 0.415     # norm range text starts here

    for i, (label, val_str, val_raw, key) in enumerate(overall):
        y = y0_l - i * dy_l
        ax.text(x_lbl, y, label, color=TEXT, fontsize=10,
                transform=ax.transAxes, va='center')
        ax.text(x_val, y, val_str, color='white', fontsize=11,
                fontweight='700', family='monospace',
                transform=ax.transAxes, va='center', ha='right')
        _draw_minibox(ax, box_x0, box_x1, y, val_raw, key)
        ax.text(x_norm, y, _fmt_norm(key),
                color='#7a9ab0', fontsize=8.5, family='monospace',
                transform=ax.transAxes, va='center')

    # ── RIGHT: per-stage breakdown ───────────────────────────────────────────
    # 5 rows, distributed over same height as 7 rows on the left
    y0_r   = 0.78
    dy_r   = dy_l * 7 / 5 * 0.9   # similar vertical span, comfortable spacing
    x_sw    = 0.520
    x_name  = 0.545
    x_time  = 0.612      # right-aligned
    x_pct   = 0.680      # right-aligned (% of TST or TIB)
    bx0     = 0.700
    bx1     = 0.870
    x_norm_r = 0.880

    for i, name in enumerate(STAGE_NAMES):
        y = y0_r - i * dy_r
        color = STAGE_COLORS[i]
        # Color swatch
        ax.add_patch(plt.Rectangle((x_sw, y - 0.038), 0.016, 0.076,
                                    color=color, alpha=0.9,
                                    transform=ax.transAxes, clip_on=False))
        ax.text(x_name, y, name, color='white', fontsize=11,
                fontweight='700', transform=ax.transAxes, va='center')
        ax.text(x_time, y, _fmt_hm(stats['stage_times'][name]),
                color='white', fontsize=11, fontweight='700',
                family='monospace',
                transform=ax.transAxes, va='center', ha='right')
        denom = 'TIB' if name == 'Wake' else 'TST'
        ax.text(x_pct, y, f"{stats['stage_pct'][name]:.1f} %",
                color='white', fontsize=11, fontweight='700',
                family='monospace',
                transform=ax.transAxes, va='center', ha='right')

        _draw_minibox(ax, bx0, bx1, y, stats['stage_pct'][name], name)
        ax.text(x_norm_r, y, _fmt_norm(name) + f' ({denom})',
                color='#7a9ab0', fontsize=8.5, family='monospace',
                transform=ax.transAxes, va='center')


def _style_ax(ax):
    ax.set_facecolor(DARK)
    for spine in ax.spines.values():
        spine.set_edgecolor('#2a3f52')
        spine.set_linewidth(1)
    ax.tick_params(colors=TEXT, labelsize=9)


def plot_sleep_staging(preds, probs, n_epochs, sig_128, ch_label='', fmax=30.0):
    """3-panel figure: probabilities | hypnogram | EEG spectrogram."""
    duration_h = n_epochs * 30 / 3600
    t_h = np.arange(n_epochs) * 30 / 3600        # epoch start times in hours

    fig = plt.figure(figsize=(20, 12.5), facecolor=DARK)
    gs = gridspec.GridSpec(4, 2, height_ratios=[1.2, 0.8, 1.5, 1.5],
                            width_ratios=[80, 1], hspace=0.42, wspace=0.012)

    # ── Panel 0: stacked probability area ────────────────────────────────────
    ax0 = fig.add_subplot(gs[0, 0])
    ax0.stackplot(t_h, *[probs[:, i] for i in range(5)],
                  labels=STAGE_NAMES, colors=STAGE_COLORS, alpha=0.85)
    ax0.set_ylim(0, 1)
    ax0.set_xlim(0, duration_h)
    ax0.set_ylabel('Probability', color=TEXT, fontsize=11, fontweight='600')
    ax0.legend(loc='upper right', framealpha=0.85, fontsize=10,
               labelcolor=DARK, facecolor=TEXT)
    ax0.grid(True, alpha=0.15, color=GRID, linestyle='--', linewidth=0.5)
    _style_ax(ax0)
    ax0.set_xticklabels([])

    # ── Panel 1: hypnogram ───────────────────────────────────────────────────
    ax1 = fig.add_subplot(gs[1, 0])
    hypo_y = np.array([STAGE_TO_HYPO_Y[int(p)] for p in preds])

    # Colored fill per segment
    for i in range(n_epochs):
        x0 = t_h[i]
        x1 = t_h[i] + 30 / 3600
        y  = hypo_y[i]
        ax1.fill_between([x0, x1], y - 0.45, y + 0.45,
                         color=STAGE_COLORS[int(preds[i])], alpha=0.7)

    ax1.step(t_h, hypo_y, where='post', color='white', linewidth=1.2, alpha=0.6)
    ax1.set_yticks(range(5))
    ax1.set_yticklabels(HYPO_Y_LABELS)
    ax1.set_ylim(-0.6, 4.6)
    ax1.set_xlim(0, duration_h)
    ax1.set_ylabel('Stage', color=TEXT, fontsize=11, fontweight='600')
    ax1.grid(True, alpha=0.12, color=GRID, linestyle='--', linewidth=0.5, axis='x')
    _style_ax(ax1)
    ax1.set_xticklabels([])

    # ── Panel 2: EEG spectrogram ─────────────────────────────────────────────
    ax2 = fig.add_subplot(gs[2, 0])
    f, t_spec, logP = compute_spectrogram(sig_128, fs=128, fmax=fmax)

    vmin, vmax = np.percentile(logP, [5, 97])
    im = ax2.pcolormesh(t_spec / 3600, f, logP,
                        shading='auto', cmap='turbo',
                        vmin=vmin, vmax=vmax, rasterized=True)
    ax2.set_ylim(0, fmax)
    ax2.set_xlim(0, duration_h)
    ax2.set_xlabel('Time (hours)', color=TEXT, fontsize=11, fontweight='600')
    ax2.set_ylabel('Frequency (Hz)', color=TEXT, fontsize=11, fontweight='600')
    _style_ax(ax2)

    cax = fig.add_subplot(gs[2, 1])
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_label('log₁₀(Power)', color=TEXT, fontsize=9)
    cbar.ax.tick_params(colors=TEXT, labelsize=8)
    cbar.outline.set_edgecolor('#2a3f52')

    # ── Panel 3: sleep statistics table ──────────────────────────────────────
    stats = compute_sleep_stats(preds, epoch_s=30)
    _add_stats_panel(fig, gs[3, 0], stats)

    title = f'Sleep Staging  —  {n_epochs} epochs  ({duration_h:.1f} h)'
    if ch_label:
        title += f'  —  {ch_label}'
    fig.suptitle(title, color=TEXT, fontsize=13, fontweight='700', y=0.975)
    fig.subplots_adjust(top=0.945, bottom=0.04, left=0.05, right=0.965)
    return fig


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description='Sleep Staging with AnySleep')
    parser.add_argument('file', help='EDF file path')
    parser.add_argument('--channel', '-c', type=int, default=0,
                        help='EEG channel index (0–7), default 0')
    parser.add_argument('--model', type=str, default=None,
                        help='Path to anysleep-run1.pth (auto-download if omitted)')
    parser.add_argument('--start', '-s', type=float, default=0.0,
                        help='Start time in seconds')
    parser.add_argument('--duration', '-d', type=float, default=-1,
                        help='Duration in seconds (-1 = full file)')
    parser.add_argument('--fmax', type=float, default=30.0,
                        help='Spectrogram max frequency in Hz (default 30, max 60)')
    parser.add_argument('--save', type=str, default=None,
                        help='Save plot to PNG instead of showing interactively')
    args = parser.parse_args()

    # ── Load signal ───────────────────────────────────────────────────────────
    print(f'Loading {args.file}, channel {args.channel}...')
    signal, fs = read_edf_channel(args.file, args.channel)
    print(f'✓ {len(signal)} samples @ {fs:.1f} Hz  ({len(signal)/fs/3600:.2f} h)')

    start_idx = int(args.start * fs)
    end_idx   = (start_idx + int(args.duration * fs)) if args.duration > 0 else len(signal)
    signal    = signal[start_idx:end_idx]
    print(f'  Using {len(signal)/fs:.0f} s  ({len(signal)/fs/3600:.2f} h)')

    # ── Model ─────────────────────────────────────────────────────────────────
    if args.model is None:
        args.model = Path(__file__).parent / 'anysleep-run1.pth'
    model_path = ensure_model(args.model)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Device: {device}')

    # ── Inference ─────────────────────────────────────────────────────────────
    print('Running inference (1–3 min on CPU)...')
    preds, probs, n_epochs, sig_128 = run_inference(signal, fs, model_path, device=device)
    print(f'✓ {n_epochs} epochs classified')

    stage_counts = {name: int(np.sum(preds == i)) for i, name in enumerate(STAGE_NAMES)}
    print('  ' + '  '.join(f'{k}: {v}' for k, v in stage_counts.items()))

    # ── Plot ──────────────────────────────────────────────────────────────────
    fmax = float(np.clip(args.fmax, 5.0, 60.0))
    ch_label = f'Channel {args.channel} (EEG Ch{args.channel + 1})'
    fig = plot_sleep_staging(preds, probs, n_epochs, sig_128,
                             ch_label=ch_label, fmax=fmax)
    if args.save:
        fig.savefig(args.save, dpi=110, facecolor=DARK, bbox_inches='tight')
        print(f'✓ Plot saved to {args.save}')
    else:
        plt.show()


if __name__ == '__main__':
    main()
