# %%
import os
from binhex import binhex
import open3d as o3d
import open3d.visualization
import trimesh
import zipfile
import numpy as np
import math
import xml
import copy
from sklearn.decomposition import PCA
from scipy.spatial.transform import Rotation as R
import time
import xml.etree.ElementTree
import sys
import json
import base64
import ctypes


def get_rotation_matrix(vec2, vec1=np.array([1, 0, 0])):
    """get rotation matrix between two vectors using scipy"""
    vec1 = np.reshape(vec1, (1, -1))
    vec2 = np.reshape(vec2, (1, -1))
    r = R.align_vectors(vec2, vec1)
    return r[0].as_matrix()

def findInZip(z, filepart):
    return next(x for x in z.filelist if filepart in x.filename)

def file(path):
    with zipfile.ZipFile(path) as z:
        landmarks = next(x for x in z.filelist if "faceLandmarks" in x.filename)
        stlzip = next(x for x in z.filelist if "headmesh" in x.filename)

        print(landmarks)
        print()

def show(scene_to_show):
    scene_show = os.getenv('SHOW_SCENE')
    if scene_show is None:
        return
    scene_show_bool = scene_show.lower() in ['true', '1']
    if scene_show_bool:
        if type(scene_to_show) is list:
            o3d.visualization.draw_geometries(scene_to_show)
        else:
            scene_to_show.show()

def alignToMeanFace(mesh):
    pass

payload = {
    "left_pupil": [],
    "right_pupil": [],
    "ipd": -1,
    "mesh": "",
}

# overide defaults
eyeRelief = 13 #13
asymmetry = 0#0
downward_tilt = 8 #8
binocular_far_pd_override = None
order_ipd = -1

# %%
if len(sys.argv) > 1:
    in_path = sys.argv[1]
else:
    in_path = "../../../../tests/fabricator/scan_RhTmPzaAz9XsnuwSOavslBe6N9I.zip"
if len(sys.argv) > 2:
    output_dir = sys.argv[2]
else:
    dir_path = os.path.dirname(os.path.realpath(__file__))
    output_dir = f"{dir_path}/.out"
if len(sys.argv) > 3:
    output_file = sys.argv[3]
else:
    output_file = "output.json"

if len(sys.argv) > 4:
    order_ipd = int(sys.argv[4])
    if order_ipd > 0: # -1 is expected in the case of the order being measure_me
        binocular_far_pd_override = order_ipd

assert len(sys.argv) <= 5, "Too many arguments"

# Create a directory ".out/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

LEFT_COLOR = np.array([0, 0, 255, 255])
RIGHT_COLOR = np.array([255, 0, 0, 255])

# %% [markdown]
# # Extract contents
assert os.path.exists(in_path)
zipf = zipfile.ZipFile(in_path)
registeredLandmarks = zipf.extract(findInZip(zipf, "registeredLandmarks.pp"))
sharedScanMetadata = zipf.extract(findInZip(zipf, "sharedScanMetadata.json"))
obj = zipf.extract(findInZip(zipf, "headMesh.obj"), output_dir)
zipf.extract(findInZip(zipf, "headMesh.mtl"), output_dir)
zipf.extract(findInZip(zipf, "headMesh.jpg"), output_dir)

lm = zipf.extract(findInZip(zipf, "registeredLandmarks.pp"))
tree = xml.etree.ElementTree.parse(lm)
#r = [x for x in tree.getroot().iter('point') if "RPupil" in x.attrib]
lm = tree.findall("point")
rp = next(x.attrib for x in lm if 'RPupil' in x.attrib.values())
lp = next(x.attrib for x in lm if 'LPupil' in x.attrib.values())
right_pupil_near = np.asarray([rp['x'], rp['y'], rp['z']], dtype=float)
left_pupil_near = np.asarray([lp['x'], lp['y'], lp['z']], dtype=float)

scanMetadata = json.load(open(sharedScanMetadata))
right_pupil = np.asarray(scanMetadata['farRightPupil'][:3])
left_pupil = np.asarray(scanMetadata['farLeftPupil'][:3])
binocular_far_pd = np.linalg.norm(left_pupil - right_pupil)

payload["scanned_binocular_far_ipd"] = np.round(binocular_far_pd, 1)
recommended_headset = int(np.round(binocular_far_pd))
payload["headset_sku"] = recommended_headset # overwritten with further logic

if order_ipd > 0: # has an ipd requested by the order
    recommended_headset = order_ipd
    payload["headset_sku"] = recommended_headset
    error_diff = binocular_far_pd - order_ipd
    if error_diff < -3 or error_diff > 1: # flag if the scan is 3mm smaller or 1mm larger than the order
        payload["result"] = "discrepancy" #ipd discrepancy between scan and order

# clamping overrides discrepancy, given an individual might know they are an outlier
clamped_headset_sku = np.clip(recommended_headset, 55, 72)
if clamped_headset_sku != recommended_headset:
    payload["result"] = "clamped" # refer to binocular_far_ipd for scanned value
    payload["headset_sku"] = clamped_headset_sku

# favour 69 over 70 headset sku, clamping and order_ipd override is mutually exclusive so no need to check
if payload["headset_sku"] == 70:
    payload["headset_sku"] = 69
    if order_ipd == 70:
        payload["result"] = "recommend69" # favour 69 over 70 headset sku
        binocular_far_pd_override = 69 # override the binocular_far_pd to 69 for the cushion to be centred.

print(f"headset_sku = {payload['headset_sku']}")

payload["headset_sku"] = str(payload["headset_sku"])
payload["scanned_binocular_far_ipd"] = str(payload["scanned_binocular_far_ipd"])

mesh = trimesh.load(obj)
# %% compute monocularPD
scene = trimesh.Scene()

s = trimesh.primitives.Sphere(radius=2, center=left_pupil_near)
s.visual.face_colors = LEFT_COLOR
scene.add_geometry(s)

s = trimesh.primitives.Sphere(radius=2, center=right_pupil_near)
s.visual.face_colors = RIGHT_COLOR
scene.add_geometry(s)

scene.add_geometry(mesh)

intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
pupil_mean = (left_pupil_near + right_pupil_near) / 2


# new bridge hit
t = time.time()
line_between_pupils = np.linspace(left_pupil_near + [-10, 0, -20], right_pupil_near + [10, 0, -20], 100)
bridge_hits = intersector.intersects_location(line_between_pupils, [np.array([0,0,1]) for i in range(len(line_between_pupils))], multiple_hits=False)[0]
bridge_hit_on_mesh = bridge_hits[np.argmax(bridge_hits[:, 2])]
print(f"time: {time.time() - t}")
for p in bridge_hits:
    s = trimesh.primitives.Sphere(radius=1, center=p)
    s.visual.face_colors = np.array([255, 255, 255, 100])
    scene.add_geometry(s)

s = trimesh.primitives.Sphere(radius=1, center=bridge_hit_on_mesh)
s.visual.face_colors = np.array([255, 255, 255, 255])
scene.add_geometry(s)

t = time.time()
down_bridge_on_mesh = [bridge_hit_on_mesh]
last_rows_hit = bridge_hit_on_mesh
for i in range(1, 30):
    line_behind_nose = np.linspace(down_bridge_on_mesh[0] + [-5, -i, -20], down_bridge_on_mesh[0] + [5, -i, -20], 10)
    line_on_nose = intersector.intersects_location(line_behind_nose,[np.array([0, 0, 1]) for i in range(len(line_behind_nose))], multiple_hits=False)[0]
    nose_peek = line_on_nose[np.argmax(line_on_nose[:, 2])]
    down_bridge_on_mesh.append(nose_peek)
    last_rows_hit = nose_peek

    for p in line_behind_nose:
        s = trimesh.primitives.Sphere(radius=1, center=p)
        s.visual.face_colors = np.array([100, 255, 100, 100])
        scene.add_geometry(s)
print(f"Finding the nose took: {time.time() - t}")



#   ____                        _   _
#  |  _ \  _____      ___ __   | \ | | ___  ___  ___
#  | | | |/ _ \ \ /\ / / '_ \  |  \| |/ _ \/ __|/ _ \
#  | |_| | (_) \ V  V /| | | | | |\  | (_) \__ \  __/
#  |____/ \___/ \_/\_/ |_| |_| |_| \_|\___/|___/\___|
#

increments = list(range(5,30,1))
down_left_pupil_offsets = [left_pupil_near + np.array([0, -i, i]) for i in increments]
down_right_pupil_offsets = [right_pupil_near + np.array([0, -i, i]) for i in increments]

pca = PCA(n_components=1)
pca.fit(down_bridge_on_mesh)
nose_vec = pca.components_[0]
nose_mean = pca.mean_

#convert direction to transform matrix
nose_vec = nose_vec / np.linalg.norm(nose_vec)
ident = np.eye(4)
if np.dot(nose_vec, np.array([0, 0, 1])) > 0:
    rotation = R.from_euler('x', -90, degrees=True)
else:
    rotation = R.from_euler('x', 90, degrees=True)
rotated_nose_vector = rotation.apply(nose_vec)


for hit in down_bridge_on_mesh:
    s = trimesh.primitives.Sphere(radius=1, center=hit)
    s.visual.face_colors = (200, 200, 100, 105)
    scene.add_geometry(s)

mesh_copy = copy.deepcopy(mesh)
scene.add_geometry(mesh_copy)


# ██████╗ ██████╗ ███████╗███████╗███████╗██╗   ██╗██████╗ ███████╗    ███████╗██╗██╗  ██╗
# ██╔══██╗██╔══██╗██╔════╝██╔════╝██╔════╝██║   ██║██╔══██╗██╔════╝    ██╔════╝██║╚██╗██╔╝
# ██████╔╝██████╔╝█████╗  ███████╗███████╗██║   ██║██████╔╝█████╗      █████╗  ██║ ╚███╔╝
# ██╔═══╝ ██╔══██╗██╔══╝  ╚════██║╚════██║██║   ██║██╔══██╗██╔══╝      ██╔══╝  ██║ ██╔██╗
# ██║     ██║  ██║███████╗███████║███████║╚██████╔╝██║  ██║███████╗    ██║     ██║██╔╝ ██╗
# ╚═╝     ╚═╝  ╚═╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝  ╚═╝╚══════╝    ╚═╝     ╚═╝╚═╝  ╚═╝
#

trimesh.repair.fix_inversion(mesh)
mesh_o3d: o3d.geometry.PointCloud = mesh.as_open3d

#mesh_o3d.estimate_normals(mesh_o3d, search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
mesh_o3d.compute_vertex_normals()

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(mesh_o3d.vertices))
pcd.normals = o3d.utility.Vector3dVector(np.asarray(mesh_o3d.vertex_normals))
face_normals = mesh_o3d.vertex_normals

# Bulge Nose Bridge (Reduce Glabella pressure)
max_distance = 25 # 20, 50
min_distances = np.linalg.norm(mesh.vertices - bridge_hit_on_mesh, axis=1)
weights = np.clip((max_distance - min_distances) / max_distance, 0.0, 1.0)
mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(np.array([[weights[i]+0.1, 0.1, 0.1] for i in range(len(mesh.vertices))]))
# show([mesh_o3d])
offsets = weights[:, np.newaxis] * face_normals * 1 # 2mm
#mesh.vertices += offsets

# Bulge Nose Bridge (Reduce nose pressure)
max_distance = 25 # 50, 50, 25
min_distances = np.linalg.norm(mesh.vertices - down_bridge_on_mesh[0], axis=1)
for hit in down_bridge_on_mesh[1:]:
    distances = np.linalg.norm(mesh.vertices - hit, axis=1)
    min_distances = np.fmin(min_distances, distances)

weights = np.clip((max_distance - min_distances) / max_distance, 0.0, 1.0)

dot_products = np.dot(face_normals, rotated_nose_vector)
cos_45_deg = np.cos(np.radians(45))
scaled_cosines = (dot_products - cos_45_deg) / (1 - cos_45_deg)
scaled_cosines = np.clip(1-scaled_cosines, 0, 1)

weights = scaled_cosines * weights
mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(np.array([[weights[i]+0.1, 0.1, 0.1] for i in range(len(mesh.vertices))]))
# show([mesh_o3d])
offsets = weights[:, np.newaxis] * face_normals * 1 # 0.25, 1, 1.5
#mesh.vertices += offsets

s = trimesh.primitives.Sphere(radius=2, center=bridge_hit_on_mesh)
s.visual.face_colors = (0, 255, 0, 255)
scene.add_geometry(s)
scene.add_geometry(mesh)

# show(scene)
#mesh.export('updated_mesh.stl')
#mesh_copy.export('original_mesh.stl')

#exit(1)

left_near_pd = np.linalg.norm(left_pupil_near[:2] - pupil_mean[:2])
right_near_pd = np.linalg.norm(right_pupil_near[:2] - pupil_mean[:2])
aspect = left_near_pd/right_near_pd
left_far_pd = binocular_far_pd/2 * aspect
right_far_pd = binocular_far_pd - left_far_pd
right_pupil, left_pupil = pupil_mean.copy(), pupil_mean.copy()
left_pupil[0] = left_pupil[0] + left_far_pd
right_pupil[0] = right_pupil[0] - right_far_pd

s = trimesh.primitives.Sphere(radius=2, center=left_pupil)
s.visual.face_colors = LEFT_COLOR
scene.add_geometry(s)
s = trimesh.primitives.Sphere(radius=2, center=right_pupil)
s.visual.face_colors = RIGHT_COLOR
scene.add_geometry(s)
print(f"binocular_far_pd {binocular_far_pd:02}")
binocular_far_pd = np.round(binocular_far_pd)
#show(scene)
# %%
#meanFacePath = "/Users/philipkrejov/PycharmProjects/testOpen3d/meanFace.obj"
#meanFace = trimesh.load(meanFacePath)



# %%
scene = trimesh.Scene()
slices = trimesh.Scene()
scene.camera_transform = trimesh.transformations.translation_matrix((0,0,300))

offset = trimesh.transformations.translation_matrix(right_pupil)
rot = get_rotation_matrix(left_pupil - right_pupil)
T = np.identity(4)
T[:3, :3] = rot
mesh.apply_transform(np.linalg.inv(offset))
mesh.apply_transform(np.linalg.inv(T))


mesh_copy.apply_transform(np.linalg.inv(offset))
mesh_copy.apply_transform(np.linalg.inv(T))


# %%

print("binocular_far_pd_override" + str(binocular_far_pd_override))
if binocular_far_pd_override is not None:
    asymmetry = ((binocular_far_pd_override - binocular_far_pd) / 4) + asymmetry
    binocular_far_pd = binocular_far_pd_override

before = trimesh.creation.axis(axis_radius=.1, axis_length=binocular_far_pd)
before.apply_transform(T)
#scene.add_geometry(before)
scene.add_geometry(mesh)
scene.add_geometry(mesh_copy)


dir_path = os.path.dirname(os.path.realpath(__file__))
cushion_subframe_path = f"{dir_path}/CushionSubframe.stl"
cushion = trimesh.load(cushion_subframe_path)
cushion.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
cushion.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
cushion.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
R = trimesh.transformations.rotation_matrix(np.radians(5), (1,0,0))
mesh.apply_transform(R)
mesh_copy.apply_transform(R)
avg_eye_radius = 12
T = trimesh.transformations.translation_matrix([asymmetry,0,avg_eye_radius])
mesh.apply_transform(T)
mesh_copy.apply_transform(T)
R = trimesh.transformations.rotation_matrix(-np.radians(downward_tilt), (1,0,0))
mesh.apply_transform(R)
mesh_copy.apply_transform(R)
T = trimesh.transformations.translation_matrix([asymmetry,0,-eyeRelief])
mesh.apply_transform(T)
mesh_copy.apply_transform(T)
#samples = trimesh.sample.sample_surface_even(cushion, 1000)[0]
#show(scene)


#show(scene)

intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)

# Outer lateral hits
#  ██████  ██    ██ ████████ ███████ ██████      ██       █████  ████████ ███████ ██████   █████  ██          ██   ██ ██ ████████ ███████
# ██    ██ ██    ██    ██    ██      ██   ██     ██      ██   ██    ██    ██      ██   ██ ██   ██ ██          ██   ██ ██    ██    ██
# ██    ██ ██    ██    ██    █████   ██████      ██      ███████    ██    █████   ██████  ███████ ██          ███████ ██    ██    ███████
# ██    ██ ██    ██    ██    ██      ██   ██     ██      ██   ██    ██    ██      ██   ██ ██   ██ ██          ██   ██ ██    ██         ██
#  ██████   ██████     ██    ███████ ██   ██     ███████ ██   ██    ██    ███████ ██   ██ ██   ██ ███████     ██   ██ ██    ██    ███████


cushion_limit = [-28.723,-22.715, -1.663 ]
s = trimesh.primitives.Sphere(radius=1, center=cushion_limit)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
cushion_limit = s.center
scene.add_geometry(s)
right_cushion_limit, left_cushion_limit = intersector.intersects_location([cushion_limit, cushion_limit], np.array([[1, 0, 0], [-1, 0, 0]]), multiple_hits=False)[0]
if right_cushion_limit[0] > left_cushion_limit[0]:
    right_cushion_limit, left_cushion_limit = left_cushion_limit, right_cushion_limit

s = trimesh.primitives.Sphere(radius=1, center=left_cushion_limit)
s.visual.face_colors = LEFT_COLOR
scene.add_geometry(s)

s = trimesh.primitives.Sphere(radius=1, center=right_cushion_limit)
s.visual.face_colors = RIGHT_COLOR
scene.add_geometry(s)

# show(scene)

scene = trimesh.Scene()
# scene.add_geometry(mesh)
scene.add_geometry(mesh_copy)
# colliders_path = f"{dir_path}/Preview.stl"
# colliders = trimesh.load(colliders_path)
# colliders.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
# colliders.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
# colliders.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
# colliders.visual.face_colors = (20, 20, 20, 250)
#scene.add_geometry(colliders)

left_cushion_inner = np.array([8.341, 88.092, -1.73])
left_cushion_outer = np.array([ 0.322, 97.71, -1.725])
right_cushion_inner = np.array([ 8.493,-23.715,-1.663])
right_cushion_outer = np.array([  0.475,-33.332,-1.625])
left_cushion_inner += [0,0,+12]
right_cushion_inner += [0,0,+12]

innner_control_point_distance = 20
outer_control_point_distance = 30

#
# ██      ███████ ███████ ████████     ████████ ███████ ███    ███ ██████  ██      ███████
# ██      ██      ██         ██           ██    ██      ████  ████ ██   ██ ██      ██
# ██      █████   █████      ██           ██    █████   ██ ████ ██ ██████  ██      █████
# ██      ██      ██         ██           ██    ██      ██  ██  ██ ██      ██      ██
# ███████ ███████ ██         ██           ██    ███████ ██      ██ ██      ███████ ███████


# Grow left Temple Inner(Invassive)
s = trimesh.primitives.Sphere(radius=1, center=left_cushion_inner)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s.visual.face_colors = LEFT_COLOR
left_cushion_inner = s.center
scene.add_geometry(s)

left_temple_inner = intersector.intersects_location([left_cushion_inner], np.array([[0, 0, -1]]), multiple_hits=False)[0]

distances = np.linalg.norm(mesh.vertices - left_temple_inner, axis=1)
weights_inner = np.clip((innner_control_point_distance - distances) / max_distance, 0.0, 1.0)

s = trimesh.primitives.Sphere(radius=1, center=left_temple_inner)
s.visual.face_colors = LEFT_COLOR
scene.add_geometry(s)

# Shrink left Temple Outer(light leakage)
s = trimesh.primitives.Sphere(radius=1, center=left_cushion_outer)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s.visual.face_colors = LEFT_COLOR
left_cushion_outer = s.center
scene.add_geometry(s)

left_temple_outer = intersector.intersects_location([left_cushion_outer], np.array([[0, 0, -1]]), multiple_hits=False)[0][0]
if left_temple_outer[2] < left_cushion_limit[2]:
    left_temple_outer = left_cushion_limit
s = trimesh.primitives.Sphere(radius=1, center=left_temple_outer)
s.visual.face_colors = LEFT_COLOR - (0, 0, 0, 100)
scene.add_geometry(s)

distances = np.linalg.norm(mesh.vertices - left_temple_outer, axis=1)
weights_outer = np.clip((outer_control_point_distance - distances) / max_distance, 0.0, 1.0)

weights = 3 * weights_inner + -1 * weights_outer

mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(np.array([[1.5*weights_inner[i]+0.1, weights_outer[i]+0.1, 0.1] for i in range(len(mesh.vertices))]))
#show([mesh_o3d])

offsets = weights[:, np.newaxis] * face_normals * 0.25 # 2mm
#mesh.vertices += offsets

# show(scene)
# ██████  ██  ██████  ██   ██ ████████     ████████ ███████ ███    ███ ██████  ██      ███████
# ██   ██ ██ ██       ██   ██    ██           ██    ██      ████  ████ ██   ██ ██      ██
# ██████  ██ ██   ███ ███████    ██           ██    █████   ██ ████ ██ ██████  ██      █████
# ██   ██ ██ ██    ██ ██   ██    ██           ██    ██      ██  ██  ██ ██      ██      ██
# ██   ██ ██  ██████  ██   ██    ██           ██    ███████ ██      ██ ██      ███████ ███████

# Grow right Temple Inner(Invassive)
s = trimesh.primitives.Sphere(radius=1, center=right_cushion_inner)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s.visual.face_colors = RIGHT_COLOR
right_cushion_inner = s.center
scene.add_geometry(s)

right_temple_inner = intersector.intersects_location([right_cushion_inner], np.array([[0, 0, -1]]), multiple_hits=False)[0]

distances = np.linalg.norm(mesh.vertices - right_temple_inner, axis=1)
weights_inner = np.clip((innner_control_point_distance - distances) / max_distance, 0.0, 1.0)

s = trimesh.primitives.Sphere(radius=1, center=right_temple_inner)
s.visual.face_colors = RIGHT_COLOR
scene.add_geometry(s)

# Shrink right Temple Outer(light leakage)
s = trimesh.primitives.Sphere(radius=1, center=right_cushion_outer)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s.visual.face_colors = RIGHT_COLOR
right_cushion_outer = s.center
scene.add_geometry(s)

right_temple_outer = intersector.intersects_location([right_cushion_outer], np.array([[0, 0, -1]]), multiple_hits=False)[0][0]
if right_temple_outer[2] < right_cushion_limit[2]:
    right_temple_outer = right_cushion_limit
s = trimesh.primitives.Sphere(radius=1, center=right_temple_outer)
s.visual.face_colors = RIGHT_COLOR - (0, 0, 0, 100)
scene.add_geometry(s)

distances = np.linalg.norm(mesh.vertices - right_temple_outer, axis=1)
weights_outer = np.clip((outer_control_point_distance - distances) / max_distance, 0.0, 1.0)

weights = 3 * weights_inner + -1 * weights_outer

mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(np.array([[1.5*weights_inner[i]+0.1, weights_outer[i]+0.1, 0.1] for i in range(len(mesh.vertices))]))
#show([mesh_o3d])

offsets = weights[:, np.newaxis] * face_normals * 0.25 # 2mm
#mesh.vertices += offsets

# show(scene)

# ██████  ██████   ██████  ██   ██ ██ ███    ███ ██ ████████ ██    ██
# ██   ██ ██   ██ ██    ██  ██ ██  ██ ████  ████ ██    ██     ██  ██
# ██████  ██████  ██    ██   ███   ██ ██ ████ ██ ██    ██      ████
# ██      ██   ██ ██    ██  ██ ██  ██ ██  ██  ██ ██    ██       ██
# ██      ██   ██  ██████  ██   ██ ██ ██      ██ ██    ██       ██

proximity_pos = (33.672, 32.035, -12.542)
s = trimesh.primitives.Sphere(radius=1, center=proximity_pos)
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s.visual.face_colors = (255, 255, 255, 100)

scene.add_geometry(s)

proximity_lookat = (21.274,32.035,-18.323)
s_la = trimesh.primitives.Sphere(radius=1, center=proximity_lookat)
s_la.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,1,0)))
s_la.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (0,0,1)))
s_la.apply_transform(trimesh.transformations.translation_matrix(((binocular_far_pd-64)/2,0,0)))
s_la.visual.face_colors = (255, 255, 255, 100)

#prox_hit = intersector.intersects_location([s.center], [s_la.center - s.center])[0]
#print(np.linalg.norm(s.center - prox_hit))
#s = trimesh.primitives.Sphere(radius=1, center=prox_hit)
#s.visual.face_colors = (255, 255, 255, 100)
#scene.add_geometry(s)

# %%
# for i in samples:
#     s = trimesh.primitives.Sphere(radius=.5, center=i)
#     s.visual.face_colors = (255,0,0,255)
#     scene.add_geometry(s)

scene.add_geometry(cushion)
scene.add_geometry(trimesh.creation.axis(axis_radius=.1, axis_length=binocular_far_pd))
for i in [left_pupil, right_pupil]:
    l = trimesh.primitives.Sphere(radius=1, center=i)
    l.visual.face_colors = (255,0,255,255)
    l.apply_transform(np.linalg.inv(offset))
    l.apply_transform(np.linalg.inv(T))
    scene.add_geometry(l)

trimesh.transformations.translation_matrix((i))

show(scene)

# %%
mesh.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (1,0,0)))
mesh.apply_scale(.1)

mesh_copy.apply_transform(trimesh.transformations.rotation_matrix(-math.pi/2, (1,0,0)))
mesh_copy.apply_scale(.1)

l = os.listdir(output_dir)
for p in l:
    if p.endswith(".stl"):
        os.remove(os.path.join(output_dir, p))
stl_export = mesh.export(os.path.join(output_dir, "mesh.stl"))
#_ = mesh_copy.export(os.path.join(output_dir, "mesh_original.stl"))

# Export bytes to base64
payload["mesh"] = base64.b64encode(stl_export).decode("utf-8")
payload["result"] = "success"
payload["ipd"] = payload["headset_sku"]

# write payload to disk as json 
with open(os.path.join(output_dir, output_file), "w") as f:
    json.dump(payload, f)

print("saved output")
