Back to Blog

Running Fine-tuned Models in Production

Baking a cake is easy. Running a bakery is hard. You need consistent recipes, quality control, inventory management, and the ability to handle a bad batch without ruining the whole day. The baking skill is necessary but not sufficient.

Fine-tuning follows the same pattern. Training a model on your data is the tutorial step. Running it in production requires checkpointing, evaluation pipelines, rollback procedures, A/B testing, and versioning. The training is necessary but not sufficient.

The Production Fine-tuning Pipeline

def production_pipeline():
    return {
        "stage_1_data": {
            "steps": [
                "Data collection and validation",
                "Quality filtering",
                "Train/val/test split",
                "Versioning (DVC, Git LFS)",
            ],
            "output": "Versioned, validated dataset",
        },

        "stage_2_training": {
            "steps": [
                "Hyperparameter selection",
                "Training with checkpointing",
                "Validation metrics at each checkpoint",
                "Early stopping if needed",
            ],
            "output": "Multiple checkpoint candidates",
        },

        "stage_3_evaluation": {
            "steps": [
                "Evaluate all checkpoints on test set",
                "Task-specific eval suite",
                "Safety and regression testing",
                "Compare to baseline",
            ],
            "output": "Best checkpoint selected",
        },

        "stage_4_deployment": {
            "steps": [
                "Canary deployment (1% traffic)",
                "A/B testing against baseline",
                "Gradual rollout",
                "Production monitoring",
            ],
            "output": "Model serving production traffic",
        },

        "stage_5_maintenance": {
            "steps": [
                "Continuous quality monitoring",
                "Drift detection",
                "Retraining triggers",
                "Rollback if needed",
            ],
            "output": "Sustained quality over time",
        },
    }

Checkpointing Strategy

class CheckpointManager:
    """
    Save and manage training checkpoints
    """

    def __init__(self, output_dir: str, keep_last_n: int = 5):
        self.output_dir = output_dir
        self.keep_last_n = keep_last_n
        self.checkpoints = []

    def save_checkpoint(self, model, step: int, metrics: dict):
        """Save checkpoint with metadata"""
        checkpoint_dir = f"{self.output_dir}/checkpoint-{step}"

        # Save model
        model.save_pretrained(checkpoint_dir)

        # Save metadata
        metadata = {
            "step": step,
            "timestamp": datetime.now().isoformat(),
            "metrics": metrics,
        }
        with open(f"{checkpoint_dir}/metadata.json", "w") as f:
            json.dump(metadata, f)

        self.checkpoints.append(checkpoint_dir)

        # Cleanup old checkpoints
        while len(self.checkpoints) > self.keep_last_n:
            old_checkpoint = self.checkpoints.pop(0)
            shutil.rmtree(old_checkpoint)

    def get_best_checkpoint(self, metric: str = "eval_loss") -> str:
        """Return checkpoint with best metric"""
        best_checkpoint = None
        best_value = float("inf")

        for checkpoint_dir in self.checkpoints:
            with open(f"{checkpoint_dir}/metadata.json") as f:
                metadata = json.load(f)
                value = metadata["metrics"].get(metric, float("inf"))
                if value < best_value:
                    best_value = value
                    best_checkpoint = checkpoint_dir

        return best_checkpoint

Evaluation Pipeline

class EvaluationPipeline:
    """
    Comprehensive evaluation before deployment
    """

    def __init__(self, baseline_model, test_datasets: dict):
        self.baseline = baseline_model
        self.test_datasets = test_datasets
        self.required_thresholds = {
            "task_accuracy": 0.85,
            "safety_score": 0.99,
            "regression_delta": -0.05,  # No more than 5% regression
        }

    def evaluate_checkpoint(self, checkpoint_path: str) -> dict:
        """Full evaluation of a checkpoint"""
        model = self.load_checkpoint(checkpoint_path)

        results = {
            "task_metrics": self.evaluate_task(model),
            "safety_metrics": self.evaluate_safety(model),
            "regression_tests": self.run_regression_suite(model),
            "baseline_comparison": self.compare_to_baseline(model),
        }

        results["deployment_ready"] = self.check_thresholds(results)
        return results

    def check_thresholds(self, results: dict) -> bool:
        """Check if results meet deployment thresholds"""
        checks = {
            "task": results["task_metrics"]["accuracy"] >= self.required_thresholds["task_accuracy"],
            "safety": results["safety_metrics"]["score"] >= self.required_thresholds["safety_score"],
            "regression": results["baseline_comparison"]["delta"] >= self.required_thresholds["regression_delta"],
        }

        return all(checks.values())

    def compare_to_baseline(self, model) -> dict:
        """Compare fine-tuned model to baseline"""
        baseline_score = self.evaluate_task(self.baseline)["accuracy"]
        model_score = self.evaluate_task(model)["accuracy"]

        return {
            "baseline_score": baseline_score,
            "model_score": model_score,
            "delta": model_score - baseline_score,
            "improved": model_score > baseline_score,
        }

Model Versioning

class ModelRegistry:
    """
    Version and track fine-tuned models
    """

    def register_model(
        self,
        model_path: str,
        version: str,
        metadata: dict
    ) -> str:
        """Register a new model version"""
        model_id = f"{metadata['name']}-{version}"

        entry = {
            "model_id": model_id,
            "version": version,
            "path": model_path,
            "created_at": datetime.now().isoformat(),
            "training_data": metadata["training_data_version"],
            "base_model": metadata["base_model"],
            "eval_results": metadata["eval_results"],
            "status": "staging",  # staging, production, deprecated
        }

        self.db.insert(entry)
        return model_id

    def promote_to_production(self, model_id: str):
        """Promote a model to production"""
        # Demote current production
        current_prod = self.get_production_model()
        if current_prod:
            self.update_status(current_prod["model_id"], "deprecated")

        # Promote new model
        self.update_status(model_id, "production")

        # Log the promotion
        self.log_event("promotion", {
            "model_id": model_id,
            "previous": current_prod["model_id"] if current_prod else None,
        })

    def rollback(self, reason: str):
        """Rollback to previous production model"""
        current = self.get_production_model()
        previous = self.get_previous_production()

        if not previous:
            raise ValueError("No previous model to rollback to")

        self.update_status(current["model_id"], "failed")
        self.update_status(previous["model_id"], "production")

        self.log_event("rollback", {
            "from": current["model_id"],
            "to": previous["model_id"],
            "reason": reason,
        })

Canary Deployment

class CanaryDeployment:
    """
    Gradual rollout of fine-tuned models
    """

    def __init__(self, baseline_model, new_model):
        self.baseline = baseline_model
        self.new_model = new_model
        self.traffic_percentage = 0.01  # Start at 1%

    def route_request(self, request) -> str:
        """Route request to appropriate model"""
        if random.random() < self.traffic_percentage:
            return self.new_model
        return self.baseline

    def update_traffic(self, new_percentage: float):
        """Update canary traffic percentage"""
        if new_percentage > self.traffic_percentage * 2:
            raise ValueError("Increase traffic gradually (max 2x)")
        self.traffic_percentage = new_percentage

    def monitor_and_decide(self, metrics: dict) -> str:
        """Decide whether to continue rollout"""
        baseline_metrics = metrics["baseline"]
        canary_metrics = metrics["canary"]

        # Check for degradation
        for metric in ["latency_p99", "error_rate"]:
            if canary_metrics[metric] > baseline_metrics[metric] * 1.1:
                return "rollback"  # >10% degradation

        # Check quality
        if canary_metrics["quality"] < baseline_metrics["quality"] - 0.02:
            return "rollback"  # >2% quality drop

        # Safe to continue
        if self.traffic_percentage < 1.0:
            return "increase_traffic"
        return "complete_rollout"

Production Monitoring

class ProductionMonitor:
    """
    Monitor fine-tuned model in production
    """

    def __init__(self, model_id: str, baseline_metrics: dict):
        self.model_id = model_id
        self.baseline = baseline_metrics
        self.alert_thresholds = {
            "quality_drop": 0.05,
            "latency_increase": 1.5,  # 50% increase
            "error_rate": 0.02,
        }

    def collect_metrics(self, window_hours: int = 1) -> dict:
        """Collect recent metrics"""
        return {
            "quality_score": self.get_quality_score(window_hours),
            "latency_p99": self.get_latency_percentile(99, window_hours),
            "error_rate": self.get_error_rate(window_hours),
            "traffic_volume": self.get_request_count(window_hours),
        }

    def check_health(self) -> dict:
        """Check model health against baseline"""
        current = self.collect_metrics()
        alerts = []

        if current["quality_score"] < self.baseline["quality_score"] - self.alert_thresholds["quality_drop"]:
            alerts.append({
                "type": "quality_regression",
                "severity": "high",
                "current": current["quality_score"],
                "baseline": self.baseline["quality_score"],
            })

        if current["latency_p99"] > self.baseline["latency_p99"] * self.alert_thresholds["latency_increase"]:
            alerts.append({
                "type": "latency_regression",
                "severity": "medium",
                "current": current["latency_p99"],
                "baseline": self.baseline["latency_p99"],
            })

        return {
            "healthy": len(alerts) == 0,
            "alerts": alerts,
            "metrics": current,
        }

    def trigger_rollback_if_needed(self, health: dict):
        """Automatic rollback on severe issues"""
        high_severity_alerts = [a for a in health["alerts"] if a["severity"] == "high"]

        if len(high_severity_alerts) >= 2:
            self.registry.rollback(f"Auto-rollback: {high_severity_alerts}")
            self.notify("Auto-rollback triggered", high_severity_alerts)

The Full Picture

def production_checklist():
    return {
        "before_training": [
            "[ ] Data versioned and validated",
            "[ ] Baseline model performance documented",
            "[ ] Evaluation suite prepared",
            "[ ] Training infrastructure tested",
        ],
        "during_training": [
            "[ ] Checkpoints saved regularly",
            "[ ] Validation metrics tracked",
            "[ ] Training logs preserved",
        ],
        "before_deployment": [
            "[ ] Best checkpoint selected",
            "[ ] Full eval suite passed",
            "[ ] Safety tests passed",
            "[ ] Regression tests passed",
            "[ ] Model registered with version",
        ],
        "deployment": [
            "[ ] Canary deployment (1% traffic)",
            "[ ] Monitoring in place",
            "[ ] Rollback procedure tested",
            "[ ] Gradual traffic increase",
        ],
        "post_deployment": [
            "[ ] Continuous monitoring active",
            "[ ] Alert thresholds configured",
            "[ ] Retraining triggers defined",
            "[ ] Documentation updated",
        ],
    }

Fine-tuning in production is an ongoing process, not a one-time event. The training is the beginning, not the end. Build the infrastructure for evaluation, versioning, deployment, and monitoring before you need it.