"""Spike S3 verdict: test each evaluator hypothesis against the OpenVR dump.

Usage:
  python verify.py <distortion_dump_*.json> <config.json>
  python verify.py --selftest

Error metric: source-UV delta scaled to panel pixels (x2544) — the units the
warp actually renders in. Pass criterion (PLAN.md): sub-pixel, i.e. winning
hypothesis max error < 1.0 px across both eyes, all three channels.

Also cross-checks projection_raw vs config intrinsics and eye_to_head
rotation vs config (canting), since the dump carries them for free.
"""

import json
import math
import sys

from poly3 import EyeDistortion, HYPOTHESES, load_config

PANEL_PX = 2544  # per-eye panel pixels (eye_target_width_in_pixels)


def test_hypothesis(dump, eyes_cfg, form, use_grow):
    """Return dict of per-channel (rms_px, max_px) plus overall max."""
    stats = {}
    overall_max = 0.0
    for eye_idx, eye_dump in enumerate(dump["eyes"]):
        model = eyes_cfg[eye_idx]
        for ch_idx, ch_name in enumerate(("r", "g", "b")):
            se = 0.0
            mx = 0.0
            count = 0
            for pt in eye_dump["points"]:
                if pt is None:
                    continue
                u, v = pt["uv"]
                ours = model.compute(u, v, form=form, use_grow=use_grow)[ch_idx]
                theirs = pt[ch_name]
                dx = (ours[0] - theirs[0]) * PANEL_PX
                dy = (ours[1] - theirs[1]) * PANEL_PX
                e2 = dx * dx + dy * dy
                se += e2
                if e2 > mx:
                    mx = e2
                count += 1
            rms = math.sqrt(se / count) if count else float("nan")
            mxe = math.sqrt(mx)
            stats[(eye_dump["eye"], ch_name)] = (rms, mxe)
            overall_max = max(overall_max, mxe)
    stats["overall_max"] = overall_max
    return stats


def cross_checks(dump, cfg):
    print("\nCross-checks (dump vs config):")
    for eye_idx, eye_dump in enumerate(dump["eyes"]):
        tte = cfg["tracking_to_eye_transform"][eye_idx]
        K = tte["intrinsics"]
        fx, cx = K[0][0], -K[0][2]
        fy, cy = K[1][1], -K[1][2]
        # OpenVR raw projection: tangents left/right/top/bottom.
        # From intrinsics: left = (-1 - cx)/fx ... (Valve convention cy sign
        # varies; report both so the dump settles it.)
        exp_l, exp_r = (-1 - cx) / fx, (1 - cx) / fx
        exp_t, exp_b = (-1 - cy) / fy, (1 - cy) / fy
        L, R, T, B = eye_dump["projection_raw"]
        print(f"  {eye_dump['eye']} projection_raw  dump L{L:+.5f} R{R:+.5f} T{T:+.5f} B{B:+.5f}")
        print(f"  {eye_dump['eye']} from intrinsics      L{exp_l:+.5f} R{exp_r:+.5f} T{exp_t:+.5f} B{exp_b:+.5f}")
        e2h_cfg = tte["eye_to_head"]
        e2h_dump = eye_dump["eye_to_head"]
        rot_err = max(
            abs(e2h_cfg[r][c] - e2h_dump[r][c]) for r in range(3) for c in range(3)
        )
        tx = e2h_dump[0][3] * 1000.0
        print(f"  {eye_dump['eye']} eye_to_head rot max-abs-diff {rot_err:.2e}; "
              f"translation x {tx:+.3f} mm (ipd.default_mm/2 expected)")
    ipd = cfg.get("ipd", {})
    print(f"  config ipd block: {ipd}")


def selftest():
    """Synthetic round trip: generate a fake dump from a known hypothesis and
    confirm verify picks exactly that hypothesis with ~0 error."""
    tte = {
        "grow_for_undistort": 0.6,
        "undistort_r2_cutoff": 1.0,
        "distortion": {"center_x": 0.01, "center_y": -0.012,
                       "coeffs": [0.2438, 0.0976, 0.0808], "type": "DISTORT_POLY3"},
        "distortion_red": {"center_x": 0.01, "center_y": -0.012,
                           "coeffs": [0.2159, 0.0887, 0.1174], "type": "DISTORT_POLY3"},
        "distortion_blue": {"center_x": 0.01, "center_y": -0.012,
                            "coeffs": [0.2805, -0.0059, 0.1701], "type": "DISTORT_POLY3"},
    }
    cfg = {"tracking_to_eye_transform": [tte, tte]}
    eyes = [EyeDistortion(tte), EyeDistortion(tte)]
    truth_form, truth_grow = "recip", True  # pretend Valve uses Monado form

    n = 17
    dump = {"eyes": []}
    for eye_idx, name in enumerate(("left", "right")):
        points = []
        for j in range(n):
            for i in range(n):
                u, v = i / (n - 1), j / (n - 1)
                r, g, b = eyes[eye_idx].compute(u, v, form=truth_form, use_grow=truth_grow)
                points.append({"uv": [u, v], "r": list(r), "g": list(g), "b": list(b)})
        dump["eyes"].append({"eye": name, "points": points})

    results = {name: test_hypothesis(dump, eyes, form, grow)
               for name, form, grow in HYPOTHESES}
    report(results)
    winner = min(results, key=lambda k: results[k]["overall_max"])
    assert winner == "recip+grow", f"selftest failed: winner {winner}"
    assert results[winner]["overall_max"] < 1e-9
    # All other hypotheses must be clearly distinguishable (not sub-pixel).
    for name, st in results.items():
        if name != winner:
            assert st["overall_max"] > 1.0, f"hypothesis {name} not distinguishable"
    print("\nSELFTEST PASS — winner recip+grow as constructed, others distinguishable")


def report(results):
    print(f"\n{'hypothesis':<12} {'overall max px':>14}   per-channel rms px (L/R averaged)")
    for name, st in results.items():
        per_ch = []
        for ch in ("r", "g", "b"):
            vals = [st[k][0] for k in st if isinstance(k, tuple) and k[1] == ch]
            per_ch.append(f"{ch}:{sum(vals)/len(vals):8.3f}")
        print(f"{name:<12} {st['overall_max']:>14.4f}   {' '.join(per_ch)}")


def main():
    if "--selftest" in sys.argv:
        selftest()
        return
    if len(sys.argv) != 3:
        print(__doc__)
        sys.exit(2)

    with open(sys.argv[1]) as f:
        dump = json.load(f)
    left, right, cfg = load_config(sys.argv[2])
    eyes_cfg = [left, right]

    serial_cfg = cfg.get("device_serial_number", "?")
    print(f"dump unit: {dump.get('serial')}   config unit: {serial_cfg}")
    if dump.get("serial", "").lower() != serial_cfg.lower():
        print("WARNING: serial mismatch — dump and config are from different units!")

    results = {name: test_hypothesis(dump, eyes_cfg, form, grow)
               for name, form, grow in HYPOTHESES}
    report(results)

    winner = min(results, key=lambda k: results[k]["overall_max"])
    wmax = results[winner]["overall_max"]
    verdict = "PASS" if wmax < 1.0 else "FAIL"
    print(f"\nWinner: {winner}  (overall max {wmax:.4f} px)  ->  {verdict} "
          f"(criterion: sub-pixel)")

    cross_checks(dump, cfg)


if __name__ == "__main__":
    main()
