#!/usr/bin/env python3
"""
Traumschreiber ExG-Zeitbereich – EKG / EEG / EMG / EOG
========================================================
Zeigt 30 Sekunden eines ExG-Kanals (EKG, EEG, EMG oder EOG) im Zeitbereich
mit signaltyp-spezifischen Filtervoreinstellungen – kein FFT / Spektrogramm.

Signaltypen und Filtervoreinstellungen (+ abschließend 30 Hz Tiefpass für alle):
  ekg  – Bandpass 0,5–40 Hz  + Kerbfilter 50 Hz + 30 Hz LP | Farbe #ff9151 (Orange)
  eeg  – Bandpass 1–40 Hz    + Kerbfilter 50 Hz + 30 Hz LP | Farbe #4dd4f4 (Cyan)
  emg  – Bandpass 20–120 Hz  + Kerbfilter 50 Hz + 30 Hz LP | Farbe #7be0ad (Grün)
  eog  – Bandpass 0,1–15 Hz  + Kerbfilter 50 Hz + 30 Hz LP | Farbe #ffe851 (Gelb)

Layout (3 horizontale Panels):
  1. Rohsignal (grau)
  2. Gefiltertes Signal (Signaltyp-Farbe)
  3. 5-Sekunden-Zoom der Mitte des gefilterten Signals (gleiche Farbe)

Der Kerbfilter (50 Hz Netzrauschen) wird als schmalbandiger Butterworth-
Bandsperr-Filter (±2 Hz um 50 Hz) implementiert.

Verwendung:
  python plot_exg_timedomain.py DATEI.EDF
  python plot_exg_timedomain.py DATEI.EDF --type eeg --channel 1
  python plot_exg_timedomain.py DATEI.EDF --type emg --channel 3 --start 60
  python plot_exg_timedomain.py DATEI.EDF --type eog --channel 0 --start 120

Bibliotheken installieren (einmalig):
  pip install numpy matplotlib scipy
"""

import argparse
import struct
import sys
import warnings
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

try:
    from scipy.signal import butter, sosfilt
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

WINDOW_S = 30   # Immer 30 Sekunden plotten
ZOOM_S   = 5    # Zoom-Panel zeigt 5 Sekunden aus der Fenstermitte


# ── Signaltyp-Presets ─────────────────────────────────────────────────────────

PRESETS = {
    'ekg': dict(
        label      = 'EKG',
        hp_hz      = 0.5,
        lp_hz      = 40.0,
        notch_hz   = 50.0,
        color      = '#ff9151',
        beschreibung = 'EKG  (Bandpass 0,5–40 Hz + Kerbfilter 50 Hz)',
    ),
    'eeg': dict(
        label      = 'EEG',
        hp_hz      = 1.0,
        lp_hz      = 40.0,
        notch_hz   = 50.0,
        color      = '#4dd4f4',
        beschreibung = 'EEG  (Bandpass 1–40 Hz + Kerbfilter 50 Hz)',
    ),
    'emg': dict(
        label      = 'EMG',
        hp_hz      = 20.0,
        lp_hz      = 120.0,
        notch_hz   = 50.0,
        color      = '#7be0ad',
        beschreibung = 'EMG  (Bandpass 20–120 Hz + Kerbfilter 50 Hz)',
    ),
    'eog': dict(
        label      = 'EOG',
        hp_hz      = 0.1,
        lp_hz      = 15.0,
        notch_hz   = 50.0,
        color      = '#ffe851',
        beschreibung = 'EOG  (Bandpass 0,1–15 Hz + Kerbfilter 50 Hz)',
    ),
}


# ── EDF Reader ────────────────────────────────────────────────────────────────

def read_edf_channel(filepath, channel_idx):
    """Liest einen ADS-Kanal (0–7) aus einer Traumschreiber EDF-Datei.

    Gibt (signal: np.ndarray float64, abtastrate: float) zurück.
    """
    raw = Path(filepath).read_bytes()
    if len(raw) < 256:
        raise ValueError("Datei zu klein – kein gültiges EDF-Format.")

    hdr = raw[:256].decode('latin-1')

    def _i(s):
        try: return int(s.strip())
        except ValueError: return 0

    def _f(s):
        try: return float(s.strip())
        except ValueError: return 1.0

    header_bytes    = _i(hdr[184:192]) or 3072
    num_signals     = _i(hdr[252:256]) or 11
    num_records     = _i(hdr[236:244])   # -1 bedeutet unbekannt → bis EOF lesen
    record_duration = _f(hdr[244:252])

    if not (0 <= channel_idx < min(num_signals, 8)):
        raise ValueError(
            f"Kanal {channel_idx} ist ungültig. "
            f"Verfügbare EEG-Kanäle: 0–{min(num_signals, 8) - 1}."
        )

    # EDF-Signalheader: Felder werden gruppenweise gespeichert
    # Layout: [alle Labels][alle Transducer][alle Einheiten]...[alle num_samples][alle Reserved]
    sig_hdr = raw[256:header_bytes].decode('latin-1')
    field_widths = [16, 80, 8, 8, 8, 8, 8, 80, 8, 32]
    field_names  = ['label', 'transducer', 'dim', 'phys_min', 'phys_max',
                    'dig_min', 'dig_max', 'prefilter', 'num_samples', 'reserved']
    sigs = [{} for _ in range(num_signals)]
    off = 0
    for fw, fn in zip(field_widths, field_names):
        for i in range(num_signals):
            sigs[i][fn] = sig_hdr[off:off + fw].strip()
            off += fw

    ns = []
    for s in sigs:
        try: ns.append(int(s['num_samples']))
        except ValueError: ns.append(0)

    record_bytes = sum(n * 2 for n in ns)
    if record_bytes == 0:
        raise ValueError("EDF-Header meldet 0 Samples/Record – Datei beschädigt?")

    sample_rate = ns[0] / record_duration if record_duration > 0 else 250.0

    # Samples für channel_idx über alle Datenrecords sammeln
    samples = []
    offset = header_bytes
    rec = 0
    while offset + record_bytes <= len(raw):
        if num_records >= 0 and rec >= num_records:
            break
        pos = offset
        for i, n in enumerate(ns):
            if i == channel_idx:
                samples.extend(struct.unpack_from(f'<{n}h', raw, pos))
            pos += n * 2
        offset = pos
        rec += 1

    return np.array(samples, dtype=np.float64), float(sample_rate)


# ── Filterhelfer ──────────────────────────────────────────────────────────────

def _safe_bandpass(signal, fs, hp_hz, lp_hz, order=5):
    """Butterworth-Bandpassfilter; klemmt lp_hz unterhalb der Nyquist-Grenze."""
    nyq = 0.5 * fs
    hp_norm = hp_hz / nyq
    lp_norm = min(lp_hz, nyq * 0.99) / nyq   # Sicherheits-Clamp (z.B. EMG 120 Hz @ 250 Hz)

    if hp_norm <= 0.0 or hp_norm >= 1.0:
        warnings.warn(
            f'HP-Grenze {hp_hz} Hz ungültig bei fs={fs} Hz – überspringe Hochpass.',
            RuntimeWarning,
        )
        return signal.copy()

    if lp_norm <= 0.0 or lp_norm >= 1.0:
        warnings.warn(
            f'LP-Grenze {lp_hz} Hz ungültig bei fs={fs} Hz – überspringe Tiefpass.',
            RuntimeWarning,
        )
        return signal.copy()

    if hp_norm >= lp_norm:
        warnings.warn(
            f'HP {hp_hz} Hz ≥ LP {lp_hz} Hz – Bandpass unmöglich – überspringe.',
            RuntimeWarning,
        )
        return signal.copy()

    sos = butter(order, [hp_norm, lp_norm], btype='bandpass', output='sos')
    return sosfilt(sos, signal)


def _safe_notch(signal, fs, notch_hz=50.0, bw_hz=2.0, order=4):
    """Schmalbandiger Butterworth-Bandsperr-Filter (±bw_hz um notch_hz).

    Implementiert als Butterworth-Bandsperre anstelle von iirnotch/tf2sos,
    um Kompatibilitätsprobleme mit älteren scipy-Versionen zu vermeiden.
    """
    nyq = 0.5 * fs
    lo  = (notch_hz - bw_hz) / nyq
    hi  = (notch_hz + bw_hz) / nyq

    if lo <= 0.0 or hi >= 1.0:
        warnings.warn(
            f'Kerbfilter {notch_hz} Hz außerhalb des darstellbaren Bereichs '
            f'bei fs={fs} Hz – überspringe.',
            RuntimeWarning,
        )
        return signal.copy()

    sos = butter(order, [lo, hi], btype='bandstop', output='sos')
    return sosfilt(sos, signal)


def _lp30(signal, fs, order=5):
    """Butterworth 30 Hz Tiefpassfilter, 5. Ordnung, sosfilt – wird auf alle Signaltypen angewendet."""
    nyq = 0.5 * fs
    cutoff = min(30.0, nyq * 0.99) / nyq
    sos = butter(order, cutoff, btype='low', output='sos')
    return sosfilt(sos, signal)


def apply_preset_filter(signal, fs, preset):
    """Wendet Bandpass, Kerbfilter und abschließend 30 Hz Tiefpass an."""
    sig_bp   = _safe_bandpass(signal, fs, preset['hp_hz'], preset['lp_hz'])
    sig_notch = _safe_notch(sig_bp, fs, preset['notch_hz'])
    return _lp30(sig_notch, fs)


# ── Plothelfer ────────────────────────────────────────────────────────────────

DARK = '#0d1c2b'
GRID = '#1a2f42'
TEXT = '#aabbcc'
GREY = '#7a9ab0'


def _style(ax):
    """Einheitliches dunkles Erscheinungsbild für eine Achse."""
    ax.set_facecolor(DARK)
    for sp in ax.spines.values():
        sp.set_color('#2a3f52')
    ax.tick_params(colors=TEXT, labelsize=7)
    ax.yaxis.label.set_color(TEXT)
    ax.xaxis.label.set_color(TEXT)
    ax.title.set_color('white')
    ax.grid(True, linestyle='--', linewidth=0.35, color=GRID, alpha=0.8)


def _plot_wave(ax, t, sig, title, color, xlabel=False):
    """Zeichnet eine Signalkurve mit Standardformatierung."""
    ax.plot(t, sig, '-', lw=0.6, color=color, rasterized=True)
    ax.set_title(title, fontsize=9, fontweight='bold', pad=4)
    ax.set_yticks([])
    ax.set_ylabel('Amplitude', fontsize=7)
    if xlabel:
        ax.set_xlabel('Time (s)', fontsize=7)
    _style(ax)


# ── Hauptprogramm ─────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description='Traumschreiber ExG – Zeitbereichsdarstellung',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            'Signaltypen:\n'
            '  ekg  – EKG  (Bandpass 0,5–40 Hz  + Kerbfilter 50 Hz)\n'
            '  eeg  – EEG  (Bandpass 1–40 Hz    + Kerbfilter 50 Hz)\n'
            '  emg  – EMG  (Bandpass 20–120 Hz  + Kerbfilter 50 Hz)\n'
            '  eog  – EOG  (Bandpass 0,1–15 Hz  + Kerbfilter 50 Hz)\n'
            '\n'
            'Beispiele:\n'
            '  python plot_exg_timedomain.py ADC/AD0041.EDF\n'
            '  python plot_exg_timedomain.py ADC/AD0041.EDF --type eeg --channel 1\n'
            '  python plot_exg_timedomain.py ADC/AD0041.EDF --type emg --channel 3 --start 60\n'
            '  python plot_exg_timedomain.py ADC/AD0041.EDF --type eog --channel 0 --start 120'
        ),
    )
    parser.add_argument('file',
                        help='EDF-Datei (z.B. ADC/AD0041.EDF)')
    parser.add_argument('--type', '-t', dest='sigtype',
                        choices=list(PRESETS.keys()), default='ekg',
                        help='Signaltyp: ekg / eeg / emg / eog  (Standard: ekg)')
    parser.add_argument('--channel', '-c', type=int, default=0,
                        help='EDF-Kanal 0–7  (Standard: 0)')
    parser.add_argument('--start', '-s', type=float, default=0.0,
                        help=f'Startzeit in Sekunden  (Standard: 0). '
                             f'Es werden immer {WINDOW_S} s geplottet.')
    args = parser.parse_args()

    # ── Abhängigkeitsprüfung ─────────────────────────────────────────────────
    if not HAS_SCIPY:
        print('FEHLER: scipy ist nicht installiert.')
        print('Bitte ausführen:  pip install numpy matplotlib scipy')
        sys.exit(1)

    preset = PRESETS[args.sigtype]

    # ── Laden ────────────────────────────────────────────────────────────────
    print(f'Loading     : {args.file}')
    print(f'Signal type : {preset["beschreibung"]}')
    print(f'Channel     : {args.channel}')
    print(f'Start time  : {args.start:.1f} s')

    try:
        sig_full, fs = read_edf_channel(args.file, args.channel)
    except FileNotFoundError:
        print(f'FEHLER: Datei nicht gefunden – {args.file}')
        sys.exit(1)
    except Exception as exc:
        print(f'FEHLER beim Lesen: {exc}')
        sys.exit(1)

    total_s = len(sig_full) / fs
    print(f'Sample rate : {fs:.1f} Hz')
    print(f'Total length: {total_s:.1f} s  ({int(total_s // 60)} min {total_s % 60:.0f} s)')

    start_s = float(args.start)
    end_s   = min(start_s + WINDOW_S, total_s)

    if start_s >= total_s:
        print(f'FEHLER: --start {start_s:.0f} s liegt hinter dem Signalende ({total_s:.0f} s).')
        sys.exit(1)
    if end_s - start_s < 1.0:
        print('FEHLER: Zeitfenster zu kurz (< 1 s).')
        sys.exit(1)

    i0  = int(start_s * fs)
    i1  = int(end_s   * fs)
    sig = sig_full[i0:i1]
    t   = np.linspace(start_s, end_s, len(sig), endpoint=False)
    win_dauer = end_s - start_s
    print(f'Window      : {start_s:.1f}–{end_s:.1f} s  ({len(sig)} samples)\n')

    # ── Filtern ──────────────────────────────────────────────────────────────
    sig_filt = apply_preset_filter(sig, fs, preset)

    # ── Zoom-Fenster (5 s aus der Fenstermitte) ──────────────────────────────
    mid_s      = start_s + win_dauer / 2.0
    zoom_half  = ZOOM_S / 2.0
    zoom_start = max(start_s, mid_s - zoom_half)
    zoom_end   = min(end_s,   mid_s + zoom_half)

    zi0 = int((zoom_start - start_s) * fs)
    zi1 = int((zoom_end   - start_s) * fs)
    t_zoom    = t[zi0:zi1]
    sig_zoom  = sig_filt[zi0:zi1]

    # ── Figur ────────────────────────────────────────────────────────────────
    fig = plt.figure(figsize=(14, 9), facecolor=DARK)
    gs  = gridspec.GridSpec(
        3, 1, figure=fig,
        hspace=0.52,
        left=0.06, right=0.97, top=0.91, bottom=0.07,
    )

    ax_raw   = fig.add_subplot(gs[0])
    ax_filt  = fig.add_subplot(gs[1])
    ax_zoom  = fig.add_subplot(gs[2])

    lbl      = preset['label']
    farbe    = preset['color']
    ch_lbl   = f'Channel {args.channel}'
    win_lbl  = f'{start_s:.0f}–{end_s:.0f} s'
    filt_lbl = f'Bandpass {preset["hp_hz"]:g}–{preset["lp_hz"]:g} Hz + Notch {preset["notch_hz"]:g} Hz'

    _plot_wave(
        ax_raw, t, sig,
        f'Raw Signal – {lbl}  |  {ch_lbl}',
        GREY,
    )
    _plot_wave(
        ax_filt, t, sig_filt,
        f'Filtered Signal – {lbl}  |  {filt_lbl}',
        farbe,
    )
    _plot_wave(
        ax_zoom, t_zoom, sig_zoom,
        f'Zoom ({ZOOM_S} s, window centre)  |  {zoom_start:.1f}–{zoom_end:.1f} s',
        farbe,
        xlabel=True,
    )

    ax_raw.set_xlim(start_s, end_s)
    ax_filt.set_xlim(start_s, end_s)
    ax_zoom.set_xlim(zoom_start, zoom_end)

    # Zoom-Bereich im Filterpanel hervorheben
    ax_filt.axvspan(zoom_start, zoom_end,
                    color='white', alpha=0.05, zorder=0)
    ax_filt.axvline(zoom_start, color='white', lw=0.6, ls=':', alpha=0.4)
    ax_filt.axvline(zoom_end,   color='white', lw=0.6, ls=':', alpha=0.4)

    fig.suptitle(
        f'Traumschreiber – {lbl} Time Domain  |  {ch_lbl}  |  {win_lbl}',
        fontsize=11, fontweight='bold', color='white',
    )

    # ── Console output ───────────────────────────────────────────────────────
    print('Panels:')
    print(f'  Panel 1: Raw signal (grey)')
    print(f'  Panel 2: Filtered signal ({filt_lbl})')
    print(f'  Panel 3: {ZOOM_S}-second zoom of window centre '
          f'({zoom_start:.1f}–{zoom_end:.1f} s)')
    print('\nClose window to exit.')

    plt.show()


if __name__ == '__main__':
    main()
