# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Evaluation identity and eval-hash computation.

This module provides:

* ``ChildEvalRule`` — per-child configuration for eval-hash filtering.
* ``_build_eval_dict`` — builds a filtered dict for eval-hash computation.
* ``compute_eval_hash`` — free function that computes a behavioral equivalence
  hash from a ``ComponentIdentifier``.
* ``EvaluationIdentifier`` — abstract base that wraps a ``ComponentIdentifier``
  with domain-specific eval-hash configuration.  Concrete subclasses declare
  per-child rules via a single ``CHILD_EVAL_RULES`` ClassVar.
* ``ScorerEvaluationIdentifier`` — scorer-domain concrete subclass.
* ``AtomicAttackEvaluationIdentifier`` — attack-domain concrete subclass.
"""

from __future__ import annotations

from abc import ABC
from dataclasses import dataclass, field
from typing import Any, ClassVar, Optional

from pyrit.identifiers.component_identifier import ComponentIdentifier, config_hash


@dataclass(frozen=True)
class ChildEvalRule:
    """
    Per-child configuration for eval-hash computation.

    Controls how a specific named child is treated when building the
    evaluation hash:

    * ``exclude`` — if ``True``, drop this child entirely from the hash.
    * ``included_params`` — if set, only include these param keys for this
      child (and its recursive descendants). ``None`` means all params.
    * ``included_item_values`` — for list-valued children, only include items
      whose ``params`` match **all** specified key-value pairs. ``None``
      means include all items.
    * ``param_fallbacks`` — maps a primary param key to a fallback key.
      When the primary key's value is falsy (empty string, ``None``, or
      missing), the fallback key's value from the component's raw params
      is used instead. This keeps fallback logic in the eval layer without
      changing full component hashes.  ``None`` means no fallbacks.
    """

    exclude: bool = False
    included_params: Optional[frozenset[str]] = None
    included_item_values: Optional[dict[str, Any]] = field(default=None)
    param_fallbacks: Optional[dict[str, str]] = field(default=None)


def _build_eval_dict(
    identifier: ComponentIdentifier,
    *,
    child_eval_rules: dict[str, ChildEvalRule],
    _included_params: Optional[frozenset[str]] = None,
    _param_fallbacks: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
    """
    Build a filtered dictionary for eval-hash computation.

    Walks the ``ComponentIdentifier`` tree and applies per-child rules from
    ``child_eval_rules``.  Children not listed in the rules receive full
    recursive treatment (no filtering).

    Args:
        identifier (ComponentIdentifier): The component identity to process.
        child_eval_rules (dict[str, ChildEvalRule]): Per-child eval rules.
            Keys are child names; values describe how each child is filtered.
        _included_params (Optional[frozenset[str]]): Internal. If set, only
            include params whose keys are in this frozenset. Passed down from
            a parent rule's ``included_params``.
        _param_fallbacks (Optional[dict[str, str]]): Internal. Maps a primary
            param key to a fallback key. When the primary value is falsy,
            the fallback key's value from raw params is used instead.
            Passed down from a parent rule's ``param_fallbacks``.

    Returns:
        dict[str, Any]: The filtered dictionary suitable for hashing.
    """
    eval_dict: dict[str, Any] = {
        ComponentIdentifier.KEY_CLASS_NAME: identifier.class_name,
        ComponentIdentifier.KEY_CLASS_MODULE: identifier.class_module,
    }

    eval_dict.update(
        {
            key: value
            for key, value in sorted(identifier.params.items())
            if value is not None and (_included_params is None or key in _included_params)
        }
    )

    # Apply fallbacks: when a primary param is missing or empty string,
    # substitute with the fallback key's value from the raw params.
    if _param_fallbacks:
        for primary_key, fallback_key in _param_fallbacks.items():
            primary_value = eval_dict.get(primary_key)
            if primary_value is None or primary_value == "":
                fallback_value = identifier.params.get(fallback_key)
                if fallback_value is not None and fallback_value != "":
                    eval_dict[primary_key] = fallback_value

    if identifier.children:
        eval_children: dict[str, Any] = {}
        for name in sorted(identifier.children):
            rule = child_eval_rules.get(name)

            if rule and rule.exclude:
                continue

            child_list = identifier.get_child_list(name)

            # Filter list items by param-value match (e.g., only is_general_technique=True seeds)
            if rule and rule.included_item_values:
                required = rule.included_item_values
                child_list = [c for c in child_list if all(c.params.get(k) == v for k, v in required.items())]

            # For children with a rule, apply included_params and param_fallbacks;
            # otherwise None → all params kept, no fallbacks.
            child_included_params = rule.included_params if rule else None
            child_param_fallbacks = rule.param_fallbacks if rule else None
            hashes = [
                config_hash(
                    _build_eval_dict(
                        c,
                        child_eval_rules=child_eval_rules,
                        _included_params=child_included_params,
                        _param_fallbacks=child_param_fallbacks,
                    )
                )
                for c in child_list
            ]
            eval_children[name] = hashes[0] if len(hashes) == 1 else hashes
        if eval_children:
            eval_dict["children"] = eval_children

    return eval_dict


def compute_eval_hash(
    identifier: ComponentIdentifier,
    *,
    child_eval_rules: dict[str, ChildEvalRule],
) -> str:
    """
    Compute a behavioral equivalence hash for evaluation grouping.

    Unlike ``ComponentIdentifier.hash`` (which includes all params of self and
    children), the eval hash applies per-child rules to strip operational params
    (like endpoint, max_requests_per_minute), exclude children entirely, or
    filter list items.  This ensures the same logical configuration on different
    deployments produces the same eval hash.

    Children not listed in ``child_eval_rules`` receive full recursive treatment.

    When ``child_eval_rules`` is empty, no filtering occurs and the result
    equals ``identifier.hash``.

    Args:
        identifier (ComponentIdentifier): The component identity to compute
            the hash for.
        child_eval_rules (dict[str, ChildEvalRule]): Per-child eval rules.

    Returns:
        str: A hex-encoded SHA256 hash suitable for eval registry keying.

    Raises:
        RuntimeError: If the identifier's hash is None and child_eval_rules is empty.
    """
    if not child_eval_rules:
        if identifier.hash is None:
            raise RuntimeError("hash should be set by __post_init__")
        return identifier.hash

    eval_dict = _build_eval_dict(
        identifier,
        child_eval_rules=child_eval_rules,
    )
    return config_hash(eval_dict)


class EvaluationIdentifier(ABC):
    """
    Wraps a ``ComponentIdentifier`` with domain-specific eval-hash configuration.

    Subclasses set ``CHILD_EVAL_RULES`` — a mapping of child names to
    ``ChildEvalRule`` instances that control how each child is treated during
    eval-hash computation.  Children not listed receive full recursive treatment.

    The concrete ``eval_hash`` property delegates to the module-level
    ``compute_eval_hash`` free function.
    """

    CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]]

    def __init__(self, identifier: ComponentIdentifier) -> None:
        """
        Wrap a ComponentIdentifier and resolve its eval hash.

        If the identifier carries an ``eval_hash`` (preserved from a prior
        DB round-trip or set by the scorer), that value is used directly.
        Otherwise the eval hash is computed from the identifier's params
        and children using the subclass's ``CHILD_EVAL_RULES``.
        """
        self._identifier = identifier
        if identifier.eval_hash is not None:
            self._eval_hash = identifier.eval_hash
        else:
            self._eval_hash = compute_eval_hash(
                identifier,
                child_eval_rules=self.CHILD_EVAL_RULES,
            )

    @property
    def identifier(self) -> ComponentIdentifier:
        """The underlying component identity."""
        return self._identifier

    @property
    def eval_hash(self) -> str:
        """Behavioral equivalence hash for evaluation grouping."""
        return self._eval_hash


class ScorerEvaluationIdentifier(EvaluationIdentifier):
    """
    Evaluation identity for scorers.

    The ``prompt_target`` child is filtered to behavioral params only
    (``underlying_model_name``, ``temperature``, ``top_p``), so the same scorer
    configuration on different deployments produces the same eval hash.
    """

    CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {
        "prompt_target": ChildEvalRule(
            included_params=frozenset({"underlying_model_name", "temperature", "top_p"}),
            param_fallbacks={"underlying_model_name": "model_name"},
        ),
    }


class AtomicAttackEvaluationIdentifier(EvaluationIdentifier):
    """
    Evaluation identity for atomic attacks.

    Per-child rules:

    * ``seed_identifiers`` — excluded entirely (present for traceability only).
    * ``attack_technique`` — not listed, so fully included by default.
      Its nested children (``objective_target``, ``adversarial_chat``,
      ``objective_scorer``, ``technique_seeds``) are processed recursively
      using the same rules dict, so the rules below apply at any depth.
    * ``objective_target`` — include only ``temperature``.
    * ``adversarial_chat`` — include ``underlying_model_name``, ``temperature``, ``top_p``.
    * ``objective_scorer`` — excluded entirely.

    Non-target children (e.g., ``request_converters``, ``response_converters``,
    ``technique_seeds``) receive full recursive eval treatment.
    """

    CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {
        "objective_target": ChildEvalRule(
            included_params=frozenset({"temperature"}),
        ),
        "adversarial_chat": ChildEvalRule(
            included_params=frozenset({"underlying_model_name", "temperature", "top_p"}),
            param_fallbacks={"underlying_model_name": "model_name"},
        ),
        "objective_scorer": ChildEvalRule(exclude=True),
        "seed_identifiers": ChildEvalRule(exclude=True),
        # attack_technique: not listed in rules — fully included in eval hash.
        # technique_seeds (nested inside attack_technique): also not listed — fully included.
    }
