"""Standalone Beyond distortion evaluator — spike S3.

Maps panel/output UV -> pre-distortion render-target UV, per color channel,
from the device config JSON (tracking_to_eye_transform block). No SteamVR.

The open question this spike settles: production B1 configs use type
DISTORT_POLY3 with POSITIVE leading coeffs (+0.2438 on this unit), while
EVT/DVT used DISTORT_DPOLY3 with NEGATIVE leading coeffs (-0.22). Public
reimplementations (Monado u_compute_distortion_vive) evaluate the Vive
DPOLY3 family as a RECIPROCAL: d = 1/(1 + k1*r2 + k2*r4 + k3*r6).
Hypothesis: POLY3 is the same curve fitted the other way round, i.e.
d = 1 + k1*r2 + k2*r4 + k3*r6 directly. Four candidate forms are
implemented; verify.py picks the one matching OpenVR ComputeDistortion.

Hypothesis axes:
  form: 'poly'  -> d = 1 + k1*r2 + k2*r4 + k3*r6
        'recip' -> d = 1 / (1 + k1*r2 + k2*r4 + k3*r6)
  grow: True  -> output scale 0.5/(1+grow_for_undistort)   (Monado)
        False -> output scale 0.5                          (no grow rescale)
"""

import json

CHANNELS = ("distortion_red", "distortion", "distortion_blue")  # r, g, b
HYPOTHESES = [
    ("poly+grow", "poly", True),
    ("poly", "poly", False),
    ("recip+grow", "recip", True),
    ("recip", "recip", False),
]


class EyeDistortion:
    """Distortion model for one eye, built from tracking_to_eye_transform[eye]."""

    def __init__(self, tte):
        self.grow = tte["grow_for_undistort"]
        self.r2_cutoff = tte["undistort_r2_cutoff"]
        # center is shared across channels within an eye on the Beyond,
        # but read per channel anyway (B2 may differ).
        self.channels = []
        for key in CHANNELS:
            block = tte[key]
            self.channels.append({
                "cx": block["center_x"],
                "cy": block["center_y"],
                "k": list(block["coeffs"]),
                "type": block["type"],
            })

    def compute(self, u, v, form="poly", use_grow=True, aspect_x_over_y=1.0):
        """Panel UV (0..1, origin upper-left per OpenVR) -> source UV per channel.

        Returns ((ur, vr), (ug, vg), (ub, vb)).
        """
        scale = 0.5 / (1.0 + self.grow) if use_grow else 0.5
        out = []
        for ch in self.channels:
            tx = 2.0 * u - 1.0
            ty = (2.0 * v - 1.0) / aspect_x_over_y
            tx -= ch["cx"]
            ty -= ch["cy"]
            # NOTE (S3 finding): ComputeDistortion does NOT clamp r2 at
            # undistort_r2_cutoff — ground truth fits exactly with no clamp
            # across the full panel (corner r2 ~ 2.1 > cutoff 1.0). The
            # cutoff presumably gates mesh/hidden-area logic elsewhere.
            r2 = tx * tx + ty * ty
            k1, k2, k3 = ch["k"]
            p = 1.0 + r2 * (k1 + r2 * (k2 + r2 * k3))
            d = p if form == "poly" else 1.0 / p
            sx = 0.5 + (tx * d + ch["cx"]) * scale
            sy = 0.5 + (ty * d + ch["cy"]) * scale * aspect_x_over_y
            out.append((sx, sy))
        return tuple(out)


def load_config(path):
    """Load device config.json -> (left EyeDistortion, right EyeDistortion, raw dict)."""
    with open(path, "r") as f:
        cfg = json.load(f)
    tte = cfg["tracking_to_eye_transform"]
    return EyeDistortion(tte[0]), EyeDistortion(tte[1]), cfg
