"""M2 gate 4: verify the closed-form config -> GetProjectionRaw fold.

S3 left this open: the driver folds the eye_to_head canting into the
asymmetric projection (parallel_render_cameras), and M2 needs that frustum
for ANY unit, not just ones we dumped. Hypothesis (derived by hand against
the LHR-599F3B91 dump, verified here on every available unit):

    G        = 1 + grow_for_undistort          (distorted-NDC half-range)
    tau_x(q) = (q - c_x) / K00                 (eye-frame tangent at NDC q)
    tau_y(q) = (q - c_y) / K11
    tphi_x   = R02 / R22                       (e2h cant, yaw component)
    tphi_y   = R12 / R22                       (e2h cant, pitch component)

    L = (tau_x(-G) - tphi_x) / (1 + tau_x(-G) * tphi_x)
    R = (tau_x(+G) - tphi_x) / (1 + tau_x(+G) * tphi_x)
    T = -(tau_y(+G) + tphi_y) / (1 - tau_y(+G) * tphi_y)    (y flip: v=0 top)
    B = -(tau_y(-G) + tphi_y) / (1 - tau_y(-G) * tphi_y)

i.e. each axis is a 1D rotation homography (tan subtraction) applied to the
grow-expanded intrinsics frustum edge, with the eye's cant angle taken from
the eye_to_head rotation column z. Note the interior of the RT stays AFFINE
in q (ComputeDistortion, S3-exact); only the edge tangents get the exact
homography. That mismatch is the driver's own approximation -- replicating
it bit-for-bit is the "replicate the fold" branch of the S3 warning.

Usage: python derive_projection_fold.py <config.json dump.json> [...]
       (defaults to both units in spikes/s3-distortion-truth)
"""

import json
import sys


def predict(cfg, eye_idx):
    tte = cfg["tracking_to_eye_transform"][eye_idx]
    K = tte["intrinsics"]
    e2h = tte["eye_to_head"]
    G = 1.0 + tte["grow_for_undistort"]
    cx, cy = -K[0][2], K[1][2]
    tphi_x = e2h[0][2] / e2h[2][2]
    tphi_y = e2h[1][2] / e2h[2][2]

    def tau_x(q):
        return (q - cx) / K[0][0]

    def tau_y(q):
        return (q - cy) / K[1][1]

    def fold(t, tphi):
        return (t - tphi) / (1.0 + t * tphi)

    L = fold(tau_x(-G), tphi_x)
    R = fold(tau_x(+G), tphi_x)
    T = -fold(tau_y(+G), -tphi_y)
    B = -fold(tau_y(-G), -tphi_y)
    return L, R, T, B


def main():
    pairs = sys.argv[1:]
    if not pairs:
        base = "spikes/s3-distortion-truth"
        pairs = [f"{base}/config_lhr_599f3b91.json",
                 f"{base}/distortion_dump_lhr_599f3b91.json",
                 f"{base}/config_lhr_1f8e25f1.json",
                 f"{base}/distortion_dump_lhr_1f8e25f1.json"]
    if len(pairs) % 2:
        raise SystemExit(__doc__)

    worst = 0.0
    for i in range(0, len(pairs), 2):
        cfg = json.load(open(pairs[i]))
        dump = json.load(open(pairs[i + 1]))
        print(f"\n{cfg['device_serial_number']}:")
        for e, eye in enumerate(dump["eyes"]):
            pred = predict(cfg, e)
            got = eye["projection_raw"]
            diff = max(abs(p - g) for p, g in zip(pred, got))
            worst = max(worst, diff)
            print(f"  {eye['eye']:<5} dump L{got[0]:+.5f} R{got[1]:+.5f} "
                  f"T{got[2]:+.5f} B{got[3]:+.5f}")
            print(f"        pred L{pred[0]:+.5f} R{pred[1]:+.5f} "
                  f"T{pred[2]:+.5f} B{pred[3]:+.5f}   max|diff| {diff:.6f}")

    print(f"\nworst |diff| across all eyes/units: {worst:.6f} tan units "
          f"-> {'PASS' if worst < 5e-4 else 'FAIL'} (criterion 5e-4)")


if __name__ == "__main__":
    main()
