#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "httpx",
#     "click",
#     "rich",
# ]
# ///
"""
OSDU User Roles Report

Lists all users from an OSDU instance
and shows their roles (Viewer, Editor, Admin, Ops).

GUID Resolution:
    This script automatically resolves GUIDs to human-readable names
    via Azure AD lookups. The resolution follows this order:
    1. Check session cache (prevents repeated lookups)
    2. Try Azure AD user lookup
    3. Try Azure AD service principal lookup (prefixed with [SP])
    4. Try Azure AD application lookup (prefixed with [App])

    Use --no-resolve to skip GUID resolution for faster output.
    Azure CLI must be authenticated (az login) for resolution to work.

Usage:
    uv run osdu_users.py [OPTIONS]

Environment variables required:
    AI_OSDU_HOST            - OSDU instance hostname (e.g., osdu.energy.azure.com)
    AI_OSDU_DATA_PARTITION  - Data partition ID (e.g., opendes)
    AI_OSDU_CLIENT          - App registration client ID
    AI_OSDU_SECRET          - App registration client secret
    AI_OSDU_TENANT_ID       - Azure AD tenant ID

Optional:
    AI_OSDU_DOMAIN          - Entitlements domain (default: dataservices.energy)
"""

import json
import os
import platform
import re
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed

import click
import httpx
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)

# Cache for Azure AD lookups
_ad_cache: dict[str, str] = {}

# Windows requires shell=True for az.cmd batch file
_use_shell = platform.system() == "Windows"

# Timeout for individual az ad lookups (reduced from 10s for faster failure)
_az_timeout = 5


def get_config() -> dict:
    """Get configuration from environment variables."""
    return {
        "host": os.environ.get("AI_OSDU_HOST"),
        "partition": os.environ.get("AI_OSDU_DATA_PARTITION"),
        "client_id": os.environ.get("AI_OSDU_CLIENT"),
        "client_secret": os.environ.get("AI_OSDU_SECRET"),
        "tenant_id": os.environ.get("AI_OSDU_TENANT_ID"),
        "domain": os.environ.get("AI_OSDU_DOMAIN", "dataservices.energy"),
    }


def validate_config(config: dict) -> list[str]:
    """Validate required configuration. Returns list of missing vars."""
    missing = []
    if not config.get("host"):
        missing.append("AI_OSDU_HOST")
    if not config.get("partition"):
        missing.append("AI_OSDU_DATA_PARTITION")
    if not config.get("client_id"):
        missing.append("AI_OSDU_CLIENT")
    if not config.get("client_secret"):
        missing.append("AI_OSDU_SECRET")
    if not config.get("tenant_id"):
        missing.append("AI_OSDU_TENANT_ID")
    return missing


def get_role_groups(partition: str, domain: str) -> dict[str, str]:
    """Get role group email addresses."""
    return {
        "Viewer": f"users.datalake.viewers@{partition}.{domain}",
        "Editor": f"users.datalake.editors@{partition}.{domain}",
        "Admin": f"users.datalake.admins@{partition}.{domain}",
        "Ops": f"users.datalake.ops@{partition}.{domain}",
    }


def get_root_group(partition: str, domain: str) -> str:
    """Get root group email address."""
    return f"users.data.root@{partition}.{domain}"


def is_guid(value: str) -> bool:
    """Check if a string is a GUID."""
    guid_pattern = re.compile(
        r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
    )
    return bool(guid_pattern.match(value))


def resolve_guid(guid: str) -> str | None:
    """Resolve a GUID to a display name using az ad commands.

    On Windows, uses shell=True because 'az' is a batch file (az.cmd).
    """
    if guid in _ad_cache:
        return _ad_cache[guid]

    display_name = None

    # Try user first
    try:
        result = subprocess.run(
            ["az", "ad", "user", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=_az_timeout,
            shell=_use_shell,
        )
        if result.returncode == 0 and result.stdout.strip():
            display_name = result.stdout.strip()
    except (subprocess.TimeoutExpired, FileNotFoundError):
        pass

    # Try service principal
    if not display_name:
        try:
            result = subprocess.run(
                ["az", "ad", "sp", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
                capture_output=True,
                text=True,
                timeout=_az_timeout,
                shell=_use_shell,
            )
            if result.returncode == 0 and result.stdout.strip():
                display_name = f"[SP] {result.stdout.strip()}"
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass

    # Try application
    if not display_name:
        try:
            result = subprocess.run(
                ["az", "ad", "app", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
                capture_output=True,
                text=True,
                timeout=_az_timeout,
                shell=_use_shell,
            )
            if result.returncode == 0 and result.stdout.strip():
                display_name = f"[App] {result.stdout.strip()}"
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass

    _ad_cache[guid] = display_name
    return display_name


def resolve_guids_parallel(guids: list[str], max_workers: int = 10) -> dict[str, str | None]:
    """Resolve multiple GUIDs in parallel for faster lookups.

    Args:
        guids: List of GUIDs to resolve
        max_workers: Maximum concurrent lookups (default 10)

    Returns:
        Dict mapping GUID to resolved name (or None if not found)
    """
    results: dict[str, str | None] = {}

    # Filter to only GUIDs we haven't cached
    guids_to_resolve = [g for g in guids if g not in _ad_cache]

    # Return cached results for already-resolved GUIDs
    for guid in guids:
        if guid in _ad_cache:
            results[guid] = _ad_cache[guid]

    if not guids_to_resolve:
        return results

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_guid = {executor.submit(resolve_guid, guid): guid for guid in guids_to_resolve}

        for future in as_completed(future_to_guid):
            guid = future_to_guid[future]
            try:
                results[guid] = future.result()
            except Exception:
                results[guid] = None

    return results


def get_access_token(config: dict) -> str:
    """Get access token using client credentials flow (v1 endpoint)."""
    token_url = f"https://login.microsoftonline.com/{config['tenant_id']}/oauth2/token"

    data = {
        "grant_type": "client_credentials",
        "client_id": config["client_id"],
        "client_secret": config["client_secret"],
        "resource": config["client_id"],
    }

    with httpx.Client() as client:
        response = client.post(token_url, data=data)
        response.raise_for_status()
        return response.json()["access_token"]


def get_group_members(config: dict, token: str, group_email: str) -> list[dict]:
    """Get all members of a group."""
    url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members"

    headers = {
        "Authorization": f"Bearer {token}",
        "Accept": "application/json",
        "data-partition-id": config["partition"],
    }

    with httpx.Client() as client:
        response = client.get(url, headers=headers)
        if response.status_code == 404:
            return []
        response.raise_for_status()
        return response.json().get("members", [])


@click.command()
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
@click.option("--role", type=click.Choice(["Viewer", "Editor", "Admin", "Ops"]), help="Filter to specific role")
@click.option("--user", help="Show roles for specific user (email or GUID)")
@click.option("--no-resolve", is_flag=True, help="Skip GUID resolution (faster)")
def main(output_json: bool, role: str | None, user: str | None, no_resolve: bool):
    """List OSDU users and their roles."""
    config = get_config()
    missing = validate_config(config)

    if missing:
        console.print(f"[red]Missing required environment variables: {', '.join(missing)}[/red]")
        sys.exit(1)

    role_groups = get_role_groups(config["partition"], config["domain"])
    root_group = get_root_group(config["partition"], config["domain"])

    if not output_json:
        console.print(f"\n[bold blue]OSDU User Roles Report[/bold blue]")
        console.print(f"Host: {config['host']}")
        console.print(f"Partition: {config['partition']}\n")

    # Get access token
    try:
        if not output_json:
            with console.status("[bold green]Authenticating..."):
                token = get_access_token(config)
                console.print("[green]✓[/green] Authenticated successfully\n")
        else:
            token = get_access_token(config)
    except httpx.HTTPStatusError as e:
        if output_json:
            print(json.dumps({"success": False, "error": "auth_failed", "message": str(e)}))
        else:
            console.print(f"[red]Authentication failed: {e}[/red]")
        sys.exit(1)

    # Filter groups if role specified
    groups_to_check = {role: role_groups[role]} if role else role_groups

    # Get members for each role group
    role_members: dict[str, dict[str, str]] = {r: {} for r in groups_to_check}
    root_users: set[str] = set()
    all_users: set[str] = set()

    if not output_json:
        with console.status("[bold green]Fetching role memberships..."):
            # Fetch root users
            try:
                members = get_group_members(config, token, root_group)
                for member in members:
                    email = member.get("email", "")
                    if email:
                        root_users.add(email)
                        all_users.add(email)
                console.print(f"[magenta]✓[/magenta] Root: {len(members)} members")
            except httpx.HTTPStatusError as e:
                console.print(f"[yellow]![/yellow] Root: Failed - {e.response.status_code}")

            # Fetch role groups
            for role_name, group_email in groups_to_check.items():
                try:
                    members = get_group_members(config, token, group_email)
                    for member in members:
                        email = member.get("email", "")
                        member_role = member.get("role", "MEMBER")
                        if email:
                            role_members[role_name][email] = member_role
                            all_users.add(email)
                    console.print(f"[green]✓[/green] {role_name}: {len(members)} members")
                except httpx.HTTPStatusError as e:
                    console.print(f"[yellow]![/yellow] {role_name}: Failed - {e.response.status_code}")
    else:
        # JSON mode - silent fetching
        try:
            members = get_group_members(config, token, root_group)
            for member in members:
                email = member.get("email", "")
                if email:
                    root_users.add(email)
                    all_users.add(email)
        except httpx.HTTPStatusError:
            pass

        for role_name, group_email in groups_to_check.items():
            try:
                members = get_group_members(config, token, group_email)
                for member in members:
                    email = member.get("email", "")
                    member_role = member.get("role", "MEMBER")
                    if email:
                        role_members[role_name][email] = member_role
                        all_users.add(email)
            except httpx.HTTPStatusError:
                pass

    # Filter to specific user if requested
    if user:
        all_users = {u for u in all_users if user.lower() in u.lower()}
        if not all_users:
            if output_json:
                print(json.dumps({"success": True, "users": [], "message": f"User '{user}' not found"}))
            else:
                console.print(f"\n[yellow]User '{user}' not found in any role groups.[/yellow]")
            return

    if not all_users:
        if output_json:
            print(json.dumps({"success": True, "users": []}))
        else:
            console.print("\n[yellow]No users found in any role groups.[/yellow]")
        return

    # Resolve GUIDs to display names (in parallel for speed)
    resolved_names: dict[str, str | None] = {}
    if not no_resolve:
        guids_to_resolve = [email for email in all_users if is_guid(email)]
        non_guids = [email for email in all_users if not is_guid(email)]

        # Non-GUIDs don't need resolution
        for email in non_guids:
            resolved_names[email] = None

        if guids_to_resolve:
            if not output_json:
                console.print("")
                console.print(f"[dim]Resolving {len(guids_to_resolve)} identities via Azure AD (parallel)...[/dim]")
            resolved_names.update(resolve_guids_parallel(guids_to_resolve, max_workers=10))
        else:
            if not output_json:
                console.print("")

    # Output results
    if output_json:
        users_data = []
        for email in sorted(all_users):
            user_data = {
                "email": email,
                "display_name": resolved_names.get(email),
                "is_root": email in root_users,
                "roles": {},
            }
            for role_name in groups_to_check:
                if email in role_members[role_name]:
                    user_data["roles"][role_name] = role_members[role_name][email]
            users_data.append(user_data)

        result = {
            "success": True,
            "host": config["host"],
            "partition": config["partition"],
            "total_users": len(users_data),
            "users": users_data,
        }
        print(json.dumps(result, indent=2))
    else:
        # Rich table output
        console.print("\n")
        table = Table(title="[bold]User Roles[/bold]", show_header=True, header_style="bold cyan")
        table.add_column("User", style="bold")
        table.add_column("Identity", style="dim")
        table.add_column("Root", justify="center")

        for role_name in groups_to_check:
            table.add_column(role_name, justify="center")

        # Sort users: root users first, then by resolved name or email
        sorted_users = sorted(
            all_users,
            key=lambda e: (0 if e in root_users else 1, (resolved_names.get(e) or e).lower())
        )

        for email in sorted_users:
            display_name = resolved_names.get(email)
            root = "[bold magenta]★[/bold magenta]" if email in root_users else ""

            def role_symbol(role_name: str, colors: dict) -> str:
                if email not in role_members.get(role_name, {}):
                    return ""
                color = colors.get(role_name, "white")
                if role_members[role_name][email] == "OWNER":
                    return f"[bold {color}]★[/bold {color}]"
                return f"[{color}]✓[/{color}]"

            colors = {"Viewer": "green", "Editor": "blue", "Admin": "yellow", "Ops": "red"}
            row = [display_name or email, email if display_name else "", root]
            for role_name in groups_to_check:
                row.append(role_symbol(role_name, colors))

            table.add_row(*row)

        console.print(table)
        console.print("[dim]★ = Owner, ✓ = Member[/dim]")

        # Summary
        console.print(f"\n[bold]Summary:[/bold]")
        console.print(f"  Total users: {len(all_users)}")
        console.print(f"  Root: {len(root_users)}")
        for role_name in groups_to_check:
            # Pluralize correctly (Ops stays Ops, others add 's')
            plural = role_name if role_name.endswith("s") else f"{role_name}s"
            console.print(f"  {plural}: {len(role_members[role_name])}")


if __name__ == "__main__":
    main()
