"""Spike S1: read the Tundra MI_00 IMU stream with no SteamVR running.

Captures raw HID reports from VID 28DE PID 2300 interface MI_00, parses the
Watchman 0x20 IMU report (3 samples x {int16 accel[3], int16 gyro[3],
uint32 timecode, uint8 seq}), and prints stream statistics that decide the
spike pass criterion: continuous samples at the expected rate, sane gravity
magnitude, monotonic seq/timecode.

Usage: python capture.py [seconds] [outdir]
"""

import hid
import struct
import sys
import time
import json
from pathlib import Path

SECONDS = float(sys.argv[1]) if len(sys.argv) > 1 else 20.0
OUTDIR = Path(sys.argv[2]) if len(sys.argv) > 2 else Path(__file__).parent / "data"
OUTDIR.mkdir(parents=True, exist_ok=True)


def find_imu_path():
    for d in hid.enumerate(0x28DE, 0x2300):
        if d.get("interface_number") == 0:
            return d["path"]
    raise SystemExit("FAIL: no 28DE:2300 MI_00 device found (Beyond unplugged?)")


def main():
    path = find_imu_path()
    print(f"device: {path.decode()}")

    dev = hid.device()
    dev.open_path(path)
    dev.set_nonblocking(False)

    raw_f = open(OUTDIR / "raw_reports.bin", "wb")
    t0 = time.perf_counter()
    n_reports = 0
    n_other = 0
    other_ids = {}
    lengths = {}
    samples = []  # (host_t, timecode, seq, ax, ay, az, gx, gy, gz)
    first_report_t = None

    while time.perf_counter() - t0 < SECONDS:
        data = dev.read(64, timeout_ms=1000)
        host_t = time.perf_counter() - t0
        if not data:
            print(f"  [{host_t:7.3f}s] read timeout (1s, no data)")
            continue
        b = bytes(data)
        raw_f.write(struct.pack("<dI", host_t, len(b)) + b)
        lengths[len(b)] = lengths.get(len(b), 0) + 1
        if b[0] == 0x20 and len(b) >= 52:
            if first_report_t is None:
                first_report_t = host_t
            n_reports += 1
            for i in range(3):
                off = 1 + 17 * i
                ax, ay, az, gx, gy, gz = struct.unpack_from("<6h", b, off)
                timecode = struct.unpack_from("<I", b, off + 12)[0]
                seq = b[off + 16]
                samples.append((host_t, timecode, seq, ax, ay, az, gx, gy, gz))
        else:
            n_other += 1
            other_ids[b[0]] = other_ids.get(b[0], 0) + 1

    dev.close()
    raw_f.close()

    if not samples:
        print("FAIL: zero 0x20 IMU reports received")
        return 1

    dur = samples[-1][0] - (first_report_t or 0)
    n = len(samples)

    # seq continuity (uint8 wrap)
    seq_gaps = 0
    for i in range(1, n):
        if (samples[i][2] - samples[i - 1][2]) & 0xFF != 1:
            seq_gaps += 1

    # timecode deltas between consecutive samples (uint32 wrap-safe)
    tc_deltas = [
        (samples[i][1] - samples[i - 1][1]) & 0xFFFFFFFF for i in range(1, n)
    ]
    tc_deltas_sorted = sorted(tc_deltas)
    tc_median = tc_deltas_sorted[len(tc_deltas_sorted) // 2]

    # accel stats (device at rest assumed)
    mags = [(s[3] ** 2 + s[4] ** 2 + s[5] ** 2) ** 0.5 for s in samples]
    mean_mag = sum(mags) / n
    gyro_means = [sum(s[6 + k] for s in samples) / n for k in range(3)]
    gyro_max = max(max(abs(s[6 + k]) for s in samples) for k in range(3))

    stats = {
        "capture_seconds": round(dur, 3),
        "imu_reports": n_reports,
        "imu_samples": n,
        "sample_rate_hz": round(n / dur, 1) if dur > 0 else None,
        "report_rate_hz": round(n_reports / dur, 1) if dur > 0 else None,
        "report_lengths_seen": lengths,
        "non_imu_reports": n_other,
        "non_imu_report_ids": {hex(k): v for k, v in other_ids.items()},
        "seq_gaps": seq_gaps,
        "timecode_median_delta": tc_median,
        "timecode_implied_tick_hz": round(n / dur * tc_median, 0) if dur > 0 else None,
        "accel_mean_magnitude_lsb": round(mean_mag, 1),
        "accel_implied_lsb_per_g": round(mean_mag, 1),
        "gyro_mean_lsb": [round(g, 2) for g in gyro_means],
        "gyro_max_abs_lsb": gyro_max,
        "first_sample_example": samples[0],
    }
    (OUTDIR / "stats.json").write_text(json.dumps(stats, indent=2))

    with open(OUTDIR / "samples.csv", "w") as f:
        f.write("host_t,timecode,seq,ax,ay,az,gx,gy,gz\n")
        for s in samples:
            f.write(",".join(str(x) for x in s) + "\n")

    print(json.dumps(stats, indent=2))
    verdict = (
        "PASS"
        if n_reports > 0 and dur > 0 and n / dur > 500 and seq_gaps < n * 0.01
        else "MARGINAL"
    )
    print(f"VERDICT: {verdict}  ({n} samples in {dur:.1f}s)")
    return 0


if __name__ == "__main__":
    sys.exit(main())
