from __future__ import annotations

import argparse
import asyncio
import json
import subprocess
import sys
import traceback
from datetime import datetime
from pathlib import Path
from typing import Any, cast

from attractor import (
    Aborted,
    AgentToolUse,
    Checkpointed,
    Engine,
    EngineError,
    HumanGate,
    InputCopy,
    NodeOutcome,
    NodeStarted,
    ParseError,
    RunCompleted,
    RunStarted,
    UnknownChoice,
    ValidationFailed,
    parse,
    validate,
)
from attractor.cli.workflow_resolver import resolve_workflow_path


def _emit(payload: dict[str, Any]) -> None:
    print(json.dumps(payload), flush=True)


def _read_payload() -> dict[str, Any]:
    raw = sys.stdin.read().strip()
    if not raw:
        return {}
    payload_obj = json.loads(raw)
    if not isinstance(payload_obj, dict):
        raise ValueError("bridge payload must be a JSON object")
    return cast(dict[str, Any], payload_obj)


def _resolve_repo_root(cwd: Path, payload: dict[str, Any]) -> Path:
    repo_root = payload.get("repo_root")
    if isinstance(repo_root, str) and repo_root.strip():
        return Path(repo_root).expanduser().resolve()

    proc = subprocess.run(
        ["git", "-C", str(cwd), "rev-parse", "--show-toplevel"],
        capture_output=True,
        check=False,
        text=True,
    )
    if proc.returncode == 0:
        return Path(proc.stdout.strip()).resolve()
    return cwd.resolve()


def _resolve_path(raw_path: str, cwd: Path) -> Path:
    path = Path(raw_path).expanduser()
    if not path.is_absolute():
        path = cwd / path
    return path.resolve()


def _resolve_workflow_arg(raw_path: str, cwd: Path, payload: dict[str, Any]) -> Path:
    repo_root = _resolve_repo_root(cwd, payload)
    return resolve_workflow_path(raw_path, project_root=repo_root).resolve()


def _serialize_timestamp(value: datetime | None) -> str | None:
    return value.isoformat() if value is not None else None


def _serialize_run_handle(handle: Any) -> dict[str, Any]:
    return {
        "run_id": handle.run_id,
        "status": handle.status.value,
        "worktree_path": str(handle.worktree_path),
        "started_at": _serialize_timestamp(handle.started_at),
    }


def _serialize_summary(summary: Any) -> dict[str, Any]:
    return {
        "run_id": summary.run_id,
        "status": summary.status.value,
        "current_node": summary.current_node,
        "started_at": _serialize_timestamp(summary.started_at),
        "last_updated_at": _serialize_timestamp(summary.last_updated_at),
        "total_tokens_in": summary.total_tokens_in,
        "total_tokens_out": summary.total_tokens_out,
        "paused_prompt": summary.paused_prompt,
        "paused_choices": list(summary.paused_choices),
        "history": [
            {
                "node_id": record.node_id,
                "visit": record.visit,
                "status": record.status.value,
                "captured_output": record.captured_output,
                "duration_ms": record.duration_ms,
                "tokens_in": record.tokens_in,
                "tokens_out": record.tokens_out,
                "timestamp": record.timestamp.isoformat(),
                "preferred_label": record.preferred_label,
                "suggested_next_ids": list(record.suggested_next_ids)
                if record.suggested_next_ids is not None
                else None,
            }
            for record in summary.history
        ],
    }


def _serialize_event(event: object) -> dict[str, Any]:
    if isinstance(event, RunStarted):
        return {
            "type": "event",
            "event": "run_started",
            "run_id": event.run_id,
        }
    if isinstance(event, NodeStarted):
        return {
            "type": "event",
            "event": "node_started",
            "run_id": event.run_id,
            "node_id": event.node_id,
            "kind": event.kind.value,
            "visit": event.visit,
        }
    if isinstance(event, NodeOutcome):
        return {
            "type": "event",
            "event": "node_outcome",
            "run_id": event.run_id,
            "node_id": event.node_id,
            "status": event.status.value,
            "captured_output": event.captured_output,
            "preferred_label": event.preferred_label,
            "suggested_next_ids": list(event.suggested_next_ids)
            if event.suggested_next_ids is not None
            else None,
        }
    if isinstance(event, HumanGate):
        return {
            "type": "event",
            "event": "human_gate",
            "run_id": event.run_id,
            "node_id": event.node_id,
            "prompt": event.prompt,
            "choices": list(event.choices),
        }
    if isinstance(event, Checkpointed):
        return {
            "type": "event",
            "event": "checkpointed",
            "run_id": event.run_id,
            "state_commit": event.state_commit,
            "worktree_commit": event.worktree_commit,
        }
    if isinstance(event, AgentToolUse):
        return {
            "type": "event",
            "event": "agent_tool_use",
            "run_id": event.run_id,
            "node_id": event.node_id,
            "tool_name": event.tool_name,
            "args_preview": event.args_preview,
        }
    if isinstance(event, RunCompleted):
        return {
            "type": "event",
            "event": "run_completed",
            "run_id": event.run_id,
            "status": event.status,
        }
    if isinstance(event, Aborted):
        return {
            "type": "event",
            "event": "aborted",
            "run_id": event.run_id,
            "reason": event.reason,
        }
    raise TypeError(f"unsupported event type: {type(event)!r}")


def _emit_exception(exc: Exception) -> int:
    if isinstance(exc, ParseError):
        _emit(
            {
                "type": "error",
                "error_type": "parse_error",
                "message": exc.message,
                "line": exc.line,
                "column": exc.column,
            }
        )
        return 1
    if isinstance(exc, ValidationFailed):
        _emit(
            {
                "type": "error",
                "error_type": "validation_failed",
                "message": "workflow validation failed",
                "errors": [
                    {
                        "line": item.line,
                        "column": item.column,
                        "message": item.message,
                    }
                    for item in exc.errors
                ],
            }
        )
        return 1
    if isinstance(exc, UnknownChoice):
        _emit(
            {
                "type": "error",
                "error_type": "unknown_choice",
                "message": str(exc),
                "valid_choices": list(exc.valid_choices),
            }
        )
        return 1
    if isinstance(exc, EngineError):
        _emit(
            {
                "type": "error",
                "error_type": exc.__class__.__name__.lower(),
                "message": str(exc),
            }
        )
        return 1
    if isinstance(exc, OSError):
        _emit(
            {
                "type": "error",
                "error_type": "os_error",
                "message": str(exc),
            }
        )
        return 1

    traceback.print_exc(file=sys.stderr)
    _emit(
        {
            "type": "error",
            "error_type": exc.__class__.__name__.lower(),
            "message": str(exc),
        }
    )
    return 1


# [impl->REQ-API-IN-PROCESS]
def _load_valid_workflow(workflow_path: Path) -> Any:
    source = workflow_path.read_text(encoding="utf-8")
    return validate(parse(source))


# [impl->REQ-API-IN-PROCESS]
def _build_inputs(payload: dict[str, Any], cwd: Path) -> list[InputCopy]:
    items_obj = payload.get("inputs", [])
    if not isinstance(items_obj, list):
        raise ValueError("inputs must be a JSON array")
    items = cast(list[object], items_obj)

    inputs: list[InputCopy] = []
    for item_obj in items:
        if not isinstance(item_obj, dict):
            raise ValueError("each input must be an object with name and source")
        item = cast(dict[str, Any], item_obj)
        name = item.get("name")
        source = item.get("source")
        if not isinstance(name, str) or not name.strip():
            raise ValueError("input.name must be a non-empty string")
        if not isinstance(source, str) or not source.strip():
            raise ValueError("input.source must be a non-empty string")
        inputs.append(InputCopy(name=name, source=_resolve_path(source, cwd)))
    return inputs


async def _command_validate(payload: dict[str, Any], cwd: Path) -> None:
    workflow_path = _resolve_workflow_arg(str(payload["workflow_path"]), cwd, payload)
    graph = _load_valid_workflow(workflow_path)
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "validate",
            "workflow_path": str(workflow_path),
            "graph_name": graph.name,
            "node_count": len(graph.nodes),
            "edge_count": len(graph.edges),
            "nodes": [
                {
                    "id": node.id,
                    "kind": node.kind.value,
                    "label": node.label,
                }
                for node in graph.nodes
            ],
        }
    )


# [impl->REQ-API-IN-PROCESS]
async def _command_run(payload: dict[str, Any], cwd: Path) -> None:
    workflow_path = _resolve_workflow_arg(str(payload["workflow_path"]), cwd, payload)
    graph = _load_valid_workflow(workflow_path)
    repo_root = _resolve_repo_root(cwd, payload)
    inputs = _build_inputs(payload, cwd)
    base_ref = str(payload.get("base_ref", "HEAD"))

    engine = Engine(repo_root)
    events: list[dict[str, Any]] = []

    def on_event(event: object) -> None:
        serialized = _serialize_event(event)
        events.append(serialized)
        _emit(serialized)

    session = await engine.start(
        graph,
        inputs=inputs,
        events=on_event,
        base_ref=base_ref,
    )
    status = await session.wait()
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "run",
            "run_id": session.run_id,
            "status": status.value,
            "events": events,
            "summary": _serialize_summary(engine.show(session.run_id)),
            "repo_root": str(repo_root),
            "workflow_path": str(workflow_path),
        }
    )


# [impl->REQ-API-IN-PROCESS]
async def _command_resume(payload: dict[str, Any], cwd: Path) -> None:
    run_id = str(payload["run_id"])
    repo_root = _resolve_repo_root(cwd, payload)
    engine = Engine(repo_root)
    events: list[dict[str, Any]] = []

    def on_event(event: object) -> None:
        serialized = _serialize_event(event)
        events.append(serialized)
        _emit(serialized)

    status = await engine.resume(run_id, events=on_event)
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "resume",
            "run_id": run_id,
            "status": status.value,
            "events": events,
            "summary": _serialize_summary(engine.show(run_id)),
            "repo_root": str(repo_root),
        }
    )


# [impl->REQ-API-IN-PROCESS]
async def _command_respond(payload: dict[str, Any], cwd: Path) -> None:
    run_id = str(payload["run_id"])
    choice = str(payload["choice"])
    reason_value = payload.get("reason")
    reason = str(reason_value) if isinstance(reason_value, str) else None
    repo_root = _resolve_repo_root(cwd, payload)
    engine = Engine(repo_root)
    await engine.respond(run_id, choice=choice, reason=reason)
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "respond",
            "run_id": run_id,
            "choice": choice,
            "reason": reason,
            "summary": _serialize_summary(engine.show(run_id)),
            "repo_root": str(repo_root),
        }
    )


# [impl->REQ-API-IN-PROCESS]
async def _command_list(payload: dict[str, Any], cwd: Path) -> None:
    repo_root = _resolve_repo_root(cwd, payload)
    engine = Engine(repo_root)
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "list",
            "repo_root": str(repo_root),
            "runs": [
                _serialize_run_handle(handle) for handle in engine.list()
            ],
        }
    )


# [impl->REQ-API-IN-PROCESS]
async def _command_show(payload: dict[str, Any], cwd: Path) -> None:
    run_id = str(payload["run_id"])
    repo_root = _resolve_repo_root(cwd, payload)
    engine = Engine(repo_root)
    _emit(
        {
            "type": "result",
            "ok": True,
            "command": "show",
            "repo_root": str(repo_root),
            "run": _serialize_summary(engine.show(run_id)),
        }
    )


def _str_list(payload: dict[str, Any], key: str) -> list[str]:
    value = payload.get(key, [])
    if value is None:
        return []
    if not isinstance(value, list):
        raise ValueError(f"{key} must be a JSON array")
    items: list[str] = []
    for item in cast(list[object], value):
        if not isinstance(item, str) or not item.strip():
            raise ValueError(f"{key} entries must be non-empty strings")
        items.append(item)
    return items


async def _command_prune(payload: dict[str, Any], cwd: Path) -> None:
    repo_root = _resolve_repo_root(cwd, payload)

    cli_args_obj = payload.get("cli_args")
    if cli_args_obj is not None:
        if not isinstance(cli_args_obj, list):
            raise ValueError("cli_args must be a JSON array")
        cli_args = []
        for item in cast(list[object], cli_args_obj):
            if not isinstance(item, str):
                raise ValueError("cli_args entries must be strings")
            cli_args.append(item)
    else:
        cli_args = []
        for run_id in _str_list(payload, "run_ids"):
            cli_args.extend(["--run-id", run_id])
        for status in _str_list(payload, "statuses"):
            cli_args.extend(["--status", status])
        older_than = payload.get("older_than")
        if isinstance(older_than, str) and older_than.strip():
            cli_args.extend(["--older-than", older_than])
        if payload.get("all_completed") is True:
            cli_args.append("--all-completed")
        if payload.get("dry_run") is True:
            cli_args.append("--dry-run")
        if payload.get("force") is True:
            cli_args.append("--force")

    proc = await asyncio.to_thread(
        subprocess.run,
        ["uv", "run", "attractor", "prune", *cli_args],
        cwd=repo_root,
        capture_output=True,
        text=True,
        check=False,
    )
    _emit(
        {
            "type": "result",
            "ok": proc.returncode == 0,
            "command": "prune",
            "repo_root": str(repo_root),
            "exit_code": proc.returncode,
            "stdout": proc.stdout,
            "stderr": proc.stderr,
            "args": cli_args,
        }
    )


_COMMANDS = {
    "validate": _command_validate,
    "run": _command_run,
    "resume": _command_resume,
    "respond": _command_respond,
    "list": _command_list,
    "show": _command_show,
    "prune": _command_prune,
}


async def _dispatch(command: str, payload: dict[str, Any]) -> int:
    cwd = Path.cwd()
    handler = _COMMANDS[command]
    await handler(payload, cwd)
    return 0


def main() -> int:
    parser = argparse.ArgumentParser(
        description=(
            "Bridge pi's TypeScript extension runtime to Attractor's "
            "stable Python host API."
        )
    )
    parser.add_argument("command", choices=sorted(_COMMANDS))
    args = parser.parse_args()

    try:
        payload = _read_payload()
        return asyncio.run(_dispatch(args.command, payload))
    except Exception as exc:  # pragma: no cover - exercised via subprocess
        return _emit_exception(exc)


if __name__ == "__main__":
    raise SystemExit(main())
