"""Tests for the SPEC §6.11.2 open-region resume rule.

Five scenarios:
1. No open region → sequential resume proceeds via §6.4 (regression guard).
2. Crashed mid-branch → resume re-enters the in-flight node; join AND-folds.
3. All branches done, join not written (§6.11.2 case e) → engine writes join, COMPLETES.
4. Branch hasn't started (crash right after fanout) → dispatched from first node.
5. Mixed: one complete, one in-flight (multi-node branch), one not-started.

Tests 2-5 hand-build journal entries; no live execution crash needed.
"""

from __future__ import annotations

import subprocess
import uuid
from datetime import UTC, datetime
from pathlib import Path

import pygit2
import pytest
from pydantic import TypeAdapter

from attractor.checkpoint import ATTRACTOR_REF_PREFIX, Author, BranchStore, CheckpointTrailers
from attractor.engine import Engine, OutcomeStatus, RunStatus
from attractor.engine.engine import (
    _JOURNAL_ADAPTER,  # pyright: ignore[reportPrivateUsage]
    _find_open_region,  # pyright: ignore[reportPrivateUsage]
    _serialize_workflow,  # pyright: ignore[reportPrivateUsage]
    _state_ref,  # pyright: ignore[reportPrivateUsage]
)
from attractor.engine.journal import (
    JournalEntry,
    NodeCompleted,
    RunInitialized,
    event_path,
)
from attractor.workflow import parse, validate

# ──────────────────────────────────────────────────────────────────────
# Workflow fixtures
# ──────────────────────────────────────────────────────────────────────

# Three-branch parallel workflow, each branch is a single tool node.
_PARALLEL_WORKFLOW = """\
digraph ParallelResume {
    graph [ default_max_visits = 3 ]
    start [shape=Mdiamond]
    exit  [shape=Msquare]
    fan   [shape=component]
    join  [shape=tripleoctagon]
    tests [shape=parallelogram, script="echo tests-ok"]
    lint  [shape=parallelogram, script="echo lint-ok"]
    types [shape=parallelogram, script="echo types-ok"]
    start -> fan
    fan   -> tests [label="tests"]
    fan   -> lint  [label="lint"]
    fan   -> types [label="types"]
    tests -> join
    lint  -> join
    types -> join
    join  -> exit
}
"""

# Sequential workflow for regression test.
_SEQ_WORKFLOW = """\
digraph Seq {
    graph [ default_max_visits = 3 ]
    start [shape=Mdiamond]
    exit  [shape=Msquare]
    a     [shape=parallelogram, script="echo a"]
    b     [shape=parallelogram, script="echo b"]
    start -> a -> b -> exit
}
"""

# Parallel workflow where the "lint" branch has TWO nodes (lint1 → lint2 → join).
# Used in test 5 so we can have a genuinely in-flight branch (lint1 done,
# lint2 not yet started).
_PARALLEL_MULTINODE = """\
digraph ParallelMulti {
    graph [ default_max_visits = 3 ]
    start [shape=Mdiamond]
    exit  [shape=Msquare]
    fan   [shape=component]
    join  [shape=tripleoctagon]
    tests [shape=parallelogram, script="echo tests-ok"]
    lint1 [shape=parallelogram, script="echo lint1-ok"]
    lint2 [shape=parallelogram, script="echo lint2-ok"]
    types [shape=parallelogram, script="echo types-ok"]
    start -> fan
    fan   -> tests [label="tests"]
    fan   -> lint1 [label="lint"]
    fan   -> types [label="types"]
    tests -> join
    lint1 -> lint2
    lint2 -> join
    types -> join
    join  -> exit
}
"""


# ──────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────

def _ts() -> datetime:
    return datetime.now(UTC)


def _write_entries(
    seeded_repo: Path,
    run_id: str,
    entries: list[JournalEntry],
) -> None:
    """Write a list of pre-built journal entries onto the run's state branch."""
    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    state_ref = _state_ref(run_id)
    for entry in entries:
        body = _JOURNAL_ADAPTER.dump_json(entry, indent=2)
        store.write_entry(
            ref=state_ref,
            path=event_path(entry.seq),
            content=body,
            message=f"checkpoint: {entry.kind} seq={entry.seq:06d}",
            trailers=CheckpointTrailers(run_id=run_id),
        )


def _read_entries(seeded_repo: Path, run_id: str) -> list[JournalEntry]:
    """Read all journal entries for a run, sorted by seq."""
    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    adapter: TypeAdapter[JournalEntry] = TypeAdapter(JournalEntry)
    ref = _state_ref(run_id)
    paths = store.list_entries(ref, prefix="events/")
    entries: list[JournalEntry] = []
    for p in paths:
        raw = store.read_entry(ref, p)
        if raw:
            entries.append(adapter.validate_json(raw))
    entries.sort(key=lambda e: e.seq)
    return entries


def _make_main_worktree(seeded_repo: Path, run_id: str) -> Path:
    """Create the main run worktree for run_id branched from main."""
    wt_dir = seeded_repo / ".attractor" / "worktrees" / run_id
    full_ref = f"{ATTRACTOR_REF_PREFIX}run/{run_id}/worktree"
    short_branch = full_ref[len("refs/heads/"):]
    subprocess.run(
        ["git", "-C", str(seeded_repo), "worktree", "add",
         "-b", short_branch, str(wt_dir), "main"],
        check=True, capture_output=True, text=True,
    )
    return wt_dir


def _make_branch_worktree(
    seeded_repo: Path,
    run_id: str,
    branch_name: str,
    fork_oid: str,
) -> Path:
    """Create a per-branch worktree for the given branch name."""
    bw_dir = seeded_repo / ".attractor" / "worktrees" / run_id / "branch" / branch_name
    bw_dir.parent.mkdir(parents=True, exist_ok=True)
    branch_ref = f"{ATTRACTOR_REF_PREFIX}run/{run_id}/branch/{branch_name}"
    short = branch_ref[len("refs/heads/"):]
    subprocess.run(
        ["git", "-C", str(seeded_repo), "worktree", "add",
         "-b", short, str(bw_dir), fork_oid],
        check=True, capture_output=True, text=True,
    )
    return bw_dir


def _head_oid(path: Path) -> str:
    """Return the HEAD OID of a worktree directory."""
    return subprocess.run(
        ["git", "-C", str(path), "rev-parse", "HEAD"],
        capture_output=True, text=True, check=True,
    ).stdout.strip()


# ──────────────────────────────────────────────────────────────────────
# Test 1 — no open region → sequential resume unchanged
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-RESUME]
@pytest.mark.asyncio
async def test_sequential_resume_unaffected(seeded_repo: Path) -> None:
    """When the journal has no parallel region, _find_open_region returns None
    and resume proceeds via §6.4 (existing sequential state machine)."""
    graph = validate(parse(_SEQ_WORKFLOW))
    engine = Engine(seeded_repo)
    status = await engine.run(graph)
    assert status == RunStatus.COMPLETED

    run_ids = engine.list_all_run_ids()
    assert len(run_ids) == 1
    run_id = run_ids[0]

    entries = _read_entries(seeded_repo, run_id)
    assert _find_open_region(graph, entries) is None

    # Resume of a RunFinalized run returns COMPLETED immediately.
    engine2 = Engine(seeded_repo)
    status2 = await engine2.resume(run_id)
    assert status2 == RunStatus.COMPLETED


# ──────────────────────────────────────────────────────────────────────
# Test 2 — crashed mid-branch (branch not started = simplest in-flight)
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-RESUME]
# [int->REQ-DOD-RESUME-BETWEEN-NODE]
@pytest.mark.asyncio
async def test_resume_crashed_mid_branch(seeded_repo: Path) -> None:
    """Simulate: fanout completed; tests + types are terminal; lint has NO
    entries (crash right after fanout but ONLY for lint — the other branches
    somehow ran and committed before the crash).

    This exercises the pending-branch dispatch path: lint is re-dispatched
    from its first node, completes, join AND-folds, run → COMPLETED."""
    graph = validate(parse(_PARALLEL_WORKFLOW))
    run_id = str(uuid.uuid4())
    wt_dir = _make_main_worktree(seeded_repo, run_id)

    # Pre-create the lint branch worktree so _resume_open_region finds it
    # via list_branches() rather than creating a new one.
    fork_oid = _head_oid(wt_dir)
    _make_branch_worktree(seeded_repo, run_id, "lint", fork_oid)

    entries: list[JournalEntry] = [
        RunInitialized(
            seq=0, run_id=run_id, timestamp=_ts(),
            workflow_dot=_serialize_workflow(graph),
            workflow_hash="x", base_ref="HEAD",
        ),
        NodeCompleted(
            seq=1, run_id=run_id, timestamp=_ts(),
            node_id="start", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node="fan", worktree_commit_after=None,
        ),
        # fanout main-thread entry (branch_name=None)
        NodeCompleted(
            seq=2, run_id=run_id, timestamp=_ts(),
            node_id="fan", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node=None, worktree_commit_after=None, branch_name=None,
        ),
        # tests: terminal
        NodeCompleted(
            seq=3, run_id=run_id, timestamp=_ts(),
            node_id="tests", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="tests-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="tests",
        ),
        # types: terminal
        NodeCompleted(
            seq=4, run_id=run_id, timestamp=_ts(),
            node_id="types", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="types-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="types",
        ),
        # lint: no entry → pending, dispatched from first node
    ]
    _write_entries(seeded_repo, run_id, entries)

    engine = Engine(seeded_repo)
    status = await engine.resume(run_id)
    assert status in (RunStatus.COMPLETED, RunStatus.INCOMPLETE)

    # join must have been written
    all_entries = _read_entries(seeded_repo, run_id)
    join_nc = next(
        (e for e in all_entries
         if isinstance(e, NodeCompleted) and e.node_id == "join"
         and e.branch_name is None),
        None,
    )
    assert join_nc is not None, "join NodeCompleted was not written"


# ──────────────────────────────────────────────────────────────────────
# Test 3 — all branches done, join not written (§6.11.2 case e)
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-RESUME]
# [int->REQ-DOD-RESUME-BETWEEN-NODE]
@pytest.mark.asyncio
async def test_resume_all_branches_done_join_not_written(seeded_repo: Path) -> None:
    """All three branches have terminal NodeCompleted entries; join absent.
    The fast-path (§6.11.2 case e) must write join + finalize COMPLETED."""
    graph = validate(parse(_PARALLEL_WORKFLOW))
    run_id = str(uuid.uuid4())
    _make_main_worktree(seeded_repo, run_id)

    entries: list[JournalEntry] = [
        RunInitialized(
            seq=0, run_id=run_id, timestamp=_ts(),
            workflow_dot=_serialize_workflow(graph),
            workflow_hash="x", base_ref="HEAD",
        ),
        NodeCompleted(
            seq=1, run_id=run_id, timestamp=_ts(),
            node_id="start", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node="fan", worktree_commit_after=None,
        ),
        NodeCompleted(
            seq=2, run_id=run_id, timestamp=_ts(),
            node_id="fan", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node=None, worktree_commit_after=None, branch_name=None,
        ),
        NodeCompleted(
            seq=3, run_id=run_id, timestamp=_ts(),
            node_id="tests", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="tests-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="tests",
        ),
        NodeCompleted(
            seq=4, run_id=run_id, timestamp=_ts(),
            node_id="lint", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="lint-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="lint",
        ),
        NodeCompleted(
            seq=5, run_id=run_id, timestamp=_ts(),
            node_id="types", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="types-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="types",
        ),
        # join NodeCompleted intentionally absent → open region
    ]
    _write_entries(seeded_repo, run_id, entries)

    engine = Engine(seeded_repo)
    status = await engine.resume(run_id)
    assert status == RunStatus.COMPLETED, f"expected COMPLETED, got {status}"

    all_entries = _read_entries(seeded_repo, run_id)
    join_nc = next(
        (e for e in all_entries
         if isinstance(e, NodeCompleted) and e.node_id == "join"
         and e.branch_name is None),
        None,
    )
    assert join_nc is not None, "join NodeCompleted not written"
    assert join_nc.status == OutcomeStatus.SUCCESS
    for bname in ("tests", "lint", "types"):
        assert f"[{bname}]" in join_nc.captured_output, (
            f"[{bname}] missing from join output: {join_nc.captured_output!r}"
        )


# ──────────────────────────────────────────────────────────────────────
# Test 4 — branch hasn't started (crash right after fanout)
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-RESUME]
# [int->REQ-DOD-RESUME-BETWEEN-NODE]
@pytest.mark.asyncio
async def test_resume_branch_not_started(seeded_repo: Path) -> None:
    """Only fanout NodeCompleted present; no branch entries at all.
    Resume must dispatch all branches from their first nodes and complete."""
    graph = validate(parse(_PARALLEL_WORKFLOW))
    run_id = str(uuid.uuid4())
    _make_main_worktree(seeded_repo, run_id)

    entries: list[JournalEntry] = [
        RunInitialized(
            seq=0, run_id=run_id, timestamp=_ts(),
            workflow_dot=_serialize_workflow(graph),
            workflow_hash="x", base_ref="HEAD",
        ),
        NodeCompleted(
            seq=1, run_id=run_id, timestamp=_ts(),
            node_id="start", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node="fan", worktree_commit_after=None,
        ),
        NodeCompleted(
            seq=2, run_id=run_id, timestamp=_ts(),
            node_id="fan", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node=None, worktree_commit_after=None, branch_name=None,
        ),
        # No branch entries at all → all branches not-started
    ]
    _write_entries(seeded_repo, run_id, entries)

    engine = Engine(seeded_repo)
    status = await engine.resume(run_id)
    assert status == RunStatus.COMPLETED, f"expected COMPLETED, got {status}"


# ──────────────────────────────────────────────────────────────────────
# Test 5 — mixed: one complete, one in-flight, one not-started
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-RESUME]
# [int->REQ-DOD-RESUME-BETWEEN-NODE]
@pytest.mark.asyncio
async def test_resume_mixed_branches(seeded_repo: Path) -> None:
    """Uses _PARALLEL_MULTINODE where the 'lint' branch has two nodes (lint1→lint2).

    State: tests=complete (terminal NC), lint=in-flight (lint1 done, lint2 not),
    types=not-started. Resume should:
    - leave tests alone (complete)
    - re-dispatch lint from lint2 (case (a): lint1's next_node="lint2")
    - dispatch types from its first node
    - AND-fold all three → join → exit → COMPLETED.
    """
    graph = validate(parse(_PARALLEL_MULTINODE))
    run_id = str(uuid.uuid4())
    wt_dir = _make_main_worktree(seeded_repo, run_id)

    # Pre-create the "lint" branch worktree (in-flight branch needs to exist).
    fork_oid = _head_oid(wt_dir)
    lint_bw_dir = _make_branch_worktree(seeded_repo, run_id, "lint", fork_oid)

    # Make a real commit on the lint branch worktree so HEAD is a real OID.
    (lint_bw_dir / "lint1.log").write_text("lint1-ok\n", encoding="utf-8")
    subprocess.run(["git", "-C", str(lint_bw_dir), "add", "-A"],
                   check=True, capture_output=True)
    subprocess.run([
        "git", "-C", str(lint_bw_dir),
        "-c", "user.name=Eng", "-c", "user.email=eng@example.com",
        "-c", "commit.gpgsign=false",
        "commit", "-m", "checkpoint: lint1 (1) [branch:lint]",
    ], check=True, capture_output=True)
    lint1_oid = _head_oid(lint_bw_dir)

    entries: list[JournalEntry] = [
        RunInitialized(
            seq=0, run_id=run_id, timestamp=_ts(),
            workflow_dot=_serialize_workflow(graph),
            workflow_hash="x", base_ref="HEAD",
        ),
        NodeCompleted(
            seq=1, run_id=run_id, timestamp=_ts(),
            node_id="start", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node="fan", worktree_commit_after=None,
        ),
        NodeCompleted(
            seq=2, run_id=run_id, timestamp=_ts(),
            node_id="fan", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="", duration_ms=0,
            next_node=None, worktree_commit_after=None, branch_name=None,
        ),
        # tests: terminal (complete)
        NodeCompleted(
            seq=3, run_id=run_id, timestamp=_ts(),
            node_id="tests", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="tests-ok", duration_ms=10,
            next_node="join", worktree_commit_after=None, branch_name="tests",
        ),
        # lint: lint1 done, next_node="lint2" (NOT terminal) → in-flight
        # worktree_commit_after matches lint1_oid → case (a): proceed to next_node
        NodeCompleted(
            seq=4, run_id=run_id, timestamp=_ts(),
            node_id="lint1", visit=1, status=OutcomeStatus.SUCCESS,
            captured_output="lint1-ok", duration_ms=5,
            next_node="lint2", worktree_commit_after=lint1_oid, branch_name="lint",
        ),
        # types: no entries → not-started
    ]
    _write_entries(seeded_repo, run_id, entries)

    engine = Engine(seeded_repo)
    status = await engine.resume(run_id)
    assert status in (RunStatus.COMPLETED, RunStatus.INCOMPLETE), (
        f"unexpected status: {status}"
    )

    # join must have been written
    all_entries = _read_entries(seeded_repo, run_id)
    join_nc = next(
        (e for e in all_entries
         if isinstance(e, NodeCompleted) and e.node_id == "join"
         and e.branch_name is None),
        None,
    )
    assert join_nc is not None, "join NodeCompleted was not written after mixed resume"
