#!/usr/bin/env python3
"""
Traumschreiber EDF – EEG Time-Frequency-Analyse (Alpha/Beta)
=============================================================
Erstellt einen Time-Frequency-Plot eines EEG-Kanals im Bereich 1–30 Hz.
Ideal für Experimente mit Entspannung vs. kognitiver Belastung – Alpha-
und Beta-Band sind klar markiert und leicht zu vergleichen.

Layout (2 Panels):
  1. Gefiltertes EEG-Signal (Bandpass 1–40 Hz, Notch 50 Hz)
  2. Spektrogramm 1–30 Hz mit Alpha- (8–13 Hz) und Beta-Band (13–30 Hz)

Beispielexperiment (je 20 Sekunden):
  Augen zu / Entspannung → Runterzählen von 700 in 7er-Schritten →
  Entspannung → Runterzählen → Entspannung
  → Alpha-Band sollte während Entspannung heller sein (mehr Power),
    Beta-Band während kognitiver Belastung.

Verwendung:
  python plot_eeg_timefreq.py DATEI.EDF
  python plot_eeg_timefreq.py DATEI.EDF --channel 0 --start 0 --duration 120
  python plot_eeg_timefreq.py DATEI.EDF --channel 2 --duration 90

Kanal wählen:
  Bei 3–4 Elektroden typisch Kanal 5 oder 7 (Ch6/Ch8).
  Bei 8 Elektroden typisch Kanal 0 oder 2 (Ch1/Ch3).
  Einfach alle 8 Kanäle nacheinander ausprobieren.

Bibliotheken installieren (einmalig):
  pip install numpy matplotlib scipy
"""

import argparse
import struct
import sys
import warnings
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

try:
    from scipy.signal import butter, sosfilt, spectrogram as _spectrogram
    from scipy.ndimage import median_filter as _median_filter
    HAS_SCIPY = True
except ImportError:
    _median_filter = None  # type: ignore[assignment]
    HAS_SCIPY = False

DEFAULT_DURATION = 120  # Sekunden


# ── EDF reader ────────────────────────────────────────────────────────────────

def read_edf_channel(filepath, channel_idx):
    """Liest einen Kanal (0–7) aus einer Traumschreiber-EDF-Datei.
    Gibt (signal: np.ndarray float64, sample_rate: float) zurück.
    """
    raw = Path(filepath).read_bytes()
    if len(raw) < 256:
        raise ValueError("Datei zu klein – kein gültiges EDF-Format.")

    hdr = raw[:256].decode('latin-1')

    def _i(s):
        try: return int(s.strip())
        except ValueError: return 0

    def _f(s):
        try: return float(s.strip())
        except ValueError: return 1.0

    header_bytes    = _i(hdr[184:192]) or 3072
    num_signals     = _i(hdr[252:256]) or 11
    num_records     = _i(hdr[236:244])
    record_duration = _f(hdr[244:252])

    if not (0 <= channel_idx < min(num_signals, 8)):
        raise ValueError(
            f"Kanal {channel_idx} ist ungültig. "
            f"Verfügbare Kanäle: 0–{min(num_signals, 8) - 1}."
        )

    sig_hdr = raw[256:header_bytes].decode('latin-1')
    field_widths = [16, 80, 8, 8, 8, 8, 8, 80, 8, 32]
    field_names  = ['label', 'transducer', 'dim', 'phys_min', 'phys_max',
                    'dig_min', 'dig_max', 'prefilter', 'num_samples', 'reserved']
    sigs = [{} for _ in range(num_signals)]
    off = 0
    for fw, fn in zip(field_widths, field_names):
        for i in range(num_signals):
            sigs[i][fn] = sig_hdr[off:off + fw].strip()
            off += fw

    ns = []
    for s in sigs:
        try: ns.append(int(s['num_samples']))
        except ValueError: ns.append(0)

    record_bytes = sum(n * 2 for n in ns)
    if record_bytes == 0:
        raise ValueError("EDF-Header meldet 0 Samples/Record – Datei beschädigt?")

    sample_rate = ns[0] / record_duration if record_duration > 0 else 250.0

    samples = []
    offset = header_bytes
    rec = 0
    while offset + record_bytes <= len(raw):
        if num_records >= 0 and rec >= num_records:
            break
        pos = offset
        for i, n in enumerate(ns):
            if i == channel_idx:
                samples.extend(struct.unpack_from(f'<{n}h', raw, pos))
            pos += n * 2
        offset = pos
        rec += 1

    return np.array(samples, dtype=np.float64), float(sample_rate)


# ── Filter helpers ────────────────────────────────────────────────────────────

def _safe_bandpass(signal, fs, hp_hz, lp_hz, order=5):
    nyq = 0.5 * fs
    lo  = hp_hz / nyq
    hi  = min(lp_hz, nyq * 0.99) / nyq
    if lo <= 0.0 or lo >= 1.0 or hi <= 0.0 or hi >= 1.0 or lo >= hi:
        warnings.warn('Bandpass-Grenzen ungültig – Filter übersprungen.', RuntimeWarning)
        return signal.copy()
    sos = butter(order, [lo, hi], btype='band', output='sos')
    return sosfilt(sos, signal)


def _safe_notch(signal, fs, freq=50.0, bw=2.0, order=4):
    nyq = 0.5 * fs
    lo  = (freq - bw) / nyq
    hi  = (freq + bw) / nyq
    if lo <= 0.0 or hi >= 1.0:
        return signal.copy()
    sos = butter(order, [lo, hi], btype='bandstop', output='sos')
    return sosfilt(sos, signal)


def _lp30(signal, fs, order=5):
    """Butterworth 30 Hz Tiefpassfilter, 5. Ordnung, sosfilt."""
    nyq = 0.5 * fs
    cutoff = min(30.0, nyq * 0.99) / nyq
    sos = butter(order, cutoff, btype='low', output='sos')
    return sosfilt(sos, signal)


# ── Spectrogram ───────────────────────────────────────────────────────────────

def compute_tf(signal, fs, fmin=13.0, fmax=30.0):
    # 4-Sekunden-Fenster → gute Frequenzauflösung im Alpha/Beta-Bereich
    nperseg  = min(int(fs * 4), len(signal))
    nperseg  = max(nperseg, 64)
    noverlap = int(nperseg * 0.75)  # 75% Überlappung → glatte Zeitachse
    f, t, Sxx = _spectrogram(
        signal, fs=fs, nperseg=nperseg, noverlap=noverlap, scaling='density'
    )
    mask = (f >= fmin) & (f <= fmax)
    return f[mask], t, np.log10(Sxx[mask] + 1e-20)


# ── Plot-Stil ─────────────────────────────────────────────────────────────────

DARK   = '#0d1c2b'
GRID   = '#1a2f42'
TEXT   = '#aabbcc'
BLUE   = '#4dd4f4'
YELLOW = '#ffe851'
ORANGE = '#ff9151'


def _style(ax):
    ax.set_facecolor(DARK)
    for sp in ax.spines.values():
        sp.set_color('#2a3f52')
    ax.tick_params(colors=TEXT, labelsize=7)
    ax.yaxis.label.set_color(TEXT)
    ax.xaxis.label.set_color(TEXT)
    ax.title.set_color('white')
    ax.grid(True, linestyle='--', linewidth=0.35, color=GRID, alpha=0.8)


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description='Traumschreiber EEG – Time-Frequency-Analyse (Alpha/Beta, 5–25 Hz)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            'Beispiele:\n'
            '  python plot_eeg_timefreq.py ADC/AD0050.EDF\n'
            '  python plot_eeg_timefreq.py ADC/AD0050.EDF --channel 0 --duration 120\n'
            '  python plot_eeg_timefreq.py ADC/AD0050.EDF --channel 2 --start 10 --duration 90\n\n'
            'Tipp: Beim Entspannungs-/Kognitionsexperiment (4 × 20 s) reichen 90–120 s Dauer.'
        ),
    )
    parser.add_argument('file',
                        help='EDF-Datei (z.B. ADC/AD0050.EDF)')
    parser.add_argument('--channel',  '-c', type=int,   default=0,
                        help='Kanal 0–7 (Standard: 0 = Ch1)')
    parser.add_argument('--start',    '-s', type=float, default=0.0,
                        help='Startzeit in Sekunden (Standard: 0)')
    parser.add_argument('--duration', '-d', type=float, default=DEFAULT_DURATION,
                        help=f'Fensterlänge in Sekunden (Standard: {DEFAULT_DURATION})')
    parser.add_argument('--median', '-m', type=float, default=11.0,
                        help='Länge des horizontalen Median-Filters in Sekunden (Standard: 11, 0 = deaktiviert)')
    parser.add_argument('--median-freq', type=float, default=1.0,
                        help='Länge des vertikalen Median-Filters in Hz (Standard: 1, 0 = deaktiviert)')
    args = parser.parse_args()

    if not HAS_SCIPY:
        print('FEHLER: scipy ist nicht installiert.')
        print('Bitte ausführen:  pip install numpy matplotlib scipy')
        sys.exit(1)

    print(f'Loading   : {args.file}')
    print(f'Channel   : {args.channel} (Ch{args.channel + 1})')
    print(f'Window    : {args.start:.1f} s + {args.duration:.0f} s')

    try:
        sig_full, fs = read_edf_channel(args.file, args.channel)
    except FileNotFoundError:
        print(f'Datei nicht gefunden: {args.file}')
        sys.exit(1)
    except Exception as exc:
        print(f'Fehler beim Lesen: {exc}')
        sys.exit(1)

    total_s = len(sig_full) / fs
    print(f'Sample rate: {fs:.1f} Hz')
    print(f'Total length: {total_s:.1f} s ({int(total_s // 60)} min {total_s % 60:.0f} s)')

    start_s = float(args.start)
    end_s   = min(start_s + args.duration, total_s)

    if start_s >= total_s:
        print(f'Fehler: --start {start_s:.0f} s liegt hinter dem Signalende ({total_s:.0f} s).')
        sys.exit(1)
    if end_s - start_s < 8.0:
        print('Fehler: Fenster zu kurz (mindestens 8 s nötig für Spektrogramm).')
        sys.exit(1)

    i0  = int(start_s * fs)
    i1  = int(end_s   * fs)
    sig = sig_full[i0:i1]
    t   = np.linspace(start_s, end_s, len(sig), endpoint=False)
    print(f'Analysing  : {start_s:.1f}–{end_s:.1f} s ({len(sig)} samples)\n')

    # ── Filtern ──────────────────────────────────────────────────────────────
    sig_filt = _safe_bandpass(sig, fs, hp_hz=1.0, lp_hz=40.0)
    sig_filt = _safe_notch(sig_filt, fs)
    sig_filt = _lp30(sig_filt, fs)

    # ── Spektrogramm ─────────────────────────────────────────────────────────
    f_s, t_s, logP = compute_tf(sig_filt, fs)

    # Median filter (horizontal = Zeit, vertikal = Frequenz)
    def _odd(n): return n if n % 2 else n + 1

    med_t = 1   # 1 bin = kein Effekt
    med_f = 1
    if args.median > 0 and len(t_s) > 1:
        dt_spec = float(t_s[1] - t_s[0])
        med_t   = max(3, _odd(int(round(args.median / dt_spec))))
    if args.median_freq > 0 and len(f_s) > 1:
        df_spec = float(f_s[1] - f_s[0])
        med_f   = max(1, _odd(int(round(args.median_freq / df_spec))))

    if med_t > 1 or med_f > 1:
        logP = _median_filter(logP, size=(med_f, med_t), mode='reflect')

    print(f'Median filter: time={med_t} bins ({med_t * (float(t_s[1]-t_s[0]) if len(t_s)>1 else 0):.1f} s)  '
          f'freq={med_f} bins ({med_f * (float(f_s[1]-f_s[0]) if len(f_s)>1 else 0):.2f} Hz)')

    # ── Figure ────────────────────────────────────────────────────────────────
    dur = end_s - start_s
    fig_w = max(14.0, min(dur / 5.0, 24.0))
    fig   = plt.figure(figsize=(fig_w, 8), facecolor=DARK)
    gs    = gridspec.GridSpec(
        2, 1, figure=fig,
        hspace=0.38,
        left=0.07, right=0.97, top=0.92, bottom=0.08,
        height_ratios=[1, 1.9],
    )

    ax_sig = fig.add_subplot(gs[0])
    ax_tf  = fig.add_subplot(gs[1])

    ch_lbl  = f'Channel {args.channel} (Ch{args.channel + 1})'
    win_lbl = f'{start_s:.0f}–{end_s:.0f} s'

    # Panel 1: raw signal with 30 Hz low-pass (Butterworth 5th order, sosfilt)
    sig_display = _lp30(sig_filt, fs)
    ax_sig.plot(t, sig_display, '-', lw=0.45, color=BLUE, rasterized=True)
    ax_sig.set_title(
        f'EEG – Filtered Signal ({ch_lbl}, Bandpass 1–40 Hz + Notch 50 Hz + LP 30 Hz)',
        fontsize=9, fontweight='bold', pad=4, color='white'
    )
    ax_sig.set_yticks([])
    ax_sig.set_ylabel('Amplitude', fontsize=7)
    ax_sig.set_xlim(start_s, end_s)
    _style(ax_sig)

    # Panel 2: spectrogram
    vmin = float(np.percentile(logP, 5))
    vmax = float(np.percentile(logP, 97))
    T_m, F_m = np.meshgrid(t_s + start_s, f_s)
    ax_tf.pcolormesh(T_m, F_m, logP, shading='auto', cmap='turbo', vmin=vmin, vmax=vmax)
    ax_tf.set_ylim(13, 30)
    ax_tf.set_xlim(start_s, end_s)
    ax_tf.set_ylabel('Frequency (Hz)', fontsize=7)
    ax_tf.set_xlabel('Time (s)', fontsize=7)
    ax_tf.set_title(
        'Time-Frequency Plot (13–30 Hz)  —  Beta band',
        fontsize=9, fontweight='bold', pad=4, color='white'
    )
    _style(ax_tf)

    # Beta band (13–30 Hz) – lower boundary
    ax_tf.axhline(13, color=ORANGE, lw=1.0, ls='--', alpha=0.75)
    ax_tf.axhline(30, color=ORANGE, lw=1.0, ls='--', alpha=0.75)
    ax_tf.axhspan(13, 30, alpha=0.03, color=ORANGE)
    ax_tf.text(start_s + dur * 0.01, 21.5,
               'Beta  13–30 Hz', fontsize=6.5, color=ORANGE, alpha=0.9, va='center')

    fig.suptitle(
        f'Traumschreiber – EEG Alpha/Beta  |  {ch_lbl}  |  {win_lbl}',
        fontsize=11, fontweight='bold', color='white',
    )

    print('Panels:')
    print('  1. Filtered EEG (Bandpass 1–40 Hz, Notch 50 Hz, LP 30 Hz)')
    print('  2. Time-Frequency Plot (13–30 Hz)  [median-filtered per freq row]')
    print('     Orange dashed : Beta band (13–30 Hz) – brighter during cognitive load')
    print('\nClose window to exit.')

    plt.show()


if __name__ == '__main__':
    main()
