#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "httpx",
#     "click",
#     "rich",
#     "azure-identity",
# ]
# ///
"""
Unified User Invitation Script

Invite external guest users to Azure AD tenant with optional:
- AD group assignments
- OSDU preshipping provisioning
- Copy setup from existing user (--like)

Usage:
    uv run invite.py --email user@company.com
    uv run invite.py --email user@company.com --groups "Group1,Group2"
    uv run invite.py --email user@company.com --groups "Group1" --preshipping
    uv run invite.py --email user@company.com --like existing@company.com
    uv run invite.py help

Authentication:
    Uses DefaultAzureCredential for Graph API (az login)
    Uses OSDU env vars for preshipping (AI_OSDU_*)
"""

import json
import os
import re
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import click
import httpx
from azure.identity import DefaultAzureCredential
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)

GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
GRAPH_SCOPE = "https://graph.microsoft.com/.default"
DEFAULT_REDIRECT_URL = "https://myapps.microsoft.com"

# Path to preshipping script (relative to this script)
PRESHIPPING_SCRIPT = Path(__file__).parent.parent.parent / "osdu-preshipping" / "scripts" / "preshipping.py"


# =============================================================================
# Validation
# =============================================================================

def is_valid_email(email: str) -> bool:
    """Validate email format."""
    pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
    return bool(re.match(pattern, email))


# =============================================================================
# Azure AD / Graph API Functions
# =============================================================================

def get_graph_token() -> str:
    """Get access token for Graph API using az cli (faster) or DefaultAzureCredential."""
    # Try az cli first - it's faster and more reliable when user has done az login
    try:
        result = subprocess.run(
            ["az", "account", "get-access-token", "--resource", "https://graph.microsoft.com", "--query", "accessToken", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            return result.stdout.strip()
    except (subprocess.TimeoutExpired, FileNotFoundError):
        pass

    # Fallback to DefaultAzureCredential
    credential = DefaultAzureCredential()
    token = credential.get_token(GRAPH_SCOPE)
    return token.token


def check_user_exists(token: str, email: str) -> dict | None:
    """Check if a user with this email already exists in the tenant."""
    url = f"{GRAPH_API_BASE}/users"
    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
    params = {
        "$filter": f"mail eq '{email}' or otherMails/any(x:x eq '{email}')",
        "$select": "id,displayName,mail,userType",
    }

    with httpx.Client(timeout=30.0) as client:
        response = client.get(url, headers=headers, params=params)
        if response.status_code == 200:
            data = response.json()
            users = data.get("value", [])
            if users:
                return users[0]
    return None


def invite_guest_user(
    token: str,
    email: str,
    redirect_url: str = DEFAULT_REDIRECT_URL,
    send_invitation: bool = True,
) -> dict:
    """Invite a guest user to the tenant."""
    url = f"{GRAPH_API_BASE}/invitations"
    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
    body = {
        "invitedUserEmailAddress": email,
        "inviteRedirectUrl": redirect_url,
        "sendInvitationMessage": send_invitation,
    }

    with httpx.Client(timeout=30.0) as client:
        response = client.post(url, headers=headers, json=body)

        if response.status_code == 201:
            data = response.json()
            return {
                "success": True,
                "status": data.get("status"),
                "user_id": data.get("invitedUser", {}).get("id"),
            }
        elif response.status_code == 403:
            return {
                "success": False,
                "error": "permission_denied",
                "message": "Insufficient permissions. Requires User.Invite.All and Guest Inviter role.",
            }
        else:
            error_msg = response.json().get("error", {}).get("message", response.text)
            return {"success": False, "error": "api_error", "message": error_msg}


def get_user_groups(email: str) -> list[str]:
    """Get AD group names for a user via az CLI."""
    try:
        # First try direct lookup by ID (works for member users)
        result = subprocess.run(
            ["az", "ad", "user", "show", "--id", email, "--query", "id", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        user_id = result.stdout.strip() if result.returncode == 0 else None

        # Fallback: search by mail property (works for guest users)
        if not user_id:
            result = subprocess.run(
                ["az", "ad", "user", "list", "--filter", f"mail eq '{email}'", "--query", "[0].id", "-o", "tsv"],
                capture_output=True,
                text=True,
                timeout=30,
            )
            user_id = result.stdout.strip() if result.returncode == 0 else None

        if not user_id:
            return []

        # Get group memberships (returns objects with displayName)
        result = subprocess.run(
            ["az", "ad", "user", "get-member-groups", "--id", user_id, "-o", "json"],
            capture_output=True,
            text=True,
            timeout=60,
        )
        if result.returncode != 0:
            return []

        groups = json.loads(result.stdout)
        if not groups:
            return []

        # Extract display names
        return [g.get("displayName") for g in groups if g.get("displayName")]
    except Exception:
        return []


def get_group_id(group_name: str) -> str | None:
    """Get AD group object ID by name."""
    try:
        result = subprocess.run(
            ["az", "ad", "group", "show", "--group", group_name, "--query", "id", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0:
            return result.stdout.strip()
    except Exception:
        pass
    return None


def add_user_to_group(user_id: str, group_id: str, group_name: str) -> dict:
    """Add user to AD group."""
    try:
        result = subprocess.run(
            ["az", "ad", "group", "member", "add", "--group", group_id, "--member-id", user_id],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0:
            return {"success": True, "group": group_name}
        elif "already exist" in result.stderr.lower():
            return {"success": True, "group": group_name, "skipped": True, "message": "Already a member"}
        else:
            return {"success": False, "group": group_name, "message": result.stderr.strip()}
    except Exception as e:
        return {"success": False, "group": group_name, "message": str(e)}


def add_user_to_groups_parallel(user_id: str, groups: list[str], max_workers: int = 5) -> list[dict]:
    """Add user to multiple AD groups in parallel."""
    results = []

    # First, resolve all group IDs
    group_ids = {}
    for group_name in groups:
        group_id = get_group_id(group_name)
        if group_id:
            group_ids[group_name] = group_id
        else:
            results.append({"success": False, "group": group_name, "message": "Group not found"})

    # Add to groups in parallel
    if group_ids:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {
                executor.submit(add_user_to_group, user_id, gid, gname): gname
                for gname, gid in group_ids.items()
            }
            for future in as_completed(futures):
                results.append(future.result())

    return results


# =============================================================================
# OSDU Preshipping Functions
# =============================================================================

def check_osdu_configured() -> bool:
    """Check if OSDU environment variables are configured."""
    required = ["AI_OSDU_HOST", "AI_OSDU_DATA_PARTITION", "AI_OSDU_CLIENT", "AI_OSDU_SECRET", "AI_OSDU_TENANT_ID"]
    return all(os.environ.get(var) for var in required)


def add_to_preshipping(email: str, level: str = "viewer", dry_run: bool = False) -> dict:
    """Add user to OSDU preshipping groups by calling preshipping.py.

    Args:
        email: User's email address
        level: Access level - 'viewer' (read-only) or 'admin' (Admin/Ops)
        dry_run: Preview without making changes
    """
    if not PRESHIPPING_SCRIPT.exists():
        return {"success": False, "error": "script_not_found", "message": f"Preshipping script not found: {PRESHIPPING_SCRIPT}"}

    if not check_osdu_configured():
        return {"success": False, "error": "not_configured", "message": "OSDU environment variables not set"}

    cmd = ["uv", "run", str(PRESHIPPING_SCRIPT), "add", "--user", email, "--level", level, "--json"]
    if dry_run:
        cmd.append("--dry-run")

    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
        if result.returncode == 0:
            return json.loads(result.stdout)
        else:
            return {"success": False, "error": "script_error", "message": result.stderr.strip()}
    except json.JSONDecodeError:
        return {"success": False, "error": "parse_error", "message": result.stdout}
    except Exception as e:
        return {"success": False, "error": "exception", "message": str(e)}


def get_user_osdu_status(email: str) -> dict | None:
    """Check if user has OSDU preshipping access (for --like functionality)."""
    # This would require querying OSDU API - for now, return None
    # The --like flag primarily copies AD groups
    return None


# =============================================================================
# CLI Commands
# =============================================================================

@click.group(invoke_without_command=True)
@click.pass_context
def cli(ctx):
    """Unified User Invitation - Invite guests with AD groups and OSDU access."""
    if ctx.invoked_subcommand is None:
        click.echo(ctx.get_help())


@cli.command()
@click.option("--email", required=True, help="Email address to invite")
@click.option("--groups", default="", help="Comma-separated AD group names")
@click.option("--preshipping", is_flag=True, help="Also provision OSDU preshipping access")
@click.option("--level", type=click.Choice(["viewer", "admin"]), default="viewer", help="OSDU access level: viewer (read-only) or admin (Admin/Ops)")
@click.option("--like", "like_user", default=None, help="Copy setup from existing user")
@click.option("--dry-run", is_flag=True, help="Preview without making changes")
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def invite(email: str, groups: str, preshipping: bool, level: str, like_user: str | None, dry_run: bool, output_json: bool):
    """Invite a guest user with optional groups and preshipping access."""
    # Validate email
    if not is_valid_email(email):
        result = {"success": False, "error": "invalid_email", "message": f"Invalid email format: {email}"}
        if output_json:
            print(json.dumps(result))
        else:
            console.print(f"[red]✗[/red] Invalid email format: {email}")
        sys.exit(1)

    # Parse groups
    group_list = [g.strip() for g in groups.split(",") if g.strip()] if groups else []

    # No auto AD group for preshipping - use --like or --groups to specify AD groups

    # Handle --like flag: copy groups from existing user
    if like_user:
        if not output_json:
            console.print(f"\n[bold blue]Copying setup from:[/bold blue] {like_user}")
        like_groups = get_user_groups(like_user)
        if like_groups:
            # Merge with explicit groups (explicit takes precedence)
            for g in like_groups:
                if g not in group_list:
                    group_list.append(g)
            if not output_json:
                console.print(f"  Found {len(like_groups)} groups: {', '.join(like_groups)}")
        else:
            if not output_json:
                console.print(f"  [yellow]⚠[/yellow] No groups found for {like_user}")

    # Show plan
    if not output_json:
        console.print(f"\n[bold blue]Invite Plan[/bold blue]")
        console.print(f"Email: {email}")
        if group_list:
            console.print(f"AD Groups: {', '.join(group_list)}")
        if preshipping:
            console.print(f"OSDU Preshipping: Yes (level: {level})")
        if dry_run:
            console.print("[yellow]DRY RUN - No changes will be made[/yellow]")
        console.print()

    # Initialize results
    results = {
        "email": email,
        "tenant": None,
        "groups": [],
        "preshipping": None,
    }

    # Step 1: Tenant invitation
    if not output_json:
        console.print("[bold]Step 1: Tenant Invitation[/bold]")

    try:
        token = get_graph_token()
        existing_user = check_user_exists(token, email)

        if existing_user:
            user_id = existing_user.get("id")
            results["tenant"] = {
                "success": True,
                "existing": True,
                "user_id": user_id,
                "display_name": existing_user.get("displayName"),
                "user_type": existing_user.get("userType"),
            }
            if not output_json:
                console.print(f"  [yellow]⚠[/yellow] User already exists: {existing_user.get('displayName')} ({existing_user.get('userType')})")
                console.print(f"      User ID: {user_id}")
        elif dry_run:
            results["tenant"] = {"success": True, "dry_run": True, "message": "Would send invitation"}
            if not output_json:
                console.print("  [dim]Would send invitation[/dim]")
            user_id = None
        else:
            invite_result = invite_guest_user(token, email)
            results["tenant"] = invite_result
            if invite_result["success"]:
                user_id = invite_result.get("user_id")
                if not output_json:
                    console.print(f"  [green]✓[/green] Invitation sent (Status: {invite_result.get('status')})")
                    console.print(f"      User ID: {user_id}")
            else:
                user_id = None
                if not output_json:
                    console.print(f"  [red]✗[/red] {invite_result.get('message')}")

    except Exception as e:
        results["tenant"] = {"success": False, "error": "auth_failed", "message": str(e)}
        user_id = None
        if not output_json:
            console.print(f"  [red]✗[/red] Authentication failed: {e}")

    # Step 2: AD Group membership
    if group_list and user_id:
        if not output_json:
            console.print(f"\n[bold]Step 2: AD Group Membership[/bold]")

        if dry_run:
            results["groups"] = [{"group": g, "dry_run": True} for g in group_list]
            if not output_json:
                for g in group_list:
                    console.print(f"  [dim]Would add to {g}[/dim]")
        else:
            group_results = add_user_to_groups_parallel(user_id, group_list)
            results["groups"] = group_results
            if not output_json:
                for gr in group_results:
                    if gr.get("success"):
                        if gr.get("skipped"):
                            console.print(f"  [dim]-[/dim] {gr['group']}: Already a member")
                        else:
                            console.print(f"  [green]✓[/green] Added to {gr['group']}")
                    else:
                        console.print(f"  [red]✗[/red] {gr['group']}: {gr.get('message')}")
    elif group_list and not user_id:
        if not output_json:
            console.print(f"\n[bold]Step 2: AD Group Membership[/bold]")
            console.print("  [yellow]⚠[/yellow] Skipped - no user ID available")
        results["groups"] = [{"group": g, "skipped": True, "message": "No user ID"} for g in group_list]

    # Step 3: OSDU Preshipping
    if preshipping:
        if not output_json:
            console.print(f"\n[bold]Step 3: OSDU Preshipping[/bold]")

        if not check_osdu_configured():
            results["preshipping"] = {"success": False, "error": "not_configured", "message": "OSDU env vars not set"}
            if not output_json:
                console.print("  [red]✗[/red] OSDU environment variables not configured")
        else:
            preshipping_result = add_to_preshipping(email, level, dry_run)
            results["preshipping"] = preshipping_result
            if not output_json:
                if preshipping_result.get("dry_run"):
                    console.print(f"  [dim]Would add to {preshipping_result.get('total_operations', '?')} preshipping groups[/dim]")
                else:
                    summary = preshipping_result.get("summary", {})
                    added = summary.get("added", 0)
                    skipped = summary.get("skipped", 0)
                    failed = summary.get("failed", 0)

                    if failed > 0:
                        # Mark as failed if there are failures
                        preshipping_result["success"] = False
                        results["preshipping"] = preshipping_result
                        console.print(f"  [red]✗[/red] Failed to add to {failed} groups (auth/permission error)")
                        if added > 0:
                            console.print(f"      Added to {added} groups")
                        if skipped > 0:
                            console.print(f"      Already in {skipped} groups")
                    elif added > 0 or skipped > 0:
                        console.print(f"  [green]✓[/green] Added to {added} groups")
                        if skipped > 0:
                            console.print(f"      Already in {skipped} groups")
                    else:
                        console.print(f"  [red]✗[/red] {preshipping_result.get('message', 'Unknown error')}")

    # Summary
    if not output_json:
        console.print(f"\n[bold]Summary[/bold]")
        tenant_ok = results["tenant"] and results["tenant"].get("success")
        groups_ok = all(g.get("success") for g in results["groups"]) if results["groups"] else True
        preshipping_ok = results["preshipping"] is None or results["preshipping"].get("success")

        status = "[green]✓ Complete[/green]" if (tenant_ok and groups_ok and preshipping_ok) else "[yellow]⚠ Partial[/yellow]"
        console.print(f"  Status: {status}")
    else:
        # Calculate overall success
        results["success"] = (
            (results["tenant"] and results["tenant"].get("success"))
            and all(g.get("success") for g in results["groups"])
            and (results["preshipping"] is None or results["preshipping"].get("success"))
        )
        print(json.dumps(results, indent=2))


@cli.command()
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def check(output_json: bool):
    """Check authentication and configuration status."""
    results = {
        "graph_api": {"configured": False, "authenticated": False},
        "osdu": {"configured": False},
    }

    if not output_json:
        console.print("\n[bold blue]Configuration Check[/bold blue]\n")

    # Check Graph API auth
    if not output_json:
        console.print("[bold]Azure AD / Graph API:[/bold]")
    try:
        token = get_graph_token()
        # Test by getting current user
        url = f"{GRAPH_API_BASE}/me"
        headers = {"Authorization": f"Bearer {token}"}
        with httpx.Client(timeout=30.0) as client:
            response = client.get(url, headers=headers)
        if response.status_code == 200:
            user_data = response.json()
            results["graph_api"] = {
                "configured": True,
                "authenticated": True,
                "user": user_data.get("displayName"),
                "email": user_data.get("mail") or user_data.get("userPrincipalName"),
            }
            if not output_json:
                console.print(f"  [green]✓[/green] Authenticated as {user_data.get('displayName')}")
        else:
            results["graph_api"]["error"] = f"API returned {response.status_code}"
            if not output_json:
                console.print(f"  [red]✗[/red] Graph API error: {response.status_code}")
    except Exception as e:
        results["graph_api"]["error"] = str(e)
        if not output_json:
            console.print(f"  [red]✗[/red] Auth failed: {e}")
            console.print("      Run 'az login' to authenticate")

    # Check OSDU config
    if not output_json:
        console.print("\n[bold]OSDU Preshipping:[/bold]")
    if check_osdu_configured():
        results["osdu"] = {
            "configured": True,
            "host": os.environ.get("AI_OSDU_HOST"),
            "partition": os.environ.get("AI_OSDU_DATA_PARTITION"),
        }
        if not output_json:
            console.print(f"  [green]✓[/green] Configured")
            console.print(f"      Host: {os.environ.get('AI_OSDU_HOST')}")
            console.print(f"      Partition: {os.environ.get('AI_OSDU_DATA_PARTITION')}")
    else:
        missing = [v for v in ["AI_OSDU_HOST", "AI_OSDU_DATA_PARTITION", "AI_OSDU_CLIENT", "AI_OSDU_SECRET", "AI_OSDU_TENANT_ID"] if not os.environ.get(v)]
        results["osdu"]["missing"] = missing
        if not output_json:
            console.print(f"  [yellow]⚠[/yellow] Not configured (missing: {', '.join(missing)})")
            console.print("      Preshipping provisioning will be unavailable")

    if output_json:
        print(json.dumps(results, indent=2))


if __name__ == "__main__":
    cli()
