Engineering in the Wild

Graceful Degradation: Dropping Log Events to Keep Training Alive

link: graceful-degradation


I've been working on ML training infrastructure lately, and one problem kept surfacing: logging was becoming a bottleneck. Training loops would pause for network calls to MLflow, disk writes would block gradient updates, and the whole system felt sluggish. So I built an async batch logger to decouple logging from the critical path.

The Core Problem

In ML training, you're logging constantly—metrics every step, parameters at startup, artifacts periodically. Each log call hits a backend (MLflow, filesystem, remote API), and these operations have unpredictable latency. A few milliseconds here and there adds up when you're doing thousands of training steps.

The obvious solution is background processing, but Python's threading story is... complicated. The GIL limits true parallelism, but for I/O-bound work like logging, threads still provide concurrency benefits. The real challenge is getting the synchronization right without introducing subtle bugs.

How I Approached It

The AsyncBatchExperimentLogger implements a producer-consumer pattern with a background worker thread:

sequenceDiagram
    participant Main as Main Thread
    participant Queue as Event Queue
    participant Worker as Worker Thread
    participant Backend as Backend Function
    
    Main->>AsyncBatchExperimentLogger: Create logger
    AsyncBatchExperimentLogger->>Worker: Start background thread
    
    loop Main Thread Operation
        Main->>Queue: log(event) [with lock]
        Main-->>Main: Continue immediately (non-blocking)
    end
    
    loop Worker Thread (every flush_interval_ms)
        Worker->>Queue: _collect_batch() [with lock]
        Worker->>Backend: _backend_fn(batch)
        Backend-->>Worker: Return Ok/Err
        Worker->>Worker: Update metrics
        Worker->>Worker: Wait until next interval
    end
    
    Main->>AsyncBatchExperimentLogger: flush_and_stop()
    AsyncBatchExperimentLogger->>Worker: Set stop_event
    AsyncBatchExperimentLogger->>Queue: Flush remaining events
    Worker-->>AsyncBatchExperimentLogger: Join thread
    AsyncBatchExperimentLogger-->>Main: Return statistics

The main thread dumps events into a queue and returns immediately. A daemon thread wakes up periodically, batches events, and ships them to the backend. Simple, but getting the details right took some thinking.

Thread Safety: The Lock

I use a single threading.Lock() to protect the shared queue:

self._queue_lock = threading.Lock()

Both the main thread (adding events) and worker thread (collecting batches) need to access the queue, so every access goes through the lock:

with self._queue_lock:
    if len(self._queue) >= self._max_queue_size:
        self._queue.popleft()
        self._dropped_events += 1
    self._queue.append(event)

I considered lock-free data structures, but Python's collections.deque is already implemented in C and quite fast. The lock contention is minimal since the critical sections are small and the worker thread only acquires the lock during batch collection.

Graceful Degradation

One design decision I'm happy with: when the queue fills up, I drop the oldest events rather than blocking or crashing:

if len(self._queue) >= self._max_queue_size:
    self._queue.popleft()  # Drop oldest
    self._dropped_events += 1

This implements graceful degradation. If the backend is slow or failing, the main thread never blocks. You lose some log data, but the training continues. The alternative—blocking the main thread—would be much worse for an ML training loop.

I track dropped events as a metric, so you can monitor whether your logging system is keeping up.

Backend Interface Design

The backend function interface is deliberately simple:

class ExperimentBatchLogFn(Protocol):
    def __call__(self, events: List[BaseExperimentLogEvent]) -> BaseLogResult: ...

It takes a batch of events and returns a result type (not exceptions). This makes error handling explicit and prevents exceptions from killing the worker thread. The backend doesn't need to know about threading—it just processes batches.

I use a Protocol rather than an abstract base class because it's more flexible. Any callable with the right signature works, whether it's a function, method, or callable object.

Error Handling Philosophy

I return result types instead of raising exceptions:

@dataclass(frozen=True)
class LogSuccess(BaseLogResult):
    value: Any = None

@dataclass(frozen=True)
class LogError(BaseLogResult):
    error: str

This approach makes error handling explicit and prevents exceptions from propagating across thread boundaries unexpectedly. The worker thread can handle failures gracefully without terminating:

result = self._backend_fn(batch)
if isinstance(result, LogSuccess):
    self._successful_batches += 1
elif isinstance(result, LogError):
    self._failed_batches += 1
    print(f"Logging batch failed: {result.error}", file=sys.stderr)

Failed batches get logged to stderr, but they don't crash the system or affect the main thread.

Shutdown Semantics

Clean shutdown is trickier than it looks. When the main process exits, you want to flush any remaining events, but you can't wait forever:

def flush_and_stop(self, flush_timeout_s: int = 10) -> Dict[str, int]:
    # Flush remaining events with timeout
    start_time = time.monotonic()
    remaining = self._collect_batch()
    while remaining and time.monotonic() - start_time < flush_timeout_s:
        self._backend_fn(remaining)
        remaining = self._collect_batch()
    
    # Stop worker thread
    self._stop_event.set()
    self._worker_thread.join(timeout=5.0)
    
    return self.logger_metrics

I use atexit.register() to automatically call this during normal process termination. The timeout ensures we don't hang indefinitely if the backend is unresponsive.

Performance Characteristics

In practice, this approach works well for ML workloads:

The worker thread spends most of its time sleeping, so CPU overhead is minimal. Network I/O happens asynchronously from the main thread, so training isn't affected by backend latency.

Why This Pattern Works

This isn't groundbreaking—it's a standard producer-consumer pattern with a background worker. But it addresses several Python-specific constraints:

  1. GIL limitations: I/O-bound operations benefit from threading even with the GIL
  2. Exception handling: Result types prevent exceptions from crossing thread boundaries
  3. Resource management: Daemon threads and explicit shutdown prevent resource leaks
  4. Memory bounds: Fixed queue size prevents unbounded memory growth

The pattern is particularly suited to ML training where:

The code is straightforward enough that new team members can understand and modify it without deep threading expertise.


# The complete implementation with event types and MLflow factory

class ExperimentBatchLogFn(Protocol):
    """Protocol for batch logging functions."""
    def __call__(self, events: List[BaseExperimentLogEvent]) -> BaseLogResult: ...

@dataclass(frozen=True)
class BaseExperimentLogEvent:
    pass

@dataclass(frozen=True)
class MetricEvent(BaseExperimentLogEvent):
    """Immutable metric event with flexible prefixing."""
    key: str
    value: float
    step: Optional[int] = None
    prefix: str = ""
    timestamp_ns: int = field(default_factory=lambda: time.time_ns())

    @property
    def full_key(self) -> str:
        if self.prefix:
            return f"{self.prefix}/{self.key}"
        return self.key

@dataclass(frozen=True)
class ParamEvent(BaseExperimentLogEvent):
    """Immutable parameter event with flexible prefixing."""
    key: str
    value: str
    prefix: str = ""
    timestamp_ns: int = field(default_factory=lambda: time.time_ns())

    @property
    def full_key(self) -> str:
        if self.prefix:
            return f"{self.prefix}/{self.key}"
        return self.key

@dataclass(frozen=True)
class ArtifactEvent(BaseExperimentLogEvent):
    """Immutable artifact event."""
    local_path: str
    artifact_path: Optional[str] = None
    timestamp_ns: int = field(default_factory=lambda: time.time_ns())

class AsyncEventLogger(T.Generic[TLoggerFn, TEvent]):
    """
    Simple async batch logger
    
    Design principles:
    - Never block the training loop
    - Batch operations for efficiency
    - Graceful degradation on failures
    - Bounded memory usage
    """
    
    def __init__(
        self,
        logging_callback: TLoggerFn,
        batch_size: int = 100,
        flush_interval_s: int = 3,
        max_queue_size: int = 10000,
    ):
        self._backend_fn = logging_callback
        self._batch_size = batch_size
        self._flush_interval_s = flush_interval_s
        self._max_queue_size = max_queue_size

        # Thread-safe queue
        self._queue: Deque[TEvent] = deque()
        self._queue_lock = threading.Lock()

        # Background worker
        self._stop_event = threading.Event()
        self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker_thread.start()

        # Performance metrics
        self._dropped_events = 0
        self._successful_batches = 0
        self._failed_batches = 0

    def log(self, event: TEvent) -> BaseLogResult:
        """Non-blocking log - returns immediately."""
        with self._queue_lock:
            if len(self._queue) >= self._max_queue_size:
                self._queue.popleft()
                self._dropped_events += 1
            self._queue.append(event)
        return LogSuccess()

    def _worker_loop(self):
        """Background worker that batches and flushes events."""
        while not self._stop_event.is_set():
            batch = self._collect_batch()
            
            if batch:
                result = self._backend_fn(batch)
                if isinstance(result, LogSuccess):
                    self._successful_batches += 1
                elif isinstance(result, LogError):
                    self._failed_batches += 1
                    print(f"Logging batch failed: {result.error}", file=sys.stderr)
            
            self._stop_event.wait(self._flush_interval_s)

    def _collect_batch(self) -> List[TEvent]:
        """Collect up to batch_size events from queue."""
        batch = []
        with self._queue_lock:
            for _ in range(min(self._batch_size, len(self._queue))):
                if self._queue:
                    batch.append(self._queue.popleft())
        return batch

    def flush_and_stop(self, flush_timeout_s: int = 10) -> Dict[str, int]:
        """Graceful shutdown - flush remaining events and return stats."""
        start_time = time.monotonic()
        remaining = self._collect_batch()
        while remaining and time.monotonic() - start_time < flush_timeout_s:
            self._backend_fn(remaining)
            remaining = self._collect_batch()

        self._stop_event.set()
        self._worker_thread.join(timeout=5.0)
        return self.logger_metrics