"""Generate a pre-distorted grid test image for the eyes-in S3 check.

For every panel pixel (per eye, per color channel) evaluate the verified
S3 warp (poly3.py, poly+grow form) to get the source UV, then sample a
procedural grid that is STRAIGHT in source (tangent-linear) space. Through
the Beyond's lens the grid must therefore look straight, and because each
channel is warped with its own coefficients, the lines must show minimal
color fringing. Wrong formula direction / wrong eye / flipped v shows up
as curved lines, rainbow fringes, or an upside-down marker.

Pattern (source UV space, per eye):
  - grid lines every 0.05, intensity white
  - thicker center cross at (0.5, 0.5)
  - filled triangle ABOVE center pointing up (settles v orientation eyes-in)
  - left eye additionally gets a small filled square left of center
    (settles eye assignment)
  - outside source [0,1]: black (matches the grown render target bounds)

Output: raw B8G8R8X8 bytes, row-major, W*H*4 — loaded by s2_nvapi --image.

Usage: python make_warped_grid.py <config.json> [out.raw] [W H]
       (default 5088 2544 — the validated native DM mode, left half = left eye)
"""

import json
import sys

import numpy as np

from poly3 import load_config

GRID_STEP = 0.05
LINE_HW = 0.0018      # line half-width in source UV
CROSS_HW = 0.005      # center cross half-width
GAMMA = 2.2


def grid_intensity(su, sv):
    """Procedural pattern in source UV space -> intensity [0,1] array."""
    inside = (su >= 0) & (su <= 1) & (sv >= 0) & (sv <= 1)
    # grid lines
    du = np.abs(((su / GRID_STEP) + 0.5) % 1.0 - 0.5) * GRID_STEP
    dv = np.abs(((sv / GRID_STEP) + 0.5) % 1.0 - 0.5) * GRID_STEP
    lines = (du < LINE_HW) | (dv < LINE_HW)
    # center cross (thicker)
    cross = (np.abs(su - 0.5) < CROSS_HW) | (np.abs(sv - 0.5) < CROSS_HW)
    # up marker: filled triangle, apex at v=0.30, base at v=0.38 (v=0 is top)
    tri_h = 0.08
    tv = (sv - 0.30) / tri_h
    tri = (tv >= 0) & (tv <= 1) & (np.abs(su - 0.5) < 0.04 * tv)
    out = np.where(lines | cross | tri, 1.0, 0.06)  # faint floor shows extent
    return np.where(inside, out, 0.0)


def left_eye_marker(su, sv):
    """Small filled square left of center, left eye only."""
    return ((np.abs(su - 0.35) < 0.02) & (np.abs(sv - 0.5) < 0.02)).astype(float)


def render_eye(model, W, H):
    """Return (H, W, 3) float intensity for one eye panel."""
    u = (np.arange(W) + 0.5) / W
    v = (np.arange(H) + 0.5) / H
    uu, vv = np.meshgrid(u, v)
    img = np.zeros((H, W, 3), dtype=np.float32)
    scale = 0.5 / (1.0 + model.grow)
    for ci, ch in enumerate(model.channels):  # r, g, b
        tx = 2.0 * uu - 1.0 - ch["cx"]
        ty = 2.0 * vv - 1.0 - ch["cy"]
        r2 = tx * tx + ty * ty
        k1, k2, k3 = ch["k"]
        d = 1.0 + r2 * (k1 + r2 * (k2 + r2 * k3))
        su = 0.5 + (tx * d + ch["cx"]) * scale
        sv = 0.5 + (ty * d + ch["cy"]) * scale
        img[:, :, ci] = grid_intensity(su, sv)
    return img


def main():
    cfg_path = sys.argv[1]
    out_path = sys.argv[2] if len(sys.argv) > 2 else "warped_grid.raw"
    W = int(sys.argv[3]) if len(sys.argv) > 4 else 5088
    H = int(sys.argv[4]) if len(sys.argv) > 4 else 2544
    half = W // 2

    left, right, cfg = load_config(cfg_path)
    print(f"unit {cfg.get('device_serial_number')}  image {W}x{H} ({half}x{H}/eye)")

    frame = np.zeros((H, W, 3), dtype=np.float32)
    for eye_idx, model in ((0, left), (1, right)):
        img = render_eye(model, half, H)
        if eye_idx == 0:
            # left-eye marker rendered in green channel source space
            scale = 0.5 / (1.0 + model.grow)
            u = (np.arange(half) + 0.5) / half
            v = (np.arange(H) + 0.5) / H
            uu, vv = np.meshgrid(u, v)
            ch = model.channels[1]
            tx = 2.0 * uu - 1.0 - ch["cx"]
            ty = 2.0 * vv - 1.0 - ch["cy"]
            r2 = tx * tx + ty * ty
            k1, k2, k3 = ch["k"]
            d = 1.0 + r2 * (k1 + r2 * (k2 + r2 * k3))
            su = 0.5 + (tx * d + ch["cx"]) * scale
            sv = 0.5 + (ty * d + ch["cy"]) * scale
            m = left_eye_marker(su, sv)
            img = np.maximum(img, m[:, :, None])
        frame[:, eye_idx * half:(eye_idx + 1) * half, :] = img

    # gamma encode, pack B8G8R8X8 little-endian (byte order B,G,R,X)
    enc = (np.clip(frame, 0, 1) ** (1.0 / GAMMA) * 255.0 + 0.5).astype(np.uint8)
    out = np.zeros((H, W, 4), dtype=np.uint8)
    out[:, :, 0] = enc[:, :, 2]  # B
    out[:, :, 1] = enc[:, :, 1]  # G
    out[:, :, 2] = enc[:, :, 0]  # R
    out.tofile(out_path)
    print(f"wrote {out_path} ({out.nbytes} bytes)")


if __name__ == "__main__":
    main()
