Graceful Degradation: Dropping Log Events to Keep Training Alive
link: graceful-degradation
I've been working on ML training infrastructure lately, and one problem kept surfacing: logging was becoming a bottleneck. Training loops would pause for network calls to MLflow, disk writes would block gradient updates, and the whole system felt sluggish. So I built an async batch logger to decouple logging from the critical path.
The Core Problem
In ML training, you're logging constantlyâmetrics every step, parameters at startup, artifacts periodically. Each log call hits a backend (MLflow, filesystem, remote API), and these operations have unpredictable latency. A few milliseconds here and there adds up when you're doing thousands of training steps.
The obvious solution is background processing, but Python's threading story is... complicated. The GIL limits true parallelism, but for I/O-bound work like logging, threads still provide concurrency benefits. The real challenge is getting the synchronization right without introducing subtle bugs.
How I Approached It
The AsyncBatchExperimentLogger
implements a producer-consumer pattern with a background worker thread:
sequenceDiagram participant Main as Main Thread participant Queue as Event Queue participant Worker as Worker Thread participant Backend as Backend Function Main->>AsyncBatchExperimentLogger: Create logger AsyncBatchExperimentLogger->>Worker: Start background thread loop Main Thread Operation Main->>Queue: log(event) [with lock] Main-->>Main: Continue immediately (non-blocking) end loop Worker Thread (every flush_interval_ms) Worker->>Queue: _collect_batch() [with lock] Worker->>Backend: _backend_fn(batch) Backend-->>Worker: Return Ok/Err Worker->>Worker: Update metrics Worker->>Worker: Wait until next interval end Main->>AsyncBatchExperimentLogger: flush_and_stop() AsyncBatchExperimentLogger->>Worker: Set stop_event AsyncBatchExperimentLogger->>Queue: Flush remaining events Worker-->>AsyncBatchExperimentLogger: Join thread AsyncBatchExperimentLogger-->>Main: Return statistics
The main thread dumps events into a queue and returns immediately. A daemon thread wakes up periodically, batches events, and ships them to the backend. Simple, but getting the details right took some thinking.
Thread Safety: The Lock
I use a single threading.Lock()
to protect the shared queue:
self._queue_lock = threading.Lock()
Both the main thread (adding events) and worker thread (collecting batches) need to access the queue, so every access goes through the lock:
with self._queue_lock:
if len(self._queue) >= self._max_queue_size:
self._queue.popleft()
self._dropped_events += 1
self._queue.append(event)
I considered lock-free data structures, but Python's collections.deque
is already implemented in C and quite fast. The lock contention is minimal since the critical sections are small and the worker thread only acquires the lock during batch collection.
Graceful Degradation
One design decision I'm happy with: when the queue fills up, I drop the oldest events rather than blocking or crashing:
if len(self._queue) >= self._max_queue_size:
self._queue.popleft() # Drop oldest
self._dropped_events += 1
This implements graceful degradation. If the backend is slow or failing, the main thread never blocks. You lose some log data, but the training continues. The alternativeâblocking the main threadâwould be much worse for an ML training loop.
I track dropped events as a metric, so you can monitor whether your logging system is keeping up.
Backend Interface Design
The backend function interface is deliberately simple:
class ExperimentBatchLogFn(Protocol):
def __call__(self, events: List[BaseExperimentLogEvent]) -> BaseLogResult: ...
It takes a batch of events and returns a result type (not exceptions). This makes error handling explicit and prevents exceptions from killing the worker thread. The backend doesn't need to know about threadingâit just processes batches.
I use a Protocol rather than an abstract base class because it's more flexible. Any callable with the right signature works, whether it's a function, method, or callable object.
Error Handling Philosophy
I return result types instead of raising exceptions:
@dataclass(frozen=True)
class LogSuccess(BaseLogResult):
value: Any = None
@dataclass(frozen=True)
class LogError(BaseLogResult):
error: str
This approach makes error handling explicit and prevents exceptions from propagating across thread boundaries unexpectedly. The worker thread can handle failures gracefully without terminating:
result = self._backend_fn(batch)
if isinstance(result, LogSuccess):
self._successful_batches += 1
elif isinstance(result, LogError):
self._failed_batches += 1
print(f"Logging batch failed: {result.error}", file=sys.stderr)
Failed batches get logged to stderr, but they don't crash the system or affect the main thread.
Shutdown Semantics
Clean shutdown is trickier than it looks. When the main process exits, you want to flush any remaining events, but you can't wait forever:
def flush_and_stop(self, flush_timeout_s: int = 10) -> Dict[str, int]:
# Flush remaining events with timeout
start_time = time.monotonic()
remaining = self._collect_batch()
while remaining and time.monotonic() - start_time < flush_timeout_s:
self._backend_fn(remaining)
remaining = self._collect_batch()
# Stop worker thread
self._stop_event.set()
self._worker_thread.join(timeout=5.0)
return self.logger_metrics
I use atexit.register()
to automatically call this during normal process termination. The timeout ensures we don't hang indefinitely if the backend is unresponsive.
Performance Characteristics
In practice, this approach works well for ML workloads:
- Log calls complete in microseconds (just queue operations)
- Batching reduces backend overhead significantly
- Memory usage is bounded by the queue size
- The system degrades gracefully under load
The worker thread spends most of its time sleeping, so CPU overhead is minimal. Network I/O happens asynchronously from the main thread, so training isn't affected by backend latency.
Why This Pattern Works
This isn't groundbreakingâit's a standard producer-consumer pattern with a background worker. But it addresses several Python-specific constraints:
- GIL limitations: I/O-bound operations benefit from threading even with the GIL
- Exception handling: Result types prevent exceptions from crossing thread boundaries
- Resource management: Daemon threads and explicit shutdown prevent resource leaks
- Memory bounds: Fixed queue size prevents unbounded memory growth
The pattern is particularly suited to ML training where:
- The main thread's performance is critical
- Occasional data loss is acceptable
- Backend operations can be batched efficiently
- Clean shutdown is important for long-running jobs
The code is straightforward enough that new team members can understand and modify it without deep threading expertise.
# The complete implementation with event types and MLflow factory
class ExperimentBatchLogFn(Protocol):
"""Protocol for batch logging functions."""
def __call__(self, events: List[BaseExperimentLogEvent]) -> BaseLogResult: ...
@dataclass(frozen=True)
class BaseExperimentLogEvent:
pass
@dataclass(frozen=True)
class MetricEvent(BaseExperimentLogEvent):
"""Immutable metric event with flexible prefixing."""
key: str
value: float
step: Optional[int] = None
prefix: str = ""
timestamp_ns: int = field(default_factory=lambda: time.time_ns())
@property
def full_key(self) -> str:
if self.prefix:
return f"{self.prefix}/{self.key}"
return self.key
@dataclass(frozen=True)
class ParamEvent(BaseExperimentLogEvent):
"""Immutable parameter event with flexible prefixing."""
key: str
value: str
prefix: str = ""
timestamp_ns: int = field(default_factory=lambda: time.time_ns())
@property
def full_key(self) -> str:
if self.prefix:
return f"{self.prefix}/{self.key}"
return self.key
@dataclass(frozen=True)
class ArtifactEvent(BaseExperimentLogEvent):
"""Immutable artifact event."""
local_path: str
artifact_path: Optional[str] = None
timestamp_ns: int = field(default_factory=lambda: time.time_ns())
class AsyncEventLogger(T.Generic[TLoggerFn, TEvent]):
"""
Simple async batch logger
Design principles:
- Never block the training loop
- Batch operations for efficiency
- Graceful degradation on failures
- Bounded memory usage
"""
def __init__(
self,
logging_callback: TLoggerFn,
batch_size: int = 100,
flush_interval_s: int = 3,
max_queue_size: int = 10000,
):
self._backend_fn = logging_callback
self._batch_size = batch_size
self._flush_interval_s = flush_interval_s
self._max_queue_size = max_queue_size
# Thread-safe queue
self._queue: Deque[TEvent] = deque()
self._queue_lock = threading.Lock()
# Background worker
self._stop_event = threading.Event()
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self._worker_thread.start()
# Performance metrics
self._dropped_events = 0
self._successful_batches = 0
self._failed_batches = 0
def log(self, event: TEvent) -> BaseLogResult:
"""Non-blocking log - returns immediately."""
with self._queue_lock:
if len(self._queue) >= self._max_queue_size:
self._queue.popleft()
self._dropped_events += 1
self._queue.append(event)
return LogSuccess()
def _worker_loop(self):
"""Background worker that batches and flushes events."""
while not self._stop_event.is_set():
batch = self._collect_batch()
if batch:
result = self._backend_fn(batch)
if isinstance(result, LogSuccess):
self._successful_batches += 1
elif isinstance(result, LogError):
self._failed_batches += 1
print(f"Logging batch failed: {result.error}", file=sys.stderr)
self._stop_event.wait(self._flush_interval_s)
def _collect_batch(self) -> List[TEvent]:
"""Collect up to batch_size events from queue."""
batch = []
with self._queue_lock:
for _ in range(min(self._batch_size, len(self._queue))):
if self._queue:
batch.append(self._queue.popleft())
return batch
def flush_and_stop(self, flush_timeout_s: int = 10) -> Dict[str, int]:
"""Graceful shutdown - flush remaining events and return stats."""
start_time = time.monotonic()
remaining = self._collect_batch()
while remaining and time.monotonic() - start_time < flush_timeout_s:
self._backend_fn(remaining)
remaining = self._collect_batch()
self._stop_event.set()
self._worker_thread.join(timeout=5.0)
return self.logger_metrics