"""
Challenge 1: Build Agents — Insurance Claims Processing
Claims Triage Agent and Claims Decision Agent for ClaimSight Insurance.

Usage:
    python agents.py

Builds both agents with system prompts, tools, and conversation handling.
"""

import json
import os
import sys
from pathlib import Path

from dotenv import load_dotenv
from azure.ai.projects import AIProjectClient
from azure.ai.projects.models import FunctionTool, PromptAgentDefinition
from azure.identity import DefaultAzureCredential
from openai.types.responses.response_input_param import FunctionCallOutput


# Resolve repo root by finding .env in parent directories.
def _find_repo_root() -> Path:
    for parent in Path(__file__).resolve().parents:
        if (parent / ".env").exists():
            return parent
    return Path(__file__).resolve().parents[2]


REPO_ROOT = _find_repo_root()

# Load environment
env_path = REPO_ROOT / ".env"
load_dotenv(env_path)

PROJECT_CONNECTION_STRING = os.getenv("PROJECT_CONNECTION_STRING")
MODEL_DEPLOYMENT_NAME = os.getenv("MODEL_DEPLOYMENT_NAME", "gpt-5.4")
CLAIMS_DATA_PATH = Path(__file__).resolve().parent / "claims_data.json"


# =============================================================================
# Tool Function: assess_claim
# This is already implemented — agents can call this to get claim risk analysis
# =============================================================================

def assess_claim(claim_id: str) -> str:
    """
    Reads claims_data.json and checks if a claim's metrics are within acceptable thresholds.
    Returns a JSON string with the analysis.
    """
    with open(CLAIMS_DATA_PATH, "r") as f:
        data = json.load(f)

    claim = None
    for c in data["claims"]:
        if c["claim_id"] == claim_id:
            claim = c
            break

    if not claim:
        return json.dumps({"error": f"Claim '{claim_id}' not found"})

    results = {
        "claim_id": claim["claim_id"],
        "type": claim["type"],
        "claimant": claim["claimant"],
        "date_filed": claim["date_filed"],
        "status": claim["status"],
        "documents_submitted": claim["documents_submitted"],
        "flags": [],
        "all_metrics": {},
    }

    for metric, reading in claim["metrics"].items():
        value = reading["value"]
        threshold = claim["thresholds"][metric]
        in_spec = threshold["min"] <= value <= threshold["max"]

        results["all_metrics"][metric] = {
            "value": value,
            "unit": reading["unit"],
            "min": threshold["min"],
            "max": threshold["max"],
            "in_spec": in_spec,
        }

        if not in_spec:
            deviation = ""
            if value > threshold["max"]:
                pct = ((value - threshold["max"]) / threshold["max"]) * 100
                deviation = f"{pct:.1f}% above max"
            elif value < threshold["min"]:
                pct = ((threshold["min"] - value) / threshold["min"]) * 100
                deviation = f"{pct:.1f}% below min"

            results["flags"].append({
                "metric": metric,
                "value": value,
                "unit": reading["unit"],
                "threshold_min": threshold["min"],
                "threshold_max": threshold["max"],
                "deviation": deviation,
            })

    return json.dumps(results, indent=2)


# Tool definition for the agent (Foundry FunctionTool format)
ASSESS_CLAIM_TOOL = FunctionTool(
    name="assess_claim",
    description="Assess an insurance claim's metrics against acceptable thresholds. Returns flags if any metrics are outside acceptable ranges (completeness too low, fraud risk too high, etc.).",
    parameters={
        "type": "object",
        "properties": {
            "claim_id": {
                "type": "string",
                "description": "The claim ID (e.g., 'CLM-001') to assess",
            }
        },
        "required": ["claim_id"],
        "additionalProperties": False,
    },
    strict=False,
)


# =============================================================================
# Claims Triage Agent
# =============================================================================

class ClaimsTriageAgent:
    def __init__(self):
        self.agent = None
        self.client = None
        self.openai = None

    def create(self):
        """Create the claims triage agent in Foundry."""
        self.client = AIProjectClient(
            endpoint=PROJECT_CONNECTION_STRING,
            credential=DefaultAzureCredential(),
        )
        self.openai = self.client.get_openai_client()

        system_prompt = """
        You are an insurance claims triage specialist for ClaimSight Insurance.
        When asked to assess claims, use the assess_claim tool for each claim.
        For each claim, report:
        - Claim ID, type, and claimant name
        - Risk classification (normal / warning / critical)
        - Each metric that is flagged: current value, threshold violated, deviation
        - Missing documents if completeness is below threshold
        Use ⚠️ for warning and 🔴 for critical flags.
        If all metrics are within acceptable ranges, mark the claim as normal (✅).
        Be concise and structured.
        """

        self.agent = self.client.agents.create_version(
            agent_name="claims-triage-agent",
            definition=PromptAgentDefinition(
                model=MODEL_DEPLOYMENT_NAME,
                instructions=system_prompt,
                tools=[ASSESS_CLAIM_TOOL],
            ),
        )

        return self.agent

    def run(self, input_text: str) -> str:
        """Run the claims triage agent with the given input."""
        conversation = self.openai.conversations.create()

        response = self.openai.responses.create(
            input=input_text,
            conversation=conversation.id,
            extra_body={"agent_reference": {"name": self.agent.name, "type": "agent_reference"}},
        )

        # Handle function call loops
        while True:
            function_calls = [item for item in response.output if item.type == "function_call"]
            if not function_calls:
                break

            input_list = []
            for item in function_calls:
                if item.name == "assess_claim":
                    args = json.loads(item.arguments)
                    result = assess_claim(args["claim_id"])
                else:
                    result = json.dumps({"error": f"Unknown tool '{item.name}'"})

                input_list.append(
                    FunctionCallOutput(
                        type="function_call_output",
                        call_id=item.call_id,
                        output=result,
                    )
                )

            response = self.openai.responses.create(
                input=input_list,
                conversation=conversation.id,
                extra_body={"agent_reference": {"name": self.agent.name, "type": "agent_reference"}},
            )

        self.openai.conversations.delete(conversation_id=conversation.id)
        return response.output_text

    def cleanup(self):
        """Delete the agent version and close connections."""
        if self.agent:
            self.client.agents.delete_version(
                agent_name=self.agent.name,
                agent_version=self.agent.version,
            )
        if self.client:
            self.client.close()


# =============================================================================
# Claims Decision Agent
# =============================================================================

class ClaimsDecisionAgent:
    def __init__(self):
        self.agent = None
        self.client = None
        self.openai = None

    def create(self):
        """Create the claims decision agent in Foundry."""
        self.client = AIProjectClient(
            endpoint=PROJECT_CONNECTION_STRING,
            credential=DefaultAzureCredential(),
        )
        self.openai = self.client.get_openai_client()

        system_prompt = """
        You are a senior claims adjuster and decision specialist for ClaimSight Insurance.
        Given a list of flags from a claim assessment, your job is to:
        1. Determine the recommended action based on the pattern of flags:
           - High fraud risk score alone → Investigate for potential fraud
           - Low completeness alone → Request missing documents before proceeding
           - High fraud risk + low damage-estimate match → Likely inflated claim, escalate to SIU
           - Low policy coverage match → Partial denial, cover only matched items
           - Multiple critical flags → Compound risk, full investigation required
        2. Recommend specific, actionable next steps for the claims adjuster.
        3. Estimate urgency: IMMEDIATE (potential fraud), WITHIN 48H (missing docs), or STANDARD (routine).
        Be concise. Format your response as:
        RECOMMENDED ACTION: APPROVE / REQUEST DOCUMENTS / INVESTIGATE / DENY
        REASONING: ...
        NEXT STEPS: ...
        URGENCY: ...
        """

        self.agent = self.client.agents.create_version(
            agent_name="claims-decision-agent",
            definition=PromptAgentDefinition(
                model=MODEL_DEPLOYMENT_NAME,
                instructions=system_prompt,
            ),
        )

        return self.agent

    def run(self, input_text: str) -> str:
        """Run the claims decision agent with the given input."""
        conversation = self.openai.conversations.create()

        response = self.openai.responses.create(
            input=input_text,
            conversation=conversation.id,
            extra_body={"agent_reference": {"name": self.agent.name, "type": "agent_reference"}},
        )

        self.openai.conversations.delete(conversation_id=conversation.id)
        return response.output_text

    def cleanup(self):
        """Delete the agent version and close connections."""
        if self.agent:
            self.client.agents.delete_version(
                agent_name=self.agent.name,
                agent_version=self.agent.version,
            )
        if self.client:
            self.client.close()


# =============================================================================
# Main — Test both agents
# =============================================================================

def main():
    if not PROJECT_CONNECTION_STRING:
        print("❌ PROJECT_CONNECTION_STRING not set. Run challenge 0 first!")
        sys.exit(1)

    print("=== Claims Triage Agent ===")
    print("Creating agent...")

    triage_agent = ClaimsTriageAgent()
    triage_agent.create()
    print(f"✅ Created: {triage_agent.agent.name} (version {triage_agent.agent.version})")

    print("\nAssessing all claims...")
    triage_result = triage_agent.run(
        "Assess all 5 claims (CLM-001, CLM-002, CLM-003, CLM-004, CLM-005) "
        "and report which ones have flags. For each flag, state the "
        "metric, its current value, the threshold it violates, and by how much."
    )
    print(triage_result)

    print("\n=== Claims Decision Agent ===")
    print("Creating agent...")

    decision_agent = ClaimsDecisionAgent()
    decision_agent.create()
    print(f"✅ Created: {decision_agent.agent.name} (version {decision_agent.agent.version})")

    print("\nDeciding on critical claim CLM-001...")
    decision_result = decision_agent.run(
        "Claim CLM-001 (auto collision, claimant Maria Torres) has these flags:\n"
        "- Damage vs estimate match: 52% (min threshold: 70%) — 25.7% below min\n"
        "- Fraud risk score: 82 (max threshold: 50) — 64.0% above max\n\n"
        "The claim has full documentation (police report, photos, repair estimate, medical report).\n"
        "Recommend an action and next steps."
    )
    print(decision_result)

    # Cleanup — comment out to keep agents visible in the Foundry portal
    # print("\nCleaning up agents...")
    # triage_agent.cleanup()
    # decision_agent.cleanup()
    # print("✅ Done!")


if __name__ == "__main__":
    main()
