#!/usr/bin/env python3
"""
Standalone LLM endpoint probe. Does not import or execute pi.

Examples:
  # Test the ChatGPT/Codex backend using Pi's saved OAuth token, but outside Pi:
  python scripts/llm_endpoint_probe.py --provider codex --model gpt-5.5 --prompt "write one word"

  # Test Cerebras OpenAI-compatible endpoint:
  set CEREBRAS_API_KEY=...
  python scripts/llm_endpoint_probe.py --provider cerebras --model zai-glm-4.7 --prompt "write one word"

  # Test any OpenAI-compatible endpoint:
  set OPENAI_API_KEY=...
  python scripts/llm_endpoint_probe.py --provider openai-compatible --base-url https://api.example.com/v1 --model model-id --prompt "hi"
"""

from __future__ import annotations

import argparse
import base64
import json
import os
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any

PI_AUTH = Path.home() / ".pi" / "agent" / "auth.json"
CODEX_AUTH_CLAIM = "https://api.openai.com/auth"


def load_pi_auth(provider: str) -> dict[str, Any]:
    if not PI_AUTH.exists():
        raise SystemExit(f"No Pi auth file found at {PI_AUTH}")
    data = json.loads(PI_AUTH.read_text(encoding="utf-8"))
    auth = data.get(provider)
    if not isinstance(auth, dict):
        raise SystemExit(f"No auth entry for provider {provider!r} in {PI_AUTH}")
    return auth


def b64url_decode(segment: str) -> bytes:
    segment += "=" * (-len(segment) % 4)
    return base64.urlsafe_b64decode(segment.encode("ascii"))


def jwt_payload(token: str) -> dict[str, Any]:
    parts = token.split(".")
    if len(parts) != 3:
        raise ValueError("token is not a JWT")
    return json.loads(b64url_decode(parts[1]))


def codex_access_token() -> str:
    token = os.environ.get("CODEX_ACCESS_TOKEN")
    if token:
        return token
    auth = load_pi_auth("openai-codex")
    token = auth.get("access")
    if not isinstance(token, str) or not token:
        raise SystemExit("No openai-codex access token found. Set CODEX_ACCESS_TOKEN or login in Pi.")
    return token


def codex_account_id(token: str) -> str:
    payload = jwt_payload(token)
    account_id = (payload.get(CODEX_AUTH_CLAIM) or {}).get("chatgpt_account_id")
    if not account_id:
        raise SystemExit("Could not extract chatgpt_account_id from Codex token")
    return account_id


def redact_headers(headers: Any) -> dict[str, str]:
    out: dict[str, str] = {}
    for k, v in dict(headers).items():
        lk = k.lower()
        if lk in {"authorization", "x-api-key", "api-key"}:
            out[k] = "<redacted>"
        else:
            out[k] = str(v)
    return out


def request_json(url: str, headers: dict[str, str], body: dict[str, Any], timeout: int) -> tuple[int, dict[str, str], str]:
    req = urllib.request.Request(
        url,
        data=json.dumps(body).encode("utf-8"),
        headers=headers,
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            raw = resp.read().decode("utf-8", errors="replace")
            return resp.status, redact_headers(resp.headers), raw
    except urllib.error.HTTPError as e:
        raw = e.read().decode("utf-8", errors="replace")
        return e.code, redact_headers(e.headers), raw


def iter_sse(url: str, headers: dict[str, str], body: dict[str, Any], timeout: int):
    req = urllib.request.Request(
        url,
        data=json.dumps(body).encode("utf-8"),
        headers=headers,
        method="POST",
    )
    try:
        resp = urllib.request.urlopen(req, timeout=timeout)
    except urllib.error.HTTPError as e:
        raw = e.read().decode("utf-8", errors="replace")
        return e.code, redact_headers(e.headers), raw, []

    events = []
    text_parts: list[str] = []
    with resp:
        status = resp.status
        headers_out = redact_headers(resp.headers)
        event_lines: list[str] = []
        while True:
            line_b = resp.readline()
            if not line_b:
                break
            line = line_b.decode("utf-8", errors="replace").rstrip("\r\n")
            if line == "":
                if event_lines:
                    handle_sse_event(event_lines, events, text_parts)
                    event_lines = []
                continue
            event_lines.append(line)
        if event_lines:
            handle_sse_event(event_lines, events, text_parts)
    return status, headers_out, "".join(text_parts), events


def handle_sse_event(lines: list[str], events: list[dict[str, Any]], text_parts: list[str]) -> None:
    data_chunks = []
    event_name = None
    for line in lines:
        if line.startswith("event:"):
            event_name = line[6:].strip()
        elif line.startswith("data:"):
            data_chunks.append(line[5:].strip())
    data = "\n".join(data_chunks).strip()
    if not data or data == "[DONE]":
        return
    try:
        obj = json.loads(data)
    except json.JSONDecodeError:
        events.append({"event": event_name, "raw": data[:500]})
        return
    events.append(obj)
    typ = obj.get("type")
    if typ in {"response.output_text.delta", "response.refusal.delta"} and isinstance(obj.get("delta"), str):
        text_parts.append(obj["delta"])
    # Some endpoints send final text only in completed response objects.
    if typ in {"response.completed", "response.done"}:
        final = extract_text(obj)
        if final and not text_parts:
            text_parts.append(final)


def extract_text(obj: Any) -> str:
    found: list[str] = []

    def walk(x: Any) -> None:
        if isinstance(x, dict):
            if x.get("type") in {"output_text", "text"} and isinstance(x.get("text"), str):
                found.append(x["text"])
            elif isinstance(x.get("content"), str):
                found.append(x["content"])
            for v in x.values():
                walk(v)
        elif isinstance(x, list):
            for v in x:
                walk(v)

    walk(obj)
    return "".join(found)


def run_codex(args: argparse.Namespace) -> int:
    token = codex_access_token()
    account_id = codex_account_id(token)
    url = args.base_url.rstrip("/") if args.base_url else "https://chatgpt.com/backend-api/codex/responses"
    body = {
        "model": args.model,
        "store": False,
        "stream": True,
        "instructions": "You are a helpful assistant.",
        "input": [{"role": "user", "content": [{"type": "input_text", "text": args.prompt}]}],
        "text": {"verbosity": "low"},
        "include": ["reasoning.encrypted_content"],
        "tool_choice": "auto",
        "parallel_tool_calls": True,
    }
    if args.reasoning_effort:
        body["reasoning"] = {"effort": args.reasoning_effort, "summary": "auto"}
    headers = {
        "Authorization": f"Bearer {token}",
        "chatgpt-account-id": account_id,
        "originator": args.originator,
        "OpenAI-Beta": "responses=experimental",
        "accept": "text/event-stream",
        "content-type": "application/json",
        "User-Agent": "standalone-llm-endpoint-probe/1.0",
    }
    status, headers_out, text, events = iter_sse(url, headers, body, args.timeout)
    print_result(args.provider, args.model, status, headers_out, text, events)
    return 0 if 200 <= status < 300 else 1


def run_openai_compatible(args: argparse.Namespace, default_base_url: str, api_key_env: str) -> int:
    api_key = os.environ.get(args.api_key_env or api_key_env)
    if not api_key:
        raise SystemExit(f"Missing API key env var: {args.api_key_env or api_key_env}")
    base = (args.base_url or default_base_url).rstrip("/")
    url = f"{base}/chat/completions"
    body: dict[str, Any] = {
        "model": args.model,
        "messages": [{"role": "user", "content": args.prompt}],
        "stream": False,
    }
    if args.max_tokens:
        body["max_completion_tokens"] = args.max_tokens
    headers = {
        "Authorization": f"Bearer {api_key}",
        "content-type": "application/json",
        "User-Agent": "standalone-llm-endpoint-probe/1.0",
    }
    status, headers_out, raw = request_json(url, headers, body, args.timeout)
    text = ""
    try:
        payload = json.loads(raw)
        text = extract_text(payload)
        if not text:
            text = payload.get("choices", [{}])[0].get("message", {}).get("content", "")
    except Exception:
        payload = raw[:2000]
    print_result(args.provider, args.model, status, headers_out, text, payload)
    return 0 if 200 <= status < 300 else 1


def print_result(provider: str, model: str, status: int, headers: dict[str, str], text: str, debug: Any) -> None:
    print(f"provider={provider} model={model} status={status}")
    interesting = {k: v for k, v in headers.items() if k.lower() in {"retry-after", "retry-after-ms", "x-request-id", "cf-ray", "date", "content-type"}}
    if interesting:
        print("headers=", json.dumps(interesting, indent=2))
    if 200 <= status < 300:
        print("\n--- response text ---")
        print((text or "<no text extracted>").strip())
    else:
        print("\n--- error/debug body ---")
        if isinstance(debug, str):
            print(debug[:4000])
        else:
            print(json.dumps(debug, indent=2)[:4000])


def main() -> int:
    p = argparse.ArgumentParser(description="Standalone direct LLM endpoint probe; does not use pi runtime.")
    p.add_argument("--provider", choices=["codex", "cerebras", "openai-compatible"], default="codex")
    p.add_argument("--model", default="gpt-5.5")
    p.add_argument("--prompt", default="write one short poem")
    p.add_argument("--base-url", help="Override provider base URL")
    p.add_argument("--api-key-env", help="API key env var for openai-compatible/cerebras")
    p.add_argument("--max-tokens", type=int, default=1024)
    p.add_argument("--timeout", type=int, default=120)
    p.add_argument("--originator", default="standalone-probe", help="Codex originator header")
    p.add_argument("--reasoning-effort", choices=["minimal", "low", "medium", "high"], help="Optional Codex reasoning effort")
    p.add_argument("--repeat", type=int, default=1)
    args = p.parse_args()

    rc = 0
    for i in range(args.repeat):
        if args.repeat > 1:
            print(f"\n=== attempt {i + 1}/{args.repeat} ===")
        if args.provider == "codex":
            rc |= run_codex(args)
        elif args.provider == "cerebras":
            rc |= run_openai_compatible(args, "https://api.cerebras.ai/v1", "CEREBRAS_API_KEY")
        else:
            if not args.base_url:
                raise SystemExit("--base-url is required for openai-compatible")
            rc |= run_openai_compatible(args, args.base_url, "OPENAI_API_KEY")
        if i + 1 < args.repeat:
            time.sleep(1)
    return rc


if __name__ == "__main__":
    raise SystemExit(main())
