"""
DB Agent - Handles all database operations with connection pooling.
"""

import asyncio
import json
from datetime import datetime
from typing import Dict, Any, List, Optional
from pathlib import Path

import sys
sys.path.insert(0, str(Path(__file__).parent.parent))

from .base import BaseAgent
from swarm.bus import MessageBus
from swarm.store import ResultStore
from swarm.pool import ConnectionPool
from swarm.context import QueryResult

SKILL_DIR = Path.home() / ".claude/skills/artemis-debug-secure"
QUERIES_FILE = SKILL_DIR / "scripts/queries.json"
SCREENSHOT_DIR = Path.home() / ".playwright-mcp"


class DBAgent(BaseAgent):
    """
    Database Agent - Handles all Artemis database operations.

    Features:
    - Connection pooling (3 sessions)
    - Parallel query execution by database
    - Batch screenshot capture
    - Query result caching
    """

    def __init__(self, bus: MessageBus, store: ResultStore, pool: ConnectionPool):
        super().__init__("db_agent", bus, store)
        self.pool = pool
        self._queries: Dict[str, Any] = {}
        self._cache: Dict[str, Any] = {}
        self._cache_ttl = 300  # 5 minutes

    async def _on_start(self):
        """Load query definitions."""
        if QUERIES_FILE.exists():
            with open(QUERIES_FILE) as f:
                self._queries = json.load(f)
        self.log(f"Loaded {len(self._queries)} query types")

    async def _on_stop(self):
        """Clear cache."""
        self._cache.clear()

    async def _execute_task(self, task: str, **kwargs) -> Any:
        """Execute a database task."""
        tasks = {
            "get_customer_id": self._get_customer_id,
            "execute_queries": self._execute_queries,
            "execute_query": self._execute_query,
            "take_screenshot": self._take_screenshot,
            "batch_screenshots": self._batch_screenshots
        }

        if task not in tasks:
            raise ValueError(f"Unknown task: {task}")

        return await tasks[task](**kwargs)

    async def _get_customer_id(self, webid: str, username: str) -> Optional[str]:
        """Get CustomerId from username."""
        sql = f"""
        SELECT TOP 1 [CustomerId]
        FROM [Main].[dbo].[Customer] WITH(NOLOCK)
        WHERE [WebId] = {webid} AND [Username] = '{username}'
        """

        self.log(f"Getting customer_id for {username} on WebId {webid}")

        conn = await self.pool.acquire()
        if not conn:
            self.log("Failed to acquire connection from pool", level="error")
            return None

        if not conn.logged_in:
            self.log("Connection is not logged in", level="error")
            await self.pool.release(conn)
            return None

        try:
            self.log("Executing customer lookup query...")
            result = await self.pool.execute_query(conn, sql, "Main")
            self.log(f"Query result: {result.get('count', 0)} rows, error: {result.get('error')}")

            if result.get('error'):
                self.log(f"Query error: {result['error']}", level="error")
                return None

            if result['rows']:
                first_row = result['rows'][0]
                self.log(f"First row keys: {list(first_row.keys())}")
                self.log(f"First row: {first_row}")
                # Try various key formats
                customer_id = (
                    first_row.get('CustomerId') or
                    first_row.get('customerid') or
                    first_row.get('CUSTOMERID') or
                    first_row.get('[CustomerId]') or
                    list(first_row.values())[0] if first_row else None
                )
                if customer_id:
                    self.log(f"Found customer_id: {customer_id}")
                    return str(customer_id)
                self.log(f"Could not extract customer_id from row", level="warn")

            self.log("No customer found", level="warn")
            return None
        except Exception as e:
            self.log(f"Exception in get_customer_id: {e}", level="error")
            return None
        finally:
            await self.pool.release(conn)

    async def _execute_queries(self, ticket_type: str, params: Dict[str, str]) -> List[QueryResult]:
        """
        Execute all queries for a ticket type in parallel.

        Args:
            ticket_type: Type of investigation (promotion, payment, etc.)
            params: Dict with webId, username, customerId

        Returns:
            List of QueryResult
        """
        if ticket_type not in self._queries:
            raise ValueError(f"Unknown ticket type: {ticket_type}")

        playbook = self._queries[ticket_type]
        queries = playbook.get("queries", [])

        # Group by database for parallel execution
        by_database: Dict[str, List[Dict]] = {}
        for q in queries:
            db = q["database"]
            if db not in by_database:
                by_database[db] = []
            by_database[db].append(q)

        # Execute each database group in parallel
        tasks = []
        for db, db_queries in by_database.items():
            task = asyncio.create_task(
                self._execute_db_group(db, db_queries, params)
            )
            tasks.append(task)

        # Wait for all
        results_by_db = await asyncio.gather(*tasks, return_exceptions=True)

        # Flatten and sort by original order
        all_results = []
        for result in results_by_db:
            if isinstance(result, Exception):
                self.log(f"DB group error: {result}", level="error")
            else:
                all_results.extend(result)

        return all_results

    async def _execute_db_group(self, database: str, queries: List[Dict],
                                 params: Dict[str, str]) -> List[QueryResult]:
        """Execute a group of queries for one database."""
        results = []
        conn = await self.pool.acquire()

        try:
            for q in queries:
                sql = self._fill_params(q["sql"], params)
                start = datetime.now()

                result = await self.pool.execute_query(conn, sql, database)
                duration = int((datetime.now() - start).total_seconds() * 1000)

                qr = QueryResult(
                    id=q["id"],
                    name=q["name"],
                    database=database,
                    sql=sql,
                    count=result['count'],
                    rows=result['rows'][:10],  # First 10 rows
                    error=result['error'],
                    duration_ms=duration
                )
                results.append(qr)

                self.log(f"  {q['name']}: {result['count']} rows ({duration}ms)")

        finally:
            await self.pool.release(conn)

        return results

    async def _execute_query(self, sql: str, database: str) -> QueryResult:
        """Execute a single query."""
        conn = await self.pool.acquire()
        try:
            start = datetime.now()
            result = await self.pool.execute_query(conn, sql, database)
            duration = int((datetime.now() - start).total_seconds() * 1000)

            return QueryResult(
                id="custom",
                name="Custom Query",
                database=database,
                sql=sql,
                count=result['count'],
                rows=result['rows'],
                error=result['error'],
                duration_ms=duration
            )
        finally:
            await self.pool.release(conn)

    async def _take_screenshot(self, filename: str) -> str:
        """Take a screenshot."""
        conn = await self.pool.acquire()
        try:
            # Zoom out for more rows
            await conn.page.evaluate('document.body.style.zoom = "0.67"')
            await conn.page.wait_for_timeout(200)

            path = await self.pool.take_screenshot(conn, filename, str(SCREENSHOT_DIR))

            # Reset zoom
            await conn.page.evaluate('document.body.style.zoom = "1"')

            return path
        finally:
            await self.pool.release(conn)

    async def _batch_screenshots(self, query_results: List[QueryResult],
                                  prefix: str) -> List[str]:
        """Take screenshots in batch for efficiency."""
        screenshots = []
        conn = await self.pool.acquire()

        try:
            # Zoom once at start
            await conn.page.evaluate('document.body.style.zoom = "0.67"')
            await conn.page.wait_for_timeout(200)

            for i, qr in enumerate(query_results):
                filename = f"{prefix}_{qr.id}_{i}.png"
                path = await self.pool.take_screenshot(conn, filename, str(SCREENSHOT_DIR))
                screenshots.append(path)
                qr.screenshot = path

            # Reset zoom once at end
            await conn.page.evaluate('document.body.style.zoom = "1"')

        finally:
            await self.pool.release(conn)

        return screenshots

    def _fill_params(self, sql: str, params: Dict[str, str]) -> str:
        """Fill SQL parameters."""
        result = sql
        for key, val in params.items():
            result = result.replace(f"{{{key}}}", str(val))
        return result

    def get_queries_for_type(self, ticket_type: str) -> List[Dict]:
        """Get query definitions for a ticket type."""
        if ticket_type in self._queries:
            return self._queries[ticket_type].get("queries", [])
        return []
