"""OOM (Out of Memory) handling and recovery strategies.""" import functools import gc import logging import time from typing import Any, Callable, Optional, ParamSpec, TypeVar import torch from src.core.gpu_manager import GPUMemoryManager logger = logging.getLogger(__name__) P = ParamSpec("P") R = TypeVar("R") class OOMRecoveryError(Exception): """Raised when OOM recovery fails after all strategies exhausted.""" pass class OOMHandler: """Handles CUDA Out of Memory errors with multi-level recovery strategies. Recovery levels: 1. Clear PyTorch CUDA cache 2. Evict unused models from registry 3. Request ComfyUI to yield VRAM """ def __init__( self, gpu_manager: GPUMemoryManager, model_registry: Optional[Any] = None, # Avoid circular import max_retries: int = 3, retry_delay: float = 0.5, ): """Initialize OOM handler. Args: gpu_manager: GPU memory manager instance model_registry: Optional model registry for eviction max_retries: Maximum recovery attempts retry_delay: Delay between retries in seconds """ self.gpu_manager = gpu_manager self.model_registry = model_registry self.max_retries = max_retries self.retry_delay = retry_delay # Track OOM events for monitoring self._oom_count = 0 self._last_oom_time: Optional[float] = None @property def oom_count(self) -> int: """Number of OOM events handled.""" return self._oom_count def set_model_registry(self, registry: Any) -> None: """Set model registry (to avoid circular import at init time).""" self.model_registry = registry def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]: """Decorator that wraps function with OOM recovery logic. Usage: @oom_handler.with_oom_recovery def generate_audio(...): ... Args: func: Function to wrap Returns: Wrapped function with OOM recovery """ @functools.wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: last_exception = None for attempt in range(self.max_retries + 1): try: if attempt > 0: logger.info(f"Retry attempt {attempt}/{self.max_retries}") time.sleep(self.retry_delay) return func(*args, **kwargs) except torch.cuda.OutOfMemoryError as e: last_exception = e self._oom_count += 1 self._last_oom_time = time.time() logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}") if attempt < self.max_retries: self._execute_recovery_strategy(attempt) else: logger.error( f"OOM recovery failed after {self.max_retries} attempts" ) raise OOMRecoveryError( f"OOM recovery failed after {self.max_retries} attempts" ) from last_exception return wrapper def _execute_recovery_strategy(self, level: int) -> None: """Execute recovery strategy based on severity level. Args: level: Recovery level (0-2) """ strategies = [ self._strategy_clear_cache, self._strategy_evict_models, self._strategy_request_comfyui_yield, ] # Execute all strategies up to and including current level for i in range(min(level + 1, len(strategies))): logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}") strategies[i]() def _strategy_clear_cache(self) -> None: """Level 1: Clear PyTorch CUDA cache. This is the fastest and least disruptive recovery strategy. Clears cached memory that PyTorch holds for future allocations. """ logger.info("Clearing CUDA cache...") gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() # Reset peak memory stats for monitoring torch.cuda.reset_peak_memory_stats() freed = self.gpu_manager.force_cleanup() logger.info(f"Cache cleared, freed approximately {freed}MB") def _strategy_evict_models(self) -> None: """Level 2: Evict non-essential models from registry. Unloads all models that don't have active references, freeing their VRAM for the current operation. """ if self.model_registry is None: logger.warning("No model registry available for eviction") self._strategy_clear_cache() return logger.info("Evicting unused models...") # Get list of loaded models loaded = self.model_registry.get_loaded_models() evicted = [] for model_info in loaded: # Only evict models with no active references if model_info["ref_count"] == 0: model_id = model_info["model_id"] variant = model_info["variant"] logger.info(f"Evicting {model_id}/{variant}") self.model_registry.unload_model(model_id, variant) evicted.append(f"{model_id}/{variant}") # Clear cache after eviction self._strategy_clear_cache() logger.info(f"Evicted {len(evicted)} model(s): {evicted}") def _strategy_request_comfyui_yield(self) -> None: """Level 3: Request ComfyUI to yield VRAM. Uses the coordination protocol to ask ComfyUI to temporarily release GPU memory. """ logger.info("Requesting ComfyUI to yield VRAM...") # First, evict our own models self._strategy_evict_models() # Calculate how much VRAM we need budget = self.gpu_manager.get_available_budget() needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total # Request priority from ComfyUI success = self.gpu_manager.request_priority(needed, timeout=15.0) if success: logger.info("ComfyUI yielded VRAM successfully") else: logger.warning("ComfyUI did not yield VRAM within timeout") # Final cache clear self._strategy_clear_cache() def recover_from_oom(self, level: int = 0) -> bool: """Manually trigger OOM recovery. Args: level: Recovery level to execute (0-2) Returns: True if recovery was successful (memory was freed) """ before = self.gpu_manager.get_memory_info() self._execute_recovery_strategy(level) after = self.gpu_manager.get_memory_info() freed = before["used"] - after["used"] logger.info(f"Manual recovery freed {freed}MB") return freed > 0 def check_memory_for_operation(self, required_mb: int) -> bool: """Check if there's enough memory for an operation. If not enough, attempts recovery strategies. Args: required_mb: Memory required in megabytes Returns: True if enough memory is available (possibly after recovery) """ budget = self.gpu_manager.get_available_budget() if budget.available_mb >= required_mb: return True logger.info( f"Need {required_mb}MB but only {budget.available_mb}MB available. " "Attempting recovery..." ) # Try progressively more aggressive recovery for level in range(3): self._execute_recovery_strategy(level) budget = self.gpu_manager.get_available_budget() if budget.available_mb >= required_mb: logger.info(f"Recovery successful at level {level + 1}") return True logger.error( f"Could not free enough memory. Need {required_mb}MB, " f"have {budget.available_mb}MB" ) return False def get_stats(self) -> dict[str, Any]: """Get OOM handling statistics. Returns: Dictionary with OOM stats """ return { "oom_count": self._oom_count, "last_oom_time": self._last_oom_time, "max_retries": self.max_retries, "has_registry": self.model_registry is not None, } # Module-level convenience function def oom_safe( gpu_manager: GPUMemoryManager, model_registry: Optional[Any] = None, max_retries: int = 3, ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator factory for OOM-safe functions. Usage: @oom_safe(gpu_manager, model_registry) def generate_audio(...): ... Args: gpu_manager: GPU memory manager model_registry: Optional model registry for eviction max_retries: Maximum recovery attempts Returns: Decorator function """ handler = OOMHandler(gpu_manager, model_registry, max_retries) return handler.with_oom_recovery