"""Spike S3 empirical analysis: discover Valve's actual POLY3 evaluation.

The 4 naive hypotheses all failed because the source-UV space is the
GetProjectionRaw frustum, not a grow-rescaled NDC. This script fits the
general radial model per eye/channel:

    t   = (2u-1, 2v-1)            panel NDC (v down, OpenVR viewport)
    r2  = |t - c|^2
    tan = a + s * (t - c) * d(r2)   (s, a 2-vector; tan = raw-projection
                                     tangent space, y-down: tan = T+v*(B-T))

with d either  poly: 1 + k1 r2 + k2 r2^2 + k3 r2^3
        or    recip: 1 / (1 + k1 r2 + k2 r2^2 + k3 r2^3)

Free params per channel: c(2) a(2) s(2) k(3) = 9. Ground truth: 4225 pts.
Then compares fitted c,k against the config block and fitted a,s against
intrinsics/projection_raw to recover the closed form.

Usage: python analyze.py distortion_dump_*.json config_*.json
"""

import json
import sys

import numpy as np
from scipy.optimize import least_squares

CHANNELS = {"r": "distortion_red", "g": "distortion", "b": "distortion_blue"}


def load_eye_points(eye_dump):
    uv, tgt = [], {"r": [], "g": [], "b": []}
    for pt in eye_dump["points"]:
        if pt is None:
            continue
        uv.append(pt["uv"])
        for ch in ("r", "g", "b"):
            tgt[ch].append(pt[ch])
    uv = np.array(uv)
    L, R, T, B = eye_dump["projection_raw"]
    tans = {}
    for ch in ("r", "g", "b"):
        suv = np.array(tgt[ch])
        # source UV -> tangent space of the raw projection (v=0 top -> T)
        tans[ch] = np.stack([L + suv[:, 0] * (R - L), T + suv[:, 1] * (B - T)], axis=1)
    return uv, tans


def model(params, t, form):
    cx, cy, ax, ay, sx, sy, k1, k2, k3 = params
    d0 = t - np.array([cx, cy])
    r2 = (d0 ** 2).sum(axis=1)
    p = 1 + r2 * (k1 + r2 * (k2 + r2 * k3))
    d = p if form == "poly" else 1.0 / p
    return np.array([ax, ay]) + np.array([sx, sy]) * d0 * d[:, None]


def fit(uv, tan, form, k0):
    t = 2 * uv - 1
    x0 = np.array([0.0, 0.0, 0.0, 0.0, 0.7, 0.7] + list(k0))

    def resid(params):
        return (model(params, t, form) - tan).ravel()

    return least_squares(resid, x0, method="lm", max_nfev=20000)


def main():
    dump = json.load(open(sys.argv[1]))
    cfg = json.load(open(sys.argv[2]))
    rt = dump["recommended_render_target"][0]

    for eye_idx, eye_dump in enumerate(dump["eyes"]):
        tte = cfg["tracking_to_eye_transform"][eye_idx]
        K = tte["intrinsics"]
        L, R, T, B = eye_dump["projection_raw"]
        print(f"\n=== {eye_dump['eye'].upper()} ===")
        print(f"projection_raw L{L:+.5f} R{R:+.5f} T{T:+.5f} B{B:+.5f}")
        print(f"intrinsics fx {K[0][0]:.5f} fy {K[1][1]:.5f} K02 {K[0][2]:+.5f} K12 {K[1][2]:+.5f}")
        print("config eye_to_head:")
        for row in tte["eye_to_head"]:
            print("   ", [round(x, 6) for x in row])
        print("dump eye_to_head:")
        for row in eye_dump["eye_to_head"]:
            print("   ", [round(x, 6) for x in row])

        uv, tans = load_eye_points(eye_dump)
        for ch in ("g", "r", "b"):
            blk = tte[CHANNELS[ch]]
            kcfg = blk["coeffs"]
            for form in ("poly", "recip"):
                res = fit(uv, tans[ch], form, kcfg)
                cx, cy, ax, ay, sx, sy, k1, k2, k3 = res.x
                # residual in source pixels
                err = res.fun.reshape(-1, 2)
                err_px = np.linalg.norm(err, axis=1) * rt / abs(R - L)
                print(f" {ch}/{form:5s} rms {err_px.std() + err_px.mean():7.3f}px max {err_px.max():8.3f}px | "
                      f"c ({cx:+.5f},{cy:+.5f}) a ({ax:+.5f},{ay:+.5f}) s ({sx:+.5f},{sy:+.5f})")
                print(f"          k_fit  [{k1:+.5f}, {k2:+.5f}, {k3:+.5f}]")
                print(f"          k_cfg  [{kcfg[0]:+.5f}, {kcfg[1]:+.5f}, {kcfg[2]:+.5f}]  "
                      f"center_cfg ({blk['center_x']:+.5f},{blk['center_y']:+.5f})  "
                      f"1/fx {1/K[0][0]:.5f} 1/fy {1/K[1][1]:.5f}")


if __name__ == "__main__":
    main()
