"""M2 gate 2/3: solve gyro scale + axis mapping against SteamVR ground truth.

No jig, no exact rotations. With SteamVR tracking the Beyond (base stations
required), wave the headset smoothly through all three axes for ~60 s. The
tool reads the raw Watchman IMU stream directly (S1 path) while polling the
SteamVR HMD pose, then solves the least-squares model

    delta_rotation_vr(bin) = A @ integral(raw_gyro, bin) + b * dt(bin)

over ~0.2 s bins. A (3x3) factors into per-axis scale (column norms) and
the IMU->head axis mapping (column directions), which is checked against
the config imu block (plus_x / plus_z). b recovers the rest bias. This
closes gate 2 (scale, target ~2%) and the gate-3 axis/sense check in one
capture.

VR ground truth: delta rotations between pose samples (log map of
R0^T @ R1 = body-frame integrated angular velocity), bins anchored exactly
at pose timestamps. IMU/VR clock offset is estimated by cross-correlating
angular-speed envelopes and applied before binning.

Usage:
  python gyro_cal_vr.py [--seconds 60] [--config path/to/config.json]
                        [--bin-poses 50] [--save capture.npz]
                        [--replay capture.npz]

Requires: SteamVR running and tracking the headset. pip: hid, openvr, numpy.
"""

import argparse
import json
import struct
import sys
import threading
import time

import numpy as np

TIMECODE_HZ = 48e6  # S1: 48 MHz device clock, uint32 wrap ~89.5 s


# ---------------------------------------------------------------- IMU reader

class ImuReader:
    """Background reader for the Watchman 0x20 IMU report (S1 conventions:
    sliding 3-sample window deduped by (timecode, seq))."""

    def __init__(self, dev):
        self.dev = dev
        self.samples = []  # (host_t, timecode, gx, gy, gz)
        self.stop = False
        self.thread = threading.Thread(target=self._run, daemon=True)

    def _run(self):
        last_tc = None
        last_seq = None
        while not self.stop:
            data = self.dev.read(64, timeout_ms=1000)
            host_t = time.perf_counter()
            if not data:
                print("  [imu] 1 s read timeout -- stream stalled?")
                continue
            b = bytes(data)
            if b[0] != 0x20 or len(b) < 52:
                continue
            for i in range(3):
                off = 1 + 17 * i
                _, _, _, gx, gy, gz = struct.unpack_from("<6h", b, off)
                tc = struct.unpack_from("<I", b, off + 12)[0]
                seq = b[off + 16]
                if last_seq is not None:
                    dseq = (seq - last_seq) & 0xFF
                    if dseq == 0 and tc == last_tc:
                        continue  # window overlap
                    if dseq >= 128:
                        continue  # stale (behind last)
                last_tc, last_seq = tc, seq
                self.samples.append((host_t, tc, gx, gy, gz))


def open_imu_matching_serial(want_serial):
    """Open the 28DE:2300 MI_00 whose HID serial matches the SteamVR HMD
    serial -- disambiguates from desk Watchman devices (S1 field hazard:
    a wired Index controller enumerates first on enlyzeam)."""
    import hid

    candidates = [d for d in hid.enumerate(0x28DE, 0x2300)
                  if d.get("interface_number") == 0]
    if not candidates:
        raise SystemExit("FAIL: no 28DE:2300 MI_00 device (Beyond unplugged?)")
    seen = []
    for c in candidates:
        dev = hid.device()
        dev.open_path(c["path"])
        serial = dev.get_serial_number_string()
        seen.append(serial)
        if serial.lower() == want_serial.lower():
            dev.set_nonblocking(False)
            return dev, serial
        dev.close()
    raise SystemExit(f"FAIL: no Watchman HID serial matches SteamVR HMD "
                     f"{want_serial}; saw {seen}")


# ---------------------------------------------------------------- math

def rot_from_pose(m):
    return np.array([[m[r][c] for c in range(3)] for r in range(3)])


def log_so3(R):
    """Rotation vector of R (radians)."""
    c = (np.trace(R) - 1.0) / 2.0
    c = np.clip(c, -1.0, 1.0)
    angle = np.arccos(c)
    if angle < 1e-9:
        return np.zeros(3)
    if angle > np.pi - 1e-3:
        # Near pi: axis from symmetric part. Bins this large are rejected
        # anyway (see MAX_BIN_ANGLE), keep a sane fallback.
        A = (R + np.eye(3)) / 2.0
        axis = np.sqrt(np.maximum(np.diag(A), 0.0))
        axis /= np.linalg.norm(axis) or 1.0
        return axis * angle
    v = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]])
    return v * (angle / (2.0 * np.sin(angle)))


def estimate_lag(imu_t, imu_speed, vr_t, vr_speed, max_lag=0.25):
    """Cross-correlate angular-speed envelopes on a 500 Hz grid; returns
    seconds to ADD to imu_t to align with vr_t."""
    t0 = max(imu_t[0], vr_t[0])
    t1 = min(imu_t[-1], vr_t[-1])
    step = 0.002
    grid = np.arange(t0, t1, step)
    a = np.interp(grid, imu_t, imu_speed)
    b = np.interp(grid, vr_t, vr_speed)
    a = a - a.mean()
    b = b - b.mean()
    n = int(max_lag / step)
    scores = np.full(2 * n + 1, -np.inf)
    for j, k in enumerate(range(-n, n + 1)):
        if k >= 0:
            scores[j] = np.dot(a[: len(a) - k or None], b[k:])
        else:
            scores[j] = np.dot(a[-k:], b[: len(b) + k])
    j = int(np.argmax(scores))
    lag = (j - n) * step
    # Parabolic peak refinement (sub-step).
    if 0 < j < len(scores) - 1:
        y0, y1, y2 = scores[j - 1], scores[j], scores[j + 1]
        denom = y0 - 2 * y1 + y2
        if denom < 0:
            lag += 0.5 * (y0 - y2) / denom * step
    return lag


# ---------------------------------------------------------------- capture

def capture(seconds):
    import openvr

    try:
        openvr.init(openvr.VRApplication_Background)
    except openvr.error_code.InitError as e:
        raise SystemExit(f"OpenVR init failed: {e}\n"
                         "Start SteamVR with the Beyond tracking, then re-run.")
    vrsys = openvr.VRSystem()
    hmd = openvr.k_unTrackedDeviceIndex_Hmd
    vr_serial = vrsys.getStringTrackedDeviceProperty(
        hmd, openvr.Prop_SerialNumber_String)
    print(f"SteamVR HMD: {vr_serial}")

    dev, serial = open_imu_matching_serial(vr_serial)
    print(f"IMU HID: {serial} (matched)")

    reader = ImuReader(dev)
    reader.thread.start()

    print(f"\ncapturing {seconds:.0f} s -- rotate the headset SMOOTHLY through "
          "all three axes\n(yaw, pitch, roll; vary speed; keep it inside "
          "tracking coverage; brief rests are fine)\n")

    poses_out = []  # (host_t, R)
    invalid = 0
    t_end = time.perf_counter() + seconds
    last_note = time.perf_counter()
    while time.perf_counter() < t_end:
        poses = vrsys.getDeviceToAbsoluteTrackingPose(
            openvr.TrackingUniverseStanding, 0.0, hmd + 1)
        p = poses[hmd]
        host_t = time.perf_counter()
        if p.bPoseIsValid and p.eTrackingResult == openvr.TrackingResult_Running_OK:
            poses_out.append((host_t,
                              rot_from_pose(p.mDeviceToAbsoluteTracking)))
        else:
            invalid += 1
        if host_t - last_note > 5.0:
            last_note = host_t
            print(f"  {t_end - host_t:5.1f} s left  | poses {len(poses_out)} "
                  f"(invalid {invalid}) | imu samples {len(reader.samples)}")
        time.sleep(0.004)  # ~250 Hz

    reader.stop = True
    reader.thread.join(timeout=2.0)
    dev.close()
    openvr.shutdown()

    if invalid:
        print(f"NOTE: {invalid} invalid/lost pose polls (tracking dropouts "
              "shrink usable data)")
    return reader.samples, poses_out, vr_serial


# ---------------------------------------------------------------- solve

MAX_BIN_ANGLE = 0.6   # rad (~34 deg): keep log-map ≈ integral small-angle valid
MIN_TOTAL_PER_AXIS = np.pi  # rad of total rotation wanted per head axis


def solve(imu_samples, poses, config_path):
    imu = np.array(imu_samples, dtype=np.float64)  # host_t, tc, gx, gy, gz
    if len(imu) < 1000 or len(poses) < 200:
        raise SystemExit(f"FAIL: too little data (imu {len(imu)}, "
                         f"poses {len(poses)})")
    imu_t = imu[:, 0]
    gyro = imu[:, 2:5]
    # Per-sample dt from timecode (uint32 wrap-safe), for integration.
    tc = imu[:, 1].astype(np.uint64)
    dtc = np.diff(tc).astype(np.int64) & 0xFFFFFFFF
    dt = np.concatenate([[1e-3], dtc / TIMECODE_HZ])
    if np.any(dt > 0.5):
        print(f"WARNING: {np.sum(dt > 0.5)} IMU stream gaps > 0.5 s -- "
              "their bins are dropped")

    pose_t = np.array([p[0] for p in poses])
    Rs = [p[1] for p in poses]

    # --- clock offset between HID arrival times and pose poll times.
    bias0 = np.median(gyro, axis=0)  # crude; refined by regression later
    imu_speed = np.linalg.norm(gyro - bias0, axis=1)
    vr_w = []
    for i in range(1, len(Rs)):
        d = pose_t[i] - pose_t[i - 1]
        if d <= 0 or d > 0.1:
            vr_w.append(0.0)
            continue
        vr_w.append(np.linalg.norm(log_so3(Rs[i - 1].T @ Rs[i])) / d)
    vr_w = np.array([0.0] + vr_w)
    lag = estimate_lag(imu_t, imu_speed, pose_t, vr_w)
    print(f"clock offset (added to IMU times): {lag*1000:+.1f} ms")
    imu_t = imu_t + lag

    # --- bins anchored at every bin_poses-th pose sample (no edge interp).
    X, Y, T = [], [], []
    dropped_angle = dropped_gap = 0
    idx = np.searchsorted(imu_t, pose_t)
    for k in range(0, len(poses) - BIN_POSES, BIN_POSES):
        i0, i1 = k, k + BIN_POSES
        if pose_t[i1] - pose_t[i0] > BIN_POSES * 0.012 * 2:
            dropped_gap += 1  # pose dropout inside the bin
            continue
        rv = log_so3(Rs[i0].T @ Rs[i1])
        if np.linalg.norm(rv) > MAX_BIN_ANGLE:
            dropped_angle += 1
            continue
        s0, s1 = idx[i0], idx[i1]
        if s1 - s0 < 10:
            dropped_gap += 1
            continue
        seg_dt = dt[s0:s1]
        if np.any(seg_dt > 0.5):
            dropped_gap += 1
            continue
        X.append((gyro[s0:s1] * seg_dt[:, None]).sum(axis=0))
        T.append(seg_dt.sum())
        Y.append(rv)
    X = np.array(X).T  # 3 x N
    Y = np.array(Y).T
    T = np.array(T)[None, :]
    n = X.shape[1]
    print(f"bins: {n} used, {dropped_angle} dropped (>34 deg rotation), "
          f"{dropped_gap} dropped (gaps)")
    if n < 40:
        raise SystemExit("FAIL: too few usable bins -- capture longer / keep "
                         "tracking solid")

    # --- least squares  Y = [A|b] @ [X; T]
    Xa = np.vstack([X, T])  # 4 x N
    AB, *_ = np.linalg.lstsq(Xa.T, Y.T, rcond=None)
    AB = AB.T  # 3 x 4
    A, b = AB[:, :3], AB[:, 3]

    resid = Y - AB @ Xa
    resid_deg = np.degrees(np.linalg.norm(resid, axis=0))
    total_rot = np.degrees(np.abs(Y).sum(axis=1))
    print(f"residual per bin: rms {resid_deg.std()+resid_deg.mean():.3f} deg "
          f"max {resid_deg.max():.3f} deg  (bin rotations up to "
          f"{np.degrees(MAX_BIN_ANGLE):.0f} deg)")
    print(f"total rotation seen per head axis (deg): "
          f"x {total_rot[0]:.0f}  y {total_rot[1]:.0f}  z {total_rot[2]:.0f}")
    for a, name in enumerate("xyz"):
        if total_rot[a] < np.degrees(MIN_TOTAL_PER_AXIS):
            print(f"WARNING: little rotation about head {name} -- that axis's "
                  "numbers are weakly constrained, redo with more motion")

    # --- factor A
    scales = np.linalg.norm(A, axis=0)          # rad/s per LSB, per IMU axis
    dirs = A / scales                            # IMU axis dirs in head frame
    # X carries the bias: X = A^-1 @ integral(omega) + bias*T, so the fitted
    # constant term is b = -A @ bias.
    bias_lsb = -np.linalg.solve(A, b)
    print("\n==== result ====")
    for a, name in enumerate("XYZ"):
        nominal = 2000 * np.pi / 180 / 32768  # +-2000 dps over int16
        print(f"IMU {name}: scale {scales[a]:.9f} rad/s/LSB  "
              f"(vs 2000dps/32768 {scales[a]/nominal-1:+.2%}, vs 1/1024 "
              f"{scales[a]*1024-1:+.2%})  -> head [{dirs[0,a]:+.4f} "
              f"{dirs[1,a]:+.4f} {dirs[2,a]:+.4f}]")
    print(f"gyro bias (LSB): [{bias_lsb[0]:+.2f} {bias_lsb[1]:+.2f} "
          f"{bias_lsb[2]:+.2f}]")
    print(f"axis-mapping handedness det(dirs) = {np.linalg.det(dirs):+.3f} "
          "(expect +1)")
    ortho = dirs.T @ dirs - np.eye(3)
    print(f"axis orthogonality max |off-diag| = "
          f"{np.abs(ortho - np.diag(np.diag(ortho))).max():.4f}")

    # --- compare against config imu block
    if config_path:
        cfg = json.load(open(config_path))
        f = cfg["imu"]
        px = np.array(f["plus_x"], dtype=float)
        pz = np.array(f["plus_z"], dtype=float)
        py = np.cross(pz, px)
        R_cfg = np.column_stack([px, py / np.linalg.norm(py),
                                 pz / np.linalg.norm(pz)])
        print(f"\nconfig {cfg.get('device_serial_number')} imu basis check "
              "(measured column vs config column):")
        for a, name in enumerate("XYZ"):
            ang = np.degrees(np.arccos(np.clip(
                np.dot(dirs[:, a], R_cfg[:, a]), -1, 1)))
            print(f"  IMU {name}: {ang:6.2f} deg from config "
                  f"[{R_cfg[0,a]:+.4f} {R_cfg[1,a]:+.4f} {R_cfg[2,a]:+.4f}]"
                  f"{'   <-- MISMATCH' if ang > 10 else ''}")
        print("(small angles = config plus_x/plus_z mapping + signs verified "
              "-> gate-3 sense check)")

    spread = np.abs(scales / scales.mean() - 1).max()
    print(f"\nper-axis scale spread vs mean: {spread:.2%} "
          f"(gate ~2%)  mean {scales.mean():.9f} rad/s/LSB")
    return scales, dirs, bias_lsb


BIN_POSES = 50  # overwritten by --bin-poses


def exp_so3(v):
    angle = np.linalg.norm(v)
    if angle < 1e-12:
        return np.eye(3)
    k = v / angle
    K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
    return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)


def selftest(config_path):
    """Synthetic round trip: known A/bias/lag -> generated streams -> solver
    must recover them. No hardware needed."""
    rng = np.random.default_rng(7)
    # Truth: config imu basis x per-axis scales near 1/1024, bias in LSB.
    px = np.array([-1.0, 0.0, 0.0])
    pz = np.array([0.00067, 0.1961, -0.98058])
    pz /= np.linalg.norm(pz)
    py = np.cross(pz, px)
    R_true = np.column_stack([px, py, pz])
    s_true = np.array([0.000976, 0.000981, 0.000972])
    A_true = R_true * s_true[None, :]
    bias_true = np.array([1.5, 2.5, 3.8])
    lag_true = 0.020  # IMU host stamps arrive 20 ms late

    # Smooth random head-frame angular velocity, 60 s.
    dur, imu_hz, vr_hz = 60.0, 994.0, 250.0
    tgrid = np.arange(0, dur, 1 / imu_hz)
    omega = np.zeros((len(tgrid), 3))
    for a in range(3):
        for f, amp in ((0.13 + 0.11 * a, 1.2), (0.31 + 0.07 * a, 0.8),
                       (0.73 + 0.05 * a, 0.35)):
            ph = rng.uniform(0, 2 * np.pi)
            omega[:, a] += amp * np.sin(2 * np.pi * f * tgrid + ph)

    # IMU stream: gyro_raw = A^-1 omega + bias + noise; timecode at 48 MHz.
    gyro_raw = (np.linalg.solve(A_true, omega.T).T + bias_true
                + rng.normal(0, 1.0, omega.shape))
    imu = [(t + lag_true, int(round(t * TIMECODE_HZ)) & 0xFFFFFFFF,
            gyro_raw[i, 0], gyro_raw[i, 1], gyro_raw[i, 2])
           for i, t in enumerate(tgrid)]

    # Pose stream: integrate world<-head at IMU rate, sample at vr_hz.
    poses = []
    R = np.eye(3)
    next_pose_t = 0.0
    for i, t in enumerate(tgrid):
        if t >= next_pose_t:
            poses.append((t, R.copy()))
            next_pose_t += 1 / vr_hz
        R = R @ exp_so3(omega[i] / imu_hz)

    scales, dirs, bias = solve(imu, poses, config_path)
    assert np.all(np.abs(scales / s_true - 1) < 0.005), scales / s_true
    for a in range(3):
        assert np.dot(dirs[:, a], R_true[:, a]) > 0.9999, (a, dirs[:, a])
    assert np.all(np.abs(bias - bias_true) < 1.0), bias
    print("\nSELFTEST PASS -- scales within 0.5%, axes aligned, bias "
          "recovered, lag handled")


def main():
    global BIN_POSES
    ap = argparse.ArgumentParser()
    ap.add_argument("--seconds", type=float, default=60.0)
    ap.add_argument("--config", default=None,
                    help="per-unit config.json for the imu-basis cross-check")
    ap.add_argument("--bin-poses", type=int, default=50,
                    help="poses per bin (~250 Hz poll -> 50 = ~0.2 s bins)")
    ap.add_argument("--save", default=None, help="save raw capture to .npz")
    ap.add_argument("--replay", default=None,
                    help="re-solve from a saved .npz instead of capturing")
    ap.add_argument("--selftest", action="store_true",
                    help="synthetic round-trip check of the solver (no VR)")
    args = ap.parse_args()
    BIN_POSES = args.bin_poses

    if args.selftest:
        selftest(args.config)
        return

    if args.replay:
        d = np.load(args.replay, allow_pickle=False)
        imu = [tuple(r) for r in d["imu"]]
        poses = [(t, R) for t, R in zip(d["pose_t"], d["pose_R"])]
        solve(imu, poses, args.config)
        return

    imu, poses, vr_serial = capture(args.seconds)
    print(f"\ncaptured: {len(imu)} imu samples, {len(poses)} poses")
    if args.save:
        np.savez_compressed(
            args.save, imu=np.array(imu, dtype=np.float64),
            pose_t=np.array([p[0] for p in poses]),
            pose_R=np.array([p[1] for p in poses]), serial=vr_serial)
        print(f"saved {args.save}")
    solve(imu, poses, args.config)


if __name__ == "__main__":
    main()
