Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
297
src/core/oom_handler.py
Normal file
297
src/core/oom_handler.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""OOM (Out of Memory) handling and recovery strategies."""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class OOMRecoveryError(Exception):
|
||||
"""Raised when OOM recovery fails after all strategies exhausted."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OOMHandler:
|
||||
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
|
||||
|
||||
Recovery levels:
|
||||
1. Clear PyTorch CUDA cache
|
||||
2. Evict unused models from registry
|
||||
3. Request ComfyUI to yield VRAM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_manager: GPUMemoryManager,
|
||||
model_registry: Optional[Any] = None, # Avoid circular import
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 0.5,
|
||||
):
|
||||
"""Initialize OOM handler.
|
||||
|
||||
Args:
|
||||
gpu_manager: GPU memory manager instance
|
||||
model_registry: Optional model registry for eviction
|
||||
max_retries: Maximum recovery attempts
|
||||
retry_delay: Delay between retries in seconds
|
||||
"""
|
||||
self.gpu_manager = gpu_manager
|
||||
self.model_registry = model_registry
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
# Track OOM events for monitoring
|
||||
self._oom_count = 0
|
||||
self._last_oom_time: Optional[float] = None
|
||||
|
||||
@property
|
||||
def oom_count(self) -> int:
|
||||
"""Number of OOM events handled."""
|
||||
return self._oom_count
|
||||
|
||||
def set_model_registry(self, registry: Any) -> None:
|
||||
"""Set model registry (to avoid circular import at init time)."""
|
||||
self.model_registry = registry
|
||||
|
||||
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator that wraps function with OOM recovery logic.
|
||||
|
||||
Usage:
|
||||
@oom_handler.with_oom_recovery
|
||||
def generate_audio(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
func: Function to wrap
|
||||
|
||||
Returns:
|
||||
Wrapped function with OOM recovery
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
|
||||
time.sleep(self.retry_delay)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
last_exception = e
|
||||
self._oom_count += 1
|
||||
self._last_oom_time = time.time()
|
||||
|
||||
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
|
||||
|
||||
if attempt < self.max_retries:
|
||||
self._execute_recovery_strategy(attempt)
|
||||
else:
|
||||
logger.error(
|
||||
f"OOM recovery failed after {self.max_retries} attempts"
|
||||
)
|
||||
|
||||
raise OOMRecoveryError(
|
||||
f"OOM recovery failed after {self.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
return wrapper
|
||||
|
||||
def _execute_recovery_strategy(self, level: int) -> None:
|
||||
"""Execute recovery strategy based on severity level.
|
||||
|
||||
Args:
|
||||
level: Recovery level (0-2)
|
||||
"""
|
||||
strategies = [
|
||||
self._strategy_clear_cache,
|
||||
self._strategy_evict_models,
|
||||
self._strategy_request_comfyui_yield,
|
||||
]
|
||||
|
||||
# Execute all strategies up to and including current level
|
||||
for i in range(min(level + 1, len(strategies))):
|
||||
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
|
||||
strategies[i]()
|
||||
|
||||
def _strategy_clear_cache(self) -> None:
|
||||
"""Level 1: Clear PyTorch CUDA cache.
|
||||
|
||||
This is the fastest and least disruptive recovery strategy.
|
||||
Clears cached memory that PyTorch holds for future allocations.
|
||||
"""
|
||||
logger.info("Clearing CUDA cache...")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Reset peak memory stats for monitoring
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
freed = self.gpu_manager.force_cleanup()
|
||||
logger.info(f"Cache cleared, freed approximately {freed}MB")
|
||||
|
||||
def _strategy_evict_models(self) -> None:
|
||||
"""Level 2: Evict non-essential models from registry.
|
||||
|
||||
Unloads all models that don't have active references,
|
||||
freeing their VRAM for the current operation.
|
||||
"""
|
||||
if self.model_registry is None:
|
||||
logger.warning("No model registry available for eviction")
|
||||
self._strategy_clear_cache()
|
||||
return
|
||||
|
||||
logger.info("Evicting unused models...")
|
||||
|
||||
# Get list of loaded models
|
||||
loaded = self.model_registry.get_loaded_models()
|
||||
evicted = []
|
||||
|
||||
for model_info in loaded:
|
||||
# Only evict models with no active references
|
||||
if model_info["ref_count"] == 0:
|
||||
model_id = model_info["model_id"]
|
||||
variant = model_info["variant"]
|
||||
logger.info(f"Evicting {model_id}/{variant}")
|
||||
self.model_registry.unload_model(model_id, variant)
|
||||
evicted.append(f"{model_id}/{variant}")
|
||||
|
||||
# Clear cache after eviction
|
||||
self._strategy_clear_cache()
|
||||
|
||||
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
|
||||
|
||||
def _strategy_request_comfyui_yield(self) -> None:
|
||||
"""Level 3: Request ComfyUI to yield VRAM.
|
||||
|
||||
Uses the coordination protocol to ask ComfyUI to
|
||||
temporarily release GPU memory.
|
||||
"""
|
||||
logger.info("Requesting ComfyUI to yield VRAM...")
|
||||
|
||||
# First, evict our own models
|
||||
self._strategy_evict_models()
|
||||
|
||||
# Calculate how much VRAM we need
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
|
||||
|
||||
# Request priority from ComfyUI
|
||||
success = self.gpu_manager.request_priority(needed, timeout=15.0)
|
||||
|
||||
if success:
|
||||
logger.info("ComfyUI yielded VRAM successfully")
|
||||
else:
|
||||
logger.warning("ComfyUI did not yield VRAM within timeout")
|
||||
|
||||
# Final cache clear
|
||||
self._strategy_clear_cache()
|
||||
|
||||
def recover_from_oom(self, level: int = 0) -> bool:
|
||||
"""Manually trigger OOM recovery.
|
||||
|
||||
Args:
|
||||
level: Recovery level to execute (0-2)
|
||||
|
||||
Returns:
|
||||
True if recovery was successful (memory was freed)
|
||||
"""
|
||||
before = self.gpu_manager.get_memory_info()
|
||||
|
||||
self._execute_recovery_strategy(level)
|
||||
|
||||
after = self.gpu_manager.get_memory_info()
|
||||
freed = before["used"] - after["used"]
|
||||
|
||||
logger.info(f"Manual recovery freed {freed}MB")
|
||||
return freed > 0
|
||||
|
||||
def check_memory_for_operation(self, required_mb: int) -> bool:
|
||||
"""Check if there's enough memory for an operation.
|
||||
|
||||
If not enough, attempts recovery strategies.
|
||||
|
||||
Args:
|
||||
required_mb: Memory required in megabytes
|
||||
|
||||
Returns:
|
||||
True if enough memory is available (possibly after recovery)
|
||||
"""
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
|
||||
if budget.available_mb >= required_mb:
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
|
||||
"Attempting recovery..."
|
||||
)
|
||||
|
||||
# Try progressively more aggressive recovery
|
||||
for level in range(3):
|
||||
self._execute_recovery_strategy(level)
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
|
||||
if budget.available_mb >= required_mb:
|
||||
logger.info(f"Recovery successful at level {level + 1}")
|
||||
return True
|
||||
|
||||
logger.error(
|
||||
f"Could not free enough memory. Need {required_mb}MB, "
|
||||
f"have {budget.available_mb}MB"
|
||||
)
|
||||
return False
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get OOM handling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with OOM stats
|
||||
"""
|
||||
return {
|
||||
"oom_count": self._oom_count,
|
||||
"last_oom_time": self._last_oom_time,
|
||||
"max_retries": self.max_retries,
|
||||
"has_registry": self.model_registry is not None,
|
||||
}
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def oom_safe(
|
||||
gpu_manager: GPUMemoryManager,
|
||||
model_registry: Optional[Any] = None,
|
||||
max_retries: int = 3,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator factory for OOM-safe functions.
|
||||
|
||||
Usage:
|
||||
@oom_safe(gpu_manager, model_registry)
|
||||
def generate_audio(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
gpu_manager: GPU memory manager
|
||||
model_registry: Optional model registry for eviction
|
||||
max_retries: Maximum recovery attempts
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
"""
|
||||
handler = OOMHandler(gpu_manager, model_registry, max_retries)
|
||||
return handler.with_oom_recovery
|
||||
Reference in New Issue
Block a user