Initial Real-ESRGAN API project setup

This commit is contained in:
Developer
2026-02-16 19:56:25 +01:00
commit 0e59652575
34 changed files with 3668 additions and 0 deletions

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Real-ESRGAN API application."""

40
app/config.py Normal file
View File

@@ -0,0 +1,40 @@
import json
from typing import List
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
model_config = {'env_prefix': 'RSR_'}
# Paths
upload_dir: str = '/data/uploads'
output_dir: str = '/data/outputs'
models_dir: str = '/data/models'
temp_dir: str = '/data/temp'
jobs_dir: str = '/data/jobs'
# Real-ESRGAN defaults
execution_providers: str = '["cpu"]'
execution_thread_count: int = 4
default_model: str = 'RealESRGAN_x4plus'
auto_model_download: bool = True
download_providers: str = '["huggingface"]'
tile_size: int = 400
tile_pad: int = 10
log_level: str = 'info'
# Limits
max_upload_size_mb: int = 500
max_image_dimension: int = 8192
sync_timeout_seconds: int = 300
auto_cleanup_hours: int = 24
def get_execution_providers(self) -> List[str]:
return json.loads(self.execution_providers)
def get_download_providers(self) -> List[str]:
return json.loads(self.download_providers)
settings = Settings()

111
app/main.py Normal file
View File

@@ -0,0 +1,111 @@
"""Real-ESRGAN API application."""
import logging
import os
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Ensure app is importable
_app_path = os.path.dirname(__file__)
if _app_path not in sys.path:
sys.path.insert(0, _app_path)
from app.routers import health, models, upscale
from app.services import file_manager, realesrgan_bridge, worker
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(name)s: %(message)s'
)
logger = logging.getLogger(__name__)
def _process_upscale_job(job) -> None:
"""Worker function to process upscaling jobs."""
from app.services import realesrgan_bridge
bridge = realesrgan_bridge.get_bridge()
success, message, _ = bridge.upscale(
input_path=job.input_path,
output_path=job.output_path,
model_name=job.model,
outscale=job.outscale,
)
if not success:
raise Exception(message)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifecycle manager."""
# Startup
logger.info('Starting Real-ESRGAN API...')
file_manager.ensure_directories()
bridge = realesrgan_bridge.get_bridge()
if not bridge.initialize():
logger.warning('Real-ESRGAN initialization failed (will attempt on first use)')
wq = worker.get_worker_queue(_process_upscale_job, num_workers=2)
wq.start()
logger.info('Real-ESRGAN API ready')
yield
# Shutdown
logger.info('Shutting down Real-ESRGAN API...')
wq.stop()
logger.info('Real-ESRGAN API stopped')
app = FastAPI(
title='Real-ESRGAN API',
version='1.0.0',
description='REST API for Real-ESRGAN image upscaling with async job processing',
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# Include routers
app.include_router(health.router)
app.include_router(models.router)
app.include_router(upscale.router)
@app.get('/')
async def root():
"""API root endpoint."""
return {
'name': 'Real-ESRGAN API',
'version': '1.0.0',
'docs': '/docs',
'redoc': '/redoc',
'endpoints': {
'health': '/api/v1/health',
'system': '/api/v1/system',
'models': '/api/v1/models',
'upscale': '/api/v1/upscale',
'jobs': '/api/v1/jobs',
},
}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
'app.main:app',
host='0.0.0.0',
port=8000,
reload=False,
)

1
app/routers/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""API routers."""

146
app/routers/health.py Normal file
View File

@@ -0,0 +1,146 @@
"""Health check and system information endpoints."""
import logging
import os
import time
from typing import Optional
import psutil
from fastapi import APIRouter, HTTPException
from app.config import settings
from app.schemas.health import HealthResponse, RequestStats, SystemInfo
from app.services import file_manager, worker
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/v1', tags=['system'])
# Track uptime
_start_time = time.time()
# Request statistics
_stats = {
'total_requests': 0,
'successful_requests': 0,
'failed_requests': 0,
'total_processing_time': 0.0,
'total_images_processed': 0,
}
@router.get('/health')
async def health_check() -> HealthResponse:
"""API health check."""
uptime = time.time() - _start_time
return HealthResponse(
status='ok',
version='1.0.0',
uptime_seconds=uptime,
message='Real-ESRGAN API is running',
)
@router.get('/health/ready')
async def readiness_check():
"""Kubernetes readiness probe."""
from app.services import realesrgan_bridge
bridge = realesrgan_bridge.get_bridge()
if not bridge.initialized:
raise HTTPException(status_code=503, detail='Not ready')
return {'ready': True}
@router.get('/health/live')
async def liveness_check():
"""Kubernetes liveness probe."""
return {'alive': True}
@router.get('/system')
async def get_system_info() -> SystemInfo:
"""Get comprehensive system information."""
try:
# Uptime
uptime = time.time() - _start_time
# CPU and memory
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
memory_percent = memory.percent
# Disk
disk = psutil.disk_usage('/')
disk_percent = disk.percent
# GPU
gpu_available = False
gpu_memory_mb = None
gpu_memory_used_mb = None
try:
import torch
gpu_available = torch.cuda.is_available()
if gpu_available:
gpu_memory_mb = int(torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
gpu_memory_used_mb = int(torch.cuda.memory_allocated(0) / (1024 * 1024))
except Exception:
pass
# Models directory size
models_size = file_manager.get_directory_size_mb(settings.models_dir)
# Jobs queue
wq = worker.get_worker_queue()
queue_length = wq.queue.qsize()
return SystemInfo(
status='ok',
version='1.0.0',
uptime_seconds=uptime,
cpu_usage_percent=cpu_percent,
memory_usage_percent=memory_percent,
disk_usage_percent=disk_percent,
gpu_available=gpu_available,
gpu_memory_mb=gpu_memory_mb,
gpu_memory_used_mb=gpu_memory_used_mb,
execution_providers=settings.get_execution_providers(),
models_dir_size_mb=models_size,
jobs_queue_length=queue_length,
)
except Exception as e:
logger.error(f'Failed to get system info: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/stats')
async def get_stats() -> RequestStats:
"""Get request statistics."""
avg_time = 0.0
if _stats['successful_requests'] > 0:
avg_time = _stats['total_processing_time'] / _stats['successful_requests']
return RequestStats(
total_requests=_stats['total_requests'],
successful_requests=_stats['successful_requests'],
failed_requests=_stats['failed_requests'],
average_processing_time_seconds=avg_time,
total_images_processed=_stats['total_images_processed'],
)
@router.post('/cleanup')
async def cleanup_old_jobs(hours: int = 24):
"""Clean up old job directories."""
try:
cleaned = file_manager.cleanup_old_jobs(hours)
return {
'success': True,
'cleaned_jobs': cleaned,
'message': f'Cleaned up {cleaned} job directories older than {hours} hours',
}
except Exception as e:
logger.error(f'Cleanup failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

109
app/routers/models.py Normal file
View File

@@ -0,0 +1,109 @@
"""Model management endpoints."""
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.models import ModelDownloadRequest, ModelDownloadResponse, ModelListResponse
from app.services import model_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/v1', tags=['models'])
@router.get('/models')
async def list_models() -> ModelListResponse:
"""List all available models."""
try:
available = model_manager.get_available_models()
local_count = sum(1 for m in available if m['available'])
return ModelListResponse(
available_models=available,
total_models=len(available),
local_models=local_count,
)
except Exception as e:
logger.error(f'Failed to list models: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/models/download')
async def download_models(request: ModelDownloadRequest) -> ModelDownloadResponse:
"""Download one or more models."""
if not request.models:
raise HTTPException(status_code=400, detail='No models specified')
try:
logger.info(f'Downloading models: {request.models}')
results = await model_manager.download_models(request.models)
downloaded = []
failed = []
errors = {}
for model_name, (success, message) in results.items():
if success:
downloaded.append(model_name)
else:
failed.append(model_name)
errors[model_name] = message
return ModelDownloadResponse(
success=len(failed) == 0,
message=f'Downloaded {len(downloaded)} model(s)',
downloaded=downloaded,
failed=failed,
errors=errors,
)
except Exception as e:
logger.error(f'Model download failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/models/{model_name}')
async def get_model_info(model_name: str):
"""Get information about a specific model."""
models = model_manager.get_available_models()
for model in models:
if model['name'] == model_name:
return model
raise HTTPException(status_code=404, detail=f'Model not found: {model_name}')
@router.post('/models/{model_name}/download')
async def download_model(model_name: str):
"""Download a specific model."""
try:
success, message = await model_manager.download_model(model_name)
if not success:
raise HTTPException(status_code=500, detail=message)
return {
'success': True,
'message': message,
'model': model_name,
}
except HTTPException:
raise
except Exception as e:
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/models-info')
async def get_models_directory_info():
"""Get information about the models directory."""
try:
info = model_manager.get_models_directory_info()
return {
'models_directory': info['path'],
'total_size_mb': round(info['size_mb'], 2),
'model_count': info['model_count'],
}
except Exception as e:
logger.error(f'Failed to get models directory info: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

265
app/routers/upscale.py Normal file
View File

@@ -0,0 +1,265 @@
"""Upscaling endpoints."""
import json
import logging
from time import time
from typing import List, Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse
from app.schemas.upscale import UpscaleOptions
from app.services import file_manager, realesrgan_bridge, worker
from app.services.model_manager import is_model_available
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/v1', tags=['upscaling'])
@router.post('/upscale')
async def upscale_sync(
image: UploadFile = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
outscale: Optional[float] = Form(None),
):
"""
Synchronous image upscaling.
Upscales an image using Real-ESRGAN and returns the result directly.
Suitable for small to medium images.
"""
request_dir = file_manager.create_request_dir()
try:
# Validate model
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}. Download it first using /api/v1/models/download'
)
# Save upload
input_path = await file_manager.save_upload(image, request_dir)
output_path = file_manager.generate_output_path(input_path)
# Process
start_time = time()
bridge = realesrgan_bridge.get_bridge()
success, message, output_size = bridge.upscale(
input_path=input_path,
output_path=output_path,
model_name=model,
outscale=outscale,
)
if not success:
raise HTTPException(status_code=500, detail=message)
processing_time = time() - start_time
logger.info(f'Sync upscaling completed in {processing_time:.2f}s')
return FileResponse(
path=output_path,
media_type='application/octet-stream',
filename=image.filename,
headers={'X-Processing-Time': f'{processing_time:.2f}'},
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Upscaling failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
finally:
file_manager.cleanup_directory(request_dir)
@router.post('/upscale-batch')
async def upscale_batch(
images: List[UploadFile] = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
):
"""
Batch upscaling via async jobs.
Submit multiple images for upscaling. Returns job IDs for monitoring.
"""
if not images:
raise HTTPException(status_code=400, detail='No images provided')
if len(images) > 100:
raise HTTPException(status_code=400, detail='Maximum 100 images per request')
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}'
)
request_dir = file_manager.create_request_dir()
job_ids = []
try:
# Save all images
input_paths = await file_manager.save_uploads(images, request_dir)
wq = worker.get_worker_queue()
# Submit jobs
for input_path in input_paths:
output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}')
job_id = wq.submit_job(
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
)
job_ids.append(job_id)
logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs')
return {
'success': True,
'job_ids': job_ids,
'total': len(job_ids),
'message': f'Batch processing started for {len(job_ids)} images',
}
except Exception as e:
logger.error(f'Batch submission failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/jobs')
async def create_job(
image: UploadFile = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
outscale: Optional[float] = Form(None),
):
"""
Create an async upscaling job.
Submit a single image for asynchronous upscaling.
Use /api/v1/jobs/{job_id} to check status and download result.
"""
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}'
)
request_dir = file_manager.create_request_dir()
try:
# Save upload
input_path = await file_manager.save_upload(image, request_dir)
output_path = file_manager.generate_output_path(input_path)
# Submit job
wq = worker.get_worker_queue()
job_id = wq.submit_job(
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
outscale=outscale,
)
return {
'success': True,
'job_id': job_id,
'status_url': f'/api/v1/jobs/{job_id}',
'result_url': f'/api/v1/jobs/{job_id}/result',
}
except Exception as e:
logger.error(f'Job creation failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/jobs/{job_id}')
async def get_job_status(job_id: str):
"""Get status of an upscaling job."""
wq = worker.get_worker_queue()
job = wq.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
return {
'job_id': job.job_id,
'status': job.status,
'model': job.model,
'created_at': job.created_at,
'started_at': job.started_at,
'completed_at': job.completed_at,
'processing_time_seconds': job.processing_time_seconds,
'error': job.error,
}
@router.get('/jobs/{job_id}/result')
async def get_job_result(job_id: str):
"""Download result of a completed upscaling job."""
wq = worker.get_worker_queue()
job = wq.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
if job.status == 'queued' or job.status == 'processing':
raise HTTPException(
status_code=202,
detail=f'Job is still processing: {job.status}'
)
if job.status == 'failed':
raise HTTPException(status_code=500, detail=f'Job failed: {job.error}')
if job.status != 'completed':
raise HTTPException(status_code=400, detail=f'Job status: {job.status}')
if not job.output_path or not __import__('os').path.exists(job.output_path):
raise HTTPException(status_code=404, detail='Result file not found')
return FileResponse(
path=job.output_path,
media_type='application/octet-stream',
filename=f'upscaled_{job_id}.png',
)
@router.get('/jobs')
async def list_jobs(
status: Optional[str] = None,
limit: int = 100,
):
"""List all jobs, optionally filtered by status."""
wq = worker.get_worker_queue()
all_jobs = wq.get_all_jobs()
jobs = []
for job in all_jobs.values():
if status and job.status != status:
continue
jobs.append({
'job_id': job.job_id,
'status': job.status,
'model': job.model,
'created_at': job.created_at,
'processing_time_seconds': job.processing_time_seconds,
})
# Sort by creation time (newest first) and limit
jobs.sort(key=lambda x: x['created_at'], reverse=True)
jobs = jobs[:limit]
return {
'total': len(all_jobs),
'returned': len(jobs),
'jobs': jobs,
}

1
app/schemas/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Pydantic schemas for request/response validation."""

37
app/schemas/health.py Normal file
View File

@@ -0,0 +1,37 @@
"""Schemas for health check and system information."""
from typing import Optional
from pydantic import BaseModel
class HealthResponse(BaseModel):
"""API health check response."""
status: str
version: str
uptime_seconds: float
message: str
class SystemInfo(BaseModel):
"""System information."""
status: str
version: str
uptime_seconds: float
cpu_usage_percent: float
memory_usage_percent: float
disk_usage_percent: float
gpu_available: bool
gpu_memory_mb: Optional[int] = None
gpu_memory_used_mb: Optional[int] = None
execution_providers: list
models_dir_size_mb: float
jobs_queue_length: int
class RequestStats(BaseModel):
"""API request statistics."""
total_requests: int
successful_requests: int
failed_requests: int
average_processing_time_seconds: float
total_images_processed: int

31
app/schemas/models.py Normal file
View File

@@ -0,0 +1,31 @@
"""Schemas for model management operations."""
from typing import List
from pydantic import BaseModel, Field
class ModelDownloadRequest(BaseModel):
"""Request to download models."""
models: List[str] = Field(
description='List of model names to download'
)
provider: str = Field(
default='huggingface',
description='Repository provider (huggingface, gdrive, etc.)'
)
class ModelDownloadResponse(BaseModel):
"""Response from model download."""
success: bool
message: str
downloaded: List[str] = Field(default_factory=list)
failed: List[str] = Field(default_factory=list)
errors: dict = Field(default_factory=dict)
class ModelListResponse(BaseModel):
"""Response containing list of models."""
available_models: List[dict]
total_models: int
local_models: int

57
app/schemas/upscale.py Normal file
View File

@@ -0,0 +1,57 @@
"""Schemas for upscaling operations."""
from typing import Optional
from pydantic import BaseModel, Field
class UpscaleOptions(BaseModel):
"""Options for image upscaling."""
model: str = Field(
default='RealESRGAN_x4plus',
description='Model to use for upscaling (RealESRGAN_x2plus, RealESRGAN_x3plus, RealESRGAN_x4plus, etc.)'
)
tile_size: Optional[int] = Field(
default=None,
description='Tile size for processing large images to avoid OOM'
)
tile_pad: Optional[int] = Field(
default=None,
description='Padding between tiles'
)
outscale: Optional[float] = Field(
default=None,
description='Upsampling scale factor'
)
class JobStatus(BaseModel):
"""Job status information."""
job_id: str
status: str # queued, processing, completed, failed
model: str
progress: float = Field(default=0.0, description='Progress as percentage 0-100')
result_path: Optional[str] = None
error: Optional[str] = None
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
processing_time_seconds: Optional[float] = None
class UpscaleResult(BaseModel):
"""Upscaling result."""
success: bool
message: str
processing_time_seconds: float
model: str
input_size: Optional[tuple] = None
output_size: Optional[tuple] = None
class ModelInfo(BaseModel):
"""Information about an available model."""
name: str
scale: int
path: str
size_mb: float
available: bool

1
app/services/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""API services."""

View File

@@ -0,0 +1,126 @@
"""File management utilities."""
import logging
import os
import shutil
import uuid
from typing import List, Tuple
from fastapi import UploadFile
from app.config import settings
logger = logging.getLogger(__name__)
def ensure_directories() -> None:
"""Ensure all required directories exist."""
for path in (settings.upload_dir, settings.output_dir, settings.models_dir,
settings.temp_dir, settings.jobs_dir):
os.makedirs(path, exist_ok=True)
logger.info(f'Directory ensured: {path}')
def create_request_dir() -> str:
"""Create a unique request directory."""
request_id = str(uuid.uuid4())
request_dir = os.path.join(settings.upload_dir, request_id)
os.makedirs(request_dir, exist_ok=True)
return request_dir
async def save_upload(file: UploadFile, directory: str) -> str:
"""Save uploaded file to directory."""
ext = os.path.splitext(file.filename or '')[1] or '.jpg'
filename = f'{uuid.uuid4()}{ext}'
filepath = os.path.join(directory, filename)
with open(filepath, 'wb') as f:
while chunk := await file.read(1024 * 1024):
f.write(chunk)
logger.debug(f'File saved: {filepath}')
return filepath
async def save_uploads(files: List[UploadFile], directory: str) -> List[str]:
"""Save multiple uploaded files to directory."""
paths = []
for file in files:
path = await save_upload(file, directory)
paths.append(path)
return paths
def generate_output_path(input_path: str, suffix: str = '_upscaled') -> str:
"""Generate output path for processed image."""
base, ext = os.path.splitext(input_path)
name = os.path.basename(base)
filename = f'{name}{suffix}{ext}'
return os.path.join(settings.output_dir, filename)
def cleanup_directory(directory: str) -> None:
"""Remove directory and all contents."""
if os.path.isdir(directory):
shutil.rmtree(directory, ignore_errors=True)
logger.debug(f'Cleaned up directory: {directory}')
def cleanup_file(filepath: str) -> None:
"""Remove a file."""
if os.path.isfile(filepath):
os.remove(filepath)
logger.debug(f'Cleaned up file: {filepath}')
def get_directory_size_mb(directory: str) -> float:
"""Get total size of directory in MB."""
total = 0
for dirpath, dirnames, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
if os.path.exists(fp):
total += os.path.getsize(fp)
return total / (1024 * 1024)
def list_model_files() -> List[Tuple[str, str, int]]:
"""Return list of (name, path, size_bytes) for all .pth/.onnx files in models dir."""
models = []
models_dir = settings.models_dir
if not os.path.isdir(models_dir):
return models
for name in sorted(os.listdir(models_dir)):
if name.endswith(('.pth', '.onnx', '.pt', '.safetensors')):
path = os.path.join(models_dir, name)
try:
size = os.path.getsize(path)
models.append((name, path, size))
except OSError:
logger.warning(f'Could not get size of model: {path}')
return models
def cleanup_old_jobs(hours: int = 24) -> int:
"""Clean up old job directories (older than specified hours)."""
import time
cutoff_time = time.time() - (hours * 3600)
cleaned = 0
if not os.path.isdir(settings.jobs_dir):
return cleaned
for item in os.listdir(settings.jobs_dir):
item_path = os.path.join(settings.jobs_dir, item)
if os.path.isdir(item_path):
try:
if os.path.getmtime(item_path) < cutoff_time:
cleanup_directory(item_path)
cleaned += 1
except OSError:
pass
if cleaned > 0:
logger.info(f'Cleaned up {cleaned} old job directories')
return cleaned

View File

@@ -0,0 +1,154 @@
"""Model management utilities."""
import json
import logging
import os
from typing import Dict, List, Optional
from app.config import settings
from app.services import file_manager
logger = logging.getLogger(__name__)
# Known models for Easy Real-ESRGAN
KNOWN_MODELS = {
'RealESRGAN_x2plus': {
'scale': 2,
'description': '2x upscaling',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
'size_mb': 66.7,
},
'RealESRGAN_x3plus': {
'scale': 3,
'description': '3x upscaling',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth',
'size_mb': 101.7,
},
'RealESRGAN_x4plus': {
'scale': 4,
'description': '4x upscaling (general purpose)',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth',
'size_mb': 101.7,
},
'RealESRGAN_x4plus_anime_6B': {
'scale': 4,
'description': '4x upscaling (anime/art)',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth',
'size_mb': 18.9,
},
}
def get_available_models() -> List[Dict]:
"""Get list of available models with details."""
models = []
for model_name, metadata in KNOWN_MODELS.items():
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
available = os.path.exists(model_path)
size_mb = metadata['size_mb']
if available:
try:
actual_size = os.path.getsize(model_path) / (1024 * 1024)
size_mb = actual_size
except OSError:
pass
models.append({
'name': model_name,
'scale': metadata['scale'],
'description': metadata['description'],
'available': available,
'size_mb': size_mb,
'size_bytes': int(size_mb * 1024 * 1024),
})
return models
def is_model_available(model_name: str) -> bool:
"""Check if a model is available locally."""
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
return os.path.exists(model_path)
def get_model_scale(model_name: str) -> Optional[int]:
"""Get upscaling factor for a model."""
if model_name in KNOWN_MODELS:
return KNOWN_MODELS[model_name]['scale']
# Try to infer from model name
if 'x2' in model_name.lower():
return 2
elif 'x3' in model_name.lower():
return 3
elif 'x4' in model_name.lower():
return 4
return None
async def download_model(model_name: str) -> tuple[bool, str]:
"""
Download a model.
Returns: (success, message)
"""
if model_name not in KNOWN_MODELS:
return False, f'Unknown model: {model_name}'
if is_model_available(model_name):
return True, f'Model already available: {model_name}'
metadata = KNOWN_MODELS[model_name]
url = metadata['url']
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
try:
logger.info(f'Downloading model: {model_name} from {url}')
import urllib.request
os.makedirs(settings.models_dir, exist_ok=True)
def download_progress(count, block_size, total_size):
downloaded = count * block_size
percent = min(downloaded * 100 / total_size, 100)
logger.debug(f'Download progress: {percent:.1f}%')
urllib.request.urlretrieve(url, model_path, download_progress)
if os.path.exists(model_path):
size_mb = os.path.getsize(model_path) / (1024 * 1024)
logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)')
return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)'
else:
return False, f'Failed to download model: {model_name}'
except Exception as e:
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
return False, f'Download failed: {str(e)}'
async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]:
"""Download multiple models."""
results = {}
for model_name in model_names:
success, message = await download_model(model_name)
results[model_name] = (success, message)
logger.info(f'Download result for {model_name}: {success} - {message}')
return results
def get_models_directory_info() -> Dict:
"""Get information about the models directory."""
models_dir = settings.models_dir
return {
'path': models_dir,
'size_mb': file_manager.get_directory_size_mb(models_dir),
'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')])
if os.path.isdir(models_dir) else 0,
}

View File

@@ -0,0 +1,200 @@
"""Real-ESRGAN model management and processing."""
import logging
import os
from typing import Optional, Tuple
import cv2
import numpy as np
from app.config import settings
logger = logging.getLogger(__name__)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
REALESRGAN_AVAILABLE = True
except ImportError:
REALESRGAN_AVAILABLE = False
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
class RealESRGANBridge:
"""Bridge to Real-ESRGAN functionality."""
def __init__(self):
"""Initialize the Real-ESRGAN bridge."""
self.upsampler: Optional[RealESRGANer] = None
self.current_model: Optional[str] = None
self.initialized = False
def initialize(self) -> bool:
"""Initialize Real-ESRGAN upsampler."""
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN library not available')
return False
try:
logger.info('Initializing Real-ESRGAN upsampler...')
# Setup model loader
scale = 4
model_name = settings.default_model
# Determine model path
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
# Load model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path if os.path.exists(model_path) else None,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
self.initialized = True
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
return False
def load_model(self, model_name: str) -> bool:
"""Load a specific upscaling model."""
try:
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN not available')
return False
logger.info(f'Loading model: {model_name}')
# Extract scale from model name
scale = 4
if 'x2' in model_name.lower():
scale = 2
elif 'x3' in model_name.lower():
scale = 3
elif 'x4' in model_name.lower():
scale = 4
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.error(f'Model file not found: {model_path}')
return False
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
logger.info(f'Model loaded: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
return False
def upscale(
self,
input_path: str,
output_path: str,
model_name: Optional[str] = None,
outscale: Optional[float] = None,
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
"""
Upscale an image.
Returns: (success, message, output_size)
"""
try:
if not self.initialized:
if not self.initialize():
return False, 'Failed to initialize Real-ESRGAN', None
if model_name and model_name != self.current_model:
if not self.load_model(model_name):
return False, f'Failed to load model: {model_name}', None
if not self.upsampler:
return False, 'Upsampler not initialized', None
# Read image
logger.info(f'Reading image: {input_path}')
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
if input_img is None:
return False, f'Failed to read image: {input_path}', None
input_shape = input_img.shape[:2]
logger.info(f'Input image shape: {input_shape}')
# Upscale
logger.info(f'Upscaling with model: {self.current_model}')
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
# Save output
cv2.imwrite(str(output_path), output)
output_shape = output.shape[:2]
logger.info(f'Output image shape: {output_shape}')
logger.info(f'Upscaled image saved: {output_path}')
return True, 'Upscaling completed successfully', tuple(output_shape)
except Exception as e:
logger.error(f'Upscaling failed: {e}', exc_info=True)
return False, f'Upscaling failed: {str(e)}', None
def get_upscale_factor(self) -> int:
"""Get current upscaling factor."""
if self.upsampler:
return self.upsampler.scale
return 4
def clear_memory(self) -> None:
"""Clear GPU memory if available."""
try:
import torch
torch.cuda.empty_cache()
logger.debug('GPU memory cleared')
except Exception:
pass
# Global instance
_bridge: Optional[RealESRGANBridge] = None
def get_bridge() -> RealESRGANBridge:
"""Get or create the global Real-ESRGAN bridge."""
global _bridge
if _bridge is None:
_bridge = RealESRGANBridge()
return _bridge

217
app/services/worker.py Normal file
View File

@@ -0,0 +1,217 @@
"""Async job worker queue."""
import json
import logging
import os
import threading
import time
import uuid
from dataclasses import dataclass, asdict
from datetime import datetime
from queue import Queue
from typing import Callable, Dict, Optional
from app.config import settings
logger = logging.getLogger(__name__)
@dataclass
class Job:
"""Async job data."""
job_id: str
status: str # queued, processing, completed, failed
input_path: str
output_path: str
model: str
tile_size: Optional[int] = None
tile_pad: Optional[int] = None
outscale: Optional[float] = None
created_at: str = ''
started_at: Optional[str] = None
completed_at: Optional[str] = None
processing_time_seconds: Optional[float] = None
error: Optional[str] = None
def __post_init__(self):
if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
def save_metadata(self) -> None:
"""Save job metadata to JSON file."""
job_dir = os.path.join(settings.jobs_dir, self.job_id)
os.makedirs(job_dir, exist_ok=True)
metadata_path = os.path.join(job_dir, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def load_metadata(cls, job_id: str) -> Optional['Job']:
"""Load job metadata from JSON file."""
metadata_path = os.path.join(settings.jobs_dir, job_id, 'metadata.json')
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
return cls(**data)
except Exception as e:
logger.error(f'Failed to load job metadata: {e}')
return None
class WorkerQueue:
"""Thread pool worker queue for processing jobs."""
def __init__(self, worker_func: Callable, num_workers: int = 2):
"""
Initialize worker queue.
Args:
worker_func: Function to process jobs (job: Job) -> None
num_workers: Number of worker threads
"""
self.queue: Queue = Queue()
self.worker_func = worker_func
self.num_workers = num_workers
self.workers = []
self.running = False
self.jobs: Dict[str, Job] = {}
self.lock = threading.Lock()
def start(self) -> None:
"""Start worker threads."""
if self.running:
return
self.running = True
for i in range(self.num_workers):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self.workers.append(worker)
logger.info(f'Started {self.num_workers} worker threads')
def stop(self, timeout: int = 10) -> None:
"""Stop worker threads gracefully."""
self.running = False
# Signal workers to stop
for _ in range(self.num_workers):
self.queue.put(None)
# Wait for workers to finish
for worker in self.workers:
worker.join(timeout=timeout)
logger.info('Worker threads stopped')
def submit_job(
self,
input_path: str,
output_path: str,
model: str,
tile_size: Optional[int] = None,
tile_pad: Optional[int] = None,
outscale: Optional[float] = None,
) -> str:
"""
Submit a job for processing.
Returns: job_id
"""
job_id = str(uuid.uuid4())
job = Job(
job_id=job_id,
status='queued',
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
outscale=outscale,
)
with self.lock:
self.jobs[job_id] = job
job.save_metadata()
self.queue.put(job)
logger.info(f'Job submitted: {job_id}')
return job_id
def get_job(self, job_id: str) -> Optional[Job]:
"""Get job by ID."""
with self.lock:
return self.jobs.get(job_id)
def get_all_jobs(self) -> Dict[str, Job]:
"""Get all jobs."""
with self.lock:
return dict(self.jobs)
def _worker_loop(self) -> None:
"""Worker thread main loop."""
logger.info(f'Worker thread started')
while self.running:
try:
job = self.queue.get(timeout=1)
if job is None: # Stop signal
break
self._process_job(job)
except Exception:
pass # Timeout is normal
logger.info(f'Worker thread stopped')
def _process_job(self, job: Job) -> None:
"""Process a single job."""
try:
with self.lock:
job.status = 'processing'
job.started_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
job.save_metadata()
start_time = time.time()
self.worker_func(job)
job.processing_time_seconds = time.time() - start_time
with self.lock:
job.status = 'completed'
job.completed_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
logger.info(f'Job completed: {job.job_id} ({job.processing_time_seconds:.2f}s)')
except Exception as e:
logger.error(f'Job failed: {job.job_id}: {e}', exc_info=True)
with self.lock:
job.status = 'failed'
job.error = str(e)
job.completed_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
job.save_metadata()
# Global instance
_worker_queue: Optional[WorkerQueue] = None
def get_worker_queue(worker_func: Callable = None, num_workers: int = 2) -> WorkerQueue:
"""Get or create the global worker queue."""
global _worker_queue
if _worker_queue is None:
if worker_func is None:
raise ValueError('worker_func required for first initialization')
_worker_queue = WorkerQueue(worker_func, num_workers)
return _worker_queue