robust-ai
Building robust AI systems including model monitoring, drift detection, reliability engineering, and failure handling for production ML.
When & Why to Use This Skill
This Claude skill provides a comprehensive framework for building and maintaining production-grade AI systems. It focuses on MLOps essentials such as real-time model monitoring, statistical drift detection, and reliability engineering. By implementing uncertainty estimation and automated fallback strategies, it ensures that machine learning models remain robust, accurate, and performant even as data distributions evolve over time.
Use Cases
- Production Monitoring: Setting up Prometheus and Grafana integration to track real-time metrics like prediction latency, confidence scores, and throughput.
- Drift Detection & Analysis: Using statistical tests (KS-test, Chi-square) and tools like Evidently AI to identify data and concept drift before they impact business outcomes.
- Reliability Engineering: Implementing multi-tier fallback mechanisms (primary model to lightweight model to rule-based logic) to handle low-confidence or failed predictions.
- Uncertainty Quantification: Applying Monte Carlo Dropout and Deep Ensembles to measure model uncertainty, essential for high-stakes applications like finance or healthcare.
- Automated MLOps Pipelines: Configuring automated retraining triggers based on performance degradation or drift scores to maintain model freshness without manual intervention.
| name | robust-ai |
|---|---|
| description | Building robust AI systems including model monitoring, drift detection, reliability engineering, and failure handling for production ML. |
Robust AI
Building reliable and robust ML systems.
Robustness Framework
┌─────────────────────────────────────────────────────────────┐
│ AI ROBUSTNESS LAYERS │
├─────────────────────────────────────────────────────────────┤
│ │
│ DATA QUALITY MODEL QUALITY SYSTEM QUALITY │
│ ──────────── ───────────── ────────────── │
│ Validation Testing Monitoring │
│ Anomaly detection Adversarial test Alerting │
│ Drift detection Uncertainty Fallbacks │
│ │
│ FAILURE MODES: │
│ ├── Data drift: Input distribution changes │
│ ├── Concept drift: Input-output relationship changes │
│ ├── Model degradation: Performance decline over time │
│ ├── Silent failures: Wrong predictions with high confidence│
│ └── System failures: Infrastructure and latency issues │
│ │
└─────────────────────────────────────────────────────────────┘
Model Monitoring
Prometheus + Grafana Setup
from prometheus_client import Counter, Histogram, Gauge, start_http_server
# Metrics
PREDICTIONS = Counter('model_predictions_total', 'Total predictions', ['model', 'class'])
LATENCY = Histogram('model_latency_seconds', 'Prediction latency', ['model'])
CONFIDENCE = Histogram('model_confidence', 'Prediction confidence', ['model'], buckets=[0.5, 0.7, 0.9, 0.95, 0.99])
DRIFT_SCORE = Gauge('model_drift_score', 'Data drift score', ['model', 'feature'])
class MonitoredModel:
def __init__(self, model, model_name):
self.model = model
self.model_name = model_name
def predict(self, x):
with LATENCY.labels(model=self.model_name).time():
output = self.model(x)
probs = torch.softmax(output, dim=1)
pred_class = probs.argmax(dim=1).item()
confidence = probs.max().item()
PREDICTIONS.labels(model=self.model_name, class_=str(pred_class)).inc()
CONFIDENCE.labels(model=self.model_name).observe(confidence)
return output
# Start metrics server
start_http_server(8000)
Evidently AI Monitoring
from evidently import ColumnMapping
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, DataQualityPreset
from evidently.metrics import ColumnDriftMetric, DatasetDriftMetric
# Define column mapping
column_mapping = ColumnMapping(
target='target',
prediction='prediction',
numerical_features=['age', 'income', 'score'],
categorical_features=['category', 'region']
)
# Create drift report
report = Report(metrics=[
DataDriftPreset(),
DataQualityPreset(),
ColumnDriftMetric(column_name='age'),
DatasetDriftMetric()
])
report.run(
reference_data=reference_df,
current_data=current_df,
column_mapping=column_mapping
)
# Save report
report.save_html('drift_report.html')
# Get drift results
results = report.as_dict()
drift_detected = results['metrics'][0]['result']['dataset_drift']
Drift Detection
Statistical Drift Detection
from scipy import stats
import numpy as np
class DriftDetector:
def __init__(self, reference_data, significance_level=0.05):
self.reference = reference_data
self.significance = significance_level
def detect_drift(self, current_data):
results = {}
for col in self.reference.columns:
if self.reference[col].dtype in ['float64', 'int64']:
# Kolmogorov-Smirnov test for numerical
stat, p_value = stats.ks_2samp(
self.reference[col],
current_data[col]
)
else:
# Chi-square test for categorical
ref_counts = self.reference[col].value_counts()
cur_counts = current_data[col].value_counts()
stat, p_value = stats.chisquare(cur_counts, ref_counts)
results[col] = {
'statistic': stat,
'p_value': p_value,
'drift_detected': p_value < self.significance
}
return results
# Population Stability Index (PSI)
def calculate_psi(reference, current, bins=10):
ref_counts, bin_edges = np.histogram(reference, bins=bins)
cur_counts, _ = np.histogram(current, bins=bin_edges)
ref_pct = ref_counts / len(reference)
cur_pct = cur_counts / len(current)
# Avoid division by zero
ref_pct = np.clip(ref_pct, 0.0001, None)
cur_pct = np.clip(cur_pct, 0.0001, None)
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
return psi # PSI > 0.25 indicates significant drift
Concept Drift Detection
from river import drift
class ConceptDriftMonitor:
def __init__(self):
self.adwin = drift.ADWIN()
self.ddm = drift.DDM()
self.performance_window = []
def update(self, y_true, y_pred):
error = int(y_true != y_pred)
# ADWIN for gradual drift
self.adwin.update(error)
adwin_drift = self.adwin.drift_detected
# DDM for sudden drift
self.ddm.update(error)
ddm_drift = self.ddm.drift_detected
return {
'adwin_drift': adwin_drift,
'ddm_drift': ddm_drift,
'error_rate': self.adwin.estimation
}
# Performance-based drift detection
class PerformanceDriftDetector:
def __init__(self, window_size=1000, threshold=0.1):
self.window_size = window_size
self.threshold = threshold
self.baseline_accuracy = None
self.current_window = []
def update(self, y_true, y_pred):
self.current_window.append(int(y_true == y_pred))
if len(self.current_window) >= self.window_size:
current_accuracy = np.mean(self.current_window[-self.window_size:])
if self.baseline_accuracy is None:
self.baseline_accuracy = current_accuracy
drift_detected = (self.baseline_accuracy - current_accuracy) > self.threshold
return {
'baseline': self.baseline_accuracy,
'current': current_accuracy,
'drift_detected': drift_detected
}
return None
Uncertainty Estimation
import torch
import torch.nn as nn
import torch.nn.functional as F
class MCDropoutModel(nn.Module):
"""Monte Carlo Dropout for uncertainty estimation."""
def __init__(self, base_model, dropout_rate=0.1):
super().__init__()
self.base_model = base_model
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x, num_samples=30):
self.train() # Enable dropout
outputs = []
for _ in range(num_samples):
out = self.base_model(x)
out = self.dropout(out)
outputs.append(out)
outputs = torch.stack(outputs)
mean = outputs.mean(dim=0)
variance = outputs.var(dim=0)
epistemic_uncertainty = variance.mean(dim=-1)
return mean, epistemic_uncertainty
# Deep Ensembles
class EnsembleModel:
def __init__(self, models):
self.models = models
def predict_with_uncertainty(self, x):
predictions = []
for model in self.models:
model.eval()
with torch.no_grad():
pred = model(x)
predictions.append(pred)
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance
# Calibration (Temperature Scaling)
class TemperatureScaling(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.temperature = nn.Parameter(torch.ones(1))
def forward(self, x):
logits = self.model(x)
return logits / self.temperature
def calibrate(self, val_loader):
nll_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
def eval_loss():
optimizer.zero_grad()
total_loss = 0
for x, y in val_loader:
logits = self.forward(x)
loss = nll_criterion(logits, y)
total_loss += loss
total_loss.backward()
return total_loss
optimizer.step(eval_loss)
Fallback Strategies
class RobustInferenceService:
def __init__(self, primary_model, fallback_model, confidence_threshold=0.7):
self.primary = primary_model
self.fallback = fallback_model
self.threshold = confidence_threshold
self.rule_based_fallback = RuleBasedModel()
def predict(self, x):
try:
# Try primary model
output = self.primary(x)
confidence = torch.softmax(output, dim=1).max().item()
if confidence >= self.threshold:
return {
'prediction': output.argmax().item(),
'confidence': confidence,
'model': 'primary'
}
# Low confidence - use fallback
output = self.fallback(x)
confidence = torch.softmax(output, dim=1).max().item()
if confidence >= self.threshold * 0.8:
return {
'prediction': output.argmax().item(),
'confidence': confidence,
'model': 'fallback'
}
# Still low confidence - use rules
return {
'prediction': self.rule_based_fallback(x),
'confidence': None,
'model': 'rule_based'
}
except Exception as e:
# System failure - use cached/default
return {
'prediction': self.get_default_prediction(),
'confidence': None,
'model': 'default',
'error': str(e)
}
def get_default_prediction(self):
# Return most common class or safe default
return 0
Automated Retraining
class AutoRetrainTrigger:
def __init__(self, drift_threshold=0.2, accuracy_threshold=0.85):
self.drift_threshold = drift_threshold
self.accuracy_threshold = accuracy_threshold
self.metrics_history = []
def should_retrain(self, metrics):
self.metrics_history.append(metrics)
# Check data drift
if metrics.get('drift_score', 0) > self.drift_threshold:
return True, 'data_drift'
# Check accuracy degradation
if metrics.get('accuracy', 1.0) < self.accuracy_threshold:
return True, 'accuracy_drop'
# Check trend
if len(self.metrics_history) >= 7:
recent = [m['accuracy'] for m in self.metrics_history[-7:]]
if all(recent[i] < recent[i-1] for i in range(1, len(recent))):
return True, 'declining_trend'
return False, None
def trigger_retrain(self, reason):
# Trigger retraining pipeline
from airflow.api.client.local_client import Client
client = Client(None, None)
client.trigger_dag(
dag_id='model_retraining',
conf={'trigger_reason': reason}
)
Commands
/omgops:monitor- Setup monitoring/omgops:drift- Drift detection/omgops:retrain- Trigger retraining/omgtrain:evaluate- Evaluate model
Best Practices
- Monitor predictions, not just system metrics
- Set up automated drift detection
- Implement graceful degradation
- Use uncertainty estimation
- Have clear retraining triggers