#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "httpx",
#     "click",
#     "rich",
# ]
# ///
"""
Access Audit

Cross-references users across Azure AD tenant and OSDU preshipping environment
to generate consolidated access reports.

Usage:
    # Audit all users from a company
    uv run audit.py --company acme
    uv run audit.py --company contoso --json

    # Audit a single user
    uv run audit.py --user john@example.com
    uv run audit.py --user john@example.com --json

Environment variables required:
    AI_OSDU_HOST            - OSDU instance hostname
    AI_OSDU_DATA_PARTITION  - Data partition ID
    AI_OSDU_CLIENT          - App registration client ID
    AI_OSDU_SECRET          - App registration secret
    AI_OSDU_TENANT_ID       - Azure AD tenant ID
"""

import json
import os
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime

import click
import httpx
from rich.console import Console
from rich.table import Table

console = Console(stderr=True)


# =============================================================================
# Azure AD Functions
# =============================================================================

def get_tenant_users(company: str) -> list[dict]:
    """Get all users from Azure AD tenant matching company domain.

    Args:
        company: Company name or email domain to search for

    Returns:
        List of user dicts with id, name, mail, type, enabled, inviteState
    """
    try:
        # First get all users with basic az ad user list (which supports contains filter via JMESPath)
        # Then enrich with Graph API for the status fields
        query = f"[?mail && contains(mail, '{company}')].{{name:displayName, mail:mail, id:id}}"

        result = subprocess.run(
            ["az", "ad", "user", "list", "--query", query, "-o", "json"],
            capture_output=True,
            text=True,
            timeout=120,
        )

        if result.returncode != 0 or not result.stdout.strip():
            # Try fallback with userPrincipalName
            query = f"[?contains(userPrincipalName, '{company}')].{{name:displayName, mail:userPrincipalName, id:id}}"
            result = subprocess.run(
                ["az", "ad", "user", "list", "--query", query, "-o", "json"],
                capture_output=True,
                text=True,
                timeout=120,
            )

        if result.returncode != 0 or not result.stdout.strip():
            return []

        basic_users = json.loads(result.stdout)
        if not basic_users:
            return []

        # Enrich with Graph API for status fields (batch by user IDs)
        enriched_users = []
        for user in basic_users:
            user_id = user.get("id")
            if not user_id:
                enriched_users.append({
                    "name": user.get("name", "Unknown"),
                    "mail": user.get("mail", ""),
                    "id": user_id,
                    "type": "Unknown",
                    "enabled": True,
                    "inviteState": None,
                })
                continue

            # Get detailed info via Graph API
            graph_url = f"https://graph.microsoft.com/v1.0/users/{user_id}?$select=displayName,mail,id,userType,accountEnabled,externalUserState"
            detail_result = subprocess.run(
                ["az", "rest", "--method", "GET", "--url", graph_url, "-o", "json"],
                capture_output=True,
                text=True,
                timeout=30,
            )
            if detail_result.returncode == 0 and detail_result.stdout.strip():
                u = json.loads(detail_result.stdout)
                enriched_users.append({
                    "name": u.get("displayName", user.get("name", "Unknown")),
                    "mail": u.get("mail", user.get("mail", "")),
                    "id": u.get("id", user_id),
                    "type": u.get("userType", "Unknown"),
                    "enabled": u.get("accountEnabled", True),
                    "inviteState": u.get("externalUserState"),
                })
            else:
                # Fallback if Graph call fails
                enriched_users.append({
                    "name": user.get("name", "Unknown"),
                    "mail": user.get("mail", ""),
                    "id": user_id,
                    "type": "Unknown",
                    "enabled": True,
                    "inviteState": None,
                })

        return enriched_users
    except (subprocess.TimeoutExpired, json.JSONDecodeError, FileNotFoundError) as e:
        console.print(f"[red]Error fetching tenant users: {e}[/red]")
        return []


def get_single_user(email: str) -> dict | None:
    """Get a single user from Azure AD by email.

    Args:
        email: User's email address

    Returns:
        User dict with id, name, mail, type, enabled, inviteState or None if not found
    """
    try:
        # Use MS Graph API to get full user details including account status
        # Filter by mail property (works for both member and guest users)
        graph_url = f"https://graph.microsoft.com/v1.0/users?$filter=mail eq '{email}'&$select=displayName,mail,id,userType,accountEnabled,externalUserState"
        result = subprocess.run(
            ["az", "rest", "--method", "GET", "--url", graph_url, "-o", "json"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            data = json.loads(result.stdout)
            users = data.get("value", [])
            if users:
                u = users[0]
                return {
                    "name": u.get("displayName", "Unknown"),
                    "mail": u.get("mail", email),
                    "id": u.get("id", ""),
                    "type": u.get("userType", "Unknown"),
                    "enabled": u.get("accountEnabled", True),
                    "inviteState": u.get("externalUserState"),
                }

        # Fallback: try by userPrincipalName for member users
        graph_url = f"https://graph.microsoft.com/v1.0/users?$filter=userPrincipalName eq '{email}'&$select=displayName,mail,id,userType,accountEnabled,externalUserState"
        result = subprocess.run(
            ["az", "rest", "--method", "GET", "--url", graph_url, "-o", "json"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            data = json.loads(result.stdout)
            users = data.get("value", [])
            if users:
                u = users[0]
                return {
                    "name": u.get("displayName", "Unknown"),
                    "mail": u.get("mail", email),
                    "id": u.get("id", ""),
                    "type": u.get("userType", "Unknown"),
                    "enabled": u.get("accountEnabled", True),
                    "inviteState": u.get("externalUserState"),
                }
        return None
    except (subprocess.TimeoutExpired, json.JSONDecodeError, FileNotFoundError):
        return None


def get_user_groups(user_id: str) -> list[str]:
    """Get Azure AD group memberships for a user.

    Args:
        user_id: Azure AD user object ID

    Returns:
        List of group display names
    """
    try:
        result = subprocess.run(
            ["az", "rest", "--method", "GET",
             "--url", f"https://graph.microsoft.com/v1.0/users/{user_id}/memberOf",
             "--query", "value[].displayName", "-o", "json"],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if result.returncode == 0 and result.stdout.strip():
            groups = json.loads(result.stdout)
            return [g for g in groups if g]  # Filter out None values
        return []
    except (subprocess.TimeoutExpired, json.JSONDecodeError, FileNotFoundError):
        return []


def get_user_groups_parallel(user_ids: list[str], max_workers: int = 10) -> dict[str, list[str]]:
    """Get Azure AD group memberships for multiple users in parallel.

    Args:
        user_ids: List of Azure AD user object IDs
        max_workers: Maximum concurrent lookups

    Returns:
        Dict mapping user_id to list of group names
    """
    results: dict[str, list[str]] = {}

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_id = {executor.submit(get_user_groups, uid): uid for uid in user_ids}

        for future in as_completed(future_to_id):
            user_id = future_to_id[future]
            try:
                results[user_id] = future.result()
            except Exception:
                results[user_id] = []

    return results


# =============================================================================
# OSDU Functions
# =============================================================================

def get_osdu_config() -> dict:
    """Get OSDU configuration from environment variables."""
    return {
        "host": os.environ.get("AI_OSDU_HOST"),
        "partition": os.environ.get("AI_OSDU_DATA_PARTITION"),
        "client_id": os.environ.get("AI_OSDU_CLIENT"),
        "client_secret": os.environ.get("AI_OSDU_SECRET"),
        "tenant_id": os.environ.get("AI_OSDU_TENANT_ID"),
        "domain": os.environ.get("AI_OSDU_DOMAIN", "dataservices.energy"),
    }


def validate_osdu_config(config: dict) -> list[str]:
    """Validate OSDU configuration. Returns list of missing vars."""
    missing = []
    required = {
        "host": "AI_OSDU_HOST",
        "partition": "AI_OSDU_DATA_PARTITION",
        "client_id": "AI_OSDU_CLIENT",
        "client_secret": "AI_OSDU_SECRET",
        "tenant_id": "AI_OSDU_TENANT_ID",
    }
    for key, env_name in required.items():
        if not config.get(key):
            missing.append(env_name)
    return missing


def get_osdu_token(config: dict) -> str:
    """Get OSDU access token using client credentials flow."""
    token_url = f"https://login.microsoftonline.com/{config['tenant_id']}/oauth2/token"

    data = {
        "grant_type": "client_credentials",
        "client_id": config["client_id"],
        "client_secret": config["client_secret"],
        "resource": config["client_id"],
    }

    with httpx.Client(timeout=30.0) as client:
        response = client.post(token_url, data=data)
        response.raise_for_status()
        return response.json()["access_token"]


def get_osdu_users(config: dict, token: str) -> dict[str, dict]:
    """Get all OSDU users with their roles.

    Returns:
        Dict mapping email/GUID to user info with roles
    """
    partition = config["partition"]
    domain = config["domain"]

    role_groups = {
        "Viewer": f"users.datalake.viewers@{partition}.{domain}",
        "Editor": f"users.datalake.editors@{partition}.{domain}",
        "Admin": f"users.datalake.admins@{partition}.{domain}",
        "Ops": f"users.datalake.ops@{partition}.{domain}",
    }
    root_group = f"users.data.root@{partition}.{domain}"

    headers = {
        "Authorization": f"Bearer {token}",
        "Accept": "application/json",
        "data-partition-id": partition,
    }

    users: dict[str, dict] = {}

    with httpx.Client(timeout=30.0) as client:
        # Fetch each role group
        for role_name, group_email in role_groups.items():
            url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members"
            try:
                response = client.get(url, headers=headers)
                if response.status_code == 200:
                    members = response.json().get("members", [])
                    for member in members:
                        email = member.get("email", "")
                        if email not in users:
                            users[email] = {"email": email, "roles": {}, "is_root": False}
                        users[email]["roles"][role_name] = member.get("role", "MEMBER")
            except httpx.HTTPError:
                pass

        # Check root group
        url = f"https://{config['host']}/api/entitlements/v2/groups/{root_group}/members"
        try:
            response = client.get(url, headers=headers)
            if response.status_code == 200:
                members = response.json().get("members", [])
                for member in members:
                    email = member.get("email", "")
                    if email in users:
                        users[email]["is_root"] = True
        except httpx.HTTPError:
            pass

    return users


def get_single_user_osdu(config: dict, token: str, user_id: str) -> dict | None:
    """Get OSDU roles for a single user by their Azure AD GUID.

    Args:
        config: OSDU configuration
        token: OSDU access token
        user_id: Azure AD user object ID (GUID)

    Returns:
        Dict with roles and is_root, or None if not found
    """
    partition = config["partition"]
    domain = config["domain"]

    role_groups = {
        "Viewer": f"users.datalake.viewers@{partition}.{domain}",
        "Editor": f"users.datalake.editors@{partition}.{domain}",
        "Admin": f"users.datalake.admins@{partition}.{domain}",
        "Ops": f"users.datalake.ops@{partition}.{domain}",
    }
    root_group = f"users.data.root@{partition}.{domain}"

    headers = {
        "Authorization": f"Bearer {token}",
        "Accept": "application/json",
        "data-partition-id": partition,
    }

    user_data: dict = {"roles": {}, "is_root": False}
    found = False

    with httpx.Client(timeout=30.0) as client:
        # Check each role group for the user
        for role_name, group_email in role_groups.items():
            url = f"https://{config['host']}/api/entitlements/v2/groups/{group_email}/members"
            try:
                response = client.get(url, headers=headers)
                if response.status_code == 200:
                    members = response.json().get("members", [])
                    for member in members:
                        member_email = member.get("email", "")
                        # Match by GUID (with or without trailing comma)
                        if member_email == user_id or member_email.rstrip(",") == user_id:
                            user_data["roles"][role_name] = member.get("role", "MEMBER")
                            found = True
                            break
            except httpx.HTTPError:
                pass

        # Check root group
        url = f"https://{config['host']}/api/entitlements/v2/groups/{root_group}/members"
        try:
            response = client.get(url, headers=headers)
            if response.status_code == 200:
                members = response.json().get("members", [])
                for member in members:
                    member_email = member.get("email", "")
                    if member_email == user_id or member_email.rstrip(",") == user_id:
                        user_data["is_root"] = True
                        found = True
                        break
        except httpx.HTTPError:
            pass

    return user_data if found else None


# =============================================================================
# Audit Logic
# =============================================================================

def cross_reference_users(
    tenant_users: list[dict],
    osdu_users: dict[str, dict],
    ad_groups: dict[str, list[str]]
) -> list[dict]:
    """Cross-reference tenant users with OSDU users.

    Args:
        tenant_users: Users from Azure AD tenant
        osdu_users: Users from OSDU (keyed by email/GUID)
        ad_groups: AD group memberships (keyed by user ID)

    Returns:
        List of consolidated user records
    """
    results = []

    for user in tenant_users:
        user_id = user.get("id", "")

        # Find matching OSDU user by GUID
        # OSDU stores users by GUID, sometimes with trailing comma
        osdu_match = None
        for osdu_email, osdu_data in osdu_users.items():
            # Match by GUID (with or without trailing comma)
            if user_id and (osdu_email == user_id or osdu_email.rstrip(",") == user_id):
                osdu_match = osdu_data
                break

        # Build consolidated record
        record = {
            "id": user_id,
            "name": user.get("name", "Unknown"),
            "email": user.get("mail", ""),
            "type": user.get("type", "Unknown"),
            "enabled": user.get("enabled", True),
            "invite_state": user.get("inviteState"),
            "in_tenant": True,
            "ad_groups": ad_groups.get(user_id, []),
            "in_osdu": osdu_match is not None,
            "osdu_roles": osdu_match.get("roles", {}) if osdu_match else {},
            "is_root": osdu_match.get("is_root", False) if osdu_match else False,
        }
        results.append(record)

    return results


def format_roles(roles: dict[str, str]) -> str:
    """Format roles dict as string with * for OWNER."""
    if not roles:
        return "-"
    parts = []
    for role, membership in sorted(roles.items()):
        if membership == "OWNER":
            parts.append(f"{role}*")
        else:
            parts.append(role)
    return ", ".join(parts)


def analyze_results(results: list[dict]) -> dict:
    """Analyze audit results to generate insights.

    Returns dict with:
        - role_patterns: Analysis of OSDU role distribution
        - ad_group_correlation: How AD groups correlate with OSDU access
        - security_observations: Notable security findings
        - recommendations: Actionable suggestions
    """
    users_with_osdu = [r for r in results if r["in_osdu"]]
    users_without_osdu = [r for r in results if not r["in_osdu"]]

    analysis = {
        "role_patterns": [],
        "ad_group_correlation": [],
        "security_observations": [],
        "recommendations": [],
    }

    # === Role Pattern Analysis ===
    if users_with_osdu:
        # Check if all users have identical roles
        role_sets = [frozenset(u["osdu_roles"].items()) for u in users_with_osdu]
        unique_role_sets = set(role_sets)

        if len(unique_role_sets) == 1:
            # All users have identical roles
            sample_roles = users_with_osdu[0]["osdu_roles"]
            role_str = ", ".join(f"{r}({'Owner' if m == 'OWNER' else 'Member'})"
                                 for r, m in sorted(sample_roles.items()))
            analysis["role_patterns"].append(
                f"All {len(users_with_osdu)} OSDU users have identical roles: {role_str}"
            )
        else:
            # Different role combinations exist
            analysis["role_patterns"].append(
                f"{len(unique_role_sets)} different role combinations across {len(users_with_osdu)} users"
            )

        # Check role distribution
        role_counts = {"Viewer": 0, "Editor": 0, "Admin": 0, "Ops": 0}
        owner_counts = {"Viewer": 0, "Editor": 0, "Admin": 0, "Ops": 0}
        for user in users_with_osdu:
            for role, membership in user["osdu_roles"].items():
                if role in role_counts:
                    role_counts[role] += 1
                    if membership == "OWNER":
                        owner_counts[role] += 1

        # Note missing roles
        missing_roles = [r for r, c in role_counts.items() if c == 0]
        if missing_roles:
            analysis["role_patterns"].append(
                f"No users have these roles: {', '.join(missing_roles)}"
            )

        # Check for root users
        root_users = [u for u in users_with_osdu if u.get("is_root")]
        if root_users:
            analysis["security_observations"].append(
                f"{len(root_users)} user(s) have Root access: {', '.join(u['name'] for u in root_users)}"
            )

    # === AD Group Correlation ===
    if results:
        # Collect all AD groups and their OSDU access rates
        group_stats: dict[str, dict] = {}
        for user in results:
            for group in user["ad_groups"]:
                if group not in group_stats:
                    group_stats[group] = {"total": 0, "with_osdu": 0, "users": []}
                group_stats[group]["total"] += 1
                group_stats[group]["users"].append(user["name"])
                if user["in_osdu"]:
                    group_stats[group]["with_osdu"] += 1

        # Find groups with high OSDU correlation
        for group, stats in sorted(group_stats.items(), key=lambda x: -x[1]["total"]):
            if stats["total"] >= 2:  # Only report groups with 2+ users
                pct = (stats["with_osdu"] / stats["total"]) * 100
                if pct == 100:
                    analysis["ad_group_correlation"].append(
                        f"{group}: {stats['total']} users, all have OSDU access"
                    )
                elif pct == 0:
                    analysis["ad_group_correlation"].append(
                        f"{group}: {stats['total']} users, none have OSDU access"
                    )
                elif pct >= 50:
                    analysis["ad_group_correlation"].append(
                        f"{group}: {stats['with_osdu']}/{stats['total']} users ({pct:.0f}%) have OSDU access"
                    )

    # === Security Observations ===
    if users_with_osdu:
        # Check if all OSDU users are owners
        all_owners = all(
            all(m == "OWNER" for m in u["osdu_roles"].values())
            for u in users_with_osdu
        )
        if all_owners and len(users_with_osdu) > 1:
            analysis["security_observations"].append(
                "All OSDU users have Owner privileges (can manage group membership)"
            )

        # Check for users without any AD groups but with OSDU access
        no_group_osdu = [u for u in users_with_osdu if not u["ad_groups"]]
        if no_group_osdu:
            analysis["security_observations"].append(
                f"{len(no_group_osdu)} user(s) have OSDU access but no AD group memberships"
            )

    # === Recommendations ===
    if users_with_osdu:
        all_owners = all(
            all(m == "OWNER" for m in u["osdu_roles"].values())
            for u in users_with_osdu
        )
        if all_owners and len(users_with_osdu) > 3:
            analysis["recommendations"].append(
                "Consider if all users need Owner privileges, or if some should be Members only"
            )

    if users_without_osdu:
        # Check if users in certain AD groups might need OSDU access
        groups_without_osdu: dict[str, list] = {}
        for user in users_without_osdu:
            for group in user["ad_groups"]:
                if "osdu" in group.lower() or "preship" in group.lower():
                    if group not in groups_without_osdu:
                        groups_without_osdu[group] = []
                    groups_without_osdu[group].append(user["name"])

        for group, users in groups_without_osdu.items():
            analysis["recommendations"].append(
                f"Review if {len(users)} user(s) in '{group}' should have OSDU access"
            )

    return analysis


def format_analysis(analysis: dict) -> list[str]:
    """Format analysis dict as markdown lines."""
    lines = ["", "---", "", "## Analysis", ""]

    if analysis["role_patterns"]:
        lines.append("### Role Patterns")
        for item in analysis["role_patterns"]:
            lines.append(f"- {item}")
        lines.append("")

    if analysis["ad_group_correlation"]:
        lines.append("### AD Group Correlation")
        for item in analysis["ad_group_correlation"]:
            lines.append(f"- {item}")
        lines.append("")

    if analysis["security_observations"]:
        lines.append("### Security Observations")
        for item in analysis["security_observations"]:
            lines.append(f"- {item}")
        lines.append("")

    if analysis["recommendations"]:
        lines.append("### Recommendations")
        for item in analysis["recommendations"]:
            lines.append(f"- {item}")
        lines.append("")

    return lines


def format_single_user_report(
    user: dict,
    ad_groups: list[str],
    osdu_data: dict | None,
    osdu_config: dict
) -> str:
    """Format a single user audit report.

    Args:
        user: Azure AD user data
        ad_groups: List of AD group names
        osdu_data: OSDU roles and root status, or None
        osdu_config: OSDU configuration

    Returns:
        Formatted markdown report
    """
    timestamp = datetime.now().strftime("%B %d, %Y")

    # Format account status
    enabled = user.get("enabled", True)
    account_status = "Enabled" if enabled else "Disabled"

    # Format invitation state for guests
    invite_state = user.get("inviteState")
    invite_display = invite_state if invite_state else "N/A"

    lines = [
        f"## User Access Audit: {user['name']}",
        "",
        f"**Generated:** {timestamp}",
        f"**Email:** {user.get('mail', 'N/A')}",
        "",
        "### Azure AD Tenant",
        "",
        "| Attribute | Value |",
        "|-----------|-------|",
        f"| Status | In Tenant |",
        f"| User Type | {user.get('type', 'Unknown')} |",
        f"| Account | {account_status} |",
        f"| Invitation | {invite_display} |",
        f"| Object ID | {user.get('id', 'N/A')} |",
        "",
    ]

    # AD Groups section
    lines.append("### AD Group Memberships")
    lines.append("")
    if ad_groups:
        for group in sorted(ad_groups):
            lines.append(f"- {group}")
    else:
        lines.append("- None")
    lines.append("")

    # OSDU section
    lines.append("### OSDU Preshipping Access")
    lines.append("")
    lines.append(f"**Instance:** {osdu_config['host']}")
    lines.append("")

    if osdu_data and osdu_data.get("roles"):
        lines.append("| Role | Membership |")
        lines.append("|------|------------|")
        for role, membership in sorted(osdu_data["roles"].items()):
            lines.append(f"| {role} | {membership} |")
        lines.append("")

        if osdu_data.get("is_root"):
            lines.append("**Root Access:** Yes")
            lines.append("")
    else:
        lines.append("**Status:** No OSDU access")
        lines.append("")

    return "\n".join(lines)


# =============================================================================
# CLI
# =============================================================================

@click.command()
@click.option("--company", "-c", help="Company name or email domain to search")
@click.option("--user", "-u", help="Single user email to audit")
@click.option("--json", "output_json", is_flag=True, help="Output as JSON")
@click.option("--output", "-o", type=click.Path(), help="Write output to file")
@click.option("--no-groups", is_flag=True, help="Skip AD group lookups (faster)")
def main(company: str, user: str, output_json: bool, output: str, no_groups: bool):
    """Generate access audit across Azure AD and OSDU.

    Either --company or --user must be specified.
    """
    # Validate that at least one of company or user is provided
    if not company and not user:
        console.print("[red]Error: Either --company or --user must be specified[/red]")
        sys.exit(1)

    if company and user:
        console.print("[red]Error: Cannot specify both --company and --user[/red]")
        sys.exit(1)

    # Validate OSDU config
    osdu_config = get_osdu_config()
    missing = validate_osdu_config(osdu_config)
    if missing:
        console.print(f"[red]Missing environment variables: {', '.join(missing)}[/red]")
        sys.exit(1)

    # =========================================================================
    # Single User Audit Mode
    # =========================================================================
    if user:
        # Step 1: Get user from Azure AD
        if not output_json:
            console.print(f"\n[bold blue]User Access Audit: {user}[/bold blue]\n")
            with console.status("[bold green]Looking up user in Azure AD..."):
                user_data = get_single_user(user)
        else:
            user_data = get_single_user(user)

        if not user_data:
            if output_json:
                print(json.dumps({"success": False, "error": "user_not_found", "message": f"User '{user}' not found in tenant"}))
            else:
                console.print(f"[yellow]User '{user}' not found in tenant[/yellow]")
            sys.exit(0)

        if not output_json:
            console.print(f"Found user: [cyan]{user_data.get('name', 'Unknown')}[/cyan]")

        # Step 2: Get AD group memberships
        user_ad_groups: list[str] = []
        if not no_groups and user_data.get("id"):
            if not output_json:
                with console.status("[bold green]Fetching AD group memberships..."):
                    user_ad_groups = get_user_groups(user_data["id"])
            else:
                user_ad_groups = get_user_groups(user_data["id"])

        # Step 3: Get OSDU access
        osdu_data = None
        try:
            if not output_json:
                with console.status("[bold green]Checking OSDU access..."):
                    token = get_osdu_token(osdu_config)
                    osdu_data = get_single_user_osdu(osdu_config, token, user_data["id"])
            else:
                token = get_osdu_token(osdu_config)
                osdu_data = get_single_user_osdu(osdu_config, token, user_data["id"])
        except httpx.HTTPError as e:
            if output_json:
                print(json.dumps({"success": False, "error": "osdu_auth", "message": str(e)}))
            else:
                console.print(f"[red]OSDU authentication failed: {e}[/red]")
            sys.exit(1)

        # Generate output
        timestamp = datetime.now().strftime("%B %d, %Y")

        if output_json:
            output_data = {
                "success": True,
                "generated": timestamp,
                "user": {
                    "name": user_data.get("name", "Unknown"),
                    "email": user_data.get("mail", user),
                    "type": user_data.get("type", "Unknown"),
                    "id": user_data.get("id", ""),
                    "enabled": user_data.get("enabled", True),
                    "invite_state": user_data.get("inviteState"),
                },
                "tenant": "azureglobal1.onmicrosoft.com",
                "osdu_instance": osdu_config["host"],
                "ad_groups": user_ad_groups,
                "osdu": osdu_data if osdu_data else {"roles": {}, "is_root": False},
                "has_osdu_access": osdu_data is not None and bool(osdu_data.get("roles")),
            }
            json_output = json.dumps(output_data, indent=2)
            if output:
                with open(output, "w") as f:
                    f.write(json_output)
                console.print(f"[green]Output written to {output}[/green]")
            else:
                print(json_output)
        else:
            report = format_single_user_report(user_data, user_ad_groups, osdu_data, osdu_config)
            if output:
                with open(output, "w") as f:
                    f.write(report)
                console.print(f"\n[green]Report written to {output}[/green]")
            else:
                console.print()
                console.print(report)

        return

    # =========================================================================
    # Company Audit Mode
    # =========================================================================

    # Step 1: Get tenant users
    if not output_json:
        console.print(f"\n[bold blue]Company Access Audit: {company.upper()}[/bold blue]\n")
        with console.status("[bold green]Fetching Azure AD tenant users..."):
            tenant_users = get_tenant_users(company)
    else:
        tenant_users = get_tenant_users(company)

    if not tenant_users:
        if output_json:
            print(json.dumps({"success": False, "error": "no_users", "message": f"No users found for '{company}'"}))
        else:
            console.print(f"[yellow]No users found matching '{company}'[/yellow]")
        sys.exit(0)

    if not output_json:
        console.print(f"Found [cyan]{len(tenant_users)}[/cyan] users in tenant")

    # Step 2: Get AD group memberships (parallel)
    ad_groups: dict[str, list[str]] = {}
    if not no_groups:
        user_ids = [u["id"] for u in tenant_users if u.get("id")]
        if not output_json:
            with console.status(f"[bold green]Fetching AD group memberships ({len(user_ids)} users)..."):
                ad_groups = get_user_groups_parallel(user_ids)
        else:
            ad_groups = get_user_groups_parallel(user_ids)

    # Step 3: Get OSDU users
    try:
        if not output_json:
            with console.status("[bold green]Fetching OSDU users..."):
                token = get_osdu_token(osdu_config)
                osdu_users = get_osdu_users(osdu_config, token)
        else:
            token = get_osdu_token(osdu_config)
            osdu_users = get_osdu_users(osdu_config, token)
    except httpx.HTTPError as e:
        if output_json:
            print(json.dumps({"success": False, "error": "osdu_auth", "message": str(e)}))
        else:
            console.print(f"[red]OSDU authentication failed: {e}[/red]")
        sys.exit(1)

    if not output_json:
        console.print(f"Found [cyan]{len(osdu_users)}[/cyan] users in OSDU")

    # Step 4: Cross-reference
    results = cross_reference_users(tenant_users, osdu_users, ad_groups)

    # Calculate stats
    users_with_osdu = sum(1 for r in results if r["in_osdu"])
    users_without_osdu = len(results) - users_with_osdu

    # Generate analysis
    analysis = analyze_results(results)

    # Generate output
    timestamp = datetime.now().strftime("%B %d, %Y")

    if output_json:
        output_data = {
            "success": True,
            "generated": timestamp,
            "company": company.upper(),
            "tenant": "azureglobal1.onmicrosoft.com",
            "osdu_instance": osdu_config["host"],
            "summary": {
                "total_users": len(results),
                "with_osdu_access": users_with_osdu,
                "without_osdu_access": users_without_osdu,
            },
            "analysis": analysis,
            "users": results,
        }
        json_output = json.dumps(output_data, indent=2)
        if output:
            with open(output, "w") as f:
                f.write(json_output)
            console.print(f"[green]Output written to {output}[/green]")
        else:
            print(json_output)
    else:
        # Build markdown report
        lines = [
            f"## {company.upper()} User Access Audit",
            "",
            f"**Generated:** {timestamp}",
            f"**Tenant:** azureglobal1.onmicrosoft.com",
            f"**OSDU Instance:** {osdu_config['host']}",
            "",
            f"| Metric | Count |",
            f"|--------|-------|",
            f"| Users in Tenant | {len(results)} |",
            f"| Users with OSDU Access | {users_with_osdu} |",
            f"| Users without OSDU Access | {users_without_osdu} |",
            "",
            "| Name | Email | Status | AD Groups | OSDU | Roles |",
            "|------|-------|--------|-----------|------|-------|",
        ]

        for user in sorted(results, key=lambda x: x["name"]):
            name = user["name"][:30]
            email = user["email"]
            # Format status: ✓ = Enabled, ✗ = Disabled, ⏸ = Invited (pending)
            enabled = user.get("enabled", True)
            invite = user.get("invite_state")
            if not enabled:
                status = "✗"
            elif invite and invite.lower() == "pendingacceptance":
                status = "⏸"
            else:
                status = "✓"
            groups = ", ".join(user["ad_groups"][:2]) if user["ad_groups"] else "None"
            if len(user["ad_groups"]) > 2:
                groups += f" (+{len(user['ad_groups']) - 2})"
            osdu_check = "✓" if user["in_osdu"] else "✗"
            roles = format_roles(user["osdu_roles"])

            lines.append(f"| {name} | {email} | {status} | {groups} | {osdu_check} | {roles} |")

        lines.extend([
            "",
            "**Legend:** * = Owner | Status: ✓ Enabled, ✗ Disabled, ⏸ Invited",
        ])

        # Add analysis section
        lines.extend(format_analysis(analysis))

        report = "\n".join(lines)

        if output:
            with open(output, "w") as f:
                f.write(report)
            console.print(f"\n[green]Report written to {output}[/green]")
        else:
            console.print()
            console.print(report)


if __name__ == "__main__":
    main()
