#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "httpx",
#     "click",
#     "rich",
# ]
# ///
"""
OSDU Preshipping Provisioning

Bulk provision users and service principals for OSDU preshipping environments.
Adds entities to all required entitlement groups defined in the configuration.

Usage:
    uv run preshipping.py add --user EMAIL
    uv run preshipping.py add --oid OID [--oid OID2 ...]
    uv run preshipping.py add --user EMAIL --dry-run
    uv run preshipping.py list-groups
    uv run preshipping.py check

Environment variables required:
    AI_OSDU_HOST            - OSDU instance hostname
    AI_OSDU_DATA_PARTITION  - Data partition ID
    AI_OSDU_CLIENT          - App registration client ID
    AI_OSDU_SECRET          - App registration secret
    AI_OSDU_TENANT_ID       - Azure AD tenant ID

Optional:
    AI_OSDU_DOMAIN          - Override domain from config
"""

import json
import os
import subprocess
import sys
from pathlib import Path

import click
import httpx
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)


def resolve_email_to_guid(email: str) -> str | None:
    """Resolve a user email to their Azure AD GUID.

    Args:
        email: User's email address

    Returns:
        Azure AD GUID or None if not found
    """
    try:
        # First try direct lookup by ID (works for member users)
        result = subprocess.run(
            ["az", "ad", "user", "show", "--id", email, "--query", "id", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            return result.stdout.strip()

        # Fallback: search by mail property (works for guest users)
        result = subprocess.run(
            ["az", "ad", "user", "list", "--filter", f"mail eq '{email}'", "--query", "[0].id", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            return result.stdout.strip()

        return None
    except (subprocess.TimeoutExpired, FileNotFoundError):
        return None

# Config file location (relative to this script)
CONFIG_PATH = Path(__file__).parent.parent / "config" / "groups.json"


def load_config() -> dict:
    """Load preshipping groups configuration."""
    if not CONFIG_PATH.exists():
        console.print(f"[red]Config not found: {CONFIG_PATH}[/red]")
        sys.exit(1)

    with open(CONFIG_PATH) as f:
        return json.load(f)


def get_env_config() -> dict:
    """Get configuration from environment variables."""
    return {
        "host": os.environ.get("AI_OSDU_HOST"),
        "partition": os.environ.get("AI_OSDU_DATA_PARTITION"),
        "client_id": os.environ.get("AI_OSDU_CLIENT"),
        "client_secret": os.environ.get("AI_OSDU_SECRET"),
        "tenant_id": os.environ.get("AI_OSDU_TENANT_ID"),
        "domain_override": os.environ.get("AI_OSDU_DOMAIN"),
    }


def validate_env_config(config: dict) -> list[str]:
    """Validate required environment configuration. Returns list of missing vars."""
    missing = []
    if not config.get("host"):
        missing.append("AI_OSDU_HOST")
    if not config.get("partition"):
        missing.append("AI_OSDU_DATA_PARTITION")
    if not config.get("client_id"):
        missing.append("AI_OSDU_CLIENT")
    if not config.get("client_secret"):
        missing.append("AI_OSDU_SECRET")
    if not config.get("tenant_id"):
        missing.append("AI_OSDU_TENANT_ID")
    return missing


def get_access_token(config: dict) -> str:
    """Get access token using client credentials flow (v1 endpoint)."""
    token_url = f"https://login.microsoftonline.com/{config['tenant_id']}/oauth2/token"

    data = {
        "grant_type": "client_credentials",
        "client_id": config["client_id"],
        "client_secret": config["client_secret"],
        "resource": config["client_id"],
    }

    with httpx.Client() as client:
        response = client.post(token_url, data=data)
        response.raise_for_status()
        return response.json()["access_token"]


def get_group_email(group_name: str, partition: str, domain: str) -> str:
    """Build full group email from name, partition, and domain."""
    return f"{group_name}@{partition}.{domain}"


def add_member_to_group(
    env_config: dict, token: str, group_email: str, member_email: str, member_role: str = "OWNER"
) -> dict:
    """Add a member to an entitlement group."""
    url = f"https://{env_config['host']}/api/entitlements/v2/groups/{group_email}/members"

    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
        "data-partition-id": env_config["partition"],
    }

    body = {
        "email": member_email,
        "role": member_role,
    }

    with httpx.Client() as client:
        response = client.post(url, headers=headers, json=body)
        if response.status_code == 409:
            return {"success": True, "skipped": True, "message": "Already in group"}
        if response.status_code == 404:
            return {"success": False, "error": "not_found", "message": "Group not found"}
        response.raise_for_status()
        return {"success": True, "skipped": False, "message": "Added to group"}


def remove_member_from_group(
    env_config: dict, token: str, group_email: str, member_email: str
) -> dict:
    """Remove a member from an entitlement group."""
    url = f"https://{env_config['host']}/api/entitlements/v2/groups/{group_email}/members/{member_email}"

    headers = {
        "Authorization": f"Bearer {token}",
        "data-partition-id": env_config["partition"],
    }

    with httpx.Client() as client:
        response = client.delete(url, headers=headers)
        if response.status_code == 404:
            return {"success": True, "skipped": True, "message": "Not in group"}
        response.raise_for_status()
        return {"success": True, "skipped": False, "message": "Removed from group"}


@click.group()
def cli():
    """OSDU Preshipping - Bulk provision users for preshipping environments."""
    pass


@cli.command()
@click.option("--user", "users", multiple=True, help="User email address (can specify multiple)")
@click.option("--oid", "oids", multiple=True, help="Service principal OID (can specify multiple)")
@click.option("--level", type=click.Choice(["viewer", "admin"]), default="viewer", help="Access level: viewer (read-only) or admin (Admin/Ops)")
@click.option("--dry-run", is_flag=True, help="Preview without making changes")
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def add(users: tuple, oids: tuple, level: str, dry_run: bool, output_json: bool):
    """Add users or OIDs to preshipping entitlement groups."""
    if not users and not oids:
        if output_json:
            print(json.dumps({"success": False, "error": "invalid_args", "message": "Specify --user or --oid"}))
        else:
            console.print("[red]Specify at least one --user EMAIL or --oid OID[/red]")
        sys.exit(1)

    # Load configurations
    groups_config = load_config()
    env_config = get_env_config()
    missing = validate_env_config(env_config)

    if missing:
        result = {"success": False, "error": "missing_config", "message": f"Missing: {', '.join(missing)}"}
        if output_json:
            print(json.dumps(result))
        else:
            console.print(f"[red]{result['message']}[/red]")
        sys.exit(1)

    # Determine domain (env override > config > default)
    domain = env_config.get("domain_override") or groups_config.get("domain", "contoso.com")
    partition = env_config["partition"]

    # Get level-specific configuration
    levels = groups_config.get("levels", {})
    level_config = levels.get(level)
    if not level_config:
        # Fallback to old format for backwards compatibility
        default_role = groups_config.get("default_role", "OWNER")
        groups = groups_config.get("groups", [])
    else:
        default_role = level_config.get("default_role", "MEMBER")
        groups = level_config.get("groups", [])

    # Resolve user emails to Azure AD GUIDs
    resolved_users = []
    failed_resolutions = []
    if users:
        if not output_json:
            console.print("\n[bold]Resolving user emails to Azure AD GUIDs...[/bold]")
        for email in users:
            guid = resolve_email_to_guid(email)
            if guid:
                resolved_users.append(guid)
                if not output_json:
                    console.print(f"  [green]✓[/green] {email} → {guid}")
            else:
                failed_resolutions.append(email)
                if not output_json:
                    console.print(f"  [red]✗[/red] {email} → Not found in Azure AD")

    if failed_resolutions and not output_json:
        console.print(f"\n[yellow]Warning: {len(failed_resolutions)} user(s) could not be resolved[/yellow]")

    # Combine resolved users and OIDs into entities to provision
    entities = resolved_users + list(oids)

    if not entities:
        result = {"success": False, "error": "no_entities", "message": "No valid entities to provision"}
        if output_json:
            print(json.dumps(result))
        else:
            console.print("[red]No valid entities to provision[/red]")
        sys.exit(1)

    if not output_json:
        console.print(f"\n[bold blue]OSDU Preshipping Provisioning[/bold blue]")
        console.print(f"Level: {level} ({level_config.get('description', '')})" if level_config else f"Level: legacy")
        console.print(f"Partition: {partition}")
        console.print(f"Domain: {domain}")
        console.print(f"Entities: {len(entities)}")
        console.print(f"Groups: {len(groups)}")
        console.print(f"Role: {default_role}")
        if dry_run:
            console.print("[yellow]DRY RUN - No changes will be made[/yellow]")
        console.print()

    if dry_run:
        # Show what would be done
        if output_json:
            result = {
                "success": True,
                "dry_run": True,
                "entities": entities,
                "groups": [get_group_email(g["name"], partition, domain) for g in groups],
                "role": default_role,
                "total_operations": len(entities) * len(groups),
            }
            print(json.dumps(result))
        else:
            console.print("[bold]Would add these entities:[/bold]")
            for entity in entities:
                console.print(f"  - {entity}")
            console.print(f"\n[bold]To {len(groups)} groups as {default_role}[/bold]")
            console.print(f"\nTotal operations: {len(entities) * len(groups)}")
        return

    # Authenticate
    try:
        if not output_json:
            with console.status("[bold green]Authenticating..."):
                token = get_access_token(env_config)
        else:
            token = get_access_token(env_config)
    except httpx.HTTPStatusError as e:
        result = {"success": False, "error": "auth_failed", "message": str(e)}
        if output_json:
            print(json.dumps(result))
        else:
            console.print(f"[red]Authentication failed: {e}[/red]")
        sys.exit(1)

    # Provision each entity to all groups
    all_results = []
    for entity in entities:
        entity_results = {"entity": entity, "groups": []}

        if not output_json:
            console.print(f"\n[bold]Provisioning: {entity}[/bold]")

        for group_def in groups:
            group_email = get_group_email(group_def["name"], partition, domain)

            try:
                result = add_member_to_group(env_config, token, group_email, entity, default_role)
                result["group"] = group_email
                entity_results["groups"].append(result)

                if not output_json:
                    if result.get("skipped"):
                        console.print(f"  [dim]-[/dim] {group_def['name']}: Already member")
                    elif result["success"]:
                        console.print(f"  [green]✓[/green] {group_def['name']}")
                    else:
                        console.print(f"  [red]✗[/red] {group_def['name']}: {result.get('message')}")

            except httpx.HTTPStatusError as e:
                result = {
                    "group": group_email,
                    "success": False,
                    "error": "api_error",
                    "status_code": e.response.status_code,
                    "message": str(e),
                }
                entity_results["groups"].append(result)

                if not output_json:
                    console.print(f"  [red]✗[/red] {group_def['name']}: HTTP {e.response.status_code}")

        all_results.append(entity_results)

    # Summary
    if output_json:
        total_success = sum(
            1 for er in all_results for gr in er["groups"] if gr.get("success") and not gr.get("skipped")
        )
        total_skipped = sum(1 for er in all_results for gr in er["groups"] if gr.get("skipped"))
        total_failed = sum(1 for er in all_results for gr in er["groups"] if not gr.get("success"))

        print(
            json.dumps(
                {
                    "success": True,
                    "results": all_results,
                    "summary": {
                        "entities": len(entities),
                        "groups": len(groups),
                        "added": total_success,
                        "skipped": total_skipped,
                        "failed": total_failed,
                    },
                }
            )
        )
    else:
        total_success = sum(
            1 for er in all_results for gr in er["groups"] if gr.get("success") and not gr.get("skipped")
        )
        total_skipped = sum(1 for er in all_results for gr in er["groups"] if gr.get("skipped"))
        total_failed = sum(1 for er in all_results for gr in er["groups"] if not gr.get("success"))

        console.print(f"\n[bold]Summary:[/bold]")
        console.print(f"  Added: {total_success}")
        console.print(f"  Already member: {total_skipped}")
        if total_failed > 0:
            console.print(f"  [red]Failed: {total_failed}[/red]")


@cli.command()
@click.option("--user", "users", multiple=True, help="User email address (can specify multiple)")
@click.option("--oid", "oids", multiple=True, help="Service principal OID (can specify multiple)")
@click.option("--dry-run", is_flag=True, help="Preview without making changes")
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def remove(users: tuple, oids: tuple, dry_run: bool, output_json: bool):
    """Remove users or OIDs from all preshipping entitlement groups."""
    if not users and not oids:
        if output_json:
            print(json.dumps({"success": False, "error": "invalid_args", "message": "Specify --user or --oid"}))
        else:
            console.print("[red]Specify at least one --user EMAIL or --oid OID[/red]")
        sys.exit(1)

    # Load configurations
    groups_config = load_config()
    env_config = get_env_config()
    missing = validate_env_config(env_config)

    if missing:
        result = {"success": False, "error": "missing_config", "message": f"Missing: {', '.join(missing)}"}
        if output_json:
            print(json.dumps(result))
        else:
            console.print(f"[red]{result['message']}[/red]")
        sys.exit(1)

    # Determine domain (env override > config > default)
    domain = env_config.get("domain_override") or groups_config.get("domain", "contoso.com")
    partition = env_config["partition"]
    groups = groups_config.get("groups", [])

    # Combine users and OIDs into entities to remove
    entities = list(users) + list(oids)

    if not output_json:
        console.print(f"\n[bold red]OSDU Preshipping Removal[/bold red]")
        console.print(f"Partition: {partition}")
        console.print(f"Domain: {domain}")
        console.print(f"Entities: {len(entities)}")
        console.print(f"Groups: {len(groups)}")
        if dry_run:
            console.print("[yellow]DRY RUN - No changes will be made[/yellow]")
        console.print()

    if dry_run:
        # Show what would be done
        if output_json:
            result = {
                "success": True,
                "dry_run": True,
                "entities": entities,
                "groups": [get_group_email(g["name"], partition, domain) for g in groups],
                "total_operations": len(entities) * len(groups),
            }
            print(json.dumps(result))
        else:
            console.print("[bold]Would remove these entities:[/bold]")
            for entity in entities:
                console.print(f"  - {entity}")
            console.print(f"\n[bold]From {len(groups)} groups[/bold]")
            console.print(f"\nTotal operations: {len(entities) * len(groups)}")
        return

    # Authenticate
    try:
        if not output_json:
            with console.status("[bold green]Authenticating..."):
                token = get_access_token(env_config)
        else:
            token = get_access_token(env_config)
    except httpx.HTTPStatusError as e:
        result = {"success": False, "error": "auth_failed", "message": str(e)}
        if output_json:
            print(json.dumps(result))
        else:
            console.print(f"[red]Authentication failed: {e}[/red]")
        sys.exit(1)

    # Remove each entity from all groups
    all_results = []
    for entity in entities:
        entity_results = {"entity": entity, "groups": []}

        if not output_json:
            console.print(f"\n[bold]Removing: {entity}[/bold]")

        for group_def in groups:
            group_email = get_group_email(group_def["name"], partition, domain)

            try:
                result = remove_member_from_group(env_config, token, group_email, entity)
                result["group"] = group_email
                entity_results["groups"].append(result)

                if not output_json:
                    if result.get("skipped"):
                        console.print(f"  [dim]-[/dim] {group_def['name']}: Not a member")
                    elif result["success"]:
                        console.print(f"  [green]✓[/green] {group_def['name']}")
                    else:
                        console.print(f"  [red]✗[/red] {group_def['name']}: {result.get('message')}")

            except httpx.HTTPStatusError as e:
                result = {
                    "group": group_email,
                    "success": False,
                    "error": "api_error",
                    "status_code": e.response.status_code,
                    "message": str(e),
                }
                entity_results["groups"].append(result)

                if not output_json:
                    console.print(f"  [red]✗[/red] {group_def['name']}: HTTP {e.response.status_code}")

        all_results.append(entity_results)

    # Summary
    if output_json:
        total_success = sum(
            1 for er in all_results for gr in er["groups"] if gr.get("success") and not gr.get("skipped")
        )
        total_skipped = sum(1 for er in all_results for gr in er["groups"] if gr.get("skipped"))
        total_failed = sum(1 for er in all_results for gr in er["groups"] if not gr.get("success"))

        print(
            json.dumps(
                {
                    "success": True,
                    "results": all_results,
                    "summary": {
                        "entities": len(entities),
                        "groups": len(groups),
                        "removed": total_success,
                        "skipped": total_skipped,
                        "failed": total_failed,
                    },
                }
            )
        )
    else:
        total_success = sum(
            1 for er in all_results for gr in er["groups"] if gr.get("success") and not gr.get("skipped")
        )
        total_skipped = sum(1 for er in all_results for gr in er["groups"] if gr.get("skipped"))
        total_failed = sum(1 for er in all_results for gr in er["groups"] if not gr.get("success"))

        console.print(f"\n[bold]Summary:[/bold]")
        console.print(f"  Removed: {total_success}")
        console.print(f"  Not a member: {total_skipped}")
        if total_failed > 0:
            console.print(f"  [red]Failed: {total_failed}[/red]")


@cli.command("list-groups")
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def list_groups(output_json: bool):
    """List all preshipping entitlement groups."""
    groups_config = load_config()
    env_config = get_env_config()

    domain = env_config.get("domain_override") or groups_config.get("domain", "contoso.com")
    partition = env_config.get("partition", "{partition}")
    groups = groups_config.get("groups", [])

    if output_json:
        result = {
            "success": True,
            "domain": domain,
            "partition": partition,
            "default_role": groups_config.get("default_role", "OWNER"),
            "groups": [
                {"name": g["name"], "email": get_group_email(g["name"], partition, domain), "description": g.get("description", "")}
                for g in groups
            ],
            "count": len(groups),
        }
        print(json.dumps(result))
    else:
        console.print(f"\n[bold blue]Preshipping Groups[/bold blue]")
        console.print(f"Domain: {domain}")
        console.print(f"Partition: {partition}")
        console.print(f"Default Role: {groups_config.get('default_role', 'OWNER')}")
        console.print()

        table = Table(show_header=True, header_style="bold")
        table.add_column("Group Name", style="cyan")
        table.add_column("Description")

        for group in groups:
            table.add_row(group["name"], group.get("description", ""))

        console.print(table)
        console.print(f"\n[dim]Total: {len(groups)} groups[/dim]")


@cli.command()
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
def check(output_json: bool):
    """Check configuration and connectivity."""
    groups_config = load_config()
    env_config = get_env_config()
    missing = validate_env_config(env_config)

    if missing:
        result = {"success": False, "configured": False, "missing": missing}
        if output_json:
            print(json.dumps(result))
        else:
            console.print("[red]Missing environment variables:[/red]")
            for var in missing:
                console.print(f"  - {var}")
        sys.exit(1)

    # Try to authenticate
    try:
        if not output_json:
            with console.status("[bold green]Testing authentication..."):
                token = get_access_token(env_config)
        else:
            token = get_access_token(env_config)

        domain = env_config.get("domain_override") or groups_config.get("domain", "contoso.com")

        result = {
            "success": True,
            "configured": True,
            "host": env_config["host"],
            "partition": env_config["partition"],
            "domain": domain,
            "groups_count": len(groups_config.get("groups", [])),
            "authenticated": True,
        }

        if output_json:
            print(json.dumps(result))
        else:
            console.print("[green]✓[/green] Configuration valid")
            console.print(f"  Host: {env_config['host']}")
            console.print(f"  Partition: {env_config['partition']}")
            console.print(f"  Domain: {domain}")
            console.print(f"  Groups: {len(groups_config.get('groups', []))}")
            console.print("[green]✓[/green] Authentication successful")

    except httpx.HTTPStatusError as e:
        result = {
            "success": False,
            "configured": True,
            "authenticated": False,
            "error": str(e),
        }
        if output_json:
            print(json.dumps(result))
        else:
            console.print("[green]✓[/green] Configuration present")
            console.print(f"[red]✗[/red] Authentication failed: {e}")
        sys.exit(1)


if __name__ == "__main__":
    cli()
