• Getting Started
  • Core Concepts
  • Reinforcement Learning
  • Model Context Protocol (MCP)
  • Workflow Patterns
  • Advanced Agent Patterns
  • Guides

Reinforcement Learning

Production Deployment

Deploying RL systems in production environments.

Best practices for deploying and maintaining RL systems in production environments.

🏭 Production Configuration

Configure RL systems for production environments.

Production-Ready Setup

from azcore.rl.rl_manager import RLManager, ExplorationStrategy
from azcore.rl.rewards import HeuristicRewardCalculator

# Production RL Manager configuration
production_rl = RLManager(
    tool_names=["search", "calculate", "weather", "email", "analyze"],
    q_table_path="/var/lib/app/rl_data/production.pkl",

    # Conservative exploration for production
    exploration_strategy=ExplorationStrategy.EPSILON_DECAY,
    exploration_rate=0.05,  # Low exploration (5%)
    epsilon_decay_rate=0.9999,  # Very slow decay
    min_exploration_rate=0.01,  # Floor at 1%

    # Moderate learning for stability
    learning_rate=0.05,  # Conservative learning
    discount_factor=0.95,

    # Performance optimizations
    use_embeddings=True,  # Better generalization
    similarity_threshold=0.75,

    # Maintenance
    enable_async_persistence=True,  # Non-blocking persistence
    batch_update_size=50,

    # Q-table pruning for memory management
    enable_q_table_pruning=True,
    prune_threshold=10000,
    min_visits_to_keep=10,

    # Caching
    state_cache_size=5000
)

print("Production RL Manager initialized")
print(f"Exploration rate: {production_rl.exploration_rate:.2%}")
print(f"Using embeddings: {production_rl.use_embeddings}")

Environment-Specific Configurations

import os
from typing import Dict

class RLConfigManager:
    """Manage RL configuration across environments."""

    def __init__(self, environment: str):
        """
        Args:
            environment: 'development', 'staging', 'production'
        """
        self.environment = environment
        self.config = self._get_config()

    def _get_config(self) -> Dict:
        """Get environment-specific configuration."""
        configs = {
            "development": {
                "exploration_rate": 0.3,
                "learning_rate": 0.2,
                "q_table_path": "rl_data/dev.pkl",
                "enable_debug": True,
                "enable_async_persistence": False,
                "enable_q_table_pruning": False
            },
            "staging": {
                "exploration_rate": 0.15,
                "learning_rate": 0.1,
                "q_table_path": "/var/lib/app/rl_data/staging.pkl",
                "enable_debug": False,
                "enable_async_persistence": True,
                "enable_q_table_pruning": True
            },
            "production": {
                "exploration_rate": 0.05,
                "learning_rate": 0.05,
                "q_table_path": "/var/lib/app/rl_data/production.pkl",
                "enable_debug": False,
                "enable_async_persistence": True,
                "enable_q_table_pruning": True,
                "prune_threshold": 10000,
                "min_visits_to_keep": 10
            }
        }

        return configs.get(self.environment, configs["development"])

    def create_rl_manager(self, tool_names: list):
        """Create RL manager with environment config."""
        return RLManager(
            tool_names=tool_names,
            q_table_path=self.config["q_table_path"],
            exploration_strategy=ExplorationStrategy.EPSILON_DECAY,
            exploration_rate=self.config["exploration_rate"],
            learning_rate=self.config["learning_rate"],
            use_embeddings=True,
            enable_async_persistence=self.config.get("enable_async_persistence", False),
            enable_q_table_pruning=self.config.get("enable_q_table_pruning", False),
            prune_threshold=self.config.get("prune_threshold", 5000),
            min_visits_to_keep=self.config.get("min_visits_to_keep", 5)
        )

# Usage
environment = os.getenv("ENVIRONMENT", "development")
config_manager = RLConfigManager(environment)

rl_manager = config_manager.create_rl_manager(
    tool_names=["search", "calculate", "weather", "email"]
)

print(f"RL Manager created for {environment} environment")
print(f"Configuration: {config_manager.config}")

🔄 High Availability Setup

Deploy RL systems with high availability.

Load-Balanced RL Deployment

import threading
import time
from queue import Queue
from typing import List

class RLLoadBalancer:
    """Load balancer for RL instances."""

    def __init__(self, rl_instances: List[RLManager]):
        """
        Args:
            rl_instances: List of RL manager instances
        """
        self.rl_instances = rl_instances
        self.current_index = 0
        self.lock = threading.Lock()
        self.health_status = [True] * len(rl_instances)

    def get_next_instance(self) -> RLManager:
        """Get next healthy RL instance (round-robin)."""
        with self.lock:
            attempts = 0
            while attempts < len(self.rl_instances):
                instance = self.rl_instances[self.current_index]
                is_healthy = self.health_status[self.current_index]

                self.current_index = (self.current_index + 1) % len(self.rl_instances)

                if is_healthy:
                    return instance

                attempts += 1

            # All instances unhealthy - return first one
            return self.rl_instances[0]

    def mark_unhealthy(self, instance: RLManager):
        """Mark instance as unhealthy."""
        with self.lock:
            try:
                index = self.rl_instances.index(instance)
                self.health_status[index] = False
                print(f"Instance {index} marked unhealthy")
            except ValueError:
                pass

    def mark_healthy(self, instance: RLManager):
        """Mark instance as healthy."""
        with self.lock:
            try:
                index = self.rl_instances.index(instance)
                self.health_status[index] = True
                print(f"Instance {index} marked healthy")
            except ValueError:
                pass

    def health_check_loop(self, interval=60):
        """Periodic health check (runs in background thread)."""
        while True:
            for i, instance in enumerate(self.rl_instances):
                try:
                    # Health check: Get statistics
                    stats = instance.get_statistics()
                    if stats:
                        self.health_status[i] = True
                    else:
                        self.health_status[i] = False
                except Exception as e:
                    print(f"Health check failed for instance {i}: {e}")
                    self.health_status[i] = False

            time.sleep(interval)

# Create multiple RL instances
rl_instances = [
    RLManager(
        tool_names=["search", "calculate", "weather", "email"],
        q_table_path=f"/var/lib/app/rl_data/instance_{i}.pkl",
        exploration_rate=0.05,
        use_embeddings=True
    )
    for i in range(3)
]

# Create load balancer
load_balancer = RLLoadBalancer(rl_instances)

# Start health check in background
health_check_thread = threading.Thread(
    target=load_balancer.health_check_loop,
    args=(60,),
    daemon=True
)
health_check_thread.start()

# Use load balancer
def handle_query(query: str):
    """Handle query with load balancing."""
    rl_instance = load_balancer.get_next_instance()

    try:
        selected, state_key = rl_instance.select_tools(query, top_n=2)
        return selected, state_key, rl_instance
    except Exception as e:
        print(f"Error with instance: {e}")
        load_balancer.mark_unhealthy(rl_instance)
        # Retry with next instance
        return handle_query(query)

# Process queries
for i in range(100):
    query = f"Query {i}"
    selected, state_key, instance = handle_query(query)
    print(f"Query {i} handled by instance: {rl_instances.index(instance)}")

Shared Q-Table Storage

import redis
import pickle
from contextlib import contextmanager

class SharedRLStorage:
    """Shared RL storage using Redis."""

    def __init__(self, redis_host="localhost", redis_port=6379):
        """
        Args:
            redis_host: Redis host
            redis_port: Redis port
        """
        self.redis_client = redis.Redis(
            host=redis_host,
            port=redis_port,
            decode_responses=False
        )

    def save_q_table(self, rl_manager: RLManager, key_prefix="rl"):
        """Save Q-table to Redis."""
        q_table_key = f"{key_prefix}:q_table"
        visit_counts_key = f"{key_prefix}:visit_counts"

        # Serialize and save
        q_table_bytes = pickle.dumps(dict(rl_manager.q_table))
        visit_counts_bytes = pickle.dumps(dict(rl_manager.visit_counts))

        self.redis_client.set(q_table_key, q_table_bytes)
        self.redis_client.set(visit_counts_key, visit_counts_bytes)

        print(f"Q-table saved to Redis under {key_prefix}")

    def load_q_table(self, rl_manager: RLManager, key_prefix="rl"):
        """Load Q-table from Redis."""
        q_table_key = f"{key_prefix}:q_table"
        visit_counts_key = f"{key_prefix}:visit_counts"

        # Load and deserialize
        q_table_bytes = self.redis_client.get(q_table_key)
        visit_counts_bytes = self.redis_client.get(visit_counts_key)

        if q_table_bytes and visit_counts_bytes:
            q_table = pickle.loads(q_table_bytes)
            visit_counts = pickle.loads(visit_counts_bytes)

            # Update RL manager
            from collections import defaultdict
            rl_manager.q_table = defaultdict(lambda: defaultdict(float), q_table)
            rl_manager.visit_counts = defaultdict(lambda: defaultdict(int), visit_counts)

            print(f"Q-table loaded from Redis ({len(q_table)} states)")
            return True

        print("No Q-table found in Redis")
        return False

    @contextmanager
    def atomic_update(self, key_prefix="rl"):
        """Context manager for atomic Q-table updates."""
        lock_key = f"{key_prefix}:lock"
        lock_acquired = False

        try:
            # Acquire lock with timeout
            lock_acquired = self.redis_client.set(
                lock_key, "locked", nx=True, ex=30
            )

            if lock_acquired:
                yield
            else:
                raise Exception("Could not acquire lock for Q-table update")

        finally:
            if lock_acquired:
                self.redis_client.delete(lock_key)

# Usage in production
storage = SharedRLStorage(redis_host="redis.example.com")

# Load shared Q-table on startup
rl_manager = RLManager(
    tool_names=["search", "calculate", "weather", "email"],
    q_table_path="/tmp/local_cache.pkl",  # Local cache
    exploration_rate=0.05,
    use_embeddings=True
)

# Load from shared storage
storage.load_q_table(rl_manager, key_prefix="production")

# Periodic sync to shared storage
def sync_to_shared_storage():
    """Sync local Q-table to shared storage."""
    with storage.atomic_update(key_prefix="production"):
        storage.save_q_table(rl_manager, key_prefix="production")
        print("Synced Q-table to shared storage")

# Schedule periodic sync (e.g., every 5 minutes)
import schedule
schedule.every(5).minutes.do(sync_to_shared_storage)

📊 Monitoring Integration

Integrate RL monitoring with production monitoring systems.

Prometheus Metrics

from prometheus_client import Counter, Gauge, Histogram, start_http_server
import time

# Define Prometheus metrics
rl_tool_selections = Counter(
    'rl_tool_selections_total',
    'Total number of tool selections',
    ['tool', 'environment']
)

rl_rewards = Histogram(
    'rl_rewards',
    'Distribution of rewards',
    ['environment']
)

rl_exploration_rate = Gauge(
    'rl_exploration_rate',
    'Current exploration rate',
    ['environment']
)

rl_q_table_size = Gauge(
    'rl_q_table_size',
    'Number of states in Q-table',
    ['environment']
)

rl_avg_q_value = Gauge(
    'rl_avg_q_value',
    'Average Q-value across all states',
    ['environment']
)

rl_query_latency = Histogram(
    'rl_query_latency_seconds',
    'Latency of RL query processing',
    ['environment']
)

class PrometheusRLMonitor:
    """Monitor RL system with Prometheus metrics."""

    def __init__(self, rl_manager: RLManager, environment="production"):
        """
        Args:
            rl_manager: RLManager instance
            environment: Environment label
        """
        self.rl_manager = rl_manager
        self.environment = environment

    def record_tool_selection(self, tools: list):
        """Record tool selection metrics."""
        for tool in tools:
            rl_tool_selections.labels(
                tool=tool,
                environment=self.environment
            ).inc()

    def record_reward(self, reward: float):
        """Record reward metric."""
        rl_rewards.labels(
            environment=self.environment
        ).observe(reward)

    def update_gauges(self):
        """Update gauge metrics."""
        stats = self.rl_manager.get_statistics()

        rl_exploration_rate.labels(
            environment=self.environment
        ).set(stats['exploration_rate'])

        rl_q_table_size.labels(
            environment=self.environment
        ).set(stats['total_states'])

        rl_avg_q_value.labels(
            environment=self.environment
        ).set(stats['avg_q_value'])

    def handle_query_with_metrics(self, query: str, reward_calculator):
        """Handle query and record metrics."""
        start_time = time.time()

        try:
            # Select tools
            selected, state_key = self.rl_manager.select_tools(query, top_n=2)

            # Record selection
            self.record_tool_selection(selected)

            # Simulate agent execution (replace with actual agent)
            result = {"messages": [{"content": "Success"}]}

            # Calculate reward
            reward = reward_calculator.calculate(None, result, query)
            self.record_reward(reward)

            # Update Q-values
            for tool in selected:
                self.rl_manager.update(state_key, tool, reward)

            # Update gauges
            self.update_gauges()

            # Record latency
            latency = time.time() - start_time
            rl_query_latency.labels(
                environment=self.environment
            ).observe(latency)

            return selected, reward

        except Exception as e:
            print(f"Error handling query: {e}")
            raise

# Start Prometheus metrics server
start_http_server(8000)
print("Prometheus metrics available at http://localhost:8000/metrics")

# Create monitor
from azcore.rl.rewards import HeuristicRewardCalculator

monitor = PrometheusRLMonitor(rl_manager, environment="production")
reward_calc = HeuristicRewardCalculator()

# Handle queries with monitoring
for i in range(100):
    query = f"Query {i}"
    selected, reward = monitor.handle_query_with_metrics(query, reward_calc)
    print(f"Query {i}: Selected {selected}, Reward {reward:.2f}")
    time.sleep(0.1)

Structured Logging

import logging
import json
from datetime import datetime

class RLStructuredLogger:
    """Structured logging for RL systems."""

    def __init__(self, rl_manager: RLManager, logger_name="rl_system"):
        """
        Args:
            rl_manager: RLManager instance
            logger_name: Logger name
        """
        self.rl_manager = rl_manager
        self.logger = logging.getLogger(logger_name)

        # Configure JSON logging
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(message)s'))
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)

    def log_tool_selection(self, query: str, selected: list, state_key: str):
        """Log tool selection event."""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "event": "tool_selection",
            "query": query,
            "selected_tools": selected,
            "state_key": state_key[:50],
            "exploration_rate": self.rl_manager.exploration_rate,
            "q_table_size": len(self.rl_manager.q_table)
        }
        self.logger.info(json.dumps(log_entry))

    def log_q_update(self, state_key: str, tool: str, reward: float, old_q: float, new_q: float):
        """Log Q-value update event."""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "event": "q_value_update",
            "state_key": state_key[:50],
            "tool": tool,
            "reward": reward,
            "old_q_value": old_q,
            "new_q_value": new_q,
            "delta": new_q - old_q
        }
        self.logger.info(json.dumps(log_entry))

    def log_anomaly(self, anomaly_type: str, message: str, severity="WARNING"):
        """Log anomaly detection event."""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "event": "anomaly_detected",
            "type": anomaly_type,
            "severity": severity,
            "message": message,
            "exploration_rate": self.rl_manager.exploration_rate,
            "avg_q_value": self.rl_manager.get_statistics()['avg_q_value']
        }
        self.logger.warning(json.dumps(log_entry))

    def log_error(self, error_type: str, error_message: str, context: dict = None):
        """Log error event."""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "event": "error",
            "error_type": error_type,
            "error_message": error_message,
            "context": context or {}
        }
        self.logger.error(json.dumps(log_entry))

# Usage
structured_logger = RLStructuredLogger(rl_manager)

# Log tool selection
query = "Search for documentation"
selected, state_key = rl_manager.select_tools(query, top_n=2)
structured_logger.log_tool_selection(query, selected, state_key)

# Log Q-value update
old_q = rl_manager.q_table[state_key].get(selected[0], 0.0)
reward = 0.8
rl_manager.update(state_key, selected[0], reward)
new_q = rl_manager.q_table[state_key][selected[0]]
structured_logger.log_q_update(state_key, selected[0], reward, old_q, new_q)

🔄 Gradual Rollout

Gradually roll out RL systems to production.

A/B Testing Framework

import random
from typing import Dict

class RLABTesting:
    """A/B testing framework for RL rollout."""

    def __init__(
        self,
        rl_manager: RLManager,
        baseline_tools: list,
        rollout_percentage: float = 0.2
    ):
        """
        Args:
            rl_manager: RLManager instance
            baseline_tools: Baseline tool set (no RL)
            rollout_percentage: Percentage of traffic to RL (0.0-1.0)
        """
        self.rl_manager = rl_manager
        self.baseline_tools = baseline_tools
        self.rollout_percentage = rollout_percentage

        # Experiment tracking
        self.rl_metrics = {
            "queries": 0,
            "total_reward": 0.0,
            "avg_reward": 0.0,
            "tool_selections": {}
        }
        self.baseline_metrics = {
            "queries": 0,
            "total_reward": 0.0,
            "avg_reward": 0.0,
            "tool_selections": {}
        }

    def should_use_rl(self) -> bool:
        """Decide if this query should use RL."""
        return random.random() < self.rollout_percentage

    def handle_query(self, query: str, reward_calculator) -> Dict:
        """
        Handle query with A/B testing.

        Returns: Dict with results and experiment info
        """
        use_rl = self.should_use_rl()

        if use_rl:
            # RL variant
            selected, state_key = self.rl_manager.select_tools(query, top_n=2)
            variant = "rl"
            metrics = self.rl_metrics
        else:
            # Baseline variant
            selected = self.baseline_tools[:2]
            state_key = None
            variant = "baseline"
            metrics = self.baseline_metrics

        # Simulate execution and calculate reward
        result = {"messages": [{"content": "Success"}]}
        reward = reward_calculator.calculate(None, result, query)

        # Update metrics
        metrics["queries"] += 1
        metrics["total_reward"] += reward
        metrics["avg_reward"] = metrics["total_reward"] / metrics["queries"]

        for tool in selected:
            metrics["tool_selections"][tool] = metrics["tool_selections"].get(tool, 0) + 1

        # Update RL if used
        if use_rl:
            for tool in selected:
                self.rl_manager.update(state_key, tool, reward)

        return {
            "query": query,
            "variant": variant,
            "selected_tools": selected,
            "reward": reward,
            "use_rl": use_rl
        }

    def get_experiment_results(self) -> Dict:
        """Get A/B test results."""
        return {
            "rollout_percentage": self.rollout_percentage,
            "rl_metrics": self.rl_metrics,
            "baseline_metrics": self.baseline_metrics,
            "rl_improvement": (
                (self.rl_metrics["avg_reward"] - self.baseline_metrics["avg_reward"])
                / self.baseline_metrics["avg_reward"] * 100
                if self.baseline_metrics["avg_reward"] > 0 else 0
            )
        }

    def increase_rollout(self, increment: float = 0.1):
        """Gradually increase RL rollout percentage."""
        self.rollout_percentage = min(1.0, self.rollout_percentage + increment)
        print(f"RL rollout increased to {self.rollout_percentage:.1%}")

# Usage
from azcore.rl.rewards import HeuristicRewardCalculator

ab_test = RLABTesting(
    rl_manager=rl_manager,
    baseline_tools=["search", "calculate", "weather", "email"],
    rollout_percentage=0.2  # Start with 20% RL traffic
)

reward_calc = HeuristicRewardCalculator()

print("=== A/B Testing RL Rollout ===\n")

# Phase 1: 20% rollout
print("Phase 1: 20% RL traffic")
for i in range(100):
    query = f"Query {i}"
    result = ab_test.handle_query(query, reward_calc)

results_phase1 = ab_test.get_experiment_results()
print(f"\nPhase 1 Results:")
print(f"  RL Avg Reward: {results_phase1['rl_metrics']['avg_reward']:.3f}")
print(f"  Baseline Avg Reward: {results_phase1['baseline_metrics']['avg_reward']:.3f}")
print(f"  Improvement: {results_phase1['rl_improvement']:.1f}%")

# Phase 2: Increase to 50% if positive results
if results_phase1['rl_improvement'] > 0:
    print("\n✓ Positive results - increasing rollout to 50%")
    ab_test.increase_rollout(0.3)

    for i in range(100, 200):
        query = f"Query {i}"
        result = ab_test.handle_query(query, reward_calc)

    results_phase2 = ab_test.get_experiment_results()
    print(f"\nPhase 2 Results (50% rollout):")
    print(f"  RL Avg Reward: {results_phase2['rl_metrics']['avg_reward']:.3f}")
    print(f"  Baseline Avg Reward: {results_phase2['baseline_metrics']['avg_reward']:.3f}")
    print(f"  Improvement: {results_phase2['rl_improvement']:.1f}%")

# Phase 3: Full rollout if still positive
if results_phase2['rl_improvement'] > 0:
    print("\n✓ Still positive - full rollout (100%)")
    ab_test.increase_rollout(0.5)

Feature Flags

class RLFeatureFlags:
    """Feature flags for RL system control."""

    def __init__(self):
        """Initialize feature flags."""
        self.flags = {
            "rl_enabled": True,
            "rl_learning_enabled": True,
            "rl_exploration_enabled": True,
            "rl_pruning_enabled": True,
            "rl_async_persist_enabled": True
        }

    def is_enabled(self, flag: str) -> bool:
        """Check if feature is enabled."""
        return self.flags.get(flag, False)

    def set_flag(self, flag: str, value: bool):
        """Set feature flag value."""
        if flag in self.flags:
            old_value = self.flags[flag]
            self.flags[flag] = value
            print(f"Feature flag '{flag}' changed: {old_value}{value}")
        else:
            print(f"Unknown feature flag: {flag}")

    def emergency_disable_rl(self):
        """Emergency disable all RL features."""
        print("🚨 EMERGENCY: Disabling all RL features")
        self.flags["rl_enabled"] = False
        self.flags["rl_learning_enabled"] = False
        self.flags["rl_exploration_enabled"] = False

class RLWithFeatureFlags:
    """RL system with feature flag support."""

    def __init__(self, rl_manager: RLManager, feature_flags: RLFeatureFlags):
        """
        Args:
            rl_manager: RLManager instance
            feature_flags: Feature flags instance
        """
        self.rl_manager = rl_manager
        self.feature_flags = feature_flags

    def select_tools(self, query: str, top_n: int = 2):
        """Select tools with feature flag checks."""
        if not self.feature_flags.is_enabled("rl_enabled"):
            # Fall back to default tools
            return self.rl_manager.tool_names[:top_n], None

        # Override exploration if disabled
        if not self.feature_flags.is_enabled("rl_exploration_enabled"):
            original_rate = self.rl_manager.exploration_rate
            self.rl_manager.exploration_rate = 0.0
            selected, state_key = self.rl_manager.select_tools(query, top_n=top_n)
            self.rl_manager.exploration_rate = original_rate
            return selected, state_key

        return self.rl_manager.select_tools(query, top_n=top_n)

    def update(self, state_key: str, tool: str, reward: float):
        """Update Q-values with feature flag checks."""
        if not self.feature_flags.is_enabled("rl_learning_enabled"):
            return  # Skip learning

        self.rl_manager.update(state_key, tool, reward)

# Usage
feature_flags = RLFeatureFlags()
rl_with_flags = RLWithFeatureFlags(rl_manager, feature_flags)

# Normal operation
selected, state_key = rl_with_flags.select_tools("Query", top_n=2)
rl_with_flags.update(state_key, selected[0], 0.9)

# Disable exploration (exploit only)
feature_flags.set_flag("rl_exploration_enabled", False)
selected, state_key = rl_with_flags.select_tools("Query", top_n=2)

# Emergency disable
feature_flags.emergency_disable_rl()

💾 Backup and Recovery

Implement backup and recovery for RL systems.

Automated Backup System

import shutil
from datetime import datetime, timedelta
from pathlib import Path

class RLBackupManager:
    """Automated backup manager for RL Q-tables."""

    def __init__(
        self,
        rl_manager: RLManager,
        backup_dir: str = "backups",
        retention_days: int = 7
    ):
        """
        Args:
            rl_manager: RLManager instance
            backup_dir: Directory for backups
            retention_days: Days to retain backups
        """
        self.rl_manager = rl_manager
        self.backup_dir = Path(backup_dir)
        self.retention_days = retention_days

        # Create backup directory
        self.backup_dir.mkdir(parents=True, exist_ok=True)

    def create_backup(self, label: str = ""):
        """
        Create Q-table backup.

        Args:
            label: Optional label for backup
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        label_suffix = f"_{label}" if label else ""
        backup_filename = f"qtable_{timestamp}{label_suffix}.pkl"
        backup_path = self.backup_dir / backup_filename

        # Force persistence to ensure Q-table is saved
        self.rl_manager.force_persist()

        # Copy Q-table file
        source_path = Path(self.rl_manager.q_table_path)
        if source_path.exists():
            shutil.copy(source_path, backup_path)
            print(f"✓ Backup created: {backup_path}")
            return backup_path
        else:
            print(f"✗ Source Q-table not found: {source_path}")
            return None

    def restore_backup(self, backup_path: str):
        """
        Restore Q-table from backup.

        Args:
            backup_path: Path to backup file
        """
        backup_file = Path(backup_path)

        if not backup_file.exists():
            print(f"✗ Backup file not found: {backup_path}")
            return False

        # Create safety backup of current state
        self.create_backup(label="pre_restore")

        # Restore from backup
        dest_path = Path(self.rl_manager.q_table_path)
        shutil.copy(backup_file, dest_path)

        # Reload Q-table
        self.rl_manager._load_q_table()

        print(f"✓ Restored from backup: {backup_path}")
        return True

    def cleanup_old_backups(self):
        """Remove backups older than retention period."""
        cutoff_date = datetime.now() - timedelta(days=self.retention_days)
        removed_count = 0

        for backup_file in self.backup_dir.glob("qtable_*.pkl"):
            # Skip pre_restore backups (keep indefinitely)
            if "pre_restore" in backup_file.name:
                continue

            # Get file modification time
            file_time = datetime.fromtimestamp(backup_file.stat().st_mtime)

            if file_time < cutoff_date:
                backup_file.unlink()
                removed_count += 1

        if removed_count > 0:
            print(f"✓ Cleaned up {removed_count} old backup(s)")

        return removed_count

    def list_backups(self):
        """List all available backups."""
        backups = []
        for backup_file in sorted(self.backup_dir.glob("qtable_*.pkl"), reverse=True):
            file_time = datetime.fromtimestamp(backup_file.stat().st_mtime)
            file_size = backup_file.stat().st_size

            backups.append({
                "filename": backup_file.name,
                "path": str(backup_file),
                "timestamp": file_time,
                "size_kb": file_size / 1024
            })

        return backups

    def get_latest_backup(self):
        """Get path to latest backup."""
        backups = self.list_backups()
        if backups:
            return backups[0]["path"]
        return None

# Usage
backup_manager = RLBackupManager(
    rl_manager=rl_manager,
    backup_dir="/var/lib/app/rl_backups",
    retention_days=7
)

# Create periodic backups (e.g., every 6 hours)
import schedule

def scheduled_backup():
    backup_manager.create_backup(label="scheduled")
    backup_manager.cleanup_old_backups()

schedule.every(6).hours.do(scheduled_backup)

# Create backup before risky operations
backup_manager.create_backup(label="before_update")

# Train or update RL
# ...

# If something goes wrong, restore
# backup_manager.restore_backup(backup_manager.get_latest_backup())

# List all backups
print("\n=== Available Backups ===")
for backup in backup_manager.list_backups():
    print(f"{backup['filename']}: {backup['timestamp'].strftime('%Y-%m-%d %H:%M:%S')} "
          f"({backup['size_kb']:.1f} KB)")

⚡ Performance Optimization

Optimize RL system performance for production.

Async Persistence

import asyncio
from threading import Thread
from queue import Queue

class AsyncPersistenceManager:
    """Asynchronous Q-table persistence."""

    def __init__(self, rl_manager: RLManager, batch_size=50):
        """
        Args:
            rl_manager: RLManager instance
            batch_size: Number of updates before persisting
        """
        self.rl_manager = rl_manager
        self.batch_size = batch_size
        self.update_count = 0
        self.persist_queue = Queue()
        self.running = True

        # Start persistence thread
        self.persist_thread = Thread(target=self._persist_worker, daemon=True)
        self.persist_thread.start()

    def _persist_worker(self):
        """Background worker for persistence."""
        while self.running:
            try:
                # Wait for persist signal
                self.persist_queue.get(timeout=1)

                # Persist Q-table
                self.rl_manager.force_persist()
                print("✓ Q-table persisted (async)")

            except Exception as e:
                print(f"Persistence error: {e}")

    def record_update(self):
        """Record an update (may trigger persist)."""
        self.update_count += 1

        if self.update_count >= self.batch_size:
            self.persist_queue.put("persist")
            self.update_count = 0

    def shutdown(self):
        """Shutdown persistence manager."""
        self.running = False
        self.persist_thread.join(timeout=5)

        # Final persist
        self.rl_manager.force_persist()

# Usage
async_persist = AsyncPersistenceManager(rl_manager, batch_size=50)

# Train with async persistence
for i in range(200):
    query = f"Query {i}"
    selected, state_key = rl_manager.select_tools(query, top_n=2)

    reward = 0.8
    for tool in selected:
        rl_manager.update(state_key, tool, reward)
        async_persist.record_update()

# Shutdown
async_persist.shutdown()

Caching Layer

from functools import lru_cache
import hashlib

class RLCachingLayer:
    """Caching layer for RL operations."""

    def __init__(self, rl_manager: RLManager, cache_size=1000):
        """
        Args:
            rl_manager: RLManager instance
            cache_size: Size of LRU cache
        """
        self.rl_manager = rl_manager

        # Create cached methods
        self.cached_select_tools = lru_cache(maxsize=cache_size)(
            self._select_tools_impl
        )

    def _select_tools_impl(self, query_hash: str, top_n: int):
        """Implementation of select_tools (cacheable)."""
        # Reverse hash to get query (simplified)
        # In practice, maintain a query map
        selected, state_key = self.rl_manager.select_tools(
            self.query_map[query_hash],
            top_n=top_n
        )
        return selected, state_key

    def select_tools(self, query: str, top_n: int = 2):
        """Select tools with caching."""
        # Hash query for caching
        query_hash = hashlib.md5(query.encode()).hexdigest()

        # Store in query map for reverse lookup
        if not hasattr(self, 'query_map'):
            self.query_map = {}
        self.query_map[query_hash] = query

        # Use cached method
        return self.cached_select_tools(query_hash, top_n)

    def clear_cache(self):
        """Clear cache."""
        self.cached_select_tools.cache_clear()

    def get_cache_stats(self):
        """Get cache statistics."""
        info = self.cached_select_tools.cache_info()
        return {
            "hits": info.hits,
            "misses": info.misses,
            "hit_rate": info.hits / (info.hits + info.misses) if (info.hits + info.misses) > 0 else 0,
            "size": info.currsize,
            "max_size": info.maxsize
        }

# Usage
caching_layer = RLCachingLayer(rl_manager, cache_size=1000)

# Use with caching
for i in range(100):
    query = f"Repeated query {i % 10}"  # Many repeats
    selected, state_key = caching_layer.select_tools(query, top_n=2)

# Check cache performance
stats = caching_layer.get_cache_stats()
print(f"\n=== Cache Statistics ===")
print(f"Hit rate: {stats['hit_rate']:.1%}")
print(f"Hits: {stats['hits']}, Misses: {stats['misses']}")
print(f"Cache size: {stats['size']}/{stats['max_size']}")

🔒 Security Considerations

Secure your production RL deployment.

Input Validation

class RLInputValidator:
    """Validate inputs to RL system."""

    def __init__(self, max_query_length=1000):
        """
        Args:
            max_query_length: Maximum query length
        """
        self.max_query_length = max_query_length

    def validate_query(self, query: str) -> tuple[bool, str]:
        """
        Validate query input.

        Returns: (is_valid, error_message)
        """
        if not query:
            return False, "Query cannot be empty"

        if len(query) > self.max_query_length:
            return False, f"Query exceeds maximum length ({self.max_query_length})"

        # Check for injection attempts
        dangerous_patterns = ["<script>", "javascript:", "eval(", "__import__"]
        for pattern in dangerous_patterns:
            if pattern.lower() in query.lower():
                return False, f"Potentially malicious pattern detected: {pattern}"

        return True, ""

    def sanitize_query(self, query: str) -> str:
        """Sanitize query input."""
        # Strip whitespace
        query = query.strip()

        # Remove control characters
        query = ''.join(char for char in query if ord(char) >= 32 or char == '\n')

        return query

# Usage
validator = RLInputValidator(max_query_length=1000)

def safe_handle_query(query: str):
    """Handle query with validation."""
    # Validate
    is_valid, error_msg = validator.validate_query(query)

    if not is_valid:
        print(f"✗ Invalid query: {error_msg}")
        return None

    # Sanitize
    query = validator.sanitize_query(query)

    # Process
    selected, state_key = rl_manager.select_tools(query, top_n=2)
    return selected

# Test validation
safe_handle_query("Normal query")  # ✓ Valid
safe_handle_query("<script>alert('xss')</script>")  # ✗ Invalid
safe_handle_query("A" * 2000)  # ✗ Too long

Rate Limiting

import time
from collections import defaultdict

class RLRateLimiter:
    """Rate limiter for RL system."""

    def __init__(self, max_requests_per_minute=100):
        """
        Args:
            max_requests_per_minute: Maximum requests per minute
        """
        self.max_requests = max_requests_per_minute
        self.requests = defaultdict(list)

    def is_allowed(self, client_id: str) -> bool:
        """Check if client is allowed to make request."""
        now = time.time()
        cutoff = now - 60  # 1 minute ago

        # Remove old requests
        self.requests[client_id] = [
            req_time for req_time in self.requests[client_id]
            if req_time > cutoff
        ]

        # Check limit
        if len(self.requests[client_id]) >= self.max_requests:
            return False

        # Record request
        self.requests[client_id].append(now)
        return True

# Usage
rate_limiter = RLRateLimiter(max_requests_per_minute=100)

def rate_limited_handle_query(query: str, client_id: str):
    """Handle query with rate limiting."""
    if not rate_limiter.is_allowed(client_id):
        print(f"✗ Rate limit exceeded for client {client_id}")
        return None

    # Process query
    selected, state_key = rl_manager.select_tools(query, top_n=2)
    return selected

📋 Production Checklist

Comprehensive checklist for production deployment.

class ProductionChecklist:
    """Production readiness checklist."""

    def __init__(self, rl_manager: RLManager):
        """
        Args:
            rl_manager: RLManager instance
        """
        self.rl_manager = rl_manager
        self.checks = {}

    def check_exploration_rate(self):
        """Check if exploration rate is appropriate."""
        rate = self.rl_manager.exploration_rate
        passed = 0.01 <= rate <= 0.1

        self.checks["exploration_rate"] = {
            "passed": passed,
            "value": rate,
            "recommendation": "Should be between 1% and 10% for production"
        }

    def check_async_persistence(self):
        """Check if async persistence is enabled."""
        # This would check actual configuration
        passed = True  # Placeholder

        self.checks["async_persistence"] = {
            "passed": passed,
            "recommendation": "Enable async persistence for better performance"
        }

    def check_monitoring(self):
        """Check if monitoring is set up."""
        # Placeholder - would check actual monitoring
        passed = True

        self.checks["monitoring"] = {
            "passed": passed,
            "recommendation": "Set up Prometheus metrics and structured logging"
        }

    def check_backup_system(self):
        """Check if backup system is configured."""
        # Placeholder
        passed = True

        self.checks["backup_system"] = {
            "passed": passed,
            "recommendation": "Configure automated backups with 7-day retention"
        }

    def check_q_table_size(self):
        """Check Q-table size."""
        stats = self.rl_manager.get_statistics()
        size = stats['total_states']
        passed = size > 0

        self.checks["q_table_size"] = {
            "passed": passed,
            "value": size,
            "recommendation": "Q-table should have been trained before production"
        }

    def run_all_checks(self):
        """Run all checks."""
        self.check_exploration_rate()
        self.check_async_persistence()
        self.check_monitoring()
        self.check_backup_system()
        self.check_q_table_size()

    def print_report(self):
        """Print checklist report."""
        print("=" * 60)
        print("PRODUCTION READINESS CHECKLIST")
        print("=" * 60)

        all_passed = True

        for check_name, result in self.checks.items():
            status = "✓" if result["passed"] else "✗"
            print(f"\n{status} {check_name.replace('_', ' ').title()}")

            if "value" in result:
                print(f"  Value: {result['value']}")

            print(f"  {result['recommendation']}")

            if not result["passed"]:
                all_passed = False

        print("\n" + "=" * 60)

        if all_passed:
            print("✓ All checks passed - ready for production!")
        else:
            print("✗ Some checks failed - review recommendations before deploying")

        print("=" * 60)

        return all_passed

# Usage
checklist = ProductionChecklist(rl_manager)
checklist.run_all_checks()
is_ready = checklist.print_report()

if is_ready:
    print("\nProceeding with deployment...")
else:
    print("\nDeployment blocked - fix issues first")

🎯 Summary

Production deployment best practices:

  1. Configuration: Environment-specific configs with conservative exploration
  2. High Availability: Load balancing and shared storage
  3. Monitoring: Prometheus metrics and structured logging
  4. Gradual Rollout: A/B testing and feature flags
  5. Backup & Recovery: Automated backups with retention policies
  6. Performance: Async persistence and caching
  7. Security: Input validation and rate limiting
  8. Checklist: Comprehensive production readiness validation

Follow these practices to ensure reliable, performant, and secure RL deployment in production.

Edit this page on GitHub
AzrienLabs logo

AzrienLabs

Craftedby Team AzrienLabs