#!/usr/bin/env python3
# ABOUTME: JSON-driven interactive party combat for D&D encounters
# ABOUTME: Players control characters, AI controls monsters, tracks state in-memory

import json
import argparse
import sys
import subprocess
import random
from pathlib import Path

SCRIPTS_DIR = Path(__file__).parent
STATE_FILE = Path.cwd() / "encounter-state.json"


def roll_dice(dice_expr):
    """Roll dice using roll_dice.py script."""
    result = subprocess.run(
        [sys.executable, str(SCRIPTS_DIR / "roll_dice.py"), dice_expr],
        capture_output=True,
        text=True,
        check=True
    )
    # Parse output like "1d20: [15] = 15"
    output = result.stdout.strip()
    total = int(output.split("=")[-1].strip())
    return total


def calculate_modifier(score):
    """Calculate ability modifier from ability score."""
    return (score - 10) // 2


def load_json_file(filepath):
    """Load JSON file."""
    with open(filepath, 'r') as f:
        return json.load(f)


def save_state(state):
    """Save encounter state to JSON file."""
    with open(STATE_FILE, 'w') as f:
        json.dump(state, f, indent=2)


def load_state():
    """Load encounter state from JSON file."""
    if not STATE_FILE.exists():
        print("Error: No active encounter found. Use 'encounter.py start' first.")
        sys.exit(1)

    with open(STATE_FILE, 'r') as f:
        return json.load(f)


def clear_state():
    """Remove state file."""
    if STATE_FILE.exists():
        STATE_FILE.unlink()


def roll_initiative(combatants):
    """
    Roll initiative for all combatants.
    Returns: list of combatants sorted by initiative (highest first)
    """
    for combatant in combatants:
        d20_roll = roll_dice("1d20")
        dex_mod = calculate_modifier(combatant['abilities']['dex'])
        combatant['initiative'] = d20_roll + dex_mod
        combatant['initiative_roll'] = d20_roll

    # Sort by initiative DESC, then by DEX DESC for ties
    combatants.sort(key=lambda c: (c['initiative'], c['abilities']['dex']), reverse=True)

    return combatants


def start_encounter(args):
    """Start a new encounter from JSON files."""
    # Load heroes and foes
    heroes_data = load_json_file(args.heroes_file)
    foes_data = load_json_file(args.foes_file)

    heroes = heroes_data.get('heroes', [])
    foes = foes_data.get('foes', [])

    if not heroes:
        print("Error: No heroes found in heroes file")
        sys.exit(1)

    if not foes:
        print("Error: No foes found in foes file")
        sys.exit(1)

    # Mark type
    for hero in heroes:
        hero['type'] = 'hero'
        hero['is_defeated'] = False
        if 'abilities' not in hero:
            # Provide defaults if missing
            hero['abilities'] = {'str': 10, 'dex': 10, 'con': 10, 'int': 10, 'wis': 10, 'cha': 10}

    for foe in foes:
        foe['type'] = 'foe'
        foe['is_defeated'] = False
        if 'abilities' not in foe:
            foe['abilities'] = {'str': 10, 'dex': 10, 'con': 10, 'int': 10, 'wis': 10, 'cha': 10}

    # Combine all combatants
    all_combatants = heroes + foes

    # Roll initiative
    all_combatants = roll_initiative(all_combatants)

    # Create state
    state = {
        'round': 1,
        'turn_index': 0,
        'combatants': all_combatants,
        'combat_log': []
    }

    save_state(state)

    # Display initial state
    print("\n" + "="*60)
    print("ENCOUNTER STARTED!")
    print("="*60)

    print("\nInitiative Order:")
    for i, c in enumerate(all_combatants, 1):
        type_icon = "👤" if c['type'] == 'hero' else "👹"
        dex_mod = calculate_modifier(c['abilities']['dex'])
        dex_str = f"+{dex_mod}" if dex_mod >= 0 else str(dex_mod)
        print(f"  {i:2}. [{c['initiative']:2}] {type_icon} {c['name']:20} (d20: {c['initiative_roll']}, DEX: {dex_str})")

    print("\n" + "="*60)
    show_current_turn(state)


def show_current_turn(state):
    """Display current turn prompt."""
    current = state['combatants'][state['turn_index']]

    print(f"\nRound {state['round']}")
    print(f"Turn: {current['name']} ({'Hero' if current['type'] == 'hero' else 'Foe'})")
    print(f"HP: {current['hp_current']}/{current['hp_max']}")

    if current['type'] == 'hero':
        # Show available actions
        print("\nAvailable actions:")
        print(f"  encounter.py attack {current['name']} TARGET")
        if current.get('class') in ['wizard', 'sorcerer', 'cleric', 'druid', 'warlock']:
            print(f"  encounter.py cast {current['name']} SPELL TARGET")
        print(f"  encounter.py show")

        # Show available targets
        foes_alive = [c for c in state['combatants'] if c['type'] == 'foe' and not c['is_defeated']]
        if foes_alive:
            print("\nTargets:")
            for foe in foes_alive:
                print(f"  • {foe['name']} (HP: {foe['hp_current']}/{foe['hp_max']}, AC: {foe['ac']})")
    else:
        # Auto-run monster turn
        run_monster_turn(state, current)


def run_monster_turn(state, monster):
    """AI-controlled monster turn."""
    # Pick target: lowest HP hero
    heroes_alive = [c for c in state['combatants'] if c['type'] == 'hero' and not c['is_defeated']]

    if not heroes_alive:
        # All heroes defeated
        end_encounter(state, 'defeat')
        return

    # Simple AI: attack hero with lowest HP
    target = min(heroes_alive, key=lambda h: h['hp_current'])

    # Roll attack
    attack = monster.get('attack', {'name': 'Attack', 'bonus': 0, 'damage_dice': '1d4'})
    attack_roll = roll_dice("1d20") + attack['bonus']

    print(f"\n{monster['name']} attacks {target['name']}!")
    print(f"  Attack roll: 1d20+{attack['bonus']} = {attack_roll}")

    if attack_roll >= target['ac']:
        # Hit! Roll damage
        damage = roll_dice(attack['damage_dice'])
        print(f"  HIT! Damage: {attack['damage_dice']} = {damage}")

        apply_damage(state, target['name'], damage)
    else:
        print(f"  MISS! (AC {target['ac']})")

    # Log action
    state['combat_log'].append({
        'round': state['round'],
        'actor': monster['name'],
        'action': 'attack',
        'target': target['name'],
        'hit': attack_roll >= target['ac'],
        'damage': damage if attack_roll >= target['ac'] else 0
    })

    # Advance turn
    advance_turn(state)


def apply_damage(state, target_name, damage):
    """Apply damage to a combatant."""
    for combatant in state['combatants']:
        if combatant['name'] == target_name:
            old_hp = combatant['hp_current']
            combatant['hp_current'] = max(0, old_hp - damage)

            print(f"  {target_name}: {old_hp} → {combatant['hp_current']} HP")

            if combatant['hp_current'] <= 0 and not combatant['is_defeated']:
                combatant['is_defeated'] = True
                type_str = "Hero" if combatant['type'] == 'hero' else "Foe"
                print(f"  💀 {target_name} has been DEFEATED!")

            break


def advance_turn(state):
    """Advance to next turn."""
    state['turn_index'] += 1

    # Skip defeated combatants
    while state['turn_index'] < len(state['combatants']):
        if not state['combatants'][state['turn_index']]['is_defeated']:
            break
        state['turn_index'] += 1

    # Check if round ended
    if state['turn_index'] >= len(state['combatants']):
        state['round'] += 1
        state['turn_index'] = 0

        # Skip defeated combatants at start of new round
        while state['turn_index'] < len(state['combatants']):
            if not state['combatants'][state['turn_index']]['is_defeated']:
                break
            state['turn_index'] += 1

    # Check victory/defeat
    heroes_alive = [c for c in state['combatants'] if c['type'] == 'hero' and not c['is_defeated']]
    foes_alive = [c for c in state['combatants'] if c['type'] == 'foe' and not c['is_defeated']]

    if not foes_alive:
        end_encounter(state, 'victory')
        return
    elif not heroes_alive:
        end_encounter(state, 'defeat')
        return

    save_state(state)

    # Show next turn
    show_current_turn(state)


def character_attack(args):
    """Player character attacks a target."""
    state = load_state()

    # Find attacker
    attacker = None
    for c in state['combatants']:
        if c['name'] == args.character and c['type'] == 'hero':
            attacker = c
            break

    if not attacker:
        print(f"Error: Hero '{args.character}' not found")
        sys.exit(1)

    # Check if it's their turn
    current = state['combatants'][state['turn_index']]
    if current['name'] != args.character:
        print(f"Error: Not {args.character}'s turn (current: {current['name']})")
        sys.exit(1)

    # Find target
    target = None
    for c in state['combatants']:
        if c['name'] == args.target and c['type'] == 'foe':
            target = c
            break

    if not target:
        print(f"Error: Target '{args.target}' not found")
        sys.exit(1)

    if target['is_defeated']:
        print(f"Error: {args.target} is already defeated")
        sys.exit(1)

    # Roll attack
    str_mod = calculate_modifier(attacker['abilities']['str'])
    prof_bonus = 2 + ((attacker.get('level', 1) - 1) // 4)
    attack_bonus = str_mod + prof_bonus

    attack_roll = roll_dice("1d20") + attack_bonus

    weapon = args.weapon or "weapon"
    print(f"\n{attacker['name']} attacks {target['name']} with {weapon}!")
    print(f"  Attack roll: 1d20+{attack_bonus} = {attack_roll}")

    if attack_roll >= target['ac']:
        # Hit! Roll damage
        # Default weapon damage: 1d8 + STR for martial, 1d6 + STR for simple
        damage_dice = "1d8" if attacker.get('class') in ['fighter', 'paladin', 'ranger', 'barbarian'] else "1d6"
        base_damage = roll_dice(damage_dice)
        total_damage = base_damage + str_mod

        print(f"  HIT! Damage: {damage_dice}+{str_mod} = {total_damage}")

        apply_damage(state, target['name'], total_damage)

        # Log action
        state['combat_log'].append({
            'round': state['round'],
            'actor': attacker['name'],
            'action': 'attack',
            'target': target['name'],
            'hit': True,
            'damage': total_damage
        })
    else:
        print(f"  MISS! (AC {target['ac']})")

        state['combat_log'].append({
            'round': state['round'],
            'actor': attacker['name'],
            'action': 'attack',
            'target': target['name'],
            'hit': False,
            'damage': 0
        })

    # Advance turn
    advance_turn(state)


def character_cast(args):
    """Player character casts a spell."""
    state = load_state()

    # Find caster
    caster = None
    for c in state['combatants']:
        if c['name'] == args.character and c['type'] == 'hero':
            caster = c
            break

    if not caster:
        print(f"Error: Hero '{args.character}' not found")
        sys.exit(1)

    # Check if it's their turn
    current = state['combatants'][state['turn_index']]
    if current['name'] != args.character:
        print(f"Error: Not {args.character}'s turn (current: {current['name']})")
        sys.exit(1)

    # Find target
    target = None
    for c in state['combatants']:
        if c['name'] == args.target:
            target = c
            break

    if not target:
        print(f"Error: Target '{args.target}' not found")
        sys.exit(1)

    # Simple spell casting (uses combat.py logic as reference)
    # For now, support basic damage cantrips
    spell_lower = args.spell.lower()

    if 'fire bolt' in spell_lower or 'firebolt' in spell_lower:
        # Fire Bolt: spell attack, 1d10 fire damage
        int_mod = calculate_modifier(caster['abilities']['int'])
        prof_bonus = 2 + ((caster.get('level', 1) - 1) // 4)
        spell_attack = int_mod + prof_bonus

        attack_roll = roll_dice("1d20") + spell_attack

        print(f"\n{caster['name']} casts Fire Bolt at {target['name']}!")
        print(f"  Spell attack roll: 1d20+{spell_attack} = {attack_roll}")

        if attack_roll >= target['ac']:
            damage = roll_dice("1d10")
            print(f"  HIT! Fire damage: 1d10 = {damage}")

            apply_damage(state, target['name'], damage)

            state['combat_log'].append({
                'round': state['round'],
                'actor': caster['name'],
                'action': 'cast',
                'spell': args.spell,
                'target': target['name'],
                'hit': True,
                'damage': damage
            })
        else:
            print(f"  MISS! (AC {target['ac']})")

            state['combat_log'].append({
                'round': state['round'],
                'actor': caster['name'],
                'action': 'cast',
                'spell': args.spell,
                'target': target['name'],
                'hit': False,
                'damage': 0
            })
    else:
        print(f"Error: Spell '{args.spell}' not implemented (try 'Fire Bolt')")
        sys.exit(1)

    # Advance turn
    advance_turn(state)


def show_state(args):
    """Display current encounter state."""
    state = load_state()

    print("\n" + "="*60)
    print(f"ENCOUNTER STATE - Round {state['round']}")
    print("="*60)

    print("\nHeroes:")
    for c in state['combatants']:
        if c['type'] == 'hero':
            status = "💀 DEFEATED" if c['is_defeated'] else f"HP: {c['hp_current']}/{c['hp_max']}"
            print(f"  • {c['name']:20} {status}")

    print("\nFoes:")
    for c in state['combatants']:
        if c['type'] == 'foe':
            status = "💀 DEFEATED" if c['is_defeated'] else f"HP: {c['hp_current']}/{c['hp_max']}, AC: {c['ac']}"
            print(f"  • {c['name']:20} {status}")

    print("\n" + "="*60)

    current = state['combatants'][state['turn_index']]
    print(f"Current turn: {current['name']} ({'Hero' if current['type'] == 'hero' else 'Foe'})")
    print("="*60 + "\n")


def end_encounter(state, outcome):
    """End the encounter and output results."""
    print("\n" + "="*60)
    print(f"ENCOUNTER ENDED: {outcome.upper()}!")
    print("="*60)

    heroes = [c for c in state['combatants'] if c['type'] == 'hero']
    foes = [c for c in state['combatants'] if c['type'] == 'foe']

    # Calculate XP
    total_xp = 0
    cr_to_xp = {
        0: 10, 0.125: 25, 0.25: 50, 0.5: 100,
        1: 200, 2: 450, 3: 700, 4: 1100, 5: 1800
    }

    for foe in foes:
        if foe['is_defeated']:
            cr = foe.get('cr', 0)
            xp = cr_to_xp.get(cr, 0)
            total_xp += xp

    heroes_alive = [h for h in heroes if not h['is_defeated']]
    xp_per_hero = total_xp // len(heroes) if heroes else 0

    # Create results
    results = {
        'outcome': outcome,
        'rounds': state['round'],
        'heroes': [
            {
                'name': h['name'],
                'hp_final': h['hp_current'],
                'hp_max': h['hp_max'],
                'is_defeated': h['is_defeated'],
                'xp_earned': xp_per_hero if not h['is_defeated'] else 0
            }
            for h in heroes
        ],
        'foes': [
            {
                'name': f['name'],
                'hp_final': f['hp_current'],
                'is_defeated': f['is_defeated']
            }
            for f in foes
        ],
        'total_xp': total_xp,
        'xp_per_hero': xp_per_hero,
        'combat_log': state['combat_log']
    }

    # Write results
    results_file = Path.cwd() / "encounter-results.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n📊 Combat Summary:")
    print(f"  Rounds: {state['round']}")
    print(f"  Total XP: {total_xp}")
    print(f"  XP per hero: {xp_per_hero}")

    print(f"\n📄 Results saved to: {results_file}")

    if outcome == 'victory':
        print("\n💡 Update characters with:")
        for hero in heroes_alive:
            print(f"  progression.py award {hero['name']} {xp_per_hero}")

    # Clean up state
    clear_state()
    sys.exit(0)


def main():
    parser = argparse.ArgumentParser(description='D&D Party Encounter (Interactive)')
    subparsers = parser.add_subparsers(dest='command', help='Commands')

    # Start encounter
    start_parser = subparsers.add_parser('start', help='Start new encounter from JSON files')
    start_parser.add_argument('heroes_file', help='Heroes JSON file (e.g., heroes.json)')
    start_parser.add_argument('foes_file', help='Foes JSON file (e.g., foes.json)')

    # Character attack
    attack_parser = subparsers.add_parser('attack', help='Character attacks target')
    attack_parser.add_argument('character', help='Character name')
    attack_parser.add_argument('target', help='Target name')
    attack_parser.add_argument('weapon', nargs='?', help='Weapon name (optional)')

    # Character cast
    cast_parser = subparsers.add_parser('cast', help='Character casts spell')
    cast_parser.add_argument('character', help='Character name')
    cast_parser.add_argument('spell', help='Spell name')
    cast_parser.add_argument('target', help='Target name')

    # Show state
    show_parser = subparsers.add_parser('show', help='Show current state')

    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        sys.exit(1)

    if args.command == 'start':
        start_encounter(args)
    elif args.command == 'attack':
        character_attack(args)
    elif args.command == 'cast':
        character_cast(args)
    elif args.command == 'show':
        show_state(args)


if __name__ == '__main__':
    main()
