"""
OSDU Common Utilities

Shared functionality for OSDU scripts including configuration management,
authentication, and GUID resolution.

This module consolidates duplicate code from osdu_users.py, osdu_manage.py,
and preshipping.py to ensure consistency and reduce maintenance burden.
"""

import json
import os
import platform
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import httpx

# Cache for Azure AD lookups (session-scoped)
_ad_cache: dict[str, str | None] = {}

# Windows requires shell=True for az.cmd batch file
_use_shell = platform.system() == "Windows"

# Timeout for individual az ad lookups (reduced for faster failure)
_az_timeout = 5


# =============================================================================
# Configuration
# =============================================================================

def get_config() -> dict:
    """Get OSDU configuration from environment variables.

    Returns:
        dict with keys: host, partition, client_id, client_secret, tenant_id, domain
    """
    return {
        "host": os.environ.get("AI_OSDU_HOST"),
        "partition": os.environ.get("AI_OSDU_DATA_PARTITION"),
        "client_id": os.environ.get("AI_OSDU_CLIENT"),
        "client_secret": os.environ.get("AI_OSDU_SECRET"),
        "tenant_id": os.environ.get("AI_OSDU_TENANT_ID"),
        "domain": os.environ.get("AI_OSDU_DOMAIN", "dataservices.energy"),
    }


def validate_config(config: dict) -> list[str]:
    """Validate required configuration. Returns list of missing vars."""
    missing = []
    required = ["host", "partition", "client_id", "client_secret", "tenant_id"]
    env_names = {
        "host": "AI_OSDU_HOST",
        "partition": "AI_OSDU_DATA_PARTITION",
        "client_id": "AI_OSDU_CLIENT",
        "client_secret": "AI_OSDU_SECRET",
        "tenant_id": "AI_OSDU_TENANT_ID",
    }
    for key in required:
        if not config.get(key):
            missing.append(env_names[key])
    return missing


# =============================================================================
# Authentication
# =============================================================================

def get_access_token(config: dict, timeout: float = 30.0) -> str:
    """Get access token using client credentials flow (v1 endpoint).

    Args:
        config: Configuration dict from get_config()
        timeout: HTTP request timeout in seconds

    Returns:
        Access token string

    Raises:
        httpx.HTTPStatusError: If authentication fails
    """
    token_url = f"https://login.microsoftonline.com/{config['tenant_id']}/oauth2/token"

    data = {
        "grant_type": "client_credentials",
        "client_id": config["client_id"],
        "client_secret": config["client_secret"],
        "resource": config["client_id"],
    }

    with httpx.Client(timeout=timeout) as client:
        response = client.post(token_url, data=data)
        response.raise_for_status()
        return response.json()["access_token"]


# =============================================================================
# Role Groups
# =============================================================================

def get_role_groups(partition: str, domain: str) -> dict[str, str]:
    """Get role group email addresses.

    Args:
        partition: Data partition ID
        domain: Entitlements domain

    Returns:
        Dict mapping role name to group email
    """
    return {
        "Viewer": f"users.datalake.viewers@{partition}.{domain}",
        "Editor": f"users.datalake.editors@{partition}.{domain}",
        "Admin": f"users.datalake.admins@{partition}.{domain}",
        "Ops": f"users.datalake.ops@{partition}.{domain}",
    }


def get_root_group(partition: str, domain: str) -> str:
    """Get root group email address."""
    return f"users.data.root@{partition}.{domain}"


# =============================================================================
# GUID Resolution
# =============================================================================

def is_guid(value: str) -> bool:
    """Check if a string is a GUID."""
    guid_pattern = re.compile(
        r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
    )
    return bool(guid_pattern.match(value))


def resolve_guid(guid: str) -> str | None:
    """Resolve a GUID to a display name using az ad commands.

    Tries in order: user, service principal, application.
    Results are cached for the session.

    Args:
        guid: Azure AD object ID (GUID)

    Returns:
        Display name if found, None otherwise.
        Service principals are prefixed with [SP].
        Applications are prefixed with [App].
    """
    if guid in _ad_cache:
        return _ad_cache[guid]

    display_name = None

    # Try user first
    try:
        result = subprocess.run(
            ["az", "ad", "user", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
            capture_output=True,
            text=True,
            timeout=_az_timeout,
            shell=_use_shell,
        )
        if result.returncode == 0 and result.stdout.strip():
            display_name = result.stdout.strip()
    except (subprocess.TimeoutExpired, FileNotFoundError):
        pass

    # Try service principal
    if not display_name:
        try:
            result = subprocess.run(
                ["az", "ad", "sp", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
                capture_output=True,
                text=True,
                timeout=_az_timeout,
                shell=_use_shell,
            )
            if result.returncode == 0 and result.stdout.strip():
                display_name = f"[SP] {result.stdout.strip()}"
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass

    # Try application
    if not display_name:
        try:
            result = subprocess.run(
                ["az", "ad", "app", "show", "--id", guid, "--query", "displayName", "-o", "tsv"],
                capture_output=True,
                text=True,
                timeout=_az_timeout,
                shell=_use_shell,
            )
            if result.returncode == 0 and result.stdout.strip():
                display_name = f"[App] {result.stdout.strip()}"
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass

    _ad_cache[guid] = display_name
    return display_name


def resolve_guids_parallel(guids: list[str], max_workers: int = 10) -> dict[str, str | None]:
    """Resolve multiple GUIDs in parallel for faster lookups.

    Args:
        guids: List of GUIDs to resolve
        max_workers: Maximum concurrent lookups (default 10)

    Returns:
        Dict mapping GUID to resolved name (or None if not found)
    """
    results: dict[str, str | None] = {}

    # Filter to only GUIDs we haven't cached
    guids_to_resolve = [g for g in guids if g not in _ad_cache]

    # Return cached results for already-resolved GUIDs
    for guid in guids:
        if guid in _ad_cache:
            results[guid] = _ad_cache[guid]

    if not guids_to_resolve:
        return results

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_guid = {executor.submit(resolve_guid, guid): guid for guid in guids_to_resolve}

        for future in as_completed(future_to_guid):
            guid = future_to_guid[future]
            try:
                results[guid] = future.result()
            except Exception:
                results[guid] = None

    return results


def clear_cache():
    """Clear the GUID resolution cache."""
    _ad_cache.clear()


# =============================================================================
# API Helpers
# =============================================================================

def get_group_members(config: dict, token: str, group_email: str, timeout: float = 30.0) -> list[dict]:
    """Get all members of an OSDU entitlement group.

    Args:
        config: Configuration dict from get_config()
        token: Access token from get_access_token()
        group_email: Full group email address
        timeout: HTTP request timeout

    Returns:
        List of member dicts with 'email' and 'role' keys
    """
    url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members"

    headers = {
        "Authorization": f"Bearer {token}",
        "Accept": "application/json",
        "data-partition-id": config["partition"],
    }

    with httpx.Client(timeout=timeout) as client:
        response = client.get(url, headers=headers)
        if response.status_code == 404:
            return []
        response.raise_for_status()
        return response.json().get("members", [])


def add_member_to_group(
    config: dict,
    token: str,
    group_email: str,
    user_email: str,
    member_role: str = "MEMBER",
    timeout: float = 30.0
) -> dict:
    """Add a user to an OSDU entitlement group.

    Args:
        config: Configuration dict from get_config()
        token: Access token from get_access_token()
        group_email: Full group email address
        user_email: User email or GUID to add
        member_role: OWNER or MEMBER
        timeout: HTTP request timeout

    Returns:
        Result dict with success, error, and message keys
    """
    url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members"

    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
        "data-partition-id": config["partition"],
    }

    body = {
        "email": user_email,
        "role": member_role,
    }

    with httpx.Client(timeout=timeout) as client:
        response = client.post(url, headers=headers, json=body)
        if response.status_code == 409:
            return {"success": False, "error": "already_exists", "message": "User already in group"}
        response.raise_for_status()
        return {"success": True, "message": f"Added {user_email} to group as {member_role}"}


def remove_member_from_group(
    config: dict,
    token: str,
    group_email: str,
    user_email: str,
    timeout: float = 30.0
) -> dict:
    """Remove a user from an OSDU entitlement group.

    Args:
        config: Configuration dict from get_config()
        token: Access token from get_access_token()
        group_email: Full group email address
        user_email: User email or GUID to remove
        timeout: HTTP request timeout

    Returns:
        Result dict with success, error, and message keys
    """
    url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members/{user_email}"

    headers = {
        "Authorization": f"Bearer {token}",
        "data-partition-id": config["partition"],
    }

    with httpx.Client(timeout=timeout) as client:
        response = client.delete(url, headers=headers)
        if response.status_code == 404:
            return {"success": False, "error": "not_found", "message": "User not in group"}
        response.raise_for_status()
        return {"success": True, "message": f"Removed {user_email} from group"}
