"""Model management utilities.""" import json import logging import os from typing import Dict, List, Optional from app.config import settings from app.services import file_manager logger = logging.getLogger(__name__) # Known models for Easy Real-ESRGAN KNOWN_MODELS = { 'RealESRGAN_x2plus': { 'scale': 2, 'description': '2x upscaling', 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 'size_mb': 66.7, }, 'RealESRGAN_x3plus': { 'scale': 3, 'description': '3x upscaling', 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth', 'size_mb': 101.7, }, 'RealESRGAN_x4plus': { 'scale': 4, 'description': '4x upscaling (general purpose)', 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 'size_mb': 101.7, }, 'RealESRGAN_x4plus_anime_6B': { 'scale': 4, 'description': '4x upscaling (anime/art)', 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth', 'size_mb': 18.9, }, } def get_available_models() -> List[Dict]: """Get list of available models with details.""" models = [] for model_name, metadata in KNOWN_MODELS.items(): model_path = os.path.join(settings.models_dir, f'{model_name}.pth') available = os.path.exists(model_path) size_mb = metadata['size_mb'] if available: try: actual_size = os.path.getsize(model_path) / (1024 * 1024) size_mb = actual_size except OSError: pass models.append({ 'name': model_name, 'scale': metadata['scale'], 'description': metadata['description'], 'available': available, 'size_mb': size_mb, 'size_bytes': int(size_mb * 1024 * 1024), }) return models def is_model_available(model_name: str) -> bool: """Check if a model is available locally.""" model_path = os.path.join(settings.models_dir, f'{model_name}.pth') return os.path.exists(model_path) def get_model_scale(model_name: str) -> Optional[int]: """Get upscaling factor for a model.""" if model_name in KNOWN_MODELS: return KNOWN_MODELS[model_name]['scale'] # Try to infer from model name if 'x2' in model_name.lower(): return 2 elif 'x3' in model_name.lower(): return 3 elif 'x4' in model_name.lower(): return 4 return None async def download_model(model_name: str) -> tuple[bool, str]: """ Download a model. Returns: (success, message) """ if model_name not in KNOWN_MODELS: return False, f'Unknown model: {model_name}' if is_model_available(model_name): return True, f'Model already available: {model_name}' metadata = KNOWN_MODELS[model_name] url = metadata['url'] model_path = os.path.join(settings.models_dir, f'{model_name}.pth') try: logger.info(f'Downloading model: {model_name} from {url}') import urllib.request os.makedirs(settings.models_dir, exist_ok=True) def download_progress(count, block_size, total_size): downloaded = count * block_size percent = min(downloaded * 100 / total_size, 100) logger.debug(f'Download progress: {percent:.1f}%') urllib.request.urlretrieve(url, model_path, download_progress) if os.path.exists(model_path): size_mb = os.path.getsize(model_path) / (1024 * 1024) logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)') return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)' else: return False, f'Failed to download model: {model_name}' except Exception as e: logger.error(f'Failed to download model {model_name}: {e}', exc_info=True) return False, f'Download failed: {str(e)}' async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]: """Download multiple models.""" results = {} for model_name in model_names: success, message = await download_model(model_name) results[model_name] = (success, message) logger.info(f'Download result for {model_name}: {success} - {message}') return results def get_models_directory_info() -> Dict: """Get information about the models directory.""" models_dir = settings.models_dir return { 'path': models_dir, 'size_mb': file_manager.get_directory_size_mb(models_dir), 'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')]) if os.path.isdir(models_dir) else 0, }