"""
Nixtla Forecasting Experiment Harness
Generated by nixtla-experiment-architect skill

This script:
1. Loads data into Nixtla schema (unique_id, ds, y)
2. Runs StatsForecast baseline models
3. Runs MLForecast with lag features (if enabled)
4. Calls TimeGPT (if API key configured)
5. Performs cross-validation
6. Computes evaluation metrics (SMAPE, MASE, RMSE)
7. Saves results to forecasting/results/
"""

import os
from pathlib import Path

import pandas as pd
import yaml

# Load configuration
with open("forecasting/config.yml", "r") as f:
    config = yaml.safe_load(f)

# ========== DATA LOADING ==========


def load_data(config):
    """
    Load data from source and transform to Nixtla schema.

    Expected output schema:
        unique_id: Series identifier (str or int)
        ds: Timestamp (datetime)
        y: Target value (float)
    """
    data_cfg = config["data"]
    source = data_cfg["source"]

    # Detect source type
    if source.endswith(".csv"):
        # CSV file
        df = pd.read_csv(source)
    elif source.endswith(".parquet"):
        # Parquet file
        df = pd.read_parquet(source)
    elif source.startswith("SELECT") or source.startswith("select"):
        # SQL query
        # TODO: Add database connection logic
        # Example:
        # from sqlalchemy import create_engine
        # engine = create_engine(os.getenv('DATABASE_URL'))
        # df = pd.read_sql(source, engine)
        raise NotImplementedError("SQL queries require database connection setup")
    elif "{{" in source and "}}" in source:
        # dbt model reference
        # TODO: Execute dbt and load result
        # Example:
        # import subprocess
        # subprocess.run(['dbt', 'run', '--select', model_name])
        # df = pd.read_parquet('target/fct_sales.parquet')
        raise NotImplementedError("dbt models require dbt CLI setup")
    else:
        raise ValueError(f"Unknown data source format: {source}")

    # Transform to Nixtla schema
    df_nixtla = df.rename(columns={data_cfg["ds"]: "ds", data_cfg["y"]: "y"})

    # Handle unique_id
    if data_cfg["unique_id"]:
        df_nixtla = df_nixtla.rename(columns={data_cfg["unique_id"]: "unique_id"})
    else:
        # Single series
        df_nixtla["unique_id"] = "series_1"

    # Ensure ds is datetime
    df_nixtla["ds"] = pd.to_datetime(df_nixtla["ds"])

    # Ensure y is numeric
    df_nixtla["y"] = pd.to_numeric(df_nixtla["y"], errors="coerce")

    # Drop nulls in target
    df_nixtla = df_nixtla.dropna(subset=["y"])

    # Select required columns
    cols = ["unique_id", "ds", "y"]
    if data_cfg.get("exog_vars"):
        cols.extend(data_cfg["exog_vars"])

    return df_nixtla[cols]


# ========== STATSFORECAST MODELS ==========


def run_statsforecast(df, config):
    """
    Run StatsForecast baseline models.
    """
    try:
        from statsforecast import StatsForecast
        from statsforecast.models import AutoARIMA, AutoETS, AutoTheta, SeasonalNaive
    except ImportError:
        print("⚠️  statsforecast not installed. Install with:")
        print("   pip install statsforecast")
        return None

    # Build model list from config
    model_map = {
        "SeasonalNaive": SeasonalNaive(season_length=config["forecast"]["season_length"]),
        "AutoARIMA": AutoARIMA(),
        "AutoETS": AutoETS(season_length=config["forecast"]["season_length"]),
        "AutoTheta": AutoTheta(season_length=config["forecast"]["season_length"]),
    }

    models = [model_map[name] for name in config["models"]["statsforecast"] if name in model_map]

    if not models:
        print("⚠️  No StatsForecast models enabled in config")
        return None

    # Initialize StatsForecast
    sf = StatsForecast(
        models=models, freq=config["forecast"]["freq"], n_jobs=-1  # Parallel processing
    )

    # Fit models
    print(f"📊 Fitting {len(models)} StatsForecast models...")
    sf.fit(df)

    # Cross-validation
    cv_cfg = config["cv"]
    print(
        f"🔄 Running {cv_cfg['method']} cross-validation "
        f"(h={cv_cfg['h']}, windows={cv_cfg['n_windows']})..."
    )

    cv_df = sf.cross_validation(
        df=df,
        h=cv_cfg["h"],
        step_size=cv_cfg["step_size"],
        n_windows=cv_cfg["n_windows"],
        refit=(cv_cfg["method"] == "expanding"),
    )

    return cv_df


# ========== MLFORECAST MODELS ==========


def run_mlforecast(df, config):
    """
    Run MLForecast with lag features and sklearn models.
    """
    if not config["models"]["mlforecast"]["enabled"]:
        print("ℹ️  MLForecast disabled in config")
        return None

    try:
        from mlforecast import MLForecast
        from mlforecast.lag_transforms import expanding_mean, rolling_std
    except ImportError:
        print("⚠️  mlforecast not installed. Install with:")
        print("   pip install mlforecast")
        return None

    # Import sklearn models
    try:
        from lightgbm import LGBMRegressor
        from sklearn.ensemble import RandomForestRegressor
    except ImportError:
        print("⚠️  sklearn or lightgbm not installed. Install with:")
        print("   pip install scikit-learn lightgbm")
        return None

    ml_cfg = config["models"]["mlforecast"]

    # Build model list
    model_map = {
        "RandomForestRegressor": RandomForestRegressor(n_estimators=100, random_state=42),
        "LGBMRegressor": LGBMRegressor(n_estimators=100, random_state=42),
    }

    models = [model_map[name] for name in ml_cfg["base_models"] if name in model_map]

    if not models:
        print("⚠️  No MLForecast models enabled in config")
        return None

    # Setup lag transforms
    lag_transforms = None
    if ml_cfg.get("lag_transforms"):
        lag_transforms = {7: [rolling_std], 28: [expanding_mean]}

    # Initialize MLForecast
    mlf = MLForecast(
        models=models,
        freq=config["forecast"]["freq"],
        lags=ml_cfg["lags"],
        lag_transforms=lag_transforms,
        date_features=["dayofweek", "month", "quarter"],
    )

    # Cross-validation
    cv_cfg = config["cv"]
    print(f"🤖 Running MLForecast cross-validation...")

    cv_df = mlf.cross_validation(
        df=df,
        h=cv_cfg["h"],
        step_size=cv_cfg["step_size"],
        n_windows=cv_cfg["n_windows"],
        refit=(cv_cfg["method"] == "expanding"),
    )

    return cv_df


# ========== TIMEGPT ==========


def run_timegpt(df, config):
    """
    Run TimeGPT forecasting (if API key configured).
    """
    if not config["models"]["timegpt"]["enabled"]:
        print("ℹ️  TimeGPT disabled in config")
        return None

    api_key = os.getenv("NIXTLA_API_KEY")
    if not api_key:
        print("⚠️  TimeGPT enabled but NIXTLA_API_KEY not set")
        print("   Set environment variable: export NIXTLA_API_KEY='your-key'")
        return None

    try:
        from nixtla import NixtlaClient
    except ImportError:
        print("⚠️  nixtla package not installed. Install with:")
        print("   pip install nixtla")
        return None

    client = NixtlaClient(api_key=api_key)

    tg_cfg = config["models"]["timegpt"]
    cv_cfg = config["cv"]

    print(f"🚀 Running TimeGPT cross-validation...")

    cv_df = client.cross_validation(
        df=df,
        h=cv_cfg["h"],
        n_windows=cv_cfg["n_windows"],
        step_size=cv_cfg["step_size"],
        level=tg_cfg.get("level", [80, 90]),
        finetune_steps=tg_cfg.get("finetune_steps", 0),
    )

    # Rename TimeGPT column to match other models
    cv_df = cv_df.rename(columns={"TimeGPT": "TimeGPT"})

    return cv_df


# ========== EVALUATION ==========


def evaluate_models(cv_results, config):
    """
    Compute evaluation metrics across all models.
    """
    try:
        from utilsforecast.evaluation import evaluate
        from utilsforecast.losses import mae, mase, rmse, smape
    except ImportError:
        print("⚠️  utilsforecast not installed. Install with:")
        print("   pip install utilsforecast")
        return None

    # Combine all CV results
    all_cv = []
    model_names = []

    for model_name, cv_df in cv_results.items():
        if cv_df is not None:
            all_cv.append(cv_df)
            model_names.append(model_name)

    if not all_cv:
        print("❌ No cross-validation results to evaluate")
        return None

    # Concatenate all results
    combined_cv = pd.concat(all_cv, axis=0).reset_index(drop=True)

    # Map metric names to functions
    metric_map = {"smape": smape, "mase": mase, "rmse": rmse, "mae": mae}

    metrics = [metric_map[m] for m in config["metrics"] if m in metric_map]

    # Evaluate
    print(f"📈 Computing {len(metrics)} metrics across {len(model_names)} model families...")

    metrics_df = evaluate(combined_cv, metrics=metrics, models=model_names)

    return metrics_df


# ========== MAIN EXPERIMENT ==========


def main():
    """
    Run complete forecasting experiment.
    """
    print("=" * 60)
    print("🔬 Nixtla Forecasting Experiment")
    print("=" * 60)

    # Load data
    print("\n📂 Loading data...")
    df = load_data(config)
    print(f"   Loaded {len(df)} rows, {df['unique_id'].nunique()} series")

    # Run models
    cv_results = {}

    # StatsForecast
    if config["models"]["statsforecast"]:
        sf_cv = run_statsforecast(df, config)
        if sf_cv is not None:
            cv_results["StatsForecast"] = sf_cv

    # MLForecast
    if config["models"]["mlforecast"]["enabled"]:
        mlf_cv = run_mlforecast(df, config)
        if mlf_cv is not None:
            cv_results["MLForecast"] = mlf_cv

    # TimeGPT
    if config["models"]["timegpt"]["enabled"]:
        tg_cv = run_timegpt(df, config)
        if tg_cv is not None:
            cv_results["TimeGPT"] = tg_cv

    # Evaluate
    print("\n📊 Evaluating models...")
    metrics_df = evaluate_models(cv_results, config)

    if metrics_df is not None:
        # Print results
        print("\n" + "=" * 60)
        print("📈 RESULTS")
        print("=" * 60)
        print(metrics_df.groupby("model").mean())

        # Save results
        output_cfg = config["output"]
        results_dir = Path(output_cfg["results_dir"])
        results_dir.mkdir(parents=True, exist_ok=True)

        if output_cfg.get("save_cv_results", True):
            for model_name, cv_df in cv_results.items():
                cv_path = results_dir / f"cv_{model_name.lower()}.csv"
                cv_df.to_csv(cv_path, index=False)
                print(f"💾 Saved: {cv_path}")

        metrics_path = results_dir / "metrics_summary.csv"
        metrics_df.to_csv(metrics_path, index=False)
        print(f"💾 Saved: {metrics_path}")

    print("\n✅ Experiment complete!")
    print("\nNext steps:")
    print("1. Review metrics in forecasting/results/metrics_summary.csv")
    print("2. Visualize forecasts (optional: add plotting code)")
    print("3. Select best model and deploy to production")


if __name__ == "__main__":
    main()
