from extended_mode_display.extended_mode_display import ExtendedModeDisplay
import numpy as np
from nicegui import ui, binding
import cv2
import os
import shutil
# from numba import jit

# @jit(nopython=True)
# def srgb_to_linear(im):
#     height, width, channels = im.shape
#     linear_im = np.empty_like(im)
#     for i in range(height):
#         for j in range(width):
#             for k in range(channels):
#                 if im[i, j, k] <= 0.04045:
#                     linear_im[i, j, k] = im[i, j, k] / 12.92
#                 else:
#                     linear_im[i, j, k] = ((im[i, j, k] + 0.055) / 1.055) ** 2.4
#     return linear_im

# @jit(nopython=True)
# def linear_to_srgb(im):
#     height, width, channels = im.shape
#     srgb_im = np.empty_like(im)
#     for i in range(height):
#         for j in range(width):
#             for k in range(channels):
#                 if im[i, j, k] <= 0.0031308:
#                     srgb_im[i, j, k] = im[i, j, k] * 12.92
#                 else:
#                     srgb_im[i, j, k] = 1.055 * im[i, j, k] ** (1/2.4) - 0.055
#     return srgb_im

channel = 1

class Channel:
    offset = binding.BindableProperty()
    repeat = binding.BindableProperty()
    drop = binding.BindableProperty()
    width = binding.BindableProperty()

    def __init__(self):
        self.offset = 5
        self.repeat = 6
        self.drop = 252
        self.width = 3

red = Channel()
green = Channel()
blue = Channel()

display = None

linear_test_images = []
test_images_path = "./test_images"
test_image_names = ["green",
                    "red",
                    "blue",
                    "yellow",
                    "cyan",
                    "purple",
                    "orange",
                    "gray",
                    "RGB",
                    ]
test_image_paths = [f"{test_images_path}/{name}.png" for name in test_image_names]

for path in test_image_paths:
    im = cv2.imread(path).astype(np.float32) / 255
    linear_test_images.append(srgb_to_linear(im))


test_im = linear_test_images[0]

def gen_bars():
    global red, green, blue
    correction = np.ones((1920, 1920, 3), dtype=np.float32)
    for c, color in enumerate([red, green, blue]):
        for i in range(color.width):
            offset = int(color.offset + i)
            repeat = int(color.repeat)
            correction[:, offset+i::repeat, c] = color.drop/255
    return correction


def apply_correction(im):
    correction = gen_bars()
    im = im * correction 
    im = (linear_to_srgb(im) * 255).astype(np.uint8)
    return im

def display_bars():
    bars = apply_correction(test_im)
    display.display_side_by_side(bars, np.zeros((1920, 1920, 3), dtype=np.uint8))


def write_to_file():
    bars = apply_correction(test_im)
    # bars[bars < 255] = 0
    cv2.imwrite("mura.png", bars[:, :, channel])

    # Replace the separate rename and move operations with a single move
    dst_path = "C:/Program Files (x86)/Steam/config/lighthouse/lhr-13a0784a/userdata/mura.mc"
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    shutil.copy("mura.png", dst_path)


@ui.page('/')
def bars_page():
    global offset, repeat, display
    display = ExtendedModeDisplay(resolution=(1920*2, 1920))
    display_bars()
    
    with ui.dropdown_button("test_image", auto_close=True):
        ui.item("green", on_click=lambda e: update_test_image("green"))
        ui.item("red", on_click=lambda e: update_test_image("red"))
        ui.item("blue", on_click=lambda e: update_test_image("blue"))
        ui.item("yellow", on_click=lambda e: update_test_image("yellow"))
        ui.item("cyan", on_click=lambda e: update_test_image("cyan"))
        ui.item("purple", on_click=lambda e: update_test_image("purple"))
        ui.item("orange", on_click=lambda e: update_test_image("orange"))
        ui.item("gray", on_click=lambda e: update_test_image("gray"))
        ui.item("RGB", on_click=lambda e: update_test_image("RGB"))

    with ui.dropdown_button("channel", auto_close=True):
        ui.item("channel Red", on_click=lambda e: update_channel(2))
        ui.item("channel Green", on_click=lambda e: update_channel(1))
        ui.item("channel Blue", on_click=lambda e: update_channel(0))
            

    ui.label('Bars Generator')
    ui.label('Offset')
    with ui.tabs().classes('w-full') as tabs:
        one = ui.tab("Red")
        two = ui.tab("Green")
        three = ui.tab("Blue")

    with ui.tab_panels(tabs, value=two).classes('w-full'):
        with ui.tab_panel(one):
            ui.slider(min=0, max=8, value=red.offset, step=1, on_change=display_bars).bind_value(red, "offset")
            ui.label('Repeat')
            ui.slider(min=2, max=16, value=red.repeat, step=1, on_change=display_bars).bind_value(red, "repeat")
            ui.label('Drop')
            ui.slider(min=200, max=255, value=red.drop, step=1, on_change=display_bars).bind_value(red, "drop")
            ui.label('Width')
            ui.slider(min=1, max=red.repeat, value=red.width, step=1, on_change=display_bars).bind_value(red, "width")
        with ui.tab_panel(two):
            ui.slider(min=0, max=8, value=green.offset, step=1, on_change=display_bars).bind_value(green, "offset")
            ui.label('Repeat')
            ui.slider(min=2, max=16, value=green.repeat, step=1, on_change=display_bars).bind_value(green, "repeat")
            ui.label('Drop')
            ui.slider(min=200, max=255, value=green.drop, step=1, on_change=display_bars).bind_value(green, "drop")
            ui.label('Width')
            ui.slider(min=1, max=green.repeat, value=green.width, step=1, on_change=display_bars).bind_value(green, "width")
        with ui.tab_panel(three):
            ui.slider(min=0, max=8, value=blue.offset, step=1, on_change=display_bars).bind_value(blue, "offset")
            ui.label('Repeat')
            ui.slider(min=2, max=16, value=blue.repeat, step=1, on_change=display_bars).bind_value(blue, "repeat")
            ui.label('Drop')
            ui.slider(min=200, max=255, value=blue.drop, step=1, on_change=display_bars).bind_value(blue, "drop")
            ui.label('Width')
            ui.slider(min=1, max=blue.repeat, value=blue.width, step=1, on_change=display_bars).bind_value(blue, "width")
    ui.button('Write to file', on_click=write_to_file)


def update_test_image(value):
    global test_im
    print(value, test_image_names.index(value))
    test_im = linear_test_images[test_image_names.index(value)]
    display_bars()  


ui.run()
