CI/CD for ML

Automating the ML deployment pipeline:

ML-Specific CI/CD Challenges:

  • Testing data dependencies
  • Model quality gates
  • Larger artifact sizes
  • Environment reproducibility
  • Specialized infrastructure
  • Model-specific rollback strategies

Example GitHub Actions CI/CD Pipeline:

# GitHub Actions workflow for ML model CI/CD
name: ML Model CI/CD Pipeline

on:
  push:
    branches: [ main ]
    paths:
      - 'src/**'
      - 'models/**'
      - 'data/**'

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python
        uses: actions/setup-python@v4
        with:
          python-version: '3.9'
      - name: Run unit tests
        run: pytest tests/unit/

  model-evaluation:
    needs: test
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - name: Evaluate model
        run: python src/evaluation/evaluate_model.py
      - name: Check model metrics
        run: python src/evaluation/check_metrics.py

  build-and-push:
    needs: model-evaluation
    if: github.ref == 'refs/heads/main'
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - name: Build and push Docker image
        uses: docker/build-push-action@v4
        with:
          push: true
          tags: myorg/ml-model:latest,myorg/ml-model:${{ github.sha }}

ML CI/CD Best Practices:

  • Automate model evaluation
  • Implement quality gates
  • Version models and data
  • Use canary deployments
  • Implement automated rollbacks
  • Monitor deployment impact

Model Monitoring and Maintenance

Model Performance Monitoring

Tracking model behavior in production:

Key Monitoring Metrics:

  • Prediction accuracy
  • Feature distributions
  • Model drift
  • Data drift
  • Latency and throughput
  • Error rates and exceptions

Example Drift Detection Implementation:

# Data drift detection with evidently
import pandas as pd
from evidently.dashboard import Dashboard
from evidently.dashboard.tabs import DataDriftTab

def detect_drift(reference_data, current_data, column_mapping, threshold=0.2):
    """
    Detect data drift between reference and current datasets.
    """
    # Create dashboard with data drift tab
    dashboard = Dashboard(tabs=[DataDriftTab()])
    
    # Calculate drift metrics
    dashboard.calculate(reference_data, current_data, column_mapping=column_mapping)
    
    # Extract drift metrics
    report = dashboard.get_results()
    
    # Check if drift detected
    data_drift_metrics = report['metrics'][0]['result']['metrics']
    drift_detected = False
    drifted_features = []
    
    for feature, metrics in data_drift_metrics.items():
        if metrics['drift_score'] > threshold:
            drift_detected = True
            drifted_features.append({
                'feature': feature,
                'drift_score': metrics['drift_score']
            })
    
    # Create drift report
    drift_report = {
        'drift_detected': drift_detected,
        'drift_score': report['metrics'][0]['result']['dataset_drift'],
        'number_of_drifted_features': len(drifted_features),
        'drifted_features': drifted_features
    }
    
    return drift_report

Monitoring Best Practices:

  • Monitor both technical and business metrics
  • Establish baseline performance
  • Set appropriate alerting thresholds
  • Implement automated retraining triggers
  • Maintain monitoring dashboards
  • Document monitoring procedures

Model Retraining

Keeping models up-to-date:

Retraining Triggers:

  • Schedule-based (time intervals)
  • Performance-based (accuracy drop)
  • Data-based (drift detection)
  • Business-based (requirement changes)
  • Event-based (external factors)

Automated Retraining Pipeline:

# Automated retraining pipeline
def automated_retraining_pipeline(
    model_id,
    drift_threshold=0.2,
    performance_threshold=0.05
):
    """
    Automated retraining pipeline that checks for drift and performance degradation.
    """
    # Get model info from registry
    model_info = model_registry.get_model(model_id)
    
    # Get reference and current data
    reference_data = get_reference_data(model_id)
    current_data = get_production_data(model_id, days=7)
    
    # Check for data drift
    drift_report = detect_drift(
        reference_data,
        current_data,
        model_info['column_mapping'],
        threshold=drift_threshold
    )
    
    # Check for performance degradation
    performance_report = evaluate_model_performance(model_id, current_data)
    
    performance_degradation = (
        model_info['baseline_performance'] - performance_report['current_performance']
    ) > performance_threshold
    
    # Determine if retraining is needed
    retraining_needed = drift_report['drift_detected'] or performance_degradation
    
    if retraining_needed:
        # Prepare training data
        training_data = prepare_training_data(model_id)
        
        # Retrain model
        new_model, training_metrics = retrain_model(
            model_id,
            training_data,
            model_info['hyperparameters']
        )
        
        # Evaluate new model
        evaluation_metrics = evaluate_model(new_model, training_data['test'])
        
        # If new model is better, register it
        if evaluation_metrics['primary_metric'] >= model_info['baseline_performance']:
            # Register new model version
            new_model_id = register_model(
                model_id,
                new_model,
                evaluation_metrics,
                training_metrics,
                drift_report
            )
            
            # Deploy new model
            deploy_model(new_model_id)
            
            return True, {
                'model_id': new_model_id,
                'retraining_reason': 'drift' if drift_report['drift_detected'] else 'performance',
                'improvement': evaluation_metrics['primary_metric'] - model_info['baseline_performance']
            }
    
    return False, {
        'model_id': model_id,
        'retraining_needed': retraining_needed,
        'drift_detected': drift_report['drift_detected'],
        'performance_degradation': performance_degradation
    }

Retraining Best Practices:

  • Automate the retraining process
  • Maintain training data history
  • Implement A/B testing for new models
  • Document retraining decisions
  • Monitor retraining effectiveness
  • Establish model retirement criteria

ML Infrastructure and Tooling