Files
audiocraft-ui/src/core/oom_handler.py

298 lines
9.1 KiB
Python
Raw Normal View History

"""OOM (Out of Memory) handling and recovery strategies."""
import functools
import gc
import logging
import time
from typing import Any, Callable, Optional, ParamSpec, TypeVar
import torch
from src.core.gpu_manager import GPUMemoryManager
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class OOMRecoveryError(Exception):
"""Raised when OOM recovery fails after all strategies exhausted."""
pass
class OOMHandler:
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
Recovery levels:
1. Clear PyTorch CUDA cache
2. Evict unused models from registry
3. Request ComfyUI to yield VRAM
"""
def __init__(
self,
gpu_manager: GPUMemoryManager,
model_registry: Optional[Any] = None, # Avoid circular import
max_retries: int = 3,
retry_delay: float = 0.5,
):
"""Initialize OOM handler.
Args:
gpu_manager: GPU memory manager instance
model_registry: Optional model registry for eviction
max_retries: Maximum recovery attempts
retry_delay: Delay between retries in seconds
"""
self.gpu_manager = gpu_manager
self.model_registry = model_registry
self.max_retries = max_retries
self.retry_delay = retry_delay
# Track OOM events for monitoring
self._oom_count = 0
self._last_oom_time: Optional[float] = None
@property
def oom_count(self) -> int:
"""Number of OOM events handled."""
return self._oom_count
def set_model_registry(self, registry: Any) -> None:
"""Set model registry (to avoid circular import at init time)."""
self.model_registry = registry
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
"""Decorator that wraps function with OOM recovery logic.
Usage:
@oom_handler.with_oom_recovery
def generate_audio(...):
...
Args:
func: Function to wrap
Returns:
Wrapped function with OOM recovery
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
time.sleep(self.retry_delay)
return func(*args, **kwargs)
except torch.cuda.OutOfMemoryError as e:
last_exception = e
self._oom_count += 1
self._last_oom_time = time.time()
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
if attempt < self.max_retries:
self._execute_recovery_strategy(attempt)
else:
logger.error(
f"OOM recovery failed after {self.max_retries} attempts"
)
raise OOMRecoveryError(
f"OOM recovery failed after {self.max_retries} attempts"
) from last_exception
return wrapper
def _execute_recovery_strategy(self, level: int) -> None:
"""Execute recovery strategy based on severity level.
Args:
level: Recovery level (0-2)
"""
strategies = [
self._strategy_clear_cache,
self._strategy_evict_models,
self._strategy_request_comfyui_yield,
]
# Execute all strategies up to and including current level
for i in range(min(level + 1, len(strategies))):
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
strategies[i]()
def _strategy_clear_cache(self) -> None:
"""Level 1: Clear PyTorch CUDA cache.
This is the fastest and least disruptive recovery strategy.
Clears cached memory that PyTorch holds for future allocations.
"""
logger.info("Clearing CUDA cache...")
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset peak memory stats for monitoring
torch.cuda.reset_peak_memory_stats()
freed = self.gpu_manager.force_cleanup()
logger.info(f"Cache cleared, freed approximately {freed}MB")
def _strategy_evict_models(self) -> None:
"""Level 2: Evict non-essential models from registry.
Unloads all models that don't have active references,
freeing their VRAM for the current operation.
"""
if self.model_registry is None:
logger.warning("No model registry available for eviction")
self._strategy_clear_cache()
return
logger.info("Evicting unused models...")
# Get list of loaded models
loaded = self.model_registry.get_loaded_models()
evicted = []
for model_info in loaded:
# Only evict models with no active references
if model_info["ref_count"] == 0:
model_id = model_info["model_id"]
variant = model_info["variant"]
logger.info(f"Evicting {model_id}/{variant}")
self.model_registry.unload_model(model_id, variant)
evicted.append(f"{model_id}/{variant}")
# Clear cache after eviction
self._strategy_clear_cache()
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
def _strategy_request_comfyui_yield(self) -> None:
"""Level 3: Request ComfyUI to yield VRAM.
Uses the coordination protocol to ask ComfyUI to
temporarily release GPU memory.
"""
logger.info("Requesting ComfyUI to yield VRAM...")
# First, evict our own models
self._strategy_evict_models()
# Calculate how much VRAM we need
budget = self.gpu_manager.get_available_budget()
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
# Request priority from ComfyUI
success = self.gpu_manager.request_priority(needed, timeout=15.0)
if success:
logger.info("ComfyUI yielded VRAM successfully")
else:
logger.warning("ComfyUI did not yield VRAM within timeout")
# Final cache clear
self._strategy_clear_cache()
def recover_from_oom(self, level: int = 0) -> bool:
"""Manually trigger OOM recovery.
Args:
level: Recovery level to execute (0-2)
Returns:
True if recovery was successful (memory was freed)
"""
before = self.gpu_manager.get_memory_info()
self._execute_recovery_strategy(level)
after = self.gpu_manager.get_memory_info()
freed = before["used"] - after["used"]
logger.info(f"Manual recovery freed {freed}MB")
return freed > 0
def check_memory_for_operation(self, required_mb: int) -> bool:
"""Check if there's enough memory for an operation.
If not enough, attempts recovery strategies.
Args:
required_mb: Memory required in megabytes
Returns:
True if enough memory is available (possibly after recovery)
"""
budget = self.gpu_manager.get_available_budget()
if budget.available_mb >= required_mb:
return True
logger.info(
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
"Attempting recovery..."
)
# Try progressively more aggressive recovery
for level in range(3):
self._execute_recovery_strategy(level)
budget = self.gpu_manager.get_available_budget()
if budget.available_mb >= required_mb:
logger.info(f"Recovery successful at level {level + 1}")
return True
logger.error(
f"Could not free enough memory. Need {required_mb}MB, "
f"have {budget.available_mb}MB"
)
return False
def get_stats(self) -> dict[str, Any]:
"""Get OOM handling statistics.
Returns:
Dictionary with OOM stats
"""
return {
"oom_count": self._oom_count,
"last_oom_time": self._last_oom_time,
"max_retries": self.max_retries,
"has_registry": self.model_registry is not None,
}
# Module-level convenience function
def oom_safe(
gpu_manager: GPUMemoryManager,
model_registry: Optional[Any] = None,
max_retries: int = 3,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator factory for OOM-safe functions.
Usage:
@oom_safe(gpu_manager, model_registry)
def generate_audio(...):
...
Args:
gpu_manager: GPU memory manager
model_registry: Optional model registry for eviction
max_retries: Maximum recovery attempts
Returns:
Decorator function
"""
handler = OOMHandler(gpu_manager, model_registry, max_retries)
return handler.with_oom_recovery