"""
Coordinator Agent - Orchestrates the investigation swarm.
"""

import asyncio
from datetime import datetime
from typing import Dict, Any, List, Optional
from pathlib import Path

import sys
sys.path.insert(0, str(Path(__file__).parent.parent))

from .base import BaseAgent, TaskResult
from .db_agent import DBAgent
from .jira_agent import JiraAgent
from .analysis_agent import AnalysisAgent
from .learning_agent import LearningAgent
from swarm.bus import MessageBus
from swarm.store import ResultStore
from swarm.context import InvestigationContext, AnalysisResult
from swarm.pool import ConnectionPool


class CoordinatorAgent(BaseAgent):
    """
    Coordinator Agent - Orchestrates the multi-agent swarm.

    Manages 4-phase investigation workflow:
    1. Init: Fetch ticket, load brain, get customer_id (parallel)
    2. Data: Execute queries, take screenshots (parallel)
    3. Analysis: Detect errors, determine RCA (parallel)
    4. Output: Generate summary, format response (sequential)
    """

    def __init__(self, bus: MessageBus, store: ResultStore, pool: ConnectionPool):
        super().__init__("coordinator", bus, store)
        self.pool = pool

        # Child agents
        self.db_agent: Optional[DBAgent] = None
        self.jira_agent: Optional[JiraAgent] = None
        self.analysis_agent: Optional[AnalysisAgent] = None
        self.learning_agent: Optional[LearningAgent] = None

        # Investigation state
        self._start_time: Optional[datetime] = None
        self._phase = "idle"

    async def _on_start(self):
        """Initialize child agents."""
        self.log("Starting child agents...")

        # Create agents
        self.db_agent = DBAgent(self.bus, self.store, self.pool)
        self.jira_agent = JiraAgent(self.bus, self.store)
        self.analysis_agent = AnalysisAgent(self.bus, self.store)
        self.learning_agent = LearningAgent(self.bus, self.store)

        # Start all agents in parallel
        await asyncio.gather(
            self.db_agent.start(),
            self.jira_agent.start(),
            self.analysis_agent.start(),
            self.learning_agent.start()
        )

        self.log("All agents ready")

    async def _on_stop(self):
        """Stop child agents."""
        self.log("Stopping child agents...")

        agents = [self.db_agent, self.jira_agent, self.analysis_agent, self.learning_agent]
        await asyncio.gather(*[a.stop() for a in agents if a])

        self.log("All agents stopped")

    async def _execute_task(self, task: str, **kwargs) -> Any:
        """Execute coordinator tasks."""
        tasks = {
            "investigate": self._investigate,
            "get_status": self._get_status
        }

        if task not in tasks:
            raise ValueError(f"Unknown task: {task}")

        return await tasks[task](**kwargs)

    async def _investigate(self, context: InvestigationContext) -> Dict[str, Any]:
        """
        Run full investigation workflow.

        Args:
            context: Investigation context with ticket info

        Returns:
            Dict with summary, to_cas, duration, etc.
        """
        self._start_time = datetime.now()
        self._phase = "init"

        # Set context on all agents
        for agent in [self.db_agent, self.jira_agent, self.analysis_agent, self.learning_agent]:
            agent.set_context(context)

        try:
            # ═══════════════════════════════════════════════════════════════
            # PHASE 1: INIT (Parallel)
            # ═══════════════════════════════════════════════════════════════
            print(f"\n{'─'*50}", flush=True)
            print(f"📋 PHASE 1: INIT - Getting customer info...", flush=True)
            print(f"{'─'*50}", flush=True)
            self._phase = "init"

            init_results = await asyncio.gather(
                self._init_jira(context),
                self._init_learning(context),
                self._init_db(context),
                return_exceptions=True
            )

            # Process init results
            jira_result, learning_result, db_result = init_results

            if isinstance(jira_result, Exception):
                self.log(f"Jira init error: {jira_result}", level="warn")
            if isinstance(db_result, Exception):
                self.log(f"DB init error: {db_result}", level="error")
                raise db_result

            # ═══════════════════════════════════════════════════════════════
            # PHASE 2: DATA (Parallel)
            # ═══════════════════════════════════════════════════════════════
            print(f"\n{'─'*50}", flush=True)
            print(f"📊 PHASE 2: DATA - Running queries...", flush=True)
            print(f"{'─'*50}", flush=True)
            self._phase = "data"

            data_results = await asyncio.gather(
                self._collect_db_data(context),
                self._search_similar_tickets(context),
                return_exceptions=True
            )

            db_data, similar_tickets = data_results

            if isinstance(db_data, Exception):
                self.log(f"DB data error: {db_data}", level="error")
                db_data = []

            if isinstance(similar_tickets, Exception):
                similar_tickets = []

            # Store in context
            await context.set_similar_tickets(similar_tickets if isinstance(similar_tickets, list) else [])

            # ═══════════════════════════════════════════════════════════════
            # PHASE 3: ANALYSIS (Parallel)
            # ═══════════════════════════════════════════════════════════════
            print(f"\n{'─'*50}", flush=True)
            print(f"🔍 PHASE 3: ANALYSIS - Detecting errors...", flush=True)
            print(f"{'─'*50}", flush=True)
            self._phase = "analysis"

            analysis_results = await asyncio.gather(
                self._analyze_results(context),
                self._find_similar_solutions(context),
                return_exceptions=True
            )

            analysis, similar_solutions = analysis_results

            if isinstance(analysis, Exception):
                self.log(f"Analysis error: {analysis}", level="error")
                analysis = AnalysisResult(
                    findings=["Analysis failed"],
                    errors=[],
                    root_cause="Unable to determine",
                    conclusion="Manual review required",
                    confidence="low"
                )

            # Store analysis
            await context.set_analysis(analysis)

            # ═══════════════════════════════════════════════════════════════
            # PHASE 4: OUTPUT (Sequential)
            # ═══════════════════════════════════════════════════════════════
            print(f"\n{'─'*50}", flush=True)
            print(f"📝 PHASE 4: OUTPUT - Generating results...", flush=True)
            print(f"{'─'*50}", flush=True)
            self._phase = "output"

            # Generate summary
            summary = await self.analysis_agent.execute(
                "generate_summary",
                context=context.to_dict(),
                analysis=analysis
            )

            # Format To CAS
            to_cas = await self.jira_agent.execute(
                "format_to_cas",
                root_cause=analysis.root_cause,
                conclusion=analysis.conclusion,
                details=""
            )

            # Learn from this investigation
            if analysis.confidence != "low":
                await self.learning_agent.execute(
                    "learn_pattern",
                    ticket_type=context.ticket_type,
                    error_pattern=analysis.errors[0].get("type", "") if analysis.errors else "",
                    root_cause=analysis.root_cause,
                    solution=analysis.conclusion,
                    queries_used=[qr.name for qr in context.query_results],
                    webid=context.webid
                )

            # Calculate duration
            duration_ms = int((datetime.now() - self._start_time).total_seconds() * 1000)

            # Update stats
            await self.learning_agent.execute(
                "update_stats",
                ticket_type=context.ticket_type,
                duration_ms=duration_ms,
                success=analysis.confidence != "low",
                webid=context.webid
            )

            self._phase = "complete"
            self.log(f"Investigation complete in {duration_ms}ms")

            return {
                "success": True,
                "summary": summary.data if summary.success else "Summary generation failed",
                "to_cas": to_cas.data if to_cas.success else analysis.conclusion,
                "root_cause": analysis.root_cause,
                "conclusion": analysis.conclusion,
                "confidence": analysis.confidence,
                "findings": analysis.findings,
                "errors": analysis.errors,
                "similar_tickets": similar_tickets if isinstance(similar_tickets, list) else [],
                "queries_executed": len(context.query_results),
                "screenshots": [qr.screenshot for qr in context.query_results if qr.screenshot],
                "duration_ms": duration_ms
            }

        except Exception as e:
            self._phase = "error"
            self.log(f"Investigation failed: {e}", level="error")
            return {
                "success": False,
                "error": str(e),
                "phase": self._phase,
                "duration_ms": int((datetime.now() - self._start_time).total_seconds() * 1000)
            }

    # ═══════════════════════════════════════════════════════════════
    # Phase 1 Tasks
    # ═══════════════════════════════════════════════════════════════

    async def _init_jira(self, context: InvestigationContext) -> Dict:
        """Parse ticket if we have summary/description."""
        if context.ticket_key and hasattr(context, 'ticket_summary'):
            result = await self.jira_agent.execute(
                "parse_ticket",
                summary=context.ticket_summary or "",
                description=context.ticket_description or ""
            )
            return result.data if result.success else {}
        return {}

    async def _init_learning(self, context: InvestigationContext) -> List[Dict]:
        """Get suggestions from learning agent."""
        result = await self.learning_agent.execute(
            "get_suggestions",
            ticket_type=context.ticket_type,
            findings=[]
        )
        return result.data if result.success else []

    async def _init_db(self, context: InvestigationContext) -> Optional[str]:
        """Get customer ID from database."""
        if context.username and context.webid:
            result = await self.db_agent.execute(
                "get_customer_id",
                webid=context.webid,
                username=context.username
            )
            if result.success and result.data:
                await context.update(customer_id=result.data)
                return result.data
        return None

    # ═══════════════════════════════════════════════════════════════
    # Phase 2 Tasks
    # ═══════════════════════════════════════════════════════════════

    async def _collect_db_data(self, context: InvestigationContext) -> List:
        """Execute all queries for the ticket type."""
        if not context.customer_id:
            self.log("No customer_id - skipping queries", level="warn")
            return []

        params = {
            "webId": context.webid,
            "username": context.username,
            "customerId": context.customer_id
        }

        result = await self.db_agent.execute(
            "execute_queries",
            ticket_type=context.ticket_type,
            params=params
        )

        if result.success:
            # Add results to context
            for qr in result.data:
                await context.add_query_result(qr)

            # Take screenshots
            if result.data:
                await self.db_agent.execute(
                    "batch_screenshots",
                    query_results=result.data,
                    prefix=context.ticket_key or "investigation"
                )

        return result.data if result.success else []

    async def _search_similar_tickets(self, context: InvestigationContext) -> List[Dict]:
        """Search for similar past tickets."""
        # Build JQL
        jql_result = await self.jira_agent.execute(
            "build_similar_jql",
            ticket_type=context.ticket_type,
            status="Done"
        )

        if jql_result.success:
            # Note: Actual Jira search would be done via MCP tool
            # This returns the JQL for the coordinator to use
            return [{
                "jql": jql_result.data,
                "note": "Execute via mcp__jira__jira_search_issues"
            }]

        return []

    # ═══════════════════════════════════════════════════════════════
    # Phase 3 Tasks
    # ═══════════════════════════════════════════════════════════════

    async def _analyze_results(self, context: InvestigationContext) -> AnalysisResult:
        """Analyze all query results."""
        result = await self.analysis_agent.execute(
            "analyze_results",
            query_results=context.query_results,
            ticket_type=context.ticket_type
        )

        if result.success:
            return result.data

        return AnalysisResult(
            findings=["Analysis failed"],
            errors=[],
            root_cause="Unable to determine",
            conclusion="Manual review required",
            confidence="low"
        )

    async def _find_similar_solutions(self, context: InvestigationContext) -> List[Dict]:
        """Find similar past solutions."""
        result = await self.learning_agent.execute(
            "find_similar",
            ticket_type=context.ticket_type,
            error_pattern="",
            webid=context.webid
        )

        return result.data if result.success else []

    # ═══════════════════════════════════════════════════════════════
    # Status
    # ═══════════════════════════════════════════════════════════════

    async def _get_status(self) -> Dict[str, Any]:
        """Get current swarm status."""
        return {
            "coordinator": self.get_stats(),
            "phase": self._phase,
            "agents": {
                "db_agent": self.db_agent.get_stats() if self.db_agent else None,
                "jira_agent": self.jira_agent.get_stats() if self.jira_agent else None,
                "analysis_agent": self.analysis_agent.get_stats() if self.analysis_agent else None,
                "learning_agent": self.learning_agent.get_stats() if self.learning_agent else None,
            },
            "pool": self.pool.get_stats() if self.pool else None
        }
