Initial Real-ESRGAN API project setup
This commit is contained in:
154
app/services/model_manager.py
Normal file
154
app/services/model_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model management utilities."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.services import file_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known models for Easy Real-ESRGAN
|
||||
KNOWN_MODELS = {
|
||||
'RealESRGAN_x2plus': {
|
||||
'scale': 2,
|
||||
'description': '2x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
'size_mb': 66.7,
|
||||
},
|
||||
'RealESRGAN_x3plus': {
|
||||
'scale': 3,
|
||||
'description': '3x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (general purpose)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus_anime_6B': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (anime/art)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth',
|
||||
'size_mb': 18.9,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> List[Dict]:
|
||||
"""Get list of available models with details."""
|
||||
models = []
|
||||
|
||||
for model_name, metadata in KNOWN_MODELS.items():
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
available = os.path.exists(model_path)
|
||||
|
||||
size_mb = metadata['size_mb']
|
||||
if available:
|
||||
try:
|
||||
actual_size = os.path.getsize(model_path) / (1024 * 1024)
|
||||
size_mb = actual_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'scale': metadata['scale'],
|
||||
'description': metadata['description'],
|
||||
'available': available,
|
||||
'size_mb': size_mb,
|
||||
'size_bytes': int(size_mb * 1024 * 1024),
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_model_available(model_name: str) -> bool:
|
||||
"""Check if a model is available locally."""
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
return os.path.exists(model_path)
|
||||
|
||||
|
||||
def get_model_scale(model_name: str) -> Optional[int]:
|
||||
"""Get upscaling factor for a model."""
|
||||
if model_name in KNOWN_MODELS:
|
||||
return KNOWN_MODELS[model_name]['scale']
|
||||
|
||||
# Try to infer from model name
|
||||
if 'x2' in model_name.lower():
|
||||
return 2
|
||||
elif 'x3' in model_name.lower():
|
||||
return 3
|
||||
elif 'x4' in model_name.lower():
|
||||
return 4
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def download_model(model_name: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download a model.
|
||||
|
||||
Returns: (success, message)
|
||||
"""
|
||||
if model_name not in KNOWN_MODELS:
|
||||
return False, f'Unknown model: {model_name}'
|
||||
|
||||
if is_model_available(model_name):
|
||||
return True, f'Model already available: {model_name}'
|
||||
|
||||
metadata = KNOWN_MODELS[model_name]
|
||||
url = metadata['url']
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
try:
|
||||
logger.info(f'Downloading model: {model_name} from {url}')
|
||||
|
||||
import urllib.request
|
||||
os.makedirs(settings.models_dir, exist_ok=True)
|
||||
|
||||
def download_progress(count, block_size, total_size):
|
||||
downloaded = count * block_size
|
||||
percent = min(downloaded * 100 / total_size, 100)
|
||||
logger.debug(f'Download progress: {percent:.1f}%')
|
||||
|
||||
urllib.request.urlretrieve(url, model_path, download_progress)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
size_mb = os.path.getsize(model_path) / (1024 * 1024)
|
||||
logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)')
|
||||
return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)'
|
||||
else:
|
||||
return False, f'Failed to download model: {model_name}'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
|
||||
return False, f'Download failed: {str(e)}'
|
||||
|
||||
|
||||
async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]:
|
||||
"""Download multiple models."""
|
||||
results = {}
|
||||
|
||||
for model_name in model_names:
|
||||
success, message = await download_model(model_name)
|
||||
results[model_name] = (success, message)
|
||||
logger.info(f'Download result for {model_name}: {success} - {message}')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_models_directory_info() -> Dict:
|
||||
"""Get information about the models directory."""
|
||||
models_dir = settings.models_dir
|
||||
|
||||
return {
|
||||
'path': models_dir,
|
||||
'size_mb': file_manager.get_directory_size_mb(models_dir),
|
||||
'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')])
|
||||
if os.path.isdir(models_dir) else 0,
|
||||
}
|
||||
Reference in New Issue
Block a user