Back to Blog
Running Fine-tuned Models in Production
Baking a cake is easy. Running a bakery is hard. You need consistent recipes, quality control, inventory management, and the ability to handle a bad batch without ruining the whole day. The baking skill is necessary but not sufficient.
Fine-tuning follows the same pattern. Training a model on your data is the tutorial step. Running it in production requires checkpointing, evaluation pipelines, rollback procedures, A/B testing, and versioning. The training is necessary but not sufficient.
The Production Fine-tuning Pipeline
def production_pipeline():
return {
"stage_1_data": {
"steps": [
"Data collection and validation",
"Quality filtering",
"Train/val/test split",
"Versioning (DVC, Git LFS)",
],
"output": "Versioned, validated dataset",
},
"stage_2_training": {
"steps": [
"Hyperparameter selection",
"Training with checkpointing",
"Validation metrics at each checkpoint",
"Early stopping if needed",
],
"output": "Multiple checkpoint candidates",
},
"stage_3_evaluation": {
"steps": [
"Evaluate all checkpoints on test set",
"Task-specific eval suite",
"Safety and regression testing",
"Compare to baseline",
],
"output": "Best checkpoint selected",
},
"stage_4_deployment": {
"steps": [
"Canary deployment (1% traffic)",
"A/B testing against baseline",
"Gradual rollout",
"Production monitoring",
],
"output": "Model serving production traffic",
},
"stage_5_maintenance": {
"steps": [
"Continuous quality monitoring",
"Drift detection",
"Retraining triggers",
"Rollback if needed",
],
"output": "Sustained quality over time",
},
}
Checkpointing Strategy
class CheckpointManager:
"""
Save and manage training checkpoints
"""
def __init__(self, output_dir: str, keep_last_n: int = 5):
self.output_dir = output_dir
self.keep_last_n = keep_last_n
self.checkpoints = []
def save_checkpoint(self, model, step: int, metrics: dict):
"""Save checkpoint with metadata"""
checkpoint_dir = f"{self.output_dir}/checkpoint-{step}"
# Save model
model.save_pretrained(checkpoint_dir)
# Save metadata
metadata = {
"step": step,
"timestamp": datetime.now().isoformat(),
"metrics": metrics,
}
with open(f"{checkpoint_dir}/metadata.json", "w") as f:
json.dump(metadata, f)
self.checkpoints.append(checkpoint_dir)
# Cleanup old checkpoints
while len(self.checkpoints) > self.keep_last_n:
old_checkpoint = self.checkpoints.pop(0)
shutil.rmtree(old_checkpoint)
def get_best_checkpoint(self, metric: str = "eval_loss") -> str:
"""Return checkpoint with best metric"""
best_checkpoint = None
best_value = float("inf")
for checkpoint_dir in self.checkpoints:
with open(f"{checkpoint_dir}/metadata.json") as f:
metadata = json.load(f)
value = metadata["metrics"].get(metric, float("inf"))
if value < best_value:
best_value = value
best_checkpoint = checkpoint_dir
return best_checkpoint
Evaluation Pipeline
class EvaluationPipeline:
"""
Comprehensive evaluation before deployment
"""
def __init__(self, baseline_model, test_datasets: dict):
self.baseline = baseline_model
self.test_datasets = test_datasets
self.required_thresholds = {
"task_accuracy": 0.85,
"safety_score": 0.99,
"regression_delta": -0.05, # No more than 5% regression
}
def evaluate_checkpoint(self, checkpoint_path: str) -> dict:
"""Full evaluation of a checkpoint"""
model = self.load_checkpoint(checkpoint_path)
results = {
"task_metrics": self.evaluate_task(model),
"safety_metrics": self.evaluate_safety(model),
"regression_tests": self.run_regression_suite(model),
"baseline_comparison": self.compare_to_baseline(model),
}
results["deployment_ready"] = self.check_thresholds(results)
return results
def check_thresholds(self, results: dict) -> bool:
"""Check if results meet deployment thresholds"""
checks = {
"task": results["task_metrics"]["accuracy"] >= self.required_thresholds["task_accuracy"],
"safety": results["safety_metrics"]["score"] >= self.required_thresholds["safety_score"],
"regression": results["baseline_comparison"]["delta"] >= self.required_thresholds["regression_delta"],
}
return all(checks.values())
def compare_to_baseline(self, model) -> dict:
"""Compare fine-tuned model to baseline"""
baseline_score = self.evaluate_task(self.baseline)["accuracy"]
model_score = self.evaluate_task(model)["accuracy"]
return {
"baseline_score": baseline_score,
"model_score": model_score,
"delta": model_score - baseline_score,
"improved": model_score > baseline_score,
}
Model Versioning
class ModelRegistry:
"""
Version and track fine-tuned models
"""
def register_model(
self,
model_path: str,
version: str,
metadata: dict
) -> str:
"""Register a new model version"""
model_id = f"{metadata['name']}-{version}"
entry = {
"model_id": model_id,
"version": version,
"path": model_path,
"created_at": datetime.now().isoformat(),
"training_data": metadata["training_data_version"],
"base_model": metadata["base_model"],
"eval_results": metadata["eval_results"],
"status": "staging", # staging, production, deprecated
}
self.db.insert(entry)
return model_id
def promote_to_production(self, model_id: str):
"""Promote a model to production"""
# Demote current production
current_prod = self.get_production_model()
if current_prod:
self.update_status(current_prod["model_id"], "deprecated")
# Promote new model
self.update_status(model_id, "production")
# Log the promotion
self.log_event("promotion", {
"model_id": model_id,
"previous": current_prod["model_id"] if current_prod else None,
})
def rollback(self, reason: str):
"""Rollback to previous production model"""
current = self.get_production_model()
previous = self.get_previous_production()
if not previous:
raise ValueError("No previous model to rollback to")
self.update_status(current["model_id"], "failed")
self.update_status(previous["model_id"], "production")
self.log_event("rollback", {
"from": current["model_id"],
"to": previous["model_id"],
"reason": reason,
})
Canary Deployment
class CanaryDeployment:
"""
Gradual rollout of fine-tuned models
"""
def __init__(self, baseline_model, new_model):
self.baseline = baseline_model
self.new_model = new_model
self.traffic_percentage = 0.01 # Start at 1%
def route_request(self, request) -> str:
"""Route request to appropriate model"""
if random.random() < self.traffic_percentage:
return self.new_model
return self.baseline
def update_traffic(self, new_percentage: float):
"""Update canary traffic percentage"""
if new_percentage > self.traffic_percentage * 2:
raise ValueError("Increase traffic gradually (max 2x)")
self.traffic_percentage = new_percentage
def monitor_and_decide(self, metrics: dict) -> str:
"""Decide whether to continue rollout"""
baseline_metrics = metrics["baseline"]
canary_metrics = metrics["canary"]
# Check for degradation
for metric in ["latency_p99", "error_rate"]:
if canary_metrics[metric] > baseline_metrics[metric] * 1.1:
return "rollback" # >10% degradation
# Check quality
if canary_metrics["quality"] < baseline_metrics["quality"] - 0.02:
return "rollback" # >2% quality drop
# Safe to continue
if self.traffic_percentage < 1.0:
return "increase_traffic"
return "complete_rollout"
Production Monitoring
class ProductionMonitor:
"""
Monitor fine-tuned model in production
"""
def __init__(self, model_id: str, baseline_metrics: dict):
self.model_id = model_id
self.baseline = baseline_metrics
self.alert_thresholds = {
"quality_drop": 0.05,
"latency_increase": 1.5, # 50% increase
"error_rate": 0.02,
}
def collect_metrics(self, window_hours: int = 1) -> dict:
"""Collect recent metrics"""
return {
"quality_score": self.get_quality_score(window_hours),
"latency_p99": self.get_latency_percentile(99, window_hours),
"error_rate": self.get_error_rate(window_hours),
"traffic_volume": self.get_request_count(window_hours),
}
def check_health(self) -> dict:
"""Check model health against baseline"""
current = self.collect_metrics()
alerts = []
if current["quality_score"] < self.baseline["quality_score"] - self.alert_thresholds["quality_drop"]:
alerts.append({
"type": "quality_regression",
"severity": "high",
"current": current["quality_score"],
"baseline": self.baseline["quality_score"],
})
if current["latency_p99"] > self.baseline["latency_p99"] * self.alert_thresholds["latency_increase"]:
alerts.append({
"type": "latency_regression",
"severity": "medium",
"current": current["latency_p99"],
"baseline": self.baseline["latency_p99"],
})
return {
"healthy": len(alerts) == 0,
"alerts": alerts,
"metrics": current,
}
def trigger_rollback_if_needed(self, health: dict):
"""Automatic rollback on severe issues"""
high_severity_alerts = [a for a in health["alerts"] if a["severity"] == "high"]
if len(high_severity_alerts) >= 2:
self.registry.rollback(f"Auto-rollback: {high_severity_alerts}")
self.notify("Auto-rollback triggered", high_severity_alerts)
The Full Picture
def production_checklist():
return {
"before_training": [
"[ ] Data versioned and validated",
"[ ] Baseline model performance documented",
"[ ] Evaluation suite prepared",
"[ ] Training infrastructure tested",
],
"during_training": [
"[ ] Checkpoints saved regularly",
"[ ] Validation metrics tracked",
"[ ] Training logs preserved",
],
"before_deployment": [
"[ ] Best checkpoint selected",
"[ ] Full eval suite passed",
"[ ] Safety tests passed",
"[ ] Regression tests passed",
"[ ] Model registered with version",
],
"deployment": [
"[ ] Canary deployment (1% traffic)",
"[ ] Monitoring in place",
"[ ] Rollback procedure tested",
"[ ] Gradual traffic increase",
],
"post_deployment": [
"[ ] Continuous monitoring active",
"[ ] Alert thresholds configured",
"[ ] Retraining triggers defined",
"[ ] Documentation updated",
],
}
Fine-tuning in production is an ongoing process, not a one-time event. The training is the beginning, not the end. Build the infrastructure for evaluation, versioning, deployment, and monitoring before you need it.