#!/usr/bin/env python3
"""
Traumschreiber EDF – EKG Herzschlag-Detektion & HRV
====================================================
Erkennt automatisch Herzschläge (R-Zacken) in einem EKG-Kanal und
berechnet grundlegende Herzratenvariabilität (HRV).

Metriken:
  Herzrate  – Mittlere Schläge pro Minute (bpm)
  SDNN      – Standardabweichung der RR-Intervalle (Gesamtvariabilität)
  RMSSD     – Quadratisches Mittel aufeinanderfolgender RR-Differenzen
  pNN50     – Anteil konsekutiver RR-Differenzen > 50 ms

Layout (4 Panels):
  1. EKG-Signal (30 s) mit markierten R-Zacken
  2. 8-Sekunden-Zoom mit einzelnen Herzschlägen und QRS-Hervorhebung
  3. RR-Tachogramm (Herzschlag-zu-Herzschlag-Intervalle über Zeit)
  4. HRV-Metriken als Textübersicht

Verwendung:
  python plot_ecg_hrv.py DATEI.EDF
  python plot_ecg_hrv.py DATEI.EDF --channel 0 --start 30
  python plot_ecg_hrv.py DATEI.EDF --channel 5 --start 0

Bibliotheken installieren (einmalig):
  pip install numpy matplotlib scipy
"""

import argparse
import struct
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

try:
    from scipy.signal import butter, sosfilt, find_peaks
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

WINDOW_S = 30   # Analysefenster in Sekunden
ZOOM_S   = 8    # Zoom-Panel: Sekunden aus der Fenstermitte


# ── EDF reader ────────────────────────────────────────────────────────────────

def read_edf_channel(filepath, channel_idx):
    """Liest einen Kanal (0–7) aus einer Traumschreiber-EDF-Datei."""
    raw = Path(filepath).read_bytes()
    if len(raw) < 256:
        raise ValueError("Datei zu klein – kein gültiges EDF-Format.")
    hdr = raw[:256].decode('latin-1')

    def _i(s):
        try: return int(s.strip())
        except ValueError: return 0

    def _f(s):
        try: return float(s.strip())
        except ValueError: return 1.0

    header_bytes    = _i(hdr[184:192]) or 3072
    num_signals     = _i(hdr[252:256]) or 11
    num_records     = _i(hdr[236:244])
    record_duration = _f(hdr[244:252])

    if not (0 <= channel_idx < min(num_signals, 8)):
        raise ValueError(
            f"Kanal {channel_idx} ungültig. Verfügbare Kanäle: 0–{min(num_signals, 8) - 1}."
        )

    sig_hdr = raw[256:header_bytes].decode('latin-1')
    field_widths = [16, 80, 8, 8, 8, 8, 8, 80, 8, 32]
    field_names  = ['label', 'transducer', 'dim', 'phys_min', 'phys_max',
                    'dig_min', 'dig_max', 'prefilter', 'num_samples', 'reserved']
    sigs = [{} for _ in range(num_signals)]
    off = 0
    for fw, fn in zip(field_widths, field_names):
        for i in range(num_signals):
            sigs[i][fn] = sig_hdr[off:off + fw].strip()
            off += fw

    ns = []
    for s in sigs:
        try: ns.append(int(s['num_samples']))
        except ValueError: ns.append(0)

    record_bytes = sum(n * 2 for n in ns)
    if record_bytes == 0:
        raise ValueError("EDF-Header meldet 0 Samples/Record – Datei beschädigt?")

    sample_rate = ns[0] / record_duration if record_duration > 0 else 250.0

    samples = []
    offset = header_bytes
    rec = 0
    while offset + record_bytes <= len(raw):
        if num_records >= 0 and rec >= num_records:
            break
        pos = offset
        for i, n in enumerate(ns):
            if i == channel_idx:
                samples.extend(struct.unpack_from(f'<{n}h', raw, pos))
            pos += n * 2
        offset = pos
        rec += 1

    return np.array(samples, dtype=np.float64), float(sample_rate)


# ── Filter ────────────────────────────────────────────────────────────────────

def _bandpass(signal, fs, hp=0.5, lp=40.0, order=5):
    nyq = 0.5 * fs
    lo  = hp / nyq
    hi  = min(lp, nyq * 0.99) / nyq
    if lo <= 0 or lo >= 1 or hi <= 0 or hi >= 1 or lo >= hi:
        return signal.copy()
    sos = butter(order, [lo, hi], btype='band', output='sos')
    return sosfilt(sos, signal)


def _notch(signal, fs, freq=50.0, bw=2.0, order=4):
    nyq = 0.5 * fs
    lo  = (freq - bw) / nyq
    hi  = (freq + bw) / nyq
    if lo <= 0 or hi >= 1:
        return signal.copy()
    sos = butter(order, [lo, hi], btype='bandstop', output='sos')
    return sosfilt(sos, signal)


def _lp30(signal, fs, order=5):
    """Butterworth 30 Hz Tiefpassfilter, 5. Ordnung, sosfilt."""
    nyq = 0.5 * fs
    cut = min(30.0, nyq * 0.99) / nyq
    sos = butter(order, cut, btype='low', output='sos')
    return sosfilt(sos, signal)


def ecg_filter(signal, fs):
    """EKG-Filterkette: Bandpass 0,5–40 Hz + Notch 50 Hz + 30 Hz Tiefpass."""
    s = _bandpass(signal, fs)
    s = _notch(s, fs)
    s = _lp30(s, fs)
    return s


# ── R-Zacken-Detektion ────────────────────────────────────────────────────────

def detect_r_peaks(signal, fs):
    """Amplitudenbasierte R-Zacken-Erkennung mit scipy find_peaks.

    Mindestabstand zwischen Peaks: 350 ms (entspricht max. ~170 bpm).
    Schwellenwert: 75. Perzentil des Signals.
    Falls zu wenige Peaks gefunden, wird der Schwellenwert gesenkt oder
    das invertierte Signal probiert (bei negativen R-Zacken).
    """
    min_dist  = int(0.35 * fs)
    threshold = float(np.percentile(signal, 75))

    peaks, _ = find_peaks(signal, height=threshold, distance=min_dist)

    if len(peaks) < 3:
        threshold = float(np.percentile(signal, 60))
        peaks, _ = find_peaks(signal, height=threshold, distance=min_dist)

    if len(peaks) < 3:
        neg_thr = float(np.percentile(-signal, 75))
        peaks, _ = find_peaks(-signal, height=neg_thr, distance=min_dist)

    return peaks


# ── HRV-Metriken ──────────────────────────────────────────────────────────────

def compute_hrv(peaks, fs):
    """Berechnet grundlegende HRV-Metriken aus den erkannten R-Zacken."""
    if len(peaks) < 3:
        return None

    rr_ms   = np.diff(peaks).astype(float) / fs * 1000.0
    mean_rr = float(np.mean(rr_ms))
    mean_hr = 60000.0 / mean_rr
    sdnn    = float(np.std(rr_ms, ddof=1))

    diff_rr = np.diff(rr_ms)
    rmssd   = float(np.sqrt(np.mean(diff_rr ** 2))) if len(diff_rr) > 0 else 0.0
    pnn50   = float(np.mean(np.abs(diff_rr) > 50.0) * 100.0) if len(diff_rr) > 0 else 0.0

    return {
        'rr_ms'  : rr_ms,
        'mean_hr': mean_hr,
        'mean_rr': mean_rr,
        'sdnn'   : sdnn,
        'rmssd'  : rmssd,
        'pnn50'  : pnn50,
        'n_beats': len(peaks),
    }


# ── Plot-Stil ─────────────────────────────────────────────────────────────────

DARK   = '#0d1c2b'
GRID   = '#1a2f42'
TEXT   = '#aabbcc'
ORANGE = '#ff9151'   # EKG-Linie
RED    = '#ff4466'   # R-Zacken-Marker
GREEN  = '#7be0ad'   # RR-Tachogramm
YELLOW = '#ffe851'   # Mittelwert-Linie
BLUE   = '#abd5ff'   # Metriken
GREY   = '#7a9ab0'   # Schwache Texte


def _style(ax):
    ax.set_facecolor(DARK)
    for sp in ax.spines.values():
        sp.set_color('#2a3f52')
    ax.tick_params(colors=TEXT, labelsize=7)
    ax.yaxis.label.set_color(TEXT)
    ax.xaxis.label.set_color(TEXT)
    ax.title.set_color('white')
    ax.grid(True, linestyle='--', linewidth=0.35, color=GRID, alpha=0.8)


def _plot_ecg(ax, t, sig, peaks, title, lw=0.6, ms=30):
    ax.plot(t, sig, '-', lw=lw, color=ORANGE, rasterized=True)
    if len(peaks) > 0:
        for pt in t[peaks]:
            ax.axvline(pt, color=RED, alpha=0.15, lw=0.7)
        ax.scatter(t[peaks], sig[peaks], color=RED, s=ms, zorder=5)
    ax.set_title(title, fontsize=9, fontweight='bold', pad=4, color='white')
    ax.set_yticks([])
    ax.set_ylabel('Amplitude', fontsize=7)
    ax.set_xlabel('Time (s)', fontsize=7)
    ax.set_xlim(t[0], t[-1])
    _style(ax)


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description='Traumschreiber EDF – EKG Herzschlag-Detektion & HRV',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            'Beispiele:\n'
            '  python plot_ecg_hrv.py ADC/AD0041.EDF\n'
            '  python plot_ecg_hrv.py ADC/AD0041.EDF --channel 0 --start 30\n'
            '  python plot_ecg_hrv.py ADC/AD0041.EDF --channel 5'
        ),
    )
    parser.add_argument('file',
                        help='EDF-Datei (z.B. ADC/AD0041.EDF)')
    parser.add_argument('--channel', '-c', type=int, default=0,
                        help='Kanal 0–7 (Standard: 0 = Ch1)')
    parser.add_argument('--start',   '-s', type=float, default=0.0,
                        help='Startzeit in Sekunden (Standard: 0)')
    args = parser.parse_args()

    if not HAS_SCIPY:
        print('FEHLER: scipy ist nicht installiert.')
        print('Bitte ausführen:  pip install numpy matplotlib scipy')
        sys.exit(1)

    print(f'Lade       : {args.file}  |  Kanal {args.channel}  |  Start {args.start:.1f} s')

    try:
        sig_full, fs = read_edf_channel(args.file, args.channel)
    except FileNotFoundError:
        print(f'Datei nicht gefunden: {args.file}')
        sys.exit(1)
    except Exception as exc:
        print(f'Fehler beim Lesen: {exc}')
        sys.exit(1)

    total_s = len(sig_full) / fs
    print(f'Abtastrate : {fs:.1f} Hz')
    print(f'Gesamtdauer: {total_s:.1f} s  ({int(total_s // 60)} min {total_s % 60:.0f} s)')

    start_s = float(args.start)
    end_s   = min(start_s + WINDOW_S, total_s)

    if start_s >= total_s:
        print(f'Fehler: --start {start_s:.0f} s liegt hinter dem Signalende ({total_s:.0f} s).')
        sys.exit(1)

    i0 = int(start_s * fs)
    i1 = int(end_s   * fs)
    sig_filt = ecg_filter(sig_full[i0:i1], fs)
    t        = np.linspace(start_s, end_s, len(sig_filt), endpoint=False)

    print(f'Fenster    : {start_s:.1f}–{end_s:.1f} s  ({len(sig_filt)} Samples)\n')

    # ── Peak-Detektion & HRV ────────────────────────────────────────────────
    peaks = detect_r_peaks(sig_filt, fs)
    hrv   = compute_hrv(peaks, fs)

    print(f'Erkannte Herzschläge: {len(peaks)}')
    if hrv:
        print(f'  Herzrate : {hrv["mean_hr"]:.1f} bpm')
        print(f'  RR-Mittel: {hrv["mean_rr"]:.1f} ms')
        print(f'  SDNN     : {hrv["sdnn"]:.1f} ms')
        print(f'  RMSSD    : {hrv["rmssd"]:.1f} ms')
        print(f'  pNN50    : {hrv["pnn50"]:.1f} %')
    else:
        print('  Zu wenige Herzschläge. Tipp: --channel anpassen.')

    # ── Zoom ────────────────────────────────────────────────────────────────
    mid        = (start_s + end_s) / 2.0
    zoom_start = max(start_s, mid - ZOOM_S / 2.0)
    zoom_end   = min(end_s,   zoom_start + ZOOM_S)
    zi0     = int((zoom_start - start_s) * fs)
    zi1     = int((zoom_end   - start_s) * fs)
    t_z     = t[zi0:zi1]
    sig_z   = sig_filt[zi0:zi1]
    peaks_z = peaks[(peaks >= zi0) & (peaks < zi1)] - zi0

    # ── Figure ────────────────────────────────────────────────────────────────
    fig = plt.figure(figsize=(14, 12), facecolor=DARK)
    gs  = gridspec.GridSpec(
        3, 2, figure=fig,
        hspace=0.52, wspace=0.25,
        left=0.07, right=0.97, top=0.94, bottom=0.06,
        height_ratios=[1.1, 1.1, 1.2],
    )
    ax_ecg  = fig.add_subplot(gs[0, :])
    ax_zoom = fig.add_subplot(gs[1, :])
    ax_rr   = fig.add_subplot(gs[2, 0])
    ax_hrv  = fig.add_subplot(gs[2, 1])

    ch_lbl  = f'Channel {args.channel} (Ch{args.channel + 1})'
    win_lbl = f'{start_s:.0f}–{end_s:.0f} s'

    _plot_ecg(ax_ecg, t, sig_filt, peaks,
              f'ECG – {ch_lbl}  ({len(peaks)} detected heartbeats)')

    _plot_ecg(ax_zoom, t_z, sig_z, peaks_z,
              f'Zoom  {zoom_start:.0f}–{zoom_end:.0f} s  –  individual heartbeats',
              lw=0.9, ms=55)

    # RR tachogram
    _style(ax_rr)
    ax_rr.set_title('RR Tachogram  (beat-to-beat intervals)', fontsize=9,
                    fontweight='bold', pad=4, color='white')
    ax_rr.set_ylabel('RR Interval (ms)', fontsize=7)
    ax_rr.set_xlabel('Time (s)', fontsize=7)
    if hrv and len(hrv['rr_ms']) > 0:
        rr_t = (peaks[1:] / fs) + start_s
        ax_rr.plot(rr_t, hrv['rr_ms'], 'o-', color=GREEN, lw=1.0, ms=3.5, zorder=3)
        ax_rr.axhline(hrv['mean_rr'], color=YELLOW, lw=0.9, ls='--', alpha=0.75)
        ax_rr.text(rr_t[0] + 0.3, hrv['mean_rr'] + 5,
                   f'mean {hrv["mean_rr"]:.0f} ms', fontsize=6.5, color=YELLOW, alpha=0.9)
    else:
        ax_rr.text(0.5, 0.5, 'Too few peaks', ha='center', va='center',
                   color=GREY, fontsize=9, transform=ax_rr.transAxes)

    # HRV metrics
    ax_hrv.set_facecolor(DARK)
    for sp in ax_hrv.spines.values():
        sp.set_color('#2a3f52')
    ax_hrv.set_xticks([])
    ax_hrv.set_yticks([])
    ax_hrv.set_title('HRV Metrics', fontsize=9, fontweight='bold', pad=4, color='white')

    if hrv:
        rows = [
            ('Heart Rate', f'{hrv["mean_hr"]:.1f} bpm', 'Mean beats per minute',    ORANGE),
            ('SDNN',       f'{hrv["sdnn"]:.1f} ms',      'Overall variability',       GREEN),
            ('RMSSD',      f'{hrv["rmssd"]:.1f} ms',     'Short-term variability',    BLUE),
            ('pNN50',      f'{hrv["pnn50"]:.1f} %',      'RR diff. > 50 ms',         YELLOW),
            ('Beats',      str(hrv['n_beats']),           f'in {end_s - start_s:.0f} s', GREY),
        ]
        for k, (name, value, desc, color) in enumerate(rows):
            y  = 0.84 - k * 0.17
            yd = y - 0.065
            ax_hrv.text(0.06, y,  name,  fontsize=8,    fontweight='800', color=color,
                        va='center', transform=ax_hrv.transAxes)
            ax_hrv.text(0.42, y,  value, fontsize=11,   fontweight='900', color='white',
                        va='center', transform=ax_hrv.transAxes)
            ax_hrv.text(0.42, yd, desc,  fontsize=6.5,  color=GREY, alpha=0.85,
                        va='center', transform=ax_hrv.transAxes)
    else:
        ax_hrv.text(0.5, 0.5,
                    'Too few heartbeats\nfor HRV computation\n\n'
                    'Tip: try a different channel\n(--channel 5 or 7)',
                    ha='center', va='center', color=GREY, fontsize=9,
                    linespacing=1.8, transform=ax_hrv.transAxes)

    fig.suptitle(
        f'Traumschreiber – ECG Heartbeat Detection & HRV  |  {ch_lbl}  |  {win_lbl}',
        fontsize=11, fontweight='bold', color='white',
    )

    print('\nClose window to exit.')
    plt.show()


if __name__ == '__main__':
    main()
