#!/usr/bin/env python3
"""
Full-Featured Gemini Deep Research Agent

Production-ready implementation with:
- Multiple research modes (sync, streaming, multi-turn)
- File integration support
- Cost tracking
- Configurable timeouts
- Export to multiple formats

Install: pip install google-genai
Set: export GEMINI_API_KEY="your-api-key"
"""

import os
import time
import json
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Generator, Optional
from google import genai


class OutputFormat(Enum):
    """Supported output formats for research reports."""
    MARKDOWN = "md"
    JSON = "json"
    TEXT = "txt"
    HTML = "html"


@dataclass
class ResearchConfig:
    """Configuration for the research agent."""
    poll_interval: int = 10
    max_wait_time: int = 3600  # 60 minutes
    thinking_summaries: str = "auto"  # "auto" or "none"
    max_reconnects: int = 5
    reconnect_delay: int = 2
    track_costs: bool = True
    input_cost_per_million: float = 2.0  # $2 per million input tokens


@dataclass
class ResearchResult:
    """Container for research results and metadata."""
    query: str
    report: str
    interaction_id: str
    duration_seconds: float
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    thinking_summaries: list = field(default_factory=list)
    estimated_tokens: int = 0
    estimated_cost: float = 0.0


class GeminiDeepResearchAgent:
    """
    Production-ready Gemini Deep Research Agent.

    Supports synchronous research, streaming, multi-turn conversations,
    and file-enhanced research with comprehensive error handling.
    """

    AGENT_ID = 'deep-research-pro-preview-12-2025'

    def __init__(
        self,
        api_key: Optional[str] = None,
        config: Optional[ResearchConfig] = None
    ):
        """
        Initialize the research agent.

        Args:
            api_key: Gemini API key (or set GEMINI_API_KEY env var)
            config: Research configuration options
        """
        self.client = genai.Client(api_key=api_key or os.getenv("GEMINI_API_KEY"))
        self.config = config or ResearchConfig()
        self.conversation_id: Optional[str] = None
        self.research_history: list[ResearchResult] = []

    def research(
        self,
        query: str,
        continue_conversation: bool = False,
        file_stores: Optional[list[str]] = None
    ) -> ResearchResult:
        """
        Execute synchronous research with polling.

        Args:
            query: Research question or topic
            continue_conversation: Build on previous research context
            file_stores: List of file store names for document context

        Returns:
            ResearchResult with report and metadata
        """
        start_time = time.time()

        # Build request parameters
        params = self._build_request_params(query, continue_conversation, file_stores)

        # Create interaction
        interaction = self.client.interactions.create(**params)
        interaction_id = interaction.id

        # Poll for completion
        report = self._poll_for_completion(interaction)

        # Build result
        duration = time.time() - start_time
        result = ResearchResult(
            query=query,
            report=report,
            interaction_id=interaction_id,
            duration_seconds=duration,
            estimated_tokens=self._estimate_tokens(query + report),
            estimated_cost=self._estimate_cost(query + report)
        )

        # Track history
        self.conversation_id = interaction_id
        self.research_history.append(result)

        return result

    def research_stream(
        self,
        query: str,
        continue_conversation: bool = False
    ) -> Generator[str, None, None]:
        """
        Stream research results in real-time.

        Args:
            query: Research question
            continue_conversation: Continue from prior research

        Yields:
            Text chunks as they arrive
        """
        params = self._build_request_params(query, continue_conversation)
        params["stream"] = True
        params["agent_config"] = {
            "type": "deep-research",
            "thinking_summaries": self.config.thinking_summaries
        }

        state = {
            "interaction_id": None,
            "last_event_id": None,
            "is_complete": False,
            "reconnect_count": 0
        }

        # Initial stream
        try:
            stream = self.client.interactions.create(**params)
            yield from self._process_stream(stream, state)
        except Exception as e:
            print(f"[Stream error: {e}]")

        # Reconnection loop
        while (not state["is_complete"]
               and state["interaction_id"]
               and state["reconnect_count"] < self.config.max_reconnects):

            state["reconnect_count"] += 1
            time.sleep(self.config.reconnect_delay)

            try:
                resume_params = {"id": state["interaction_id"], "stream": True}
                if state["last_event_id"]:
                    resume_params["last_event_id"] = state["last_event_id"]

                resume_stream = self.client.interactions.get(**resume_params)
                yield from self._process_stream(resume_stream, state)
            except Exception:
                continue

        if state["interaction_id"]:
            self.conversation_id = state["interaction_id"]

    def multi_turn_research(self, queries: list[str]) -> list[ResearchResult]:
        """
        Conduct multi-turn research, building context across queries.

        Args:
            queries: List of research questions

        Returns:
            List of ResearchResults for each query
        """
        results = []
        for i, query in enumerate(queries):
            print(f"[Research {i+1}/{len(queries)}: {query[:40]}...]")
            result = self.research(query, continue_conversation=(i > 0))
            results.append(result)
        return results

    def export_result(
        self,
        result: ResearchResult,
        output_path: str,
        format: OutputFormat = OutputFormat.MARKDOWN
    ) -> Path:
        """
        Export research result to file.

        Args:
            result: ResearchResult to export
            output_path: Output file path (extension auto-added)
            format: Output format

        Returns:
            Path to created file
        """
        path = Path(output_path).with_suffix(f".{format.value}")

        if format == OutputFormat.MARKDOWN:
            content = self._format_markdown(result)
        elif format == OutputFormat.JSON:
            content = json.dumps({
                "query": result.query,
                "report": result.report,
                "interaction_id": result.interaction_id,
                "duration_seconds": result.duration_seconds,
                "timestamp": result.timestamp,
                "estimated_cost": result.estimated_cost
            }, indent=2)
        elif format == OutputFormat.HTML:
            content = self._format_html(result)
        else:
            content = result.report

        path.write_text(content)
        return path

    def get_total_cost(self) -> float:
        """Get total estimated cost of all research in this session."""
        return sum(r.estimated_cost for r in self.research_history)

    # Private methods

    def _build_request_params(
        self,
        query: str,
        continue_conversation: bool,
        file_stores: Optional[list[str]] = None
    ) -> dict:
        """Build request parameters for interaction creation."""
        params = {
            "input": query,
            "agent": self.AGENT_ID,
            "background": True
        }

        if continue_conversation and self.conversation_id:
            params["previous_interaction_id"] = self.conversation_id

        if file_stores:
            params["tools"] = [{
                "type": "file_search",
                "file_search_store_names": [
                    f"fileSearchStores/{store}" for store in file_stores
                ]
            }]

        return params

    def _poll_for_completion(self, interaction) -> str:
        """Poll interaction until completion."""
        elapsed = 0
        while elapsed < self.config.max_wait_time:
            interaction = self.client.interactions.get(interaction.id)

            if interaction.status == "completed":
                return interaction.outputs[-1].text
            elif interaction.status == "failed":
                error = getattr(interaction, 'error', 'Unknown error')
                raise Exception(f"Research failed: {error}")

            time.sleep(self.config.poll_interval)
            elapsed += self.config.poll_interval

        raise TimeoutError(
            f"Research exceeded {self.config.max_wait_time}s timeout"
        )

    def _process_stream(self, stream, state: dict) -> Generator[str, None, None]:
        """Process stream events and yield content."""
        for event in stream:
            if event.event_type == "interaction.start":
                state["interaction_id"] = event.interaction.id

            if hasattr(event, 'event_id') and event.event_id:
                state["last_event_id"] = event.event_id

            if event.event_type == "content.delta":
                if hasattr(event, 'delta') and hasattr(event.delta, 'text'):
                    yield event.delta.text

            if event.event_type in ['interaction.complete', 'error']:
                state["is_complete"] = True
                break

    def _estimate_tokens(self, text: str) -> int:
        """Rough token estimation (4 chars per token)."""
        return len(text) // 4

    def _estimate_cost(self, text: str) -> float:
        """Estimate cost based on token count."""
        if not self.config.track_costs:
            return 0.0
        tokens = self._estimate_tokens(text)
        return (tokens / 1_000_000) * self.config.input_cost_per_million

    def _format_markdown(self, result: ResearchResult) -> str:
        """Format result as markdown document."""
        return f"""# Research Report

**Query:** {result.query}

**Generated:** {result.timestamp}

**Duration:** {result.duration_seconds:.1f} seconds

**Estimated Cost:** ${result.estimated_cost:.4f}

---

{result.report}

---

*Generated by Gemini Deep Research Agent*
*Interaction ID: {result.interaction_id}*
"""

    def _format_html(self, result: ResearchResult) -> str:
        """Format result as HTML document."""
        return f"""<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <title>Research Report</title>
    <style>
        body {{ font-family: system-ui, sans-serif; max-width: 800px; margin: 40px auto; padding: 20px; }}
        .meta {{ color: #666; font-size: 0.9em; }}
        hr {{ margin: 20px 0; }}
    </style>
</head>
<body>
    <h1>Research Report</h1>
    <p class="meta"><strong>Query:</strong> {result.query}</p>
    <p class="meta"><strong>Generated:</strong> {result.timestamp}</p>
    <p class="meta"><strong>Duration:</strong> {result.duration_seconds:.1f}s</p>
    <hr>
    <div class="content">
        {result.report.replace(chr(10), '<br>')}
    </div>
    <hr>
    <p class="meta"><em>Generated by Gemini Deep Research Agent</em></p>
</body>
</html>
"""


def main():
    """Demonstration of full-featured agent."""
    # Initialize agent with custom config
    config = ResearchConfig(
        poll_interval=10,
        max_wait_time=3600,
        thinking_summaries="auto",
        track_costs=True
    )

    agent = GeminiDeepResearchAgent(config=config)

    print("\n" + "="*70)
    print("GEMINI DEEP RESEARCH AGENT - FULL DEMO")
    print("="*70)

    # Single research
    print("\n[1] Single Research Query")
    print("-"*70)

    result = agent.research(
        "What are the key differences between transformer and state-space models for language?"
    )

    print(f"\nQuery: {result.query}")
    print(f"Duration: {result.duration_seconds:.1f}s")
    print(f"Estimated cost: ${result.estimated_cost:.4f}")
    print(f"\nReport preview:\n{result.report[:500]}...")

    # Export
    output_path = agent.export_result(result, "research_report", OutputFormat.MARKDOWN)
    print(f"\nExported to: {output_path}")

    # Multi-turn example
    print("\n" + "-"*70)
    print("[2] Multi-Turn Research")
    print("-"*70)

    queries = [
        "Explain the architecture of GPT-4",
        "What are the known limitations?",
        "How does it compare to Claude?"
    ]

    results = agent.multi_turn_research(queries)

    for r in results:
        print(f"\n- {r.query[:40]}... ({r.duration_seconds:.1f}s)")

    # Summary
    print("\n" + "="*70)
    print("SESSION SUMMARY")
    print("="*70)
    print(f"Total queries: {len(agent.research_history)}")
    print(f"Total cost: ${agent.get_total_cost():.4f}")


if __name__ == "__main__":
    main()
