#!/usr/bin/env python3
"""
Traumschreiber EDF Tutorial – Signalanalyse mit Python
=======================================================
Zeigt 30 Sekunden eines EEG-Kanals aus einer EDF-Datei mit vier Panels:
  1. Rohsignal (30 s)
  2. Time-Frequency-Plot (Spektrogramm)
  3. Vergleich: ohne vs. mit 0,5 Hz Hochpassfilter
  4. Vergleich: ohne vs. mit 30 Hz Tiefpassfilter (Butterworth, 5. Ordnung)

Verwendung:
  python plot_edf_tutorial.py DATEI.EDF
  python plot_edf_tutorial.py DATEI.EDF --channel 0 --start 30

Bibliotheken installieren (einmalig):
  pip install numpy matplotlib scipy
"""

import argparse
import struct
import sys
import warnings
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

try:
    from scipy.signal import butter, sosfilt, spectrogram as _spectrogram
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

WINDOW_S = 30  # immer 30 Sekunden plotten


# ── EDF reader ───────────────────────────────────────────────────────────────

def read_edf_channel(filepath, channel_idx):
    """Read one ADS channel (0–7) from a Traumschreiber EDF file.

    Returns (signal: np.ndarray float64, sample_rate: float).
    """
    raw = Path(filepath).read_bytes()
    if len(raw) < 256:
        raise ValueError("Datei zu klein – kein gültiges EDF-Format.")

    hdr = raw[:256].decode('latin-1')

    def _i(s):
        try: return int(s.strip())
        except ValueError: return 0

    def _f(s):
        try: return float(s.strip())
        except ValueError: return 1.0

    header_bytes    = _i(hdr[184:192]) or 3072
    num_signals     = _i(hdr[252:256]) or 11
    num_records     = _i(hdr[236:244])   # -1 means unknown → scan to EOF
    record_duration = _f(hdr[244:252])

    if not (0 <= channel_idx < min(num_signals, 8)):
        raise ValueError(
            f"Kanal {channel_idx} ist ungültig. "
            f"Verfügbare EEG-Kanäle: 0–{min(num_signals, 8) - 1}."
        )

    # EDF signal headers: fields are stored as groups, each group covers all signals
    # Layout: [all labels][all transducers][all dims]...[all num_samples][all reserved]
    sig_hdr = raw[256:header_bytes].decode('latin-1')
    field_widths = [16, 80, 8, 8, 8, 8, 8, 80, 8, 32]
    field_names  = ['label', 'transducer', 'dim', 'phys_min', 'phys_max',
                    'dig_min', 'dig_max', 'prefilter', 'num_samples', 'reserved']
    sigs = [{} for _ in range(num_signals)]
    off = 0
    for fw, fn in zip(field_widths, field_names):
        for i in range(num_signals):
            sigs[i][fn] = sig_hdr[off:off + fw].strip()
            off += fw

    ns = []
    for s in sigs:
        try: ns.append(int(s['num_samples']))
        except ValueError: ns.append(0)

    record_bytes = sum(n * 2 for n in ns)
    if record_bytes == 0:
        raise ValueError("EDF-Header meldet 0 Samples/Record – Datei beschädigt?")

    sample_rate = ns[0] / record_duration if record_duration > 0 else 250.0

    # Collect samples for channel_idx across all data records
    samples = []
    offset = header_bytes
    rec = 0
    while offset + record_bytes <= len(raw):
        if num_records >= 0 and rec >= num_records:
            break
        pos = offset
        for i, n in enumerate(ns):
            if i == channel_idx:
                samples.extend(struct.unpack_from(f'<{n}h', raw, pos))
            pos += n * 2
        offset = pos
        rec += 1

    return np.array(samples, dtype=np.float64), float(sample_rate)


# ── Filter helpers ────────────────────────────────────────────────────────────

def hp_filter(signal, fs, cutoff=0.5, order=5):
    """0,5 Hz Hochpassfilter – Butterworth, 5. Ordnung, kausal (sosfilt)."""
    sos = butter(order, cutoff / (0.5 * fs), btype='high', output='sos')
    return sosfilt(sos, signal)


def lp_filter(signal, fs, cutoff=30.0, order=5):
    """30 Hz Tiefpassfilter – Butterworth, 5. Ordnung, kausal (sosfilt)."""
    sos = butter(order, cutoff / (0.5 * fs), btype='low', output='sos')
    return sosfilt(sos, signal)


# ── Spectrogram helper ────────────────────────────────────────────────────────

def compute_tf(signal, fs):
    """Return (freqs, times, log10_power_clipped) from a Welch spectrogram."""
    nperseg = min(int(fs * 2), len(signal))
    nperseg = max(nperseg, 64)
    f, t, Sxx = _spectrogram(
        signal, fs=fs, nperseg=nperseg, noverlap=nperseg // 2, scaling='density'
    )
    logP = np.log10(Sxx + 1e-20)
    return f, t, logP


# ── Plot helpers ──────────────────────────────────────────────────────────────

DARK  = '#0d1c2b'
GRID  = '#1a2f42'
TEXT  = '#aabbcc'
ORANGE = '#ff9151'
GREEN  = '#7be0ad'
BLUE   = '#abd5ff'
GREY   = '#7a9ab0'


def _style(ax):
    ax.set_facecolor(DARK)
    for sp in ax.spines.values():
        sp.set_color('#2a3f52')
    ax.tick_params(colors=TEXT, labelsize=7)
    ax.yaxis.label.set_color(TEXT)
    ax.xaxis.label.set_color(TEXT)
    ax.title.set_color('white')
    ax.grid(True, linestyle='--', linewidth=0.35, color=GRID, alpha=0.8)


def plot_wave(ax, t, sig, title, color=ORANGE, xlabel=False):
    ax.plot(t, sig, '-', lw=0.55, color=color, rasterized=True)
    ax.set_title(title, fontsize=9, fontweight='bold', pad=4)
    ax.set_yticks([])
    ax.set_ylabel('Amplitude', fontsize=7)
    if xlabel:
        ax.set_xlabel('Time (s)', fontsize=7)
    _style(ax)


def plot_tf_ax(ax, f, t, logP, t0=0.0, max_freq=62.0):
    mask = (f >= 0.3) & (f <= max_freq)
    f_s  = f[mask]
    S_s  = logP[mask]
    vmin = float(np.percentile(S_s, 5))
    vmax = float(np.percentile(S_s, 97))
    T, F = np.meshgrid(t + t0, f_s)
    ax.pcolormesh(T, F, S_s, shading='auto', cmap='turbo', vmin=vmin, vmax=vmax)
    ax.set_ylim(0, max_freq)
    ax.set_ylabel('Frequency (Hz)', fontsize=7)
    ax.set_xlabel('Time (s)', fontsize=7)
    ax.set_title('Time-Frequency Plot (Spectrogram)', fontsize=9, fontweight='bold', pad=4)
    _style(ax)
    # Mark 50 Hz powerline noise
    ax.axhline(50, color='white', lw=0.7, ls='--', alpha=0.55)
    ax.text(t0 + 0.3, 50.8, '50 Hz', fontsize=6, color='white', alpha=0.6)


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description='Traumschreiber EDF Tutorial – Signalanalyse',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            'Beispiele:\n'
            '  python plot_edf_tutorial.py ADC/AD0041.EDF\n'
            '  python plot_edf_tutorial.py ADC/AD0041.EDF --channel 0 --start 30\n'
            '  python plot_edf_tutorial.py ADC/AD0041.EDF --channel 2 --start 60'
        ),
    )
    parser.add_argument('file',
                        help='EDF-Datei (z.B. ADC/AD0041.EDF)')
    parser.add_argument('--channel', '-c', type=int, default=0,
                        help='EEG-Kanal 0–7 (Standard: 0 = EEG Ch1)')
    parser.add_argument('--start', '-s', type=float, default=0.0,
                        help='Startzeit in Sekunden (Standard: 0). '
                             f'Es werden immer {WINDOW_S} s geplottet.')
    args = parser.parse_args()

    # ── dependency check ─────────────────────────────────────────────────────
    if not HAS_SCIPY:
        print('FEHLER: scipy ist nicht installiert.')
        print('Bitte ausführen:  pip install numpy matplotlib scipy')
        sys.exit(1)

    # ── load ─────────────────────────────────────────────────────────────────
    print(f'Lade: {args.file}  |  Kanal {args.channel}  |  Start {args.start:.1f} s')
    try:
        sig_full, fs = read_edf_channel(args.file, args.channel)
    except FileNotFoundError:
        print(f'Datei nicht gefunden: {args.file}')
        sys.exit(1)
    except Exception as exc:
        print(f'Fehler beim Lesen: {exc}')
        sys.exit(1)

    total_s = len(sig_full) / fs
    print(f'  Abtastrate : {fs:.1f} Hz')
    print(f'  Gesamtdauer: {total_s:.1f} s  ({int(total_s//60)} min {total_s%60:.0f} s)')

    start_s = float(args.start)
    end_s   = min(start_s + WINDOW_S, total_s)

    if start_s >= total_s:
        print(f'Fehler: --start {start_s:.0f} s liegt hinter dem Signalende ({total_s:.0f} s).')
        sys.exit(1)
    if end_s - start_s < 1.0:
        print('Fehler: Zeitfenster zu kurz (< 1 s).')
        sys.exit(1)

    i0  = int(start_s * fs)
    i1  = int(end_s   * fs)
    sig = sig_full[i0:i1]
    t   = np.linspace(start_s, end_s, len(sig), endpoint=False)
    print(f'  Fenster    : {start_s:.1f}–{end_s:.1f} s  ({len(sig)} Samples)\n')

    # ── filters ──────────────────────────────────────────────────────────────
    nyq = 0.5 * fs

    if 0.5 / nyq >= 1.0:
        warnings.warn('Abtastrate zu niedrig für HP-Filter – überspringe.', RuntimeWarning)
        sig_hp = sig.copy()
    else:
        sig_hp = hp_filter(sig, fs)

    if 30.0 / nyq >= 1.0:
        warnings.warn('Abtastrate zu niedrig für LP-Filter – überspringe.', RuntimeWarning)
        sig_lp = sig.copy()
    else:
        sig_lp = lp_filter(sig, fs)

    # ── spectrograms ──────────────────────────────────────────────────────────
    f_r, t_r, lP_r = compute_tf(sig,    fs)
    f_h, t_h, lP_h = compute_tf(sig_hp, fs)
    f_l, t_l, lP_l = compute_tf(sig_lp, fs)

    # ── figure ────────────────────────────────────────────────────────────────
    # Layout:
    #   Row 0 (full width): Rohsignal
    #   Row 1 (full width): TF-Plot
    #   Row 2 left/right  : ohne HP  | mit 0,5 Hz HP
    #   Row 3 left/right  : ohne LP  | mit 30 Hz LP
    fig = plt.figure(figsize=(14, 13), facecolor=DARK)
    gs  = gridspec.GridSpec(
        4, 2, figure=fig,
        hspace=0.50, wspace=0.20,
        left=0.07, right=0.97, top=0.94, bottom=0.05,
        height_ratios=[1, 1.1, 1, 1],
    )

    ax_raw  = fig.add_subplot(gs[0, :])
    ax_tf   = fig.add_subplot(gs[1, :])
    ax_nohp = fig.add_subplot(gs[2, 0])
    ax_hp   = fig.add_subplot(gs[2, 1])
    ax_nolp = fig.add_subplot(gs[3, 0])
    ax_lp   = fig.add_subplot(gs[3, 1])

    ch_lbl  = f'Channel {args.channel} (EEG Ch{args.channel + 1})'
    win_lbl = f'{start_s:.0f}–{end_s:.0f} s'

    plot_wave(ax_raw,  t, sig,    f'Raw Signal – {ch_lbl}',                   ORANGE)
    plot_tf_ax(ax_tf,  f_r, t_r, lP_r, t0=start_s)

    plot_wave(ax_nohp, t, sig,    'Without High-pass Filter (Raw Signal)',     GREY,  xlabel=True)
    plot_wave(ax_hp,   t, sig_hp, 'With 0.5 Hz High-pass Filter',              GREEN, xlabel=True)

    plot_wave(ax_nolp, t, sig,    'Without Low-pass Filter (Raw Signal)',      GREY,  xlabel=True)
    plot_wave(ax_lp,   t, sig_lp, 'With 30 Hz Low-pass Filter',                BLUE,  xlabel=True)

    for ax in (ax_raw, ax_nohp, ax_hp, ax_nolp, ax_lp):
        ax.set_xlim(start_s, end_s)
    ax_tf.set_xlim(start_s, end_s)

    fig.suptitle(
        f'Traumschreiber – Signal Analysis  |  {ch_lbl}  |  {win_lbl}',
        fontsize=11, fontweight='bold', color='white',
    )

    print('Plot-Beschreibung:')
    print('  Zeile 1: Rohsignal')
    print('  Zeile 2: Time-Frequency-Plot (weiß gestrichelt = 50 Hz Netzrauschen)')
    print('  Zeile 3: Vergleich ohne/mit 0,5 Hz Hochpassfilter (entfernt langsame Drift)')
    print('  Zeile 4: Vergleich ohne/mit 30 Hz Tiefpassfilter (entfernt 50 Hz Rauschen)')
    print('\nFenster schließen zum Beenden.')

    plt.show()


if __name__ == '__main__':
    main()
