Initial Real-ESRGAN API project setup

This commit is contained in:
Developer
2026-02-16 19:56:25 +01:00
commit 0e59652575
34 changed files with 3668 additions and 0 deletions

1
app/services/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""API services."""

View File

@@ -0,0 +1,126 @@
"""File management utilities."""
import logging
import os
import shutil
import uuid
from typing import List, Tuple
from fastapi import UploadFile
from app.config import settings
logger = logging.getLogger(__name__)
def ensure_directories() -> None:
"""Ensure all required directories exist."""
for path in (settings.upload_dir, settings.output_dir, settings.models_dir,
settings.temp_dir, settings.jobs_dir):
os.makedirs(path, exist_ok=True)
logger.info(f'Directory ensured: {path}')
def create_request_dir() -> str:
"""Create a unique request directory."""
request_id = str(uuid.uuid4())
request_dir = os.path.join(settings.upload_dir, request_id)
os.makedirs(request_dir, exist_ok=True)
return request_dir
async def save_upload(file: UploadFile, directory: str) -> str:
"""Save uploaded file to directory."""
ext = os.path.splitext(file.filename or '')[1] or '.jpg'
filename = f'{uuid.uuid4()}{ext}'
filepath = os.path.join(directory, filename)
with open(filepath, 'wb') as f:
while chunk := await file.read(1024 * 1024):
f.write(chunk)
logger.debug(f'File saved: {filepath}')
return filepath
async def save_uploads(files: List[UploadFile], directory: str) -> List[str]:
"""Save multiple uploaded files to directory."""
paths = []
for file in files:
path = await save_upload(file, directory)
paths.append(path)
return paths
def generate_output_path(input_path: str, suffix: str = '_upscaled') -> str:
"""Generate output path for processed image."""
base, ext = os.path.splitext(input_path)
name = os.path.basename(base)
filename = f'{name}{suffix}{ext}'
return os.path.join(settings.output_dir, filename)
def cleanup_directory(directory: str) -> None:
"""Remove directory and all contents."""
if os.path.isdir(directory):
shutil.rmtree(directory, ignore_errors=True)
logger.debug(f'Cleaned up directory: {directory}')
def cleanup_file(filepath: str) -> None:
"""Remove a file."""
if os.path.isfile(filepath):
os.remove(filepath)
logger.debug(f'Cleaned up file: {filepath}')
def get_directory_size_mb(directory: str) -> float:
"""Get total size of directory in MB."""
total = 0
for dirpath, dirnames, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
if os.path.exists(fp):
total += os.path.getsize(fp)
return total / (1024 * 1024)
def list_model_files() -> List[Tuple[str, str, int]]:
"""Return list of (name, path, size_bytes) for all .pth/.onnx files in models dir."""
models = []
models_dir = settings.models_dir
if not os.path.isdir(models_dir):
return models
for name in sorted(os.listdir(models_dir)):
if name.endswith(('.pth', '.onnx', '.pt', '.safetensors')):
path = os.path.join(models_dir, name)
try:
size = os.path.getsize(path)
models.append((name, path, size))
except OSError:
logger.warning(f'Could not get size of model: {path}')
return models
def cleanup_old_jobs(hours: int = 24) -> int:
"""Clean up old job directories (older than specified hours)."""
import time
cutoff_time = time.time() - (hours * 3600)
cleaned = 0
if not os.path.isdir(settings.jobs_dir):
return cleaned
for item in os.listdir(settings.jobs_dir):
item_path = os.path.join(settings.jobs_dir, item)
if os.path.isdir(item_path):
try:
if os.path.getmtime(item_path) < cutoff_time:
cleanup_directory(item_path)
cleaned += 1
except OSError:
pass
if cleaned > 0:
logger.info(f'Cleaned up {cleaned} old job directories')
return cleaned

View File

@@ -0,0 +1,154 @@
"""Model management utilities."""
import json
import logging
import os
from typing import Dict, List, Optional
from app.config import settings
from app.services import file_manager
logger = logging.getLogger(__name__)
# Known models for Easy Real-ESRGAN
KNOWN_MODELS = {
'RealESRGAN_x2plus': {
'scale': 2,
'description': '2x upscaling',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
'size_mb': 66.7,
},
'RealESRGAN_x3plus': {
'scale': 3,
'description': '3x upscaling',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth',
'size_mb': 101.7,
},
'RealESRGAN_x4plus': {
'scale': 4,
'description': '4x upscaling (general purpose)',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth',
'size_mb': 101.7,
},
'RealESRGAN_x4plus_anime_6B': {
'scale': 4,
'description': '4x upscaling (anime/art)',
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth',
'size_mb': 18.9,
},
}
def get_available_models() -> List[Dict]:
"""Get list of available models with details."""
models = []
for model_name, metadata in KNOWN_MODELS.items():
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
available = os.path.exists(model_path)
size_mb = metadata['size_mb']
if available:
try:
actual_size = os.path.getsize(model_path) / (1024 * 1024)
size_mb = actual_size
except OSError:
pass
models.append({
'name': model_name,
'scale': metadata['scale'],
'description': metadata['description'],
'available': available,
'size_mb': size_mb,
'size_bytes': int(size_mb * 1024 * 1024),
})
return models
def is_model_available(model_name: str) -> bool:
"""Check if a model is available locally."""
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
return os.path.exists(model_path)
def get_model_scale(model_name: str) -> Optional[int]:
"""Get upscaling factor for a model."""
if model_name in KNOWN_MODELS:
return KNOWN_MODELS[model_name]['scale']
# Try to infer from model name
if 'x2' in model_name.lower():
return 2
elif 'x3' in model_name.lower():
return 3
elif 'x4' in model_name.lower():
return 4
return None
async def download_model(model_name: str) -> tuple[bool, str]:
"""
Download a model.
Returns: (success, message)
"""
if model_name not in KNOWN_MODELS:
return False, f'Unknown model: {model_name}'
if is_model_available(model_name):
return True, f'Model already available: {model_name}'
metadata = KNOWN_MODELS[model_name]
url = metadata['url']
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
try:
logger.info(f'Downloading model: {model_name} from {url}')
import urllib.request
os.makedirs(settings.models_dir, exist_ok=True)
def download_progress(count, block_size, total_size):
downloaded = count * block_size
percent = min(downloaded * 100 / total_size, 100)
logger.debug(f'Download progress: {percent:.1f}%')
urllib.request.urlretrieve(url, model_path, download_progress)
if os.path.exists(model_path):
size_mb = os.path.getsize(model_path) / (1024 * 1024)
logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)')
return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)'
else:
return False, f'Failed to download model: {model_name}'
except Exception as e:
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
return False, f'Download failed: {str(e)}'
async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]:
"""Download multiple models."""
results = {}
for model_name in model_names:
success, message = await download_model(model_name)
results[model_name] = (success, message)
logger.info(f'Download result for {model_name}: {success} - {message}')
return results
def get_models_directory_info() -> Dict:
"""Get information about the models directory."""
models_dir = settings.models_dir
return {
'path': models_dir,
'size_mb': file_manager.get_directory_size_mb(models_dir),
'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')])
if os.path.isdir(models_dir) else 0,
}

View File

@@ -0,0 +1,200 @@
"""Real-ESRGAN model management and processing."""
import logging
import os
from typing import Optional, Tuple
import cv2
import numpy as np
from app.config import settings
logger = logging.getLogger(__name__)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
REALESRGAN_AVAILABLE = True
except ImportError:
REALESRGAN_AVAILABLE = False
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
class RealESRGANBridge:
"""Bridge to Real-ESRGAN functionality."""
def __init__(self):
"""Initialize the Real-ESRGAN bridge."""
self.upsampler: Optional[RealESRGANer] = None
self.current_model: Optional[str] = None
self.initialized = False
def initialize(self) -> bool:
"""Initialize Real-ESRGAN upsampler."""
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN library not available')
return False
try:
logger.info('Initializing Real-ESRGAN upsampler...')
# Setup model loader
scale = 4
model_name = settings.default_model
# Determine model path
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
# Load model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path if os.path.exists(model_path) else None,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
self.initialized = True
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
return False
def load_model(self, model_name: str) -> bool:
"""Load a specific upscaling model."""
try:
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN not available')
return False
logger.info(f'Loading model: {model_name}')
# Extract scale from model name
scale = 4
if 'x2' in model_name.lower():
scale = 2
elif 'x3' in model_name.lower():
scale = 3
elif 'x4' in model_name.lower():
scale = 4
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.error(f'Model file not found: {model_path}')
return False
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
logger.info(f'Model loaded: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
return False
def upscale(
self,
input_path: str,
output_path: str,
model_name: Optional[str] = None,
outscale: Optional[float] = None,
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
"""
Upscale an image.
Returns: (success, message, output_size)
"""
try:
if not self.initialized:
if not self.initialize():
return False, 'Failed to initialize Real-ESRGAN', None
if model_name and model_name != self.current_model:
if not self.load_model(model_name):
return False, f'Failed to load model: {model_name}', None
if not self.upsampler:
return False, 'Upsampler not initialized', None
# Read image
logger.info(f'Reading image: {input_path}')
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
if input_img is None:
return False, f'Failed to read image: {input_path}', None
input_shape = input_img.shape[:2]
logger.info(f'Input image shape: {input_shape}')
# Upscale
logger.info(f'Upscaling with model: {self.current_model}')
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
# Save output
cv2.imwrite(str(output_path), output)
output_shape = output.shape[:2]
logger.info(f'Output image shape: {output_shape}')
logger.info(f'Upscaled image saved: {output_path}')
return True, 'Upscaling completed successfully', tuple(output_shape)
except Exception as e:
logger.error(f'Upscaling failed: {e}', exc_info=True)
return False, f'Upscaling failed: {str(e)}', None
def get_upscale_factor(self) -> int:
"""Get current upscaling factor."""
if self.upsampler:
return self.upsampler.scale
return 4
def clear_memory(self) -> None:
"""Clear GPU memory if available."""
try:
import torch
torch.cuda.empty_cache()
logger.debug('GPU memory cleared')
except Exception:
pass
# Global instance
_bridge: Optional[RealESRGANBridge] = None
def get_bridge() -> RealESRGANBridge:
"""Get or create the global Real-ESRGAN bridge."""
global _bridge
if _bridge is None:
_bridge = RealESRGANBridge()
return _bridge

217
app/services/worker.py Normal file
View File

@@ -0,0 +1,217 @@
"""Async job worker queue."""
import json
import logging
import os
import threading
import time
import uuid
from dataclasses import dataclass, asdict
from datetime import datetime
from queue import Queue
from typing import Callable, Dict, Optional
from app.config import settings
logger = logging.getLogger(__name__)
@dataclass
class Job:
"""Async job data."""
job_id: str
status: str # queued, processing, completed, failed
input_path: str
output_path: str
model: str
tile_size: Optional[int] = None
tile_pad: Optional[int] = None
outscale: Optional[float] = None
created_at: str = ''
started_at: Optional[str] = None
completed_at: Optional[str] = None
processing_time_seconds: Optional[float] = None
error: Optional[str] = None
def __post_init__(self):
if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
def save_metadata(self) -> None:
"""Save job metadata to JSON file."""
job_dir = os.path.join(settings.jobs_dir, self.job_id)
os.makedirs(job_dir, exist_ok=True)
metadata_path = os.path.join(job_dir, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def load_metadata(cls, job_id: str) -> Optional['Job']:
"""Load job metadata from JSON file."""
metadata_path = os.path.join(settings.jobs_dir, job_id, 'metadata.json')
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
return cls(**data)
except Exception as e:
logger.error(f'Failed to load job metadata: {e}')
return None
class WorkerQueue:
"""Thread pool worker queue for processing jobs."""
def __init__(self, worker_func: Callable, num_workers: int = 2):
"""
Initialize worker queue.
Args:
worker_func: Function to process jobs (job: Job) -> None
num_workers: Number of worker threads
"""
self.queue: Queue = Queue()
self.worker_func = worker_func
self.num_workers = num_workers
self.workers = []
self.running = False
self.jobs: Dict[str, Job] = {}
self.lock = threading.Lock()
def start(self) -> None:
"""Start worker threads."""
if self.running:
return
self.running = True
for i in range(self.num_workers):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self.workers.append(worker)
logger.info(f'Started {self.num_workers} worker threads')
def stop(self, timeout: int = 10) -> None:
"""Stop worker threads gracefully."""
self.running = False
# Signal workers to stop
for _ in range(self.num_workers):
self.queue.put(None)
# Wait for workers to finish
for worker in self.workers:
worker.join(timeout=timeout)
logger.info('Worker threads stopped')
def submit_job(
self,
input_path: str,
output_path: str,
model: str,
tile_size: Optional[int] = None,
tile_pad: Optional[int] = None,
outscale: Optional[float] = None,
) -> str:
"""
Submit a job for processing.
Returns: job_id
"""
job_id = str(uuid.uuid4())
job = Job(
job_id=job_id,
status='queued',
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
outscale=outscale,
)
with self.lock:
self.jobs[job_id] = job
job.save_metadata()
self.queue.put(job)
logger.info(f'Job submitted: {job_id}')
return job_id
def get_job(self, job_id: str) -> Optional[Job]:
"""Get job by ID."""
with self.lock:
return self.jobs.get(job_id)
def get_all_jobs(self) -> Dict[str, Job]:
"""Get all jobs."""
with self.lock:
return dict(self.jobs)
def _worker_loop(self) -> None:
"""Worker thread main loop."""
logger.info(f'Worker thread started')
while self.running:
try:
job = self.queue.get(timeout=1)
if job is None: # Stop signal
break
self._process_job(job)
except Exception:
pass # Timeout is normal
logger.info(f'Worker thread stopped')
def _process_job(self, job: Job) -> None:
"""Process a single job."""
try:
with self.lock:
job.status = 'processing'
job.started_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
job.save_metadata()
start_time = time.time()
self.worker_func(job)
job.processing_time_seconds = time.time() - start_time
with self.lock:
job.status = 'completed'
job.completed_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
logger.info(f'Job completed: {job.job_id} ({job.processing_time_seconds:.2f}s)')
except Exception as e:
logger.error(f'Job failed: {job.job_id}: {e}', exc_info=True)
with self.lock:
job.status = 'failed'
job.error = str(e)
job.completed_at = datetime.utcnow().isoformat()
self.jobs[job.job_id] = job
job.save_metadata()
# Global instance
_worker_queue: Optional[WorkerQueue] = None
def get_worker_queue(worker_func: Callable = None, num_workers: int = 2) -> WorkerQueue:
"""Get or create the global worker queue."""
global _worker_queue
if _worker_queue is None:
if worker_func is None:
raise ValueError('worker_func required for first initialization')
_worker_queue = WorkerQueue(worker_func, num_workers)
return _worker_queue