import logging
import time
import steamvr.steamvr
from steamvr.unbuffered.lighthouse import readline
from steamvr.unbuffered.process import UnbufferedProcess
from dataclasses import dataclass
from typing import List, Callable, Protocol
import json
import tempfile as tf
import os
from dacite import from_dict
from threading import Thread
from steamvr.paths import LHCALIB_PATH, INIT_JSON_CONFIG_PATH
import polling2

logger = logging.getLogger(__name__)

@dataclass(frozen=True)
class Imu:
    acc_bias: List[float]
    acc_scale: List[float]
    gyro_bias: List[float]


@dataclass(frozen=True)
class ImuSample:
    acc: List[float]
    flags: int
    gyro: List[float]
    time: float


@dataclass(frozen=True)
class ImuObservation:
    device_serial_number: str
    file_tracking: dict
    imu_samples: List[dict]


@dataclass(frozen=True)
class ImuSolution:
    device_serial_number: str
    file_tracking: dict
    imu: Imu
    imu_solve: dict

    @property
    def acc_fit_error(self) -> float:
        return self.imu_solve['acc_fit_error']


def parse_imu_solve(path) -> ImuSolution:
    assert os.path.exists(path)
    with open(path) as f:
        return from_dict(ImuSolution, json.load(f))


class ImuCalibrationProcess(Thread):
    def __init__(self, device_serial, seconds, on_completion: Callable[[ImuSolution], None]):
        super(ImuCalibrationProcess, self).__init__()
        self.tempdir = tf.mkdtemp()
        self.on_completion = on_completion
        self.device_serial = device_serial
        cmd = f'{LHCALIB_PATH} /outputdir {self.tempdir} /imucal {seconds} /bodyserial {device_serial} /imucalautoexit'
        self.process = UnbufferedProcess(cmd)

    def run(self):
        self.process.run()
        while True:
            line = readline(self.process.stdout)
            if line is None:
                break
            logger.getChild(self.device_serial).info(line)
            if "No devices" in line:
                raise RuntimeError("No Lighthouse devices connected")
            if "Wrote IMU calibration" in line:
                solve = parse_imu_solve(line.split()[-1])
                self.on_completion(solve)


@dataclass(frozen=True)
class LightHouseCalibration:
    channelMap: List[int]
    modelNormals: List[List[float]]
    modelPoints: List[List[float]]


@dataclass(frozen=True)
class PhotoDiodeSolution:
    lighthouse_config: LightHouseCalibration
    device_serial_number: str

    @staticmethod
    def from_file(path):
        assert os.path.exists(path)
        assert os.path.isfile(path)
        with open(path, 'r') as f:
            return from_dict(PhotoDiodeSolution, json.load(f))


class LockedLogger(Protocol):
    def lock(self):
        pass

    def info(self, line):
        pass


class PhotoDiodeCalibrationProcess(Thread):
    def __init__(self, device_serial, on_completion: Callable[[PhotoDiodeSolution], None], min_observations=800,
                 min_hit_per_sensor=200, locked_logger: LockedLogger = None):
        super().__init__()
        self.log = locked_logger
        self.on_completion = on_completion
        self.process_dir = tf.TemporaryDirectory()

        # create temporary json file with device serial
        config = steamvr.steamvr.load_steamvr_json(INIT_JSON_CONFIG_PATH)
        steamvr.steamvr.set_serial_number(config, device_serial)
        config_path = self.process_dir.name + '/config.json'
        steamvr.steamvr.save_steamvr_json(config, config_path)
        time.sleep(1)
        self.device_serial = device_serial
        cmd = LHCALIB_PATH + f' /bodyserial {device_serial}' \
                             f' /usedisambiguation synconbeam' \
                             f' /bodycalmulti 2' \
                             f' /bodycal {config_path}' \
                             f' {min_observations} {min_hit_per_sensor}' \
                             f' /outputdir {self.process_dir.name}'

        self.process = UnbufferedProcess(cmd)

    def __del__(self):
        self.process_dir.cleanup()

    def parse_line(self, line: str):
        if line is None:
            return
        if line:
            logger.getChild(self.device_serial).info(line)
        if "Ready to run capture position number 0" in line:
            os.write(self.process.stdin, b'\r\n')

    def run(self):
        super().run()
        self.process.run()
        while True:
            line = readline(self.process.stdout)
            self.parse_line(line)
            if line is None:
                file_path = self.process_dir.name + f'/auto_{self.device_serial}.json'
                polling2.poll(target=lambda: os.path.exists(file_path), step=1, timeout=20)
                result = PhotoDiodeSolution.from_file(file_path)
                assert result.device_serial_number == self.device_serial
                self.on_completion(result)
                break
