"""End-to-end and unit tests for parallel-region engine traversal.

Covers:
1. Three-branch happy-path run → COMPLETED.
2. One-branch-fails → join FAILURE.
3. Per-branch sub-worktree refs exist after run.
4. AND-fold combined-output prefix.
5. branch_name round-trips through the journal store.

All tests run against a real git repo (`seeded_repo` fixture from
conftest.py) using tool-node scripts compatible with pwsh and sh.
"""

from __future__ import annotations

from datetime import UTC, datetime
from pathlib import Path
from typing import cast

import pygit2
import pytest
from pydantic import TypeAdapter

from attractor.agent import AgentConfig, AgentOutcome
from attractor.checkpoint import Author, BranchStore
from attractor.engine import (
    Engine,
    EngineEvent,
    OutcomeStatus,
    RunStatus,
)
from attractor.engine.journal import (
    JournalEntry,
    NodeCompleted,
    event_path,
)
from attractor.workflow import parse, validate

# ──────────────────────────────────────────────────────────────────────
# Workflow DOT fixtures
# ──────────────────────────────────────────────────────────────────────

# Three-branch happy-path workflow (labels become branch names).
_PARALLEL_HAPPY = """\
digraph ParallelVerifiers {
    graph [ default_max_visits = 3 ]
    start  [shape=Mdiamond, label="Start"]
    exit   [shape=Msquare,  label="Exit"]
    fan    [shape=component]
    join   [shape=tripleoctagon]
    tests  [shape=parallelogram, script="echo tests-ok"]
    lint   [shape=parallelogram, script="echo lint-ok"]
    types  [shape=parallelogram, script="echo types-ok"]
    start  -> fan
    fan    -> tests [label="tests"]
    fan    -> lint  [label="lint"]
    fan    -> types [label="types"]
    tests  -> join
    lint   -> join
    types  -> join
    join   -> exit
}
"""

# One branch fails; join must emit FAILURE.
_PARALLEL_ONE_FAIL = """\
digraph ParallelOneFail {
    graph [ default_max_visits = 3 ]
    start [shape=Mdiamond]
    exit  [shape=Msquare]
    fan   [shape=component]
    join  [shape=tripleoctagon]
    ok1   [shape=parallelogram, script="echo ok1"]
    ok2   [shape=parallelogram, script="echo ok2"]
    bad   [shape=parallelogram, script="exit 1"]
    start -> fan
    fan   -> ok1 [label="ok1"]
    fan   -> ok2 [label="ok2"]
    fan   -> bad [label="bad"]
    ok1   -> join
    ok2   -> join
    bad   -> join
    join  -> exit  [label="SUCCESS"]
    join  -> exit  [label="FAILURE"]
}
"""

# Two-branch workflow whose join routes into a merge agent.
_PARALLEL_JOIN_TO_AGENT = """\
digraph ParallelJoinToAgent {
    graph [ default_max_visits = 3 ]
    start [shape=Mdiamond]
    exit  [shape=Msquare]
    fan   [shape=component]
    join  [shape=tripleoctagon]
    left  [shape=parallelogram, script="echo left-ok"]
    right [shape=parallelogram, script="echo right-ok"]
    merge [shape=box, prompt="Merge predecessor branches."]
    start -> fan
    fan   -> left  [label="left"]
    fan   -> right [label="right"]
    left  -> join
    right -> join
    join  -> merge [label="SUCCESS"]
    merge -> exit
}
"""

# Two-branch workflow with file writes to prove per-branch worktree commits.
_PARALLEL_WRITES = """\
digraph ParallelWrites {
    graph [ default_max_visits = 3 ]
    start  [shape=Mdiamond]
    exit   [shape=Msquare]
    fan    [shape=component]
    join   [shape=tripleoctagon]
    alpha  [shape=parallelogram, script="echo alpha > alpha.txt"]
    beta   [shape=parallelogram, script="echo beta > beta.txt"]
    start  -> fan
    fan    -> alpha [label="alpha"]
    fan    -> beta  [label="beta"]
    alpha  -> join
    beta   -> join
    join   -> exit
}
"""


# ──────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────

def _ts() -> datetime:
    return datetime(2026, 1, 1, 12, 0, tzinfo=UTC)


_ADAPTER: TypeAdapter[JournalEntry] = TypeAdapter(JournalEntry)


# ──────────────────────────────────────────────────────────────────────
# Test 1 — three-branch happy path
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-FANOUT]
# [int->REQ-DOD-WORKFLOW-RUN]
@pytest.mark.asyncio
async def test_three_branch_happy_path(seeded_repo: Path) -> None:
    """All three branches succeed → run COMPLETED; each branch NodeCompleted
    carries the correct branch_name; join success=True."""
    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_HAPPY))

    events: list[EngineEvent] = []
    status = await engine.run(graph, events=events.append)

    assert status == RunStatus.COMPLETED, f"expected COMPLETED, got {status}"

    # Collect every NodeCompleted by reading the journal.
    run_ids = engine.list_all_run_ids()
    assert len(run_ids) == 1
    run_id = run_ids[0]

    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    ref = f"refs/heads/attractor/run/{run_id}/state"
    paths = store.list_entries(ref, prefix="events/")
    entries: list[JournalEntry] = []
    for p in paths:
        raw = store.read_entry(ref, p)
        if raw:
            entries.append(_ADAPTER.validate_json(raw))
    entries.sort(key=lambda e: e.seq)

    node_completeds = [e for e in entries if isinstance(e, NodeCompleted)]

    # Branch-internal completions must carry their branch names.
    branch_entries: dict[str, list[NodeCompleted]] = {}
    for nc in node_completeds:
        if nc.branch_name is not None:
            branch_entries.setdefault(nc.branch_name, []).append(nc)

    assert set(branch_entries.keys()) == {"tests", "lint", "types"}, (
        f"expected branch names {{tests, lint, types}}, got {set(branch_entries.keys())}"
    )

    # Join's NodeCompleted must have branch_name=None and success=True.
    join_nc = next(
        (nc for nc in node_completeds if nc.node_id == "join" and nc.branch_name is None),
        None,
    )
    assert join_nc is not None, "join NodeCompleted not found"
    assert join_nc.status == OutcomeStatus.SUCCESS


# ──────────────────────────────────────────────────────────────────────
# Test 2 — one branch fails → join FAILURE
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-JOIN]
@pytest.mark.asyncio
async def test_one_branch_fails_join_failure(seeded_repo: Path) -> None:
    """When one branch exits non-zero, join's NodeCompleted has success=False
    and the captured_output includes the failing branch name."""
    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_ONE_FAIL))

    # Both SUCCESS and FAILURE edges from join → exit, so the run
    # reaches EXIT regardless of join outcome. No goal_gate → COMPLETED.
    await engine.run(graph)

    run_ids = engine.list_all_run_ids()
    assert len(run_ids) == 1
    run_id = run_ids[0]

    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    ref = f"refs/heads/attractor/run/{run_id}/state"
    paths = store.list_entries(ref, prefix="events/")
    entries: list[JournalEntry] = []
    for p in paths:
        raw = store.read_entry(ref, p)
        if raw:
            entries.append(_ADAPTER.validate_json(raw))
    entries.sort(key=lambda e: e.seq)

    join_nc = next(
        (
            e for e in entries
            if isinstance(e, NodeCompleted)
            and e.node_id == "join"
            and e.branch_name is None
        ),
        None,
    )
    assert join_nc is not None, "join NodeCompleted not found"
    assert join_nc.status == OutcomeStatus.FAILURE, "expected join failure"
    # Combined output must mention the failing branch.
    assert "bad" in join_nc.captured_output, (
        f"expected 'bad' in combined output, got: {join_nc.captured_output!r}"
    )


# ──────────────────────────────────────────────────────────────────────
# Test 3 — join successor agent receives predecessor_branches context
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-WORKTREE]
# [unit->REQ-EXEC-REVISIT-CONTEXT]
@pytest.mark.asyncio
async def test_join_successor_agent_receives_predecessor_branches_context(
    seeded_repo: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    """A merge agent downstream of a join receives branch refs + outputs."""
    captured_configs: list[AgentConfig] = []

    async def fake_run_agent_node(config: AgentConfig) -> AgentOutcome:
        captured_configs.append(config)
        return AgentOutcome(
            status=OutcomeStatus.SUCCESS,
            captured_output="merged",
        )

    monkeypatch.setattr(
        "attractor.engine.engine.run_agent_node",
        fake_run_agent_node,
    )

    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_JOIN_TO_AGENT))

    status = await engine.run(graph)

    assert status == RunStatus.COMPLETED
    assert len(captured_configs) == 1
    ctx = captured_configs[0].revisit_context
    assert ctx is not None
    assert ctx["predecessor_node"] == "join"
    assert ctx["entry_edge_label"] == "SUCCESS"
    assert "[left]" in str(ctx["captured_output"])
    assert "[right]" in str(ctx["captured_output"])

    raw_branches = ctx["predecessor_branches"]
    assert isinstance(raw_branches, list)
    branches = cast("list[dict[str, str | bool]]", raw_branches)
    assert branches[0]["name"] == "left"
    assert branches[0]["success"] is True
    assert "left-ok" in str(branches[0]["captured_output"])
    assert branches[1]["name"] == "right"
    assert branches[1]["success"] is True
    assert "right-ok" in str(branches[1]["captured_output"])
    assert str(branches[0]["ref"]).startswith("refs/heads/attractor/run/")
    assert str(branches[0]["ref"]).endswith("/branch/left")
    assert str(branches[1]["ref"]).startswith("refs/heads/attractor/run/")
    assert str(branches[1]["ref"]).endswith("/branch/right")


# [unit->REQ-EXEC-PARALLEL-WORKTREE]
# [unit->REQ-EXEC-PARALLEL-RESUME]
@pytest.mark.asyncio
async def test_resume_after_join_reconstructs_predecessor_branches_context(
    seeded_repo: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    """Idle-between-nodes resume after a join rebuilds merge-agent context."""
    captured_configs: list[AgentConfig] = []
    calls = 0

    async def fake_run_agent_node(config: AgentConfig) -> AgentOutcome:
        nonlocal calls
        calls += 1
        captured_configs.append(config)
        if calls == 1:
            raise RuntimeError("simulated crash before merge checkpoint")
        return AgentOutcome(
            status=OutcomeStatus.SUCCESS,
            captured_output="merged after resume",
        )

    monkeypatch.setattr(
        "attractor.engine.engine.run_agent_node",
        fake_run_agent_node,
    )

    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_JOIN_TO_AGENT))

    with pytest.raises(RuntimeError, match="simulated crash"):
        await engine.run(graph)

    run_ids = engine.list_all_run_ids()
    assert len(run_ids) == 1

    status = await engine.resume(run_ids[0])

    assert status == RunStatus.COMPLETED
    assert len(captured_configs) == 2
    resumed_ctx = captured_configs[1].revisit_context
    assert resumed_ctx is not None
    raw_branches = resumed_ctx["predecessor_branches"]
    assert isinstance(raw_branches, list)
    branches = cast("list[dict[str, str | bool]]", raw_branches)
    assert [branch["name"] for branch in branches] == ["left", "right"]
    assert str(branches[0]["ref"]).endswith("/branch/left")
    assert str(branches[1]["ref"]).endswith("/branch/right")


# ──────────────────────────────────────────────────────────────────────
# Test 4 — per-branch sub-worktree refs exist after run
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-FANOUT]
@pytest.mark.asyncio
async def test_per_branch_sub_worktree_refs(seeded_repo: Path) -> None:
    """After a successful run each branch ref persists in the repo and
    resolves to a real commit; the main worktree ref is unaffected."""
    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_WRITES))

    status = await engine.run(graph)
    assert status == RunStatus.COMPLETED

    run_ids = engine.list_all_run_ids()
    assert len(run_ids) == 1
    run_id = run_ids[0]

    repo = pygit2.Repository(str(seeded_repo))

    # Branch sub-worktree refs must exist and point to a real OID.
    all_refs = list(repo.references)
    for branch_name in ("alpha", "beta"):
        ref_name = f"refs/heads/attractor/run/{run_id}/branch/{branch_name}"
        assert ref_name in all_refs, (
            f"branch ref {ref_name!r} not found in repo; all refs: {all_refs}"
        )
        ref_obj = repo.lookup_reference(ref_name)
        # `resolve()` follows symbolic refs; `.target` is the OID object.
        resolved = ref_obj.resolve()
        assert resolved.target is not None, (
            f"branch ref {ref_name} target is None"
        )

    # Main worktree ref also exists.
    wt_ref = f"refs/heads/attractor/run/{run_id}/worktree"
    assert repo.lookup_reference(wt_ref) is not None


# ──────────────────────────────────────────────────────────────────────
# Test 4 — AND-fold output has per-branch prefix
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-JOIN]
@pytest.mark.asyncio
async def test_and_fold_output_prefix(seeded_repo: Path) -> None:
    """The join's captured_output contains '[<branch>] ...' prefix for
    each branch."""
    engine = Engine(seeded_repo)
    graph = validate(parse(_PARALLEL_HAPPY))

    await engine.run(graph)

    run_ids = engine.list_all_run_ids()
    run_id = run_ids[0]

    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    ref = f"refs/heads/attractor/run/{run_id}/state"
    paths = store.list_entries(ref, prefix="events/")
    entries: list[JournalEntry] = []
    for p in paths:
        raw = store.read_entry(ref, p)
        if raw:
            entries.append(_ADAPTER.validate_json(raw))
    entries.sort(key=lambda e: e.seq)

    join_nc = next(
        (
            e for e in entries
            if isinstance(e, NodeCompleted) and e.node_id == "join"
        ),
        None,
    )
    assert join_nc is not None
    out = join_nc.captured_output
    # Each branch's output must appear with its prefix.
    for branch in ("tests", "lint", "types"):
        assert f"[{branch}]" in out, (
            f"prefix '[{branch}]' not found in join output: {out!r}"
        )


# ──────────────────────────────────────────────────────────────────────
# Test 5 — branch_name round-trips through BranchStore
# ──────────────────────────────────────────────────────────────────────

# [unit->REQ-EXEC-PARALLEL-FANOUT]
def test_branch_name_round_trips_journal(seeded_repo: Path) -> None:
    """NodeCompleted with branch_name='X' survives BranchStore write + read."""
    repo = pygit2.Repository(str(seeded_repo))
    store = BranchStore(repo, Author())
    run_id = "00000000-0000-4000-8000-000000000001"
    ref = f"refs/heads/attractor/run/{run_id}/state"

    entry = NodeCompleted(
        seq=0,
        run_id=run_id,
        timestamp=_ts(),
        node_id="lint",
        visit=1,
        status=OutcomeStatus.SUCCESS,
        captured_output="lint-ok",
        duration_ms=42,
        branch_name="lint",
    )
    body = _ADAPTER.dump_json(entry, indent=2)
    store.write_entry(
        ref=ref,
        path=event_path(0),
        content=body,
        message="test: branch_name round-trip",
        trailers=None,
    )

    raw = store.read_entry(ref, event_path(0))
    assert raw is not None
    rebuilt = _ADAPTER.validate_json(raw)
    assert isinstance(rebuilt, NodeCompleted)
    assert rebuilt.branch_name == "lint"
    assert rebuilt.node_id == "lint"
    assert rebuilt.status == OutcomeStatus.SUCCESS
