Best practices for deploying and maintaining RL systems in production environments.
🏭 Production Configuration
Configure RL systems for production environments.
Production-Ready Setup
from azcore.rl.rl_manager import RLManager, ExplorationStrategy
from azcore.rl.rewards import HeuristicRewardCalculator
# Production RL Manager configuration
production_rl = RLManager(
tool_names=["search", "calculate", "weather", "email", "analyze"],
q_table_path="/var/lib/app/rl_data/production.pkl",
# Conservative exploration for production
exploration_strategy=ExplorationStrategy.EPSILON_DECAY,
exploration_rate=0.05, # Low exploration (5%)
epsilon_decay_rate=0.9999, # Very slow decay
min_exploration_rate=0.01, # Floor at 1%
# Moderate learning for stability
learning_rate=0.05, # Conservative learning
discount_factor=0.95,
# Performance optimizations
use_embeddings=True, # Better generalization
similarity_threshold=0.75,
# Maintenance
enable_async_persistence=True, # Non-blocking persistence
batch_update_size=50,
# Q-table pruning for memory management
enable_q_table_pruning=True,
prune_threshold=10000,
min_visits_to_keep=10,
# Caching
state_cache_size=5000
)
print("Production RL Manager initialized")
print(f"Exploration rate: {production_rl.exploration_rate:.2%}")
print(f"Using embeddings: {production_rl.use_embeddings}")
Environment-Specific Configurations
import os
from typing import Dict
class RLConfigManager:
"""Manage RL configuration across environments."""
def __init__(self, environment: str):
"""
Args:
environment: 'development', 'staging', 'production'
"""
self.environment = environment
self.config = self._get_config()
def _get_config(self) -> Dict:
"""Get environment-specific configuration."""
configs = {
"development": {
"exploration_rate": 0.3,
"learning_rate": 0.2,
"q_table_path": "rl_data/dev.pkl",
"enable_debug": True,
"enable_async_persistence": False,
"enable_q_table_pruning": False
},
"staging": {
"exploration_rate": 0.15,
"learning_rate": 0.1,
"q_table_path": "/var/lib/app/rl_data/staging.pkl",
"enable_debug": False,
"enable_async_persistence": True,
"enable_q_table_pruning": True
},
"production": {
"exploration_rate": 0.05,
"learning_rate": 0.05,
"q_table_path": "/var/lib/app/rl_data/production.pkl",
"enable_debug": False,
"enable_async_persistence": True,
"enable_q_table_pruning": True,
"prune_threshold": 10000,
"min_visits_to_keep": 10
}
}
return configs.get(self.environment, configs["development"])
def create_rl_manager(self, tool_names: list):
"""Create RL manager with environment config."""
return RLManager(
tool_names=tool_names,
q_table_path=self.config["q_table_path"],
exploration_strategy=ExplorationStrategy.EPSILON_DECAY,
exploration_rate=self.config["exploration_rate"],
learning_rate=self.config["learning_rate"],
use_embeddings=True,
enable_async_persistence=self.config.get("enable_async_persistence", False),
enable_q_table_pruning=self.config.get("enable_q_table_pruning", False),
prune_threshold=self.config.get("prune_threshold", 5000),
min_visits_to_keep=self.config.get("min_visits_to_keep", 5)
)
# Usage
environment = os.getenv("ENVIRONMENT", "development")
config_manager = RLConfigManager(environment)
rl_manager = config_manager.create_rl_manager(
tool_names=["search", "calculate", "weather", "email"]
)
print(f"RL Manager created for {environment} environment")
print(f"Configuration: {config_manager.config}")
🔄 High Availability Setup
Deploy RL systems with high availability.
Load-Balanced RL Deployment
import threading
import time
from queue import Queue
from typing import List
class RLLoadBalancer:
"""Load balancer for RL instances."""
def __init__(self, rl_instances: List[RLManager]):
"""
Args:
rl_instances: List of RL manager instances
"""
self.rl_instances = rl_instances
self.current_index = 0
self.lock = threading.Lock()
self.health_status = [True] * len(rl_instances)
def get_next_instance(self) -> RLManager:
"""Get next healthy RL instance (round-robin)."""
with self.lock:
attempts = 0
while attempts < len(self.rl_instances):
instance = self.rl_instances[self.current_index]
is_healthy = self.health_status[self.current_index]
self.current_index = (self.current_index + 1) % len(self.rl_instances)
if is_healthy:
return instance
attempts += 1
# All instances unhealthy - return first one
return self.rl_instances[0]
def mark_unhealthy(self, instance: RLManager):
"""Mark instance as unhealthy."""
with self.lock:
try:
index = self.rl_instances.index(instance)
self.health_status[index] = False
print(f"Instance {index} marked unhealthy")
except ValueError:
pass
def mark_healthy(self, instance: RLManager):
"""Mark instance as healthy."""
with self.lock:
try:
index = self.rl_instances.index(instance)
self.health_status[index] = True
print(f"Instance {index} marked healthy")
except ValueError:
pass
def health_check_loop(self, interval=60):
"""Periodic health check (runs in background thread)."""
while True:
for i, instance in enumerate(self.rl_instances):
try:
# Health check: Get statistics
stats = instance.get_statistics()
if stats:
self.health_status[i] = True
else:
self.health_status[i] = False
except Exception as e:
print(f"Health check failed for instance {i}: {e}")
self.health_status[i] = False
time.sleep(interval)
# Create multiple RL instances
rl_instances = [
RLManager(
tool_names=["search", "calculate", "weather", "email"],
q_table_path=f"/var/lib/app/rl_data/instance_{i}.pkl",
exploration_rate=0.05,
use_embeddings=True
)
for i in range(3)
]
# Create load balancer
load_balancer = RLLoadBalancer(rl_instances)
# Start health check in background
health_check_thread = threading.Thread(
target=load_balancer.health_check_loop,
args=(60,),
daemon=True
)
health_check_thread.start()
# Use load balancer
def handle_query(query: str):
"""Handle query with load balancing."""
rl_instance = load_balancer.get_next_instance()
try:
selected, state_key = rl_instance.select_tools(query, top_n=2)
return selected, state_key, rl_instance
except Exception as e:
print(f"Error with instance: {e}")
load_balancer.mark_unhealthy(rl_instance)
# Retry with next instance
return handle_query(query)
# Process queries
for i in range(100):
query = f"Query {i}"
selected, state_key, instance = handle_query(query)
print(f"Query {i} handled by instance: {rl_instances.index(instance)}")
Shared Q-Table Storage
import redis
import pickle
from contextlib import contextmanager
class SharedRLStorage:
"""Shared RL storage using Redis."""
def __init__(self, redis_host="localhost", redis_port=6379):
"""
Args:
redis_host: Redis host
redis_port: Redis port
"""
self.redis_client = redis.Redis(
host=redis_host,
port=redis_port,
decode_responses=False
)
def save_q_table(self, rl_manager: RLManager, key_prefix="rl"):
"""Save Q-table to Redis."""
q_table_key = f"{key_prefix}:q_table"
visit_counts_key = f"{key_prefix}:visit_counts"
# Serialize and save
q_table_bytes = pickle.dumps(dict(rl_manager.q_table))
visit_counts_bytes = pickle.dumps(dict(rl_manager.visit_counts))
self.redis_client.set(q_table_key, q_table_bytes)
self.redis_client.set(visit_counts_key, visit_counts_bytes)
print(f"Q-table saved to Redis under {key_prefix}")
def load_q_table(self, rl_manager: RLManager, key_prefix="rl"):
"""Load Q-table from Redis."""
q_table_key = f"{key_prefix}:q_table"
visit_counts_key = f"{key_prefix}:visit_counts"
# Load and deserialize
q_table_bytes = self.redis_client.get(q_table_key)
visit_counts_bytes = self.redis_client.get(visit_counts_key)
if q_table_bytes and visit_counts_bytes:
q_table = pickle.loads(q_table_bytes)
visit_counts = pickle.loads(visit_counts_bytes)
# Update RL manager
from collections import defaultdict
rl_manager.q_table = defaultdict(lambda: defaultdict(float), q_table)
rl_manager.visit_counts = defaultdict(lambda: defaultdict(int), visit_counts)
print(f"Q-table loaded from Redis ({len(q_table)} states)")
return True
print("No Q-table found in Redis")
return False
@contextmanager
def atomic_update(self, key_prefix="rl"):
"""Context manager for atomic Q-table updates."""
lock_key = f"{key_prefix}:lock"
lock_acquired = False
try:
# Acquire lock with timeout
lock_acquired = self.redis_client.set(
lock_key, "locked", nx=True, ex=30
)
if lock_acquired:
yield
else:
raise Exception("Could not acquire lock for Q-table update")
finally:
if lock_acquired:
self.redis_client.delete(lock_key)
# Usage in production
storage = SharedRLStorage(redis_host="redis.example.com")
# Load shared Q-table on startup
rl_manager = RLManager(
tool_names=["search", "calculate", "weather", "email"],
q_table_path="/tmp/local_cache.pkl", # Local cache
exploration_rate=0.05,
use_embeddings=True
)
# Load from shared storage
storage.load_q_table(rl_manager, key_prefix="production")
# Periodic sync to shared storage
def sync_to_shared_storage():
"""Sync local Q-table to shared storage."""
with storage.atomic_update(key_prefix="production"):
storage.save_q_table(rl_manager, key_prefix="production")
print("Synced Q-table to shared storage")
# Schedule periodic sync (e.g., every 5 minutes)
import schedule
schedule.every(5).minutes.do(sync_to_shared_storage)
📊 Monitoring Integration
Integrate RL monitoring with production monitoring systems.
Prometheus Metrics
from prometheus_client import Counter, Gauge, Histogram, start_http_server
import time
# Define Prometheus metrics
rl_tool_selections = Counter(
'rl_tool_selections_total',
'Total number of tool selections',
['tool', 'environment']
)
rl_rewards = Histogram(
'rl_rewards',
'Distribution of rewards',
['environment']
)
rl_exploration_rate = Gauge(
'rl_exploration_rate',
'Current exploration rate',
['environment']
)
rl_q_table_size = Gauge(
'rl_q_table_size',
'Number of states in Q-table',
['environment']
)
rl_avg_q_value = Gauge(
'rl_avg_q_value',
'Average Q-value across all states',
['environment']
)
rl_query_latency = Histogram(
'rl_query_latency_seconds',
'Latency of RL query processing',
['environment']
)
class PrometheusRLMonitor:
"""Monitor RL system with Prometheus metrics."""
def __init__(self, rl_manager: RLManager, environment="production"):
"""
Args:
rl_manager: RLManager instance
environment: Environment label
"""
self.rl_manager = rl_manager
self.environment = environment
def record_tool_selection(self, tools: list):
"""Record tool selection metrics."""
for tool in tools:
rl_tool_selections.labels(
tool=tool,
environment=self.environment
).inc()
def record_reward(self, reward: float):
"""Record reward metric."""
rl_rewards.labels(
environment=self.environment
).observe(reward)
def update_gauges(self):
"""Update gauge metrics."""
stats = self.rl_manager.get_statistics()
rl_exploration_rate.labels(
environment=self.environment
).set(stats['exploration_rate'])
rl_q_table_size.labels(
environment=self.environment
).set(stats['total_states'])
rl_avg_q_value.labels(
environment=self.environment
).set(stats['avg_q_value'])
def handle_query_with_metrics(self, query: str, reward_calculator):
"""Handle query and record metrics."""
start_time = time.time()
try:
# Select tools
selected, state_key = self.rl_manager.select_tools(query, top_n=2)
# Record selection
self.record_tool_selection(selected)
# Simulate agent execution (replace with actual agent)
result = {"messages": [{"content": "Success"}]}
# Calculate reward
reward = reward_calculator.calculate(None, result, query)
self.record_reward(reward)
# Update Q-values
for tool in selected:
self.rl_manager.update(state_key, tool, reward)
# Update gauges
self.update_gauges()
# Record latency
latency = time.time() - start_time
rl_query_latency.labels(
environment=self.environment
).observe(latency)
return selected, reward
except Exception as e:
print(f"Error handling query: {e}")
raise
# Start Prometheus metrics server
start_http_server(8000)
print("Prometheus metrics available at http://localhost:8000/metrics")
# Create monitor
from azcore.rl.rewards import HeuristicRewardCalculator
monitor = PrometheusRLMonitor(rl_manager, environment="production")
reward_calc = HeuristicRewardCalculator()
# Handle queries with monitoring
for i in range(100):
query = f"Query {i}"
selected, reward = monitor.handle_query_with_metrics(query, reward_calc)
print(f"Query {i}: Selected {selected}, Reward {reward:.2f}")
time.sleep(0.1)
Structured Logging
import logging
import json
from datetime import datetime
class RLStructuredLogger:
"""Structured logging for RL systems."""
def __init__(self, rl_manager: RLManager, logger_name="rl_system"):
"""
Args:
rl_manager: RLManager instance
logger_name: Logger name
"""
self.rl_manager = rl_manager
self.logger = logging.getLogger(logger_name)
# Configure JSON logging
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(message)s'))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_tool_selection(self, query: str, selected: list, state_key: str):
"""Log tool selection event."""
log_entry = {
"timestamp": datetime.now().isoformat(),
"event": "tool_selection",
"query": query,
"selected_tools": selected,
"state_key": state_key[:50],
"exploration_rate": self.rl_manager.exploration_rate,
"q_table_size": len(self.rl_manager.q_table)
}
self.logger.info(json.dumps(log_entry))
def log_q_update(self, state_key: str, tool: str, reward: float, old_q: float, new_q: float):
"""Log Q-value update event."""
log_entry = {
"timestamp": datetime.now().isoformat(),
"event": "q_value_update",
"state_key": state_key[:50],
"tool": tool,
"reward": reward,
"old_q_value": old_q,
"new_q_value": new_q,
"delta": new_q - old_q
}
self.logger.info(json.dumps(log_entry))
def log_anomaly(self, anomaly_type: str, message: str, severity="WARNING"):
"""Log anomaly detection event."""
log_entry = {
"timestamp": datetime.now().isoformat(),
"event": "anomaly_detected",
"type": anomaly_type,
"severity": severity,
"message": message,
"exploration_rate": self.rl_manager.exploration_rate,
"avg_q_value": self.rl_manager.get_statistics()['avg_q_value']
}
self.logger.warning(json.dumps(log_entry))
def log_error(self, error_type: str, error_message: str, context: dict = None):
"""Log error event."""
log_entry = {
"timestamp": datetime.now().isoformat(),
"event": "error",
"error_type": error_type,
"error_message": error_message,
"context": context or {}
}
self.logger.error(json.dumps(log_entry))
# Usage
structured_logger = RLStructuredLogger(rl_manager)
# Log tool selection
query = "Search for documentation"
selected, state_key = rl_manager.select_tools(query, top_n=2)
structured_logger.log_tool_selection(query, selected, state_key)
# Log Q-value update
old_q = rl_manager.q_table[state_key].get(selected[0], 0.0)
reward = 0.8
rl_manager.update(state_key, selected[0], reward)
new_q = rl_manager.q_table[state_key][selected[0]]
structured_logger.log_q_update(state_key, selected[0], reward, old_q, new_q)
🔄 Gradual Rollout
Gradually roll out RL systems to production.
A/B Testing Framework
import random
from typing import Dict
class RLABTesting:
"""A/B testing framework for RL rollout."""
def __init__(
self,
rl_manager: RLManager,
baseline_tools: list,
rollout_percentage: float = 0.2
):
"""
Args:
rl_manager: RLManager instance
baseline_tools: Baseline tool set (no RL)
rollout_percentage: Percentage of traffic to RL (0.0-1.0)
"""
self.rl_manager = rl_manager
self.baseline_tools = baseline_tools
self.rollout_percentage = rollout_percentage
# Experiment tracking
self.rl_metrics = {
"queries": 0,
"total_reward": 0.0,
"avg_reward": 0.0,
"tool_selections": {}
}
self.baseline_metrics = {
"queries": 0,
"total_reward": 0.0,
"avg_reward": 0.0,
"tool_selections": {}
}
def should_use_rl(self) -> bool:
"""Decide if this query should use RL."""
return random.random() < self.rollout_percentage
def handle_query(self, query: str, reward_calculator) -> Dict:
"""
Handle query with A/B testing.
Returns: Dict with results and experiment info
"""
use_rl = self.should_use_rl()
if use_rl:
# RL variant
selected, state_key = self.rl_manager.select_tools(query, top_n=2)
variant = "rl"
metrics = self.rl_metrics
else:
# Baseline variant
selected = self.baseline_tools[:2]
state_key = None
variant = "baseline"
metrics = self.baseline_metrics
# Simulate execution and calculate reward
result = {"messages": [{"content": "Success"}]}
reward = reward_calculator.calculate(None, result, query)
# Update metrics
metrics["queries"] += 1
metrics["total_reward"] += reward
metrics["avg_reward"] = metrics["total_reward"] / metrics["queries"]
for tool in selected:
metrics["tool_selections"][tool] = metrics["tool_selections"].get(tool, 0) + 1
# Update RL if used
if use_rl:
for tool in selected:
self.rl_manager.update(state_key, tool, reward)
return {
"query": query,
"variant": variant,
"selected_tools": selected,
"reward": reward,
"use_rl": use_rl
}
def get_experiment_results(self) -> Dict:
"""Get A/B test results."""
return {
"rollout_percentage": self.rollout_percentage,
"rl_metrics": self.rl_metrics,
"baseline_metrics": self.baseline_metrics,
"rl_improvement": (
(self.rl_metrics["avg_reward"] - self.baseline_metrics["avg_reward"])
/ self.baseline_metrics["avg_reward"] * 100
if self.baseline_metrics["avg_reward"] > 0 else 0
)
}
def increase_rollout(self, increment: float = 0.1):
"""Gradually increase RL rollout percentage."""
self.rollout_percentage = min(1.0, self.rollout_percentage + increment)
print(f"RL rollout increased to {self.rollout_percentage:.1%}")
# Usage
from azcore.rl.rewards import HeuristicRewardCalculator
ab_test = RLABTesting(
rl_manager=rl_manager,
baseline_tools=["search", "calculate", "weather", "email"],
rollout_percentage=0.2 # Start with 20% RL traffic
)
reward_calc = HeuristicRewardCalculator()
print("=== A/B Testing RL Rollout ===\n")
# Phase 1: 20% rollout
print("Phase 1: 20% RL traffic")
for i in range(100):
query = f"Query {i}"
result = ab_test.handle_query(query, reward_calc)
results_phase1 = ab_test.get_experiment_results()
print(f"\nPhase 1 Results:")
print(f" RL Avg Reward: {results_phase1['rl_metrics']['avg_reward']:.3f}")
print(f" Baseline Avg Reward: {results_phase1['baseline_metrics']['avg_reward']:.3f}")
print(f" Improvement: {results_phase1['rl_improvement']:.1f}%")
# Phase 2: Increase to 50% if positive results
if results_phase1['rl_improvement'] > 0:
print("\n✓ Positive results - increasing rollout to 50%")
ab_test.increase_rollout(0.3)
for i in range(100, 200):
query = f"Query {i}"
result = ab_test.handle_query(query, reward_calc)
results_phase2 = ab_test.get_experiment_results()
print(f"\nPhase 2 Results (50% rollout):")
print(f" RL Avg Reward: {results_phase2['rl_metrics']['avg_reward']:.3f}")
print(f" Baseline Avg Reward: {results_phase2['baseline_metrics']['avg_reward']:.3f}")
print(f" Improvement: {results_phase2['rl_improvement']:.1f}%")
# Phase 3: Full rollout if still positive
if results_phase2['rl_improvement'] > 0:
print("\n✓ Still positive - full rollout (100%)")
ab_test.increase_rollout(0.5)
Feature Flags
class RLFeatureFlags:
"""Feature flags for RL system control."""
def __init__(self):
"""Initialize feature flags."""
self.flags = {
"rl_enabled": True,
"rl_learning_enabled": True,
"rl_exploration_enabled": True,
"rl_pruning_enabled": True,
"rl_async_persist_enabled": True
}
def is_enabled(self, flag: str) -> bool:
"""Check if feature is enabled."""
return self.flags.get(flag, False)
def set_flag(self, flag: str, value: bool):
"""Set feature flag value."""
if flag in self.flags:
old_value = self.flags[flag]
self.flags[flag] = value
print(f"Feature flag '{flag}' changed: {old_value} → {value}")
else:
print(f"Unknown feature flag: {flag}")
def emergency_disable_rl(self):
"""Emergency disable all RL features."""
print("🚨 EMERGENCY: Disabling all RL features")
self.flags["rl_enabled"] = False
self.flags["rl_learning_enabled"] = False
self.flags["rl_exploration_enabled"] = False
class RLWithFeatureFlags:
"""RL system with feature flag support."""
def __init__(self, rl_manager: RLManager, feature_flags: RLFeatureFlags):
"""
Args:
rl_manager: RLManager instance
feature_flags: Feature flags instance
"""
self.rl_manager = rl_manager
self.feature_flags = feature_flags
def select_tools(self, query: str, top_n: int = 2):
"""Select tools with feature flag checks."""
if not self.feature_flags.is_enabled("rl_enabled"):
# Fall back to default tools
return self.rl_manager.tool_names[:top_n], None
# Override exploration if disabled
if not self.feature_flags.is_enabled("rl_exploration_enabled"):
original_rate = self.rl_manager.exploration_rate
self.rl_manager.exploration_rate = 0.0
selected, state_key = self.rl_manager.select_tools(query, top_n=top_n)
self.rl_manager.exploration_rate = original_rate
return selected, state_key
return self.rl_manager.select_tools(query, top_n=top_n)
def update(self, state_key: str, tool: str, reward: float):
"""Update Q-values with feature flag checks."""
if not self.feature_flags.is_enabled("rl_learning_enabled"):
return # Skip learning
self.rl_manager.update(state_key, tool, reward)
# Usage
feature_flags = RLFeatureFlags()
rl_with_flags = RLWithFeatureFlags(rl_manager, feature_flags)
# Normal operation
selected, state_key = rl_with_flags.select_tools("Query", top_n=2)
rl_with_flags.update(state_key, selected[0], 0.9)
# Disable exploration (exploit only)
feature_flags.set_flag("rl_exploration_enabled", False)
selected, state_key = rl_with_flags.select_tools("Query", top_n=2)
# Emergency disable
feature_flags.emergency_disable_rl()
💾 Backup and Recovery
Implement backup and recovery for RL systems.
Automated Backup System
import shutil
from datetime import datetime, timedelta
from pathlib import Path
class RLBackupManager:
"""Automated backup manager for RL Q-tables."""
def __init__(
self,
rl_manager: RLManager,
backup_dir: str = "backups",
retention_days: int = 7
):
"""
Args:
rl_manager: RLManager instance
backup_dir: Directory for backups
retention_days: Days to retain backups
"""
self.rl_manager = rl_manager
self.backup_dir = Path(backup_dir)
self.retention_days = retention_days
# Create backup directory
self.backup_dir.mkdir(parents=True, exist_ok=True)
def create_backup(self, label: str = ""):
"""
Create Q-table backup.
Args:
label: Optional label for backup
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
label_suffix = f"_{label}" if label else ""
backup_filename = f"qtable_{timestamp}{label_suffix}.pkl"
backup_path = self.backup_dir / backup_filename
# Force persistence to ensure Q-table is saved
self.rl_manager.force_persist()
# Copy Q-table file
source_path = Path(self.rl_manager.q_table_path)
if source_path.exists():
shutil.copy(source_path, backup_path)
print(f"✓ Backup created: {backup_path}")
return backup_path
else:
print(f"✗ Source Q-table not found: {source_path}")
return None
def restore_backup(self, backup_path: str):
"""
Restore Q-table from backup.
Args:
backup_path: Path to backup file
"""
backup_file = Path(backup_path)
if not backup_file.exists():
print(f"✗ Backup file not found: {backup_path}")
return False
# Create safety backup of current state
self.create_backup(label="pre_restore")
# Restore from backup
dest_path = Path(self.rl_manager.q_table_path)
shutil.copy(backup_file, dest_path)
# Reload Q-table
self.rl_manager._load_q_table()
print(f"✓ Restored from backup: {backup_path}")
return True
def cleanup_old_backups(self):
"""Remove backups older than retention period."""
cutoff_date = datetime.now() - timedelta(days=self.retention_days)
removed_count = 0
for backup_file in self.backup_dir.glob("qtable_*.pkl"):
# Skip pre_restore backups (keep indefinitely)
if "pre_restore" in backup_file.name:
continue
# Get file modification time
file_time = datetime.fromtimestamp(backup_file.stat().st_mtime)
if file_time < cutoff_date:
backup_file.unlink()
removed_count += 1
if removed_count > 0:
print(f"✓ Cleaned up {removed_count} old backup(s)")
return removed_count
def list_backups(self):
"""List all available backups."""
backups = []
for backup_file in sorted(self.backup_dir.glob("qtable_*.pkl"), reverse=True):
file_time = datetime.fromtimestamp(backup_file.stat().st_mtime)
file_size = backup_file.stat().st_size
backups.append({
"filename": backup_file.name,
"path": str(backup_file),
"timestamp": file_time,
"size_kb": file_size / 1024
})
return backups
def get_latest_backup(self):
"""Get path to latest backup."""
backups = self.list_backups()
if backups:
return backups[0]["path"]
return None
# Usage
backup_manager = RLBackupManager(
rl_manager=rl_manager,
backup_dir="/var/lib/app/rl_backups",
retention_days=7
)
# Create periodic backups (e.g., every 6 hours)
import schedule
def scheduled_backup():
backup_manager.create_backup(label="scheduled")
backup_manager.cleanup_old_backups()
schedule.every(6).hours.do(scheduled_backup)
# Create backup before risky operations
backup_manager.create_backup(label="before_update")
# Train or update RL
# ...
# If something goes wrong, restore
# backup_manager.restore_backup(backup_manager.get_latest_backup())
# List all backups
print("\n=== Available Backups ===")
for backup in backup_manager.list_backups():
print(f"{backup['filename']}: {backup['timestamp'].strftime('%Y-%m-%d %H:%M:%S')} "
f"({backup['size_kb']:.1f} KB)")
⚡ Performance Optimization
Optimize RL system performance for production.
Async Persistence
import asyncio
from threading import Thread
from queue import Queue
class AsyncPersistenceManager:
"""Asynchronous Q-table persistence."""
def __init__(self, rl_manager: RLManager, batch_size=50):
"""
Args:
rl_manager: RLManager instance
batch_size: Number of updates before persisting
"""
self.rl_manager = rl_manager
self.batch_size = batch_size
self.update_count = 0
self.persist_queue = Queue()
self.running = True
# Start persistence thread
self.persist_thread = Thread(target=self._persist_worker, daemon=True)
self.persist_thread.start()
def _persist_worker(self):
"""Background worker for persistence."""
while self.running:
try:
# Wait for persist signal
self.persist_queue.get(timeout=1)
# Persist Q-table
self.rl_manager.force_persist()
print("✓ Q-table persisted (async)")
except Exception as e:
print(f"Persistence error: {e}")
def record_update(self):
"""Record an update (may trigger persist)."""
self.update_count += 1
if self.update_count >= self.batch_size:
self.persist_queue.put("persist")
self.update_count = 0
def shutdown(self):
"""Shutdown persistence manager."""
self.running = False
self.persist_thread.join(timeout=5)
# Final persist
self.rl_manager.force_persist()
# Usage
async_persist = AsyncPersistenceManager(rl_manager, batch_size=50)
# Train with async persistence
for i in range(200):
query = f"Query {i}"
selected, state_key = rl_manager.select_tools(query, top_n=2)
reward = 0.8
for tool in selected:
rl_manager.update(state_key, tool, reward)
async_persist.record_update()
# Shutdown
async_persist.shutdown()
Caching Layer
from functools import lru_cache
import hashlib
class RLCachingLayer:
"""Caching layer for RL operations."""
def __init__(self, rl_manager: RLManager, cache_size=1000):
"""
Args:
rl_manager: RLManager instance
cache_size: Size of LRU cache
"""
self.rl_manager = rl_manager
# Create cached methods
self.cached_select_tools = lru_cache(maxsize=cache_size)(
self._select_tools_impl
)
def _select_tools_impl(self, query_hash: str, top_n: int):
"""Implementation of select_tools (cacheable)."""
# Reverse hash to get query (simplified)
# In practice, maintain a query map
selected, state_key = self.rl_manager.select_tools(
self.query_map[query_hash],
top_n=top_n
)
return selected, state_key
def select_tools(self, query: str, top_n: int = 2):
"""Select tools with caching."""
# Hash query for caching
query_hash = hashlib.md5(query.encode()).hexdigest()
# Store in query map for reverse lookup
if not hasattr(self, 'query_map'):
self.query_map = {}
self.query_map[query_hash] = query
# Use cached method
return self.cached_select_tools(query_hash, top_n)
def clear_cache(self):
"""Clear cache."""
self.cached_select_tools.cache_clear()
def get_cache_stats(self):
"""Get cache statistics."""
info = self.cached_select_tools.cache_info()
return {
"hits": info.hits,
"misses": info.misses,
"hit_rate": info.hits / (info.hits + info.misses) if (info.hits + info.misses) > 0 else 0,
"size": info.currsize,
"max_size": info.maxsize
}
# Usage
caching_layer = RLCachingLayer(rl_manager, cache_size=1000)
# Use with caching
for i in range(100):
query = f"Repeated query {i % 10}" # Many repeats
selected, state_key = caching_layer.select_tools(query, top_n=2)
# Check cache performance
stats = caching_layer.get_cache_stats()
print(f"\n=== Cache Statistics ===")
print(f"Hit rate: {stats['hit_rate']:.1%}")
print(f"Hits: {stats['hits']}, Misses: {stats['misses']}")
print(f"Cache size: {stats['size']}/{stats['max_size']}")
🔒 Security Considerations
Secure your production RL deployment.
Input Validation
class RLInputValidator:
"""Validate inputs to RL system."""
def __init__(self, max_query_length=1000):
"""
Args:
max_query_length: Maximum query length
"""
self.max_query_length = max_query_length
def validate_query(self, query: str) -> tuple[bool, str]:
"""
Validate query input.
Returns: (is_valid, error_message)
"""
if not query:
return False, "Query cannot be empty"
if len(query) > self.max_query_length:
return False, f"Query exceeds maximum length ({self.max_query_length})"
# Check for injection attempts
dangerous_patterns = ["<script>", "javascript:", "eval(", "__import__"]
for pattern in dangerous_patterns:
if pattern.lower() in query.lower():
return False, f"Potentially malicious pattern detected: {pattern}"
return True, ""
def sanitize_query(self, query: str) -> str:
"""Sanitize query input."""
# Strip whitespace
query = query.strip()
# Remove control characters
query = ''.join(char for char in query if ord(char) >= 32 or char == '\n')
return query
# Usage
validator = RLInputValidator(max_query_length=1000)
def safe_handle_query(query: str):
"""Handle query with validation."""
# Validate
is_valid, error_msg = validator.validate_query(query)
if not is_valid:
print(f"✗ Invalid query: {error_msg}")
return None
# Sanitize
query = validator.sanitize_query(query)
# Process
selected, state_key = rl_manager.select_tools(query, top_n=2)
return selected
# Test validation
safe_handle_query("Normal query") # ✓ Valid
safe_handle_query("<script>alert('xss')</script>") # ✗ Invalid
safe_handle_query("A" * 2000) # ✗ Too long
Rate Limiting
import time
from collections import defaultdict
class RLRateLimiter:
"""Rate limiter for RL system."""
def __init__(self, max_requests_per_minute=100):
"""
Args:
max_requests_per_minute: Maximum requests per minute
"""
self.max_requests = max_requests_per_minute
self.requests = defaultdict(list)
def is_allowed(self, client_id: str) -> bool:
"""Check if client is allowed to make request."""
now = time.time()
cutoff = now - 60 # 1 minute ago
# Remove old requests
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > cutoff
]
# Check limit
if len(self.requests[client_id]) >= self.max_requests:
return False
# Record request
self.requests[client_id].append(now)
return True
# Usage
rate_limiter = RLRateLimiter(max_requests_per_minute=100)
def rate_limited_handle_query(query: str, client_id: str):
"""Handle query with rate limiting."""
if not rate_limiter.is_allowed(client_id):
print(f"✗ Rate limit exceeded for client {client_id}")
return None
# Process query
selected, state_key = rl_manager.select_tools(query, top_n=2)
return selected
📋 Production Checklist
Comprehensive checklist for production deployment.
class ProductionChecklist:
"""Production readiness checklist."""
def __init__(self, rl_manager: RLManager):
"""
Args:
rl_manager: RLManager instance
"""
self.rl_manager = rl_manager
self.checks = {}
def check_exploration_rate(self):
"""Check if exploration rate is appropriate."""
rate = self.rl_manager.exploration_rate
passed = 0.01 <= rate <= 0.1
self.checks["exploration_rate"] = {
"passed": passed,
"value": rate,
"recommendation": "Should be between 1% and 10% for production"
}
def check_async_persistence(self):
"""Check if async persistence is enabled."""
# This would check actual configuration
passed = True # Placeholder
self.checks["async_persistence"] = {
"passed": passed,
"recommendation": "Enable async persistence for better performance"
}
def check_monitoring(self):
"""Check if monitoring is set up."""
# Placeholder - would check actual monitoring
passed = True
self.checks["monitoring"] = {
"passed": passed,
"recommendation": "Set up Prometheus metrics and structured logging"
}
def check_backup_system(self):
"""Check if backup system is configured."""
# Placeholder
passed = True
self.checks["backup_system"] = {
"passed": passed,
"recommendation": "Configure automated backups with 7-day retention"
}
def check_q_table_size(self):
"""Check Q-table size."""
stats = self.rl_manager.get_statistics()
size = stats['total_states']
passed = size > 0
self.checks["q_table_size"] = {
"passed": passed,
"value": size,
"recommendation": "Q-table should have been trained before production"
}
def run_all_checks(self):
"""Run all checks."""
self.check_exploration_rate()
self.check_async_persistence()
self.check_monitoring()
self.check_backup_system()
self.check_q_table_size()
def print_report(self):
"""Print checklist report."""
print("=" * 60)
print("PRODUCTION READINESS CHECKLIST")
print("=" * 60)
all_passed = True
for check_name, result in self.checks.items():
status = "✓" if result["passed"] else "✗"
print(f"\n{status} {check_name.replace('_', ' ').title()}")
if "value" in result:
print(f" Value: {result['value']}")
print(f" {result['recommendation']}")
if not result["passed"]:
all_passed = False
print("\n" + "=" * 60)
if all_passed:
print("✓ All checks passed - ready for production!")
else:
print("✗ Some checks failed - review recommendations before deploying")
print("=" * 60)
return all_passed
# Usage
checklist = ProductionChecklist(rl_manager)
checklist.run_all_checks()
is_ready = checklist.print_report()
if is_ready:
print("\nProceeding with deployment...")
else:
print("\nDeployment blocked - fix issues first")
🎯 Summary
Production deployment best practices:
- Configuration: Environment-specific configs with conservative exploration
- High Availability: Load balancing and shared storage
- Monitoring: Prometheus metrics and structured logging
- Gradual Rollout: A/B testing and feature flags
- Backup & Recovery: Automated backups with retention policies
- Performance: Async persistence and caching
- Security: Input validation and rate limiting
- Checklist: Comprehensive production readiness validation
Follow these practices to ensure reliable, performant, and secure RL deployment in production.