from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from artifact_tool_v2 import BoundingBoxRect, Slide, SlideElement

OverlapUnit = Literal["normalized_by_slide_dimensions", "absolute"]


@dataclass(frozen=True)
class OverlapInfo:
    element_a_index: int
    element_b_index: int
    intersection: BoundingBoxRect
    normalized_intersection: BoundingBoxRect | None
    iou: float
    severity: Literal["warning", "error"]
    message: str


def _element_bbox(element: SlideElement) -> BoundingBoxRect | None:
    position = element.position
    if position is None:
        return None
    bbox: BoundingBoxRect = {
        "left": position.left,
        "top": position.top,
        "width": position.width,
        "height": position.height,
    }
    left = bbox.get("left")
    top = bbox.get("top")
    width = bbox.get("width")
    height = bbox.get("height")
    if left is None or top is None or width is None or height is None:
        return None
    if width <= 0 or height <= 0:
        return None
    return bbox


def _intersection(first: BoundingBoxRect, second: BoundingBoxRect) -> BoundingBoxRect | None:
    left = max(first.get("left"), second.get("left"))
    top = max(first.get("top"), second.get("top"))
    right = min(
        first.get("left") + first.get("width"),
        second.get("left") + second.get("width"),
    )
    bottom = min(
        first.get("top") + first.get("height"),
        second.get("top") + second.get("height"),
    )
    if right <= left or bottom <= top:
        return None
    return BoundingBoxRect(left=left, top=top, width=right - left, height=bottom - top)


def _area(rect: BoundingBoxRect) -> float:
    width = rect.get("width")
    height = rect.get("height")
    if width is None or height is None:
        return 0.0
    return width * height


def _normalize_intersection(intersection: BoundingBoxRect, slide: Slide) -> BoundingBoxRect | None:
    frame = slide.frame
    if frame is None or frame.width is None or frame.height is None:
        return None
    if frame.width <= 0 or frame.height <= 0:
        return None
    return BoundingBoxRect(
        left=intersection.get("left") / frame.width,
        top=intersection.get("top") / frame.height,
        width=intersection.get("width") / frame.width,
        height=intersection.get("height") / frame.height,
    )


def _normalize_point(x: float, y: float, slide: Slide) -> tuple[float, float] | None:
    frame = slide.frame
    if frame is None or frame.width is None or frame.height is None:
        return None
    if frame.width <= 0 or frame.height <= 0:
        return None
    return (x / frame.width, y / frame.height)


def _format_rect(rect: BoundingBoxRect | None) -> str:
    if rect is None:
        return "unavailable"
    return "left={left:.3f}, top={top:.3f}, width={width:.3f}, height={height:.3f}".format(
        left=rect.get("left"),
        top=rect.get("top"),
        width=rect.get("width"),
        height=rect.get("height"),
    )


def _format_center(point: tuple[float, float] | None) -> str:
    if point is None:
        return "center_x=unavailable, center_y=unavailable"
    return "center_x={x:.3f}, center_y={y:.3f}".format(x=point[0], y=point[1])


def warn_about_overlaps(
    slide: Slide,
    *,
    slide_number: int | None = None,
    emit: bool = True,
    overlap_unit: OverlapUnit = "normalized_by_slide_dimensions",
) -> list[OverlapInfo]:
    """Print overlap warnings for slide elements and return overlap details."""
    elements = list(slide.elements.shapes.items)
    bboxes = [_element_bbox(element) for element in elements]
    slide_label = f"Slide {slide_number}" if slide_number is not None else f"Slide {slide.id}"
    overlaps: list[OverlapInfo] = []

    def _get_element_identifier(element: SlideElement, center: tuple[float, float] | None) -> str:
        raw_name = element.name or f"element {element.id}"
        base = raw_name
        details: list[str] = []
        if ":" in raw_name:
            base, rest = raw_name.split(":", 1)
            base = base.strip() or raw_name
            rest = rest.strip()
            if rest:
                details.append(rest)
        details.append(f"id={element.id}")
        details.append(str(element.type))
        details.append(_format_center(center))
        return f"{base} ({', '.join(details)})"

    for i, element_a in enumerate(elements):
        bbox_a = bboxes[i]
        if bbox_a is None:
            continue
        for j in range(i + 1, len(elements)):
            element_b = elements[j]
            bbox_b = bboxes[j]
            if bbox_b is None:
                continue
            intersection = _intersection(bbox_a, bbox_b)
            if intersection is None:
                continue
            area_a = _area(bbox_a)
            area_b = _area(bbox_b)
            if area_a <= 0 or area_b <= 0:
                continue
            intersection_area = _area(intersection)
            min_area = min(area_a, area_b)
            if min_area <= 0:
                continue
            overlap_ratio = intersection_area / min_area
            if overlap_ratio <= 0 or overlap_ratio >= 1:
                continue
            union_area = area_a + area_b - intersection_area
            if union_area <= 0:
                continue
            iou = intersection_area / union_area
            if iou <= 0 or iou >= 1:
                continue
            normalized_intersection = _normalize_intersection(intersection, slide)
            center_a = (
                bbox_a.get("left") + bbox_a.get("width") / 2,
                bbox_a.get("top") + bbox_a.get("height") / 2,
            )
            center_b = (
                bbox_b.get("left") + bbox_b.get("width") / 2,
                bbox_b.get("top") + bbox_b.get("height") / 2,
            )
            if overlap_unit == "normalized_by_slide_dimensions":
                center_a = _normalize_point(center_a[0], center_a[1], slide)
                center_b = _normalize_point(center_b[0], center_b[1], slide)
            else:
                center_a = (center_a[0], center_a[1])
                center_b = (center_b[0], center_b[1])
            severity: Literal["warning", "error"] = "warning"
            overlap_rect = (
                normalized_intersection
                if overlap_unit == "normalized_by_slide_dimensions"
                else intersection
            )
            message = (
                f"{'⚠️'} {slide_label}: Overlap detected between "
                f"{_get_element_identifier(element_a, center_a)} and "
                f"{_get_element_identifier(element_b, center_b)} "
                f"(iou={iou:.3f}, overlap={{{_format_rect(overlap_rect)}}})."
            )
            overlap_info = OverlapInfo(
                element_a_index=i,
                element_b_index=j,
                intersection=intersection,
                normalized_intersection=normalized_intersection,
                iou=iou,
                severity=severity,
                message=message,
            )
            overlaps.append(overlap_info)
            if emit:
                print(message)
    return overlaps
