Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
298 lines
9.1 KiB
Python
298 lines
9.1 KiB
Python
"""OOM (Out of Memory) handling and recovery strategies."""
|
|
|
|
import functools
|
|
import gc
|
|
import logging
|
|
import time
|
|
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
|
|
|
import torch
|
|
|
|
from src.core.gpu_manager import GPUMemoryManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
class OOMRecoveryError(Exception):
|
|
"""Raised when OOM recovery fails after all strategies exhausted."""
|
|
|
|
pass
|
|
|
|
|
|
class OOMHandler:
|
|
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
|
|
|
|
Recovery levels:
|
|
1. Clear PyTorch CUDA cache
|
|
2. Evict unused models from registry
|
|
3. Request ComfyUI to yield VRAM
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
gpu_manager: GPUMemoryManager,
|
|
model_registry: Optional[Any] = None, # Avoid circular import
|
|
max_retries: int = 3,
|
|
retry_delay: float = 0.5,
|
|
):
|
|
"""Initialize OOM handler.
|
|
|
|
Args:
|
|
gpu_manager: GPU memory manager instance
|
|
model_registry: Optional model registry for eviction
|
|
max_retries: Maximum recovery attempts
|
|
retry_delay: Delay between retries in seconds
|
|
"""
|
|
self.gpu_manager = gpu_manager
|
|
self.model_registry = model_registry
|
|
self.max_retries = max_retries
|
|
self.retry_delay = retry_delay
|
|
|
|
# Track OOM events for monitoring
|
|
self._oom_count = 0
|
|
self._last_oom_time: Optional[float] = None
|
|
|
|
@property
|
|
def oom_count(self) -> int:
|
|
"""Number of OOM events handled."""
|
|
return self._oom_count
|
|
|
|
def set_model_registry(self, registry: Any) -> None:
|
|
"""Set model registry (to avoid circular import at init time)."""
|
|
self.model_registry = registry
|
|
|
|
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
|
|
"""Decorator that wraps function with OOM recovery logic.
|
|
|
|
Usage:
|
|
@oom_handler.with_oom_recovery
|
|
def generate_audio(...):
|
|
...
|
|
|
|
Args:
|
|
func: Function to wrap
|
|
|
|
Returns:
|
|
Wrapped function with OOM recovery
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
last_exception = None
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
if attempt > 0:
|
|
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
|
|
time.sleep(self.retry_delay)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
except torch.cuda.OutOfMemoryError as e:
|
|
last_exception = e
|
|
self._oom_count += 1
|
|
self._last_oom_time = time.time()
|
|
|
|
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
|
|
|
|
if attempt < self.max_retries:
|
|
self._execute_recovery_strategy(attempt)
|
|
else:
|
|
logger.error(
|
|
f"OOM recovery failed after {self.max_retries} attempts"
|
|
)
|
|
|
|
raise OOMRecoveryError(
|
|
f"OOM recovery failed after {self.max_retries} attempts"
|
|
) from last_exception
|
|
|
|
return wrapper
|
|
|
|
def _execute_recovery_strategy(self, level: int) -> None:
|
|
"""Execute recovery strategy based on severity level.
|
|
|
|
Args:
|
|
level: Recovery level (0-2)
|
|
"""
|
|
strategies = [
|
|
self._strategy_clear_cache,
|
|
self._strategy_evict_models,
|
|
self._strategy_request_comfyui_yield,
|
|
]
|
|
|
|
# Execute all strategies up to and including current level
|
|
for i in range(min(level + 1, len(strategies))):
|
|
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
|
|
strategies[i]()
|
|
|
|
def _strategy_clear_cache(self) -> None:
|
|
"""Level 1: Clear PyTorch CUDA cache.
|
|
|
|
This is the fastest and least disruptive recovery strategy.
|
|
Clears cached memory that PyTorch holds for future allocations.
|
|
"""
|
|
logger.info("Clearing CUDA cache...")
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Reset peak memory stats for monitoring
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
freed = self.gpu_manager.force_cleanup()
|
|
logger.info(f"Cache cleared, freed approximately {freed}MB")
|
|
|
|
def _strategy_evict_models(self) -> None:
|
|
"""Level 2: Evict non-essential models from registry.
|
|
|
|
Unloads all models that don't have active references,
|
|
freeing their VRAM for the current operation.
|
|
"""
|
|
if self.model_registry is None:
|
|
logger.warning("No model registry available for eviction")
|
|
self._strategy_clear_cache()
|
|
return
|
|
|
|
logger.info("Evicting unused models...")
|
|
|
|
# Get list of loaded models
|
|
loaded = self.model_registry.get_loaded_models()
|
|
evicted = []
|
|
|
|
for model_info in loaded:
|
|
# Only evict models with no active references
|
|
if model_info["ref_count"] == 0:
|
|
model_id = model_info["model_id"]
|
|
variant = model_info["variant"]
|
|
logger.info(f"Evicting {model_id}/{variant}")
|
|
self.model_registry.unload_model(model_id, variant)
|
|
evicted.append(f"{model_id}/{variant}")
|
|
|
|
# Clear cache after eviction
|
|
self._strategy_clear_cache()
|
|
|
|
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
|
|
|
|
def _strategy_request_comfyui_yield(self) -> None:
|
|
"""Level 3: Request ComfyUI to yield VRAM.
|
|
|
|
Uses the coordination protocol to ask ComfyUI to
|
|
temporarily release GPU memory.
|
|
"""
|
|
logger.info("Requesting ComfyUI to yield VRAM...")
|
|
|
|
# First, evict our own models
|
|
self._strategy_evict_models()
|
|
|
|
# Calculate how much VRAM we need
|
|
budget = self.gpu_manager.get_available_budget()
|
|
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
|
|
|
|
# Request priority from ComfyUI
|
|
success = self.gpu_manager.request_priority(needed, timeout=15.0)
|
|
|
|
if success:
|
|
logger.info("ComfyUI yielded VRAM successfully")
|
|
else:
|
|
logger.warning("ComfyUI did not yield VRAM within timeout")
|
|
|
|
# Final cache clear
|
|
self._strategy_clear_cache()
|
|
|
|
def recover_from_oom(self, level: int = 0) -> bool:
|
|
"""Manually trigger OOM recovery.
|
|
|
|
Args:
|
|
level: Recovery level to execute (0-2)
|
|
|
|
Returns:
|
|
True if recovery was successful (memory was freed)
|
|
"""
|
|
before = self.gpu_manager.get_memory_info()
|
|
|
|
self._execute_recovery_strategy(level)
|
|
|
|
after = self.gpu_manager.get_memory_info()
|
|
freed = before["used"] - after["used"]
|
|
|
|
logger.info(f"Manual recovery freed {freed}MB")
|
|
return freed > 0
|
|
|
|
def check_memory_for_operation(self, required_mb: int) -> bool:
|
|
"""Check if there's enough memory for an operation.
|
|
|
|
If not enough, attempts recovery strategies.
|
|
|
|
Args:
|
|
required_mb: Memory required in megabytes
|
|
|
|
Returns:
|
|
True if enough memory is available (possibly after recovery)
|
|
"""
|
|
budget = self.gpu_manager.get_available_budget()
|
|
|
|
if budget.available_mb >= required_mb:
|
|
return True
|
|
|
|
logger.info(
|
|
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
|
|
"Attempting recovery..."
|
|
)
|
|
|
|
# Try progressively more aggressive recovery
|
|
for level in range(3):
|
|
self._execute_recovery_strategy(level)
|
|
budget = self.gpu_manager.get_available_budget()
|
|
|
|
if budget.available_mb >= required_mb:
|
|
logger.info(f"Recovery successful at level {level + 1}")
|
|
return True
|
|
|
|
logger.error(
|
|
f"Could not free enough memory. Need {required_mb}MB, "
|
|
f"have {budget.available_mb}MB"
|
|
)
|
|
return False
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get OOM handling statistics.
|
|
|
|
Returns:
|
|
Dictionary with OOM stats
|
|
"""
|
|
return {
|
|
"oom_count": self._oom_count,
|
|
"last_oom_time": self._last_oom_time,
|
|
"max_retries": self.max_retries,
|
|
"has_registry": self.model_registry is not None,
|
|
}
|
|
|
|
|
|
# Module-level convenience function
|
|
def oom_safe(
|
|
gpu_manager: GPUMemoryManager,
|
|
model_registry: Optional[Any] = None,
|
|
max_retries: int = 3,
|
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
"""Decorator factory for OOM-safe functions.
|
|
|
|
Usage:
|
|
@oom_safe(gpu_manager, model_registry)
|
|
def generate_audio(...):
|
|
...
|
|
|
|
Args:
|
|
gpu_manager: GPU memory manager
|
|
model_registry: Optional model registry for eviction
|
|
max_retries: Maximum recovery attempts
|
|
|
|
Returns:
|
|
Decorator function
|
|
"""
|
|
handler = OOMHandler(gpu_manager, model_registry, max_retries)
|
|
return handler.with_oom_recovery
|