"""Stereo field audit: measure binocular disparity of the rendered+warped
output across the visual field, against the DRIVER's dumped ground truth.

Answers "is the stereo geometry right?" with numbers instead of eyeballs:
for a flat fronto-parallel quad at distance Z rendered with parallel
cameras at separation IPD, every world feature must show horizontal
disparity EXACTLY IPD/Z (tan units) and vertical disparity EXACTLY 0,
uniform across the whole field. Any per-eye mapping bug (swapped frusta,
wrong center, doubled IPD, v-flip) shows as a deviation pattern.

Method: predict each checker-corner's panel position per eye (world ->
texture via the render window, texture -> panel via Newton-inverted POLY3),
refine with subpixel saddle correlation on the actual panel image, then map
the DETECTED position through the dumped ComputeDistortion grid (bilinear)
and projection_raw window into tangent space. Left-minus-right per corner.

Inputs: a panel raw (from warp_check on render_core dumps at identity pose)
plus the unit's config + OpenVR dump. Run:
  render_core --config cfg.json --pose 0,0,0 --dump /tmp/st   (BMP -> raw)
  warp_check --config cfg.json --left l.raw --right r.raw --out panel.raw --linear
  python stereo_field_audit.py cfg.json dump.json panel.raw Z IPD QUADW QUADH

Result on LHR-1F8E25F1 (2026-06-10, M2): 113 corners, ecc 0..0.5 tan,
d_tanx +0.0421 +/- 0.0001 (expect +0.0420), d_tany <= 0.0002. Stereo field
verified uniform to ~0.3 arcmin — the eyes-in divergence complaint was not
render geometry (see session notes: pupil swim / motion candidates).
"""

import itertools
import json
import os
import sys

import numpy as np

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "spikes",
                                "s3-distortion-truth"))
from poly3 import load_config  # noqa: E402

PANEL_W, PANEL_H, HALF = 5088, 2544, 2544


def main():
    cfg_path, dump_path, panel_path = sys.argv[1:4]
    Z = float(sys.argv[4]) if len(sys.argv) > 4 else 2.0
    IPD = float(sys.argv[5]) if len(sys.argv) > 5 else 0.063
    QW = float(sys.argv[6]) if len(sys.argv) > 6 else 2.1333
    QH = float(sys.argv[7]) if len(sys.argv) > 7 else 1.2

    left, right, cfg = load_config(cfg_path)
    models = [left, right]
    dump = json.load(open(dump_path))
    panel = np.fromfile(panel_path, dtype=np.uint8).reshape(
        PANEL_H, PANEL_W, 4)[:, :, 1].astype(float)  # green channel

    def window(eye):
        return dump["eyes"][eye]["projection_raw"]

    def world_to_tex(eye, wx, wy):
        L, R, T, B = window(eye)
        ex = -IPD / 2 if eye == 0 else IPD / 2
        tx, ty = (wx - ex) / Z, wy / Z
        return (tx - L) / (R - L), (-ty - T) / (B - T)

    def tex_to_panel(eye, su, sv):
        m = models[eye]
        ch = m.channels[1]
        s = 0.5 / (1 + m.grow)
        cx, cy = ch["cx"], ch["cy"]
        k1, k2, k3 = ch["k"]
        u, v = su, sv
        for _ in range(8):
            tx, ty = 2 * u - 1 - cx, 2 * v - 1 - cy
            r2 = tx * tx + ty * ty
            d = 1 + r2 * (k1 + r2 * (k2 + r2 * k3))
            dp = k1 + 2 * k2 * r2 + 3 * k3 * r2 * r2
            fu = 0.5 + (tx * d + cx) * s - su
            fv = 0.5 + (ty * d + cy) * s - sv
            J00 = 2 * s * (d + tx * dp * 2 * tx)
            J01 = 2 * s * (tx * dp * 2 * ty)
            J10 = 2 * s * (ty * dp * 2 * tx)
            J11 = 2 * s * (d + ty * dp * 2 * ty)
            det = J00 * J11 - J01 * J10
            u -= (J11 * fu - J01 * fv) / det
            v -= (-J10 * fu + J00 * fv) / det
        return u * HALF - 0.5, v * PANEL_H - 0.5

    def saddle_refine(img, px, py, half=14):
        x0, y0 = int(round(px)), int(round(py))
        if (x0 < half + 6 or y0 < half + 6 or x0 >= img.shape[1] - half - 6 or
                y0 >= img.shape[0] - half - 6):
            return None
        yy, xx = np.mgrid[-half:half + 1, -half:half + 1]
        kern = np.sign(xx) * np.sign(yy)
        best, bx, by = -1e18, 0, 0
        scores = {}
        for dy in range(-4, 5):
            for dx in range(-4, 5):
                w = img[y0 + dy - half:y0 + dy + half + 1,
                        x0 + dx - half:x0 + dx + half + 1]
                s = abs(((w - w.mean()) * kern).sum())
                scores[(dx, dy)] = s
                if s > best:
                    best, bx, by = s, dx, dy
        if abs(bx) > 3 or abs(by) > 3:
            return None

        def para(m1, c0, p1):
            d = m1 - 2 * c0 + p1
            return 0.5 * (m1 - p1) / d if d < 0 else 0.0

        sx = para(scores[(bx - 1, by)], scores[(bx, by)], scores[(bx + 1, by)])
        sy = para(scores[(bx, by - 1)], scores[(bx, by)], scores[(bx, by + 1)])
        return x0 + bx + sx, y0 + by + sy

    grids = []
    for e in (0, 1):
        eye = dump["eyes"][e]
        n = dump["grid"]
        grids.append(np.array([[p["g"] for p in eye["points"][j * n:(j + 1) * n]]
                               for j in range(n)]))

    def to_tan(e, pt):
        g = grids[e]
        n = dump["grid"]
        L, R, T, B = window(e)
        u = (pt[0] + 0.5) / HALF
        v = (pt[1] + 0.5) / PANEL_H
        fu, fv = u * (n - 1), v * (n - 1)
        i0, j0 = min(int(fu), n - 2), min(int(fv), n - 2)
        a, b = fu - i0, fv - j0
        src = (g[j0, i0] * (1 - a) * (1 - b) + g[j0, i0 + 1] * a * (1 - b) +
               g[j0 + 1, i0] * (1 - a) * b + g[j0 + 1, i0 + 1] * a * b)
        return np.array([L + src[0] * (R - L), T + src[1] * (B - T)])

    rows = []
    for wy in np.arange(-QH / 2 + 0.1, QH / 2 - 0.099, 0.1):
        for wx in np.arange(-QW / 2 + 0.1, QW / 2 - 0.099, 0.1):
            tans = []
            for e in (0, 1):
                su, sv = world_to_tex(e, wx, wy)
                px, py = tex_to_panel(e, su, sv)
                r = saddle_refine(panel[:, e * HALF:(e + 1) * HALF], px, py)
                tans.append(None if r is None else to_tan(e, r))
            if tans[0] is None or tans[1] is None:
                continue
            d = tans[0] - tans[1]
            rows.append((np.hypot(wx, wy) / Z, d[0], d[1]))

    rows.sort()
    print(f"{len(rows)} corners.  expect d_tanx={IPD/Z:+.4f} uniform, d_tany=0")
    print(f"{'ecc(tan)':>8} {'d_tanx':>9} {'d_tany':>9}")
    for ecc, grp in itertools.groupby(rows, key=lambda r: round(r[0], 1)):
        g = list(grp)
        print(f"{ecc:>8.1f} {np.mean([r[1] for r in g]):+9.5f} "
              f"{np.mean([r[2] for r in g]):+9.5f}   (n={len(g)}, "
              f"sd_x {np.std([r[1] for r in g]):.5f})")


if __name__ == "__main__":
    main()
