Initial Real-ESRGAN API project setup
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Real-ESRGAN API application."""
|
||||
40
app/config.py
Normal file
40
app/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {'env_prefix': 'RSR_'}
|
||||
|
||||
# Paths
|
||||
upload_dir: str = '/data/uploads'
|
||||
output_dir: str = '/data/outputs'
|
||||
models_dir: str = '/data/models'
|
||||
temp_dir: str = '/data/temp'
|
||||
jobs_dir: str = '/data/jobs'
|
||||
|
||||
# Real-ESRGAN defaults
|
||||
execution_providers: str = '["cpu"]'
|
||||
execution_thread_count: int = 4
|
||||
default_model: str = 'RealESRGAN_x4plus'
|
||||
auto_model_download: bool = True
|
||||
download_providers: str = '["huggingface"]'
|
||||
tile_size: int = 400
|
||||
tile_pad: int = 10
|
||||
log_level: str = 'info'
|
||||
|
||||
# Limits
|
||||
max_upload_size_mb: int = 500
|
||||
max_image_dimension: int = 8192
|
||||
sync_timeout_seconds: int = 300
|
||||
auto_cleanup_hours: int = 24
|
||||
|
||||
def get_execution_providers(self) -> List[str]:
|
||||
return json.loads(self.execution_providers)
|
||||
|
||||
def get_download_providers(self) -> List[str]:
|
||||
return json.loads(self.download_providers)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
111
app/main.py
Normal file
111
app/main.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Real-ESRGAN API application."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Ensure app is importable
|
||||
_app_path = os.path.dirname(__file__)
|
||||
if _app_path not in sys.path:
|
||||
sys.path.insert(0, _app_path)
|
||||
|
||||
from app.routers import health, models, upscale
|
||||
from app.services import file_manager, realesrgan_bridge, worker
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s %(levelname)s %(name)s: %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _process_upscale_job(job) -> None:
|
||||
"""Worker function to process upscaling jobs."""
|
||||
from app.services import realesrgan_bridge
|
||||
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
success, message, _ = bridge.upscale(
|
||||
input_path=job.input_path,
|
||||
output_path=job.output_path,
|
||||
model_name=job.model,
|
||||
outscale=job.outscale,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise Exception(message)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifecycle manager."""
|
||||
# Startup
|
||||
logger.info('Starting Real-ESRGAN API...')
|
||||
file_manager.ensure_directories()
|
||||
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
if not bridge.initialize():
|
||||
logger.warning('Real-ESRGAN initialization failed (will attempt on first use)')
|
||||
|
||||
wq = worker.get_worker_queue(_process_upscale_job, num_workers=2)
|
||||
wq.start()
|
||||
|
||||
logger.info('Real-ESRGAN API ready')
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info('Shutting down Real-ESRGAN API...')
|
||||
wq.stop()
|
||||
logger.info('Real-ESRGAN API stopped')
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title='Real-ESRGAN API',
|
||||
version='1.0.0',
|
||||
description='REST API for Real-ESRGAN image upscaling with async job processing',
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['*'],
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(health.router)
|
||||
app.include_router(models.router)
|
||||
app.include_router(upscale.router)
|
||||
|
||||
|
||||
@app.get('/')
|
||||
async def root():
|
||||
"""API root endpoint."""
|
||||
return {
|
||||
'name': 'Real-ESRGAN API',
|
||||
'version': '1.0.0',
|
||||
'docs': '/docs',
|
||||
'redoc': '/redoc',
|
||||
'endpoints': {
|
||||
'health': '/api/v1/health',
|
||||
'system': '/api/v1/system',
|
||||
'models': '/api/v1/models',
|
||||
'upscale': '/api/v1/upscale',
|
||||
'jobs': '/api/v1/jobs',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
'app.main:app',
|
||||
host='0.0.0.0',
|
||||
port=8000,
|
||||
reload=False,
|
||||
)
|
||||
1
app/routers/__init__.py
Normal file
1
app/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routers."""
|
||||
146
app/routers/health.py
Normal file
146
app/routers/health.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Health check and system information endpoints."""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.health import HealthResponse, RequestStats, SystemInfo
|
||||
from app.services import file_manager, worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['system'])
|
||||
|
||||
# Track uptime
|
||||
_start_time = time.time()
|
||||
|
||||
# Request statistics
|
||||
_stats = {
|
||||
'total_requests': 0,
|
||||
'successful_requests': 0,
|
||||
'failed_requests': 0,
|
||||
'total_processing_time': 0.0,
|
||||
'total_images_processed': 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get('/health')
|
||||
async def health_check() -> HealthResponse:
|
||||
"""API health check."""
|
||||
uptime = time.time() - _start_time
|
||||
return HealthResponse(
|
||||
status='ok',
|
||||
version='1.0.0',
|
||||
uptime_seconds=uptime,
|
||||
message='Real-ESRGAN API is running',
|
||||
)
|
||||
|
||||
|
||||
@router.get('/health/ready')
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe."""
|
||||
from app.services import realesrgan_bridge
|
||||
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
|
||||
if not bridge.initialized:
|
||||
raise HTTPException(status_code=503, detail='Not ready')
|
||||
|
||||
return {'ready': True}
|
||||
|
||||
|
||||
@router.get('/health/live')
|
||||
async def liveness_check():
|
||||
"""Kubernetes liveness probe."""
|
||||
return {'alive': True}
|
||||
|
||||
|
||||
@router.get('/system')
|
||||
async def get_system_info() -> SystemInfo:
|
||||
"""Get comprehensive system information."""
|
||||
try:
|
||||
# Uptime
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
# CPU and memory
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Disk
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_percent = disk.percent
|
||||
|
||||
# GPU
|
||||
gpu_available = False
|
||||
gpu_memory_mb = None
|
||||
gpu_memory_used_mb = None
|
||||
|
||||
try:
|
||||
import torch
|
||||
gpu_available = torch.cuda.is_available()
|
||||
if gpu_available:
|
||||
gpu_memory_mb = int(torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
gpu_memory_used_mb = int(torch.cuda.memory_allocated(0) / (1024 * 1024))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Models directory size
|
||||
models_size = file_manager.get_directory_size_mb(settings.models_dir)
|
||||
|
||||
# Jobs queue
|
||||
wq = worker.get_worker_queue()
|
||||
queue_length = wq.queue.qsize()
|
||||
|
||||
return SystemInfo(
|
||||
status='ok',
|
||||
version='1.0.0',
|
||||
uptime_seconds=uptime,
|
||||
cpu_usage_percent=cpu_percent,
|
||||
memory_usage_percent=memory_percent,
|
||||
disk_usage_percent=disk_percent,
|
||||
gpu_available=gpu_available,
|
||||
gpu_memory_mb=gpu_memory_mb,
|
||||
gpu_memory_used_mb=gpu_memory_used_mb,
|
||||
execution_providers=settings.get_execution_providers(),
|
||||
models_dir_size_mb=models_size,
|
||||
jobs_queue_length=queue_length,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get system info: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/stats')
|
||||
async def get_stats() -> RequestStats:
|
||||
"""Get request statistics."""
|
||||
avg_time = 0.0
|
||||
if _stats['successful_requests'] > 0:
|
||||
avg_time = _stats['total_processing_time'] / _stats['successful_requests']
|
||||
|
||||
return RequestStats(
|
||||
total_requests=_stats['total_requests'],
|
||||
successful_requests=_stats['successful_requests'],
|
||||
failed_requests=_stats['failed_requests'],
|
||||
average_processing_time_seconds=avg_time,
|
||||
total_images_processed=_stats['total_images_processed'],
|
||||
)
|
||||
|
||||
|
||||
@router.post('/cleanup')
|
||||
async def cleanup_old_jobs(hours: int = 24):
|
||||
"""Clean up old job directories."""
|
||||
try:
|
||||
cleaned = file_manager.cleanup_old_jobs(hours)
|
||||
return {
|
||||
'success': True,
|
||||
'cleaned_jobs': cleaned,
|
||||
'message': f'Cleaned up {cleaned} job directories older than {hours} hours',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Cleanup failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
109
app/routers/models.py
Normal file
109
app/routers/models.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Model management endpoints."""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.models import ModelDownloadRequest, ModelDownloadResponse, ModelListResponse
|
||||
from app.services import model_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['models'])
|
||||
|
||||
|
||||
@router.get('/models')
|
||||
async def list_models() -> ModelListResponse:
|
||||
"""List all available models."""
|
||||
try:
|
||||
available = model_manager.get_available_models()
|
||||
local_count = sum(1 for m in available if m['available'])
|
||||
|
||||
return ModelListResponse(
|
||||
available_models=available,
|
||||
total_models=len(available),
|
||||
local_models=local_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to list models: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post('/models/download')
|
||||
async def download_models(request: ModelDownloadRequest) -> ModelDownloadResponse:
|
||||
"""Download one or more models."""
|
||||
if not request.models:
|
||||
raise HTTPException(status_code=400, detail='No models specified')
|
||||
|
||||
try:
|
||||
logger.info(f'Downloading models: {request.models}')
|
||||
results = await model_manager.download_models(request.models)
|
||||
|
||||
downloaded = []
|
||||
failed = []
|
||||
errors = {}
|
||||
|
||||
for model_name, (success, message) in results.items():
|
||||
if success:
|
||||
downloaded.append(model_name)
|
||||
else:
|
||||
failed.append(model_name)
|
||||
errors[model_name] = message
|
||||
|
||||
return ModelDownloadResponse(
|
||||
success=len(failed) == 0,
|
||||
message=f'Downloaded {len(downloaded)} model(s)',
|
||||
downloaded=downloaded,
|
||||
failed=failed,
|
||||
errors=errors,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Model download failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/models/{model_name}')
|
||||
async def get_model_info(model_name: str):
|
||||
"""Get information about a specific model."""
|
||||
models = model_manager.get_available_models()
|
||||
|
||||
for model in models:
|
||||
if model['name'] == model_name:
|
||||
return model
|
||||
|
||||
raise HTTPException(status_code=404, detail=f'Model not found: {model_name}')
|
||||
|
||||
|
||||
@router.post('/models/{model_name}/download')
|
||||
async def download_model(model_name: str):
|
||||
"""Download a specific model."""
|
||||
try:
|
||||
success, message = await model_manager.download_model(model_name)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=message)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': message,
|
||||
'model': model_name,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/models-info')
|
||||
async def get_models_directory_info():
|
||||
"""Get information about the models directory."""
|
||||
try:
|
||||
info = model_manager.get_models_directory_info()
|
||||
return {
|
||||
'models_directory': info['path'],
|
||||
'total_size_mb': round(info['size_mb'], 2),
|
||||
'model_count': info['model_count'],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get models directory info: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
265
app/routers/upscale.py
Normal file
265
app/routers/upscale.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Upscaling endpoints."""
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.schemas.upscale import UpscaleOptions
|
||||
from app.services import file_manager, realesrgan_bridge, worker
|
||||
from app.services.model_manager import is_model_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['upscaling'])
|
||||
|
||||
|
||||
@router.post('/upscale')
|
||||
async def upscale_sync(
|
||||
image: UploadFile = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
outscale: Optional[float] = Form(None),
|
||||
):
|
||||
"""
|
||||
Synchronous image upscaling.
|
||||
|
||||
Upscales an image using Real-ESRGAN and returns the result directly.
|
||||
Suitable for small to medium images.
|
||||
"""
|
||||
request_dir = file_manager.create_request_dir()
|
||||
|
||||
try:
|
||||
# Validate model
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}. Download it first using /api/v1/models/download'
|
||||
)
|
||||
|
||||
# Save upload
|
||||
input_path = await file_manager.save_upload(image, request_dir)
|
||||
output_path = file_manager.generate_output_path(input_path)
|
||||
|
||||
# Process
|
||||
start_time = time()
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
success, message, output_size = bridge.upscale(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model_name=model,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=message)
|
||||
|
||||
processing_time = time() - start_time
|
||||
logger.info(f'Sync upscaling completed in {processing_time:.2f}s')
|
||||
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
media_type='application/octet-stream',
|
||||
filename=image.filename,
|
||||
headers={'X-Processing-Time': f'{processing_time:.2f}'},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
file_manager.cleanup_directory(request_dir)
|
||||
|
||||
|
||||
@router.post('/upscale-batch')
|
||||
async def upscale_batch(
|
||||
images: List[UploadFile] = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
):
|
||||
"""
|
||||
Batch upscaling via async jobs.
|
||||
|
||||
Submit multiple images for upscaling. Returns job IDs for monitoring.
|
||||
"""
|
||||
if not images:
|
||||
raise HTTPException(status_code=400, detail='No images provided')
|
||||
|
||||
if len(images) > 100:
|
||||
raise HTTPException(status_code=400, detail='Maximum 100 images per request')
|
||||
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}'
|
||||
)
|
||||
|
||||
request_dir = file_manager.create_request_dir()
|
||||
job_ids = []
|
||||
|
||||
try:
|
||||
# Save all images
|
||||
input_paths = await file_manager.save_uploads(images, request_dir)
|
||||
wq = worker.get_worker_queue()
|
||||
|
||||
# Submit jobs
|
||||
for input_path in input_paths:
|
||||
output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}')
|
||||
job_id = wq.submit_job(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
)
|
||||
job_ids.append(job_id)
|
||||
|
||||
logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'job_ids': job_ids,
|
||||
'total': len(job_ids),
|
||||
'message': f'Batch processing started for {len(job_ids)} images',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Batch submission failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post('/jobs')
|
||||
async def create_job(
|
||||
image: UploadFile = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
outscale: Optional[float] = Form(None),
|
||||
):
|
||||
"""
|
||||
Create an async upscaling job.
|
||||
|
||||
Submit a single image for asynchronous upscaling.
|
||||
Use /api/v1/jobs/{job_id} to check status and download result.
|
||||
"""
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}'
|
||||
)
|
||||
|
||||
request_dir = file_manager.create_request_dir()
|
||||
|
||||
try:
|
||||
# Save upload
|
||||
input_path = await file_manager.save_upload(image, request_dir)
|
||||
output_path = file_manager.generate_output_path(input_path)
|
||||
|
||||
# Submit job
|
||||
wq = worker.get_worker_queue()
|
||||
job_id = wq.submit_job(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'job_id': job_id,
|
||||
'status_url': f'/api/v1/jobs/{job_id}',
|
||||
'result_url': f'/api/v1/jobs/{job_id}/result',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Job creation failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/jobs/{job_id}')
|
||||
async def get_job_status(job_id: str):
|
||||
"""Get status of an upscaling job."""
|
||||
wq = worker.get_worker_queue()
|
||||
job = wq.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
||||
|
||||
return {
|
||||
'job_id': job.job_id,
|
||||
'status': job.status,
|
||||
'model': job.model,
|
||||
'created_at': job.created_at,
|
||||
'started_at': job.started_at,
|
||||
'completed_at': job.completed_at,
|
||||
'processing_time_seconds': job.processing_time_seconds,
|
||||
'error': job.error,
|
||||
}
|
||||
|
||||
|
||||
@router.get('/jobs/{job_id}/result')
|
||||
async def get_job_result(job_id: str):
|
||||
"""Download result of a completed upscaling job."""
|
||||
wq = worker.get_worker_queue()
|
||||
job = wq.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
||||
|
||||
if job.status == 'queued' or job.status == 'processing':
|
||||
raise HTTPException(
|
||||
status_code=202,
|
||||
detail=f'Job is still processing: {job.status}'
|
||||
)
|
||||
|
||||
if job.status == 'failed':
|
||||
raise HTTPException(status_code=500, detail=f'Job failed: {job.error}')
|
||||
|
||||
if job.status != 'completed':
|
||||
raise HTTPException(status_code=400, detail=f'Job status: {job.status}')
|
||||
|
||||
if not job.output_path or not __import__('os').path.exists(job.output_path):
|
||||
raise HTTPException(status_code=404, detail='Result file not found')
|
||||
|
||||
return FileResponse(
|
||||
path=job.output_path,
|
||||
media_type='application/octet-stream',
|
||||
filename=f'upscaled_{job_id}.png',
|
||||
)
|
||||
|
||||
|
||||
@router.get('/jobs')
|
||||
async def list_jobs(
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
):
|
||||
"""List all jobs, optionally filtered by status."""
|
||||
wq = worker.get_worker_queue()
|
||||
all_jobs = wq.get_all_jobs()
|
||||
|
||||
jobs = []
|
||||
for job in all_jobs.values():
|
||||
if status and job.status != status:
|
||||
continue
|
||||
jobs.append({
|
||||
'job_id': job.job_id,
|
||||
'status': job.status,
|
||||
'model': job.model,
|
||||
'created_at': job.created_at,
|
||||
'processing_time_seconds': job.processing_time_seconds,
|
||||
})
|
||||
|
||||
# Sort by creation time (newest first) and limit
|
||||
jobs.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
jobs = jobs[:limit]
|
||||
|
||||
return {
|
||||
'total': len(all_jobs),
|
||||
'returned': len(jobs),
|
||||
'jobs': jobs,
|
||||
}
|
||||
1
app/schemas/__init__.py
Normal file
1
app/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
37
app/schemas/health.py
Normal file
37
app/schemas/health.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Schemas for health check and system information."""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""API health check response."""
|
||||
status: str
|
||||
version: str
|
||||
uptime_seconds: float
|
||||
message: str
|
||||
|
||||
|
||||
class SystemInfo(BaseModel):
|
||||
"""System information."""
|
||||
status: str
|
||||
version: str
|
||||
uptime_seconds: float
|
||||
cpu_usage_percent: float
|
||||
memory_usage_percent: float
|
||||
disk_usage_percent: float
|
||||
gpu_available: bool
|
||||
gpu_memory_mb: Optional[int] = None
|
||||
gpu_memory_used_mb: Optional[int] = None
|
||||
execution_providers: list
|
||||
models_dir_size_mb: float
|
||||
jobs_queue_length: int
|
||||
|
||||
|
||||
class RequestStats(BaseModel):
|
||||
"""API request statistics."""
|
||||
total_requests: int
|
||||
successful_requests: int
|
||||
failed_requests: int
|
||||
average_processing_time_seconds: float
|
||||
total_images_processed: int
|
||||
31
app/schemas/models.py
Normal file
31
app/schemas/models.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Schemas for model management operations."""
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelDownloadRequest(BaseModel):
|
||||
"""Request to download models."""
|
||||
models: List[str] = Field(
|
||||
description='List of model names to download'
|
||||
)
|
||||
provider: str = Field(
|
||||
default='huggingface',
|
||||
description='Repository provider (huggingface, gdrive, etc.)'
|
||||
)
|
||||
|
||||
|
||||
class ModelDownloadResponse(BaseModel):
|
||||
"""Response from model download."""
|
||||
success: bool
|
||||
message: str
|
||||
downloaded: List[str] = Field(default_factory=list)
|
||||
failed: List[str] = Field(default_factory=list)
|
||||
errors: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
"""Response containing list of models."""
|
||||
available_models: List[dict]
|
||||
total_models: int
|
||||
local_models: int
|
||||
57
app/schemas/upscale.py
Normal file
57
app/schemas/upscale.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Schemas for upscaling operations."""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UpscaleOptions(BaseModel):
|
||||
"""Options for image upscaling."""
|
||||
model: str = Field(
|
||||
default='RealESRGAN_x4plus',
|
||||
description='Model to use for upscaling (RealESRGAN_x2plus, RealESRGAN_x3plus, RealESRGAN_x4plus, etc.)'
|
||||
)
|
||||
tile_size: Optional[int] = Field(
|
||||
default=None,
|
||||
description='Tile size for processing large images to avoid OOM'
|
||||
)
|
||||
tile_pad: Optional[int] = Field(
|
||||
default=None,
|
||||
description='Padding between tiles'
|
||||
)
|
||||
outscale: Optional[float] = Field(
|
||||
default=None,
|
||||
description='Upsampling scale factor'
|
||||
)
|
||||
|
||||
|
||||
class JobStatus(BaseModel):
|
||||
"""Job status information."""
|
||||
job_id: str
|
||||
status: str # queued, processing, completed, failed
|
||||
model: str
|
||||
progress: float = Field(default=0.0, description='Progress as percentage 0-100')
|
||||
result_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
processing_time_seconds: Optional[float] = None
|
||||
|
||||
|
||||
class UpscaleResult(BaseModel):
|
||||
"""Upscaling result."""
|
||||
success: bool
|
||||
message: str
|
||||
processing_time_seconds: float
|
||||
model: str
|
||||
input_size: Optional[tuple] = None
|
||||
output_size: Optional[tuple] = None
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about an available model."""
|
||||
name: str
|
||||
scale: int
|
||||
path: str
|
||||
size_mb: float
|
||||
available: bool
|
||||
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API services."""
|
||||
126
app/services/file_manager.py
Normal file
126
app/services/file_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""File management utilities."""
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_directories() -> None:
|
||||
"""Ensure all required directories exist."""
|
||||
for path in (settings.upload_dir, settings.output_dir, settings.models_dir,
|
||||
settings.temp_dir, settings.jobs_dir):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
logger.info(f'Directory ensured: {path}')
|
||||
|
||||
|
||||
def create_request_dir() -> str:
|
||||
"""Create a unique request directory."""
|
||||
request_id = str(uuid.uuid4())
|
||||
request_dir = os.path.join(settings.upload_dir, request_id)
|
||||
os.makedirs(request_dir, exist_ok=True)
|
||||
return request_dir
|
||||
|
||||
|
||||
async def save_upload(file: UploadFile, directory: str) -> str:
|
||||
"""Save uploaded file to directory."""
|
||||
ext = os.path.splitext(file.filename or '')[1] or '.jpg'
|
||||
filename = f'{uuid.uuid4()}{ext}'
|
||||
filepath = os.path.join(directory, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
while chunk := await file.read(1024 * 1024):
|
||||
f.write(chunk)
|
||||
|
||||
logger.debug(f'File saved: {filepath}')
|
||||
return filepath
|
||||
|
||||
|
||||
async def save_uploads(files: List[UploadFile], directory: str) -> List[str]:
|
||||
"""Save multiple uploaded files to directory."""
|
||||
paths = []
|
||||
for file in files:
|
||||
path = await save_upload(file, directory)
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
|
||||
def generate_output_path(input_path: str, suffix: str = '_upscaled') -> str:
|
||||
"""Generate output path for processed image."""
|
||||
base, ext = os.path.splitext(input_path)
|
||||
name = os.path.basename(base)
|
||||
filename = f'{name}{suffix}{ext}'
|
||||
return os.path.join(settings.output_dir, filename)
|
||||
|
||||
|
||||
def cleanup_directory(directory: str) -> None:
|
||||
"""Remove directory and all contents."""
|
||||
if os.path.isdir(directory):
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
logger.debug(f'Cleaned up directory: {directory}')
|
||||
|
||||
|
||||
def cleanup_file(filepath: str) -> None:
|
||||
"""Remove a file."""
|
||||
if os.path.isfile(filepath):
|
||||
os.remove(filepath)
|
||||
logger.debug(f'Cleaned up file: {filepath}')
|
||||
|
||||
|
||||
def get_directory_size_mb(directory: str) -> float:
|
||||
"""Get total size of directory in MB."""
|
||||
total = 0
|
||||
for dirpath, dirnames, filenames in os.walk(directory):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
if os.path.exists(fp):
|
||||
total += os.path.getsize(fp)
|
||||
return total / (1024 * 1024)
|
||||
|
||||
|
||||
def list_model_files() -> List[Tuple[str, str, int]]:
|
||||
"""Return list of (name, path, size_bytes) for all .pth/.onnx files in models dir."""
|
||||
models = []
|
||||
models_dir = settings.models_dir
|
||||
if not os.path.isdir(models_dir):
|
||||
return models
|
||||
|
||||
for name in sorted(os.listdir(models_dir)):
|
||||
if name.endswith(('.pth', '.onnx', '.pt', '.safetensors')):
|
||||
path = os.path.join(models_dir, name)
|
||||
try:
|
||||
size = os.path.getsize(path)
|
||||
models.append((name, path, size))
|
||||
except OSError:
|
||||
logger.warning(f'Could not get size of model: {path}')
|
||||
return models
|
||||
|
||||
|
||||
def cleanup_old_jobs(hours: int = 24) -> int:
|
||||
"""Clean up old job directories (older than specified hours)."""
|
||||
import time
|
||||
cutoff_time = time.time() - (hours * 3600)
|
||||
cleaned = 0
|
||||
|
||||
if not os.path.isdir(settings.jobs_dir):
|
||||
return cleaned
|
||||
|
||||
for item in os.listdir(settings.jobs_dir):
|
||||
item_path = os.path.join(settings.jobs_dir, item)
|
||||
if os.path.isdir(item_path):
|
||||
try:
|
||||
if os.path.getmtime(item_path) < cutoff_time:
|
||||
cleanup_directory(item_path)
|
||||
cleaned += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f'Cleaned up {cleaned} old job directories')
|
||||
return cleaned
|
||||
154
app/services/model_manager.py
Normal file
154
app/services/model_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model management utilities."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.services import file_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known models for Easy Real-ESRGAN
|
||||
KNOWN_MODELS = {
|
||||
'RealESRGAN_x2plus': {
|
||||
'scale': 2,
|
||||
'description': '2x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
'size_mb': 66.7,
|
||||
},
|
||||
'RealESRGAN_x3plus': {
|
||||
'scale': 3,
|
||||
'description': '3x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (general purpose)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus_anime_6B': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (anime/art)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth',
|
||||
'size_mb': 18.9,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> List[Dict]:
|
||||
"""Get list of available models with details."""
|
||||
models = []
|
||||
|
||||
for model_name, metadata in KNOWN_MODELS.items():
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
available = os.path.exists(model_path)
|
||||
|
||||
size_mb = metadata['size_mb']
|
||||
if available:
|
||||
try:
|
||||
actual_size = os.path.getsize(model_path) / (1024 * 1024)
|
||||
size_mb = actual_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'scale': metadata['scale'],
|
||||
'description': metadata['description'],
|
||||
'available': available,
|
||||
'size_mb': size_mb,
|
||||
'size_bytes': int(size_mb * 1024 * 1024),
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_model_available(model_name: str) -> bool:
|
||||
"""Check if a model is available locally."""
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
return os.path.exists(model_path)
|
||||
|
||||
|
||||
def get_model_scale(model_name: str) -> Optional[int]:
|
||||
"""Get upscaling factor for a model."""
|
||||
if model_name in KNOWN_MODELS:
|
||||
return KNOWN_MODELS[model_name]['scale']
|
||||
|
||||
# Try to infer from model name
|
||||
if 'x2' in model_name.lower():
|
||||
return 2
|
||||
elif 'x3' in model_name.lower():
|
||||
return 3
|
||||
elif 'x4' in model_name.lower():
|
||||
return 4
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def download_model(model_name: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download a model.
|
||||
|
||||
Returns: (success, message)
|
||||
"""
|
||||
if model_name not in KNOWN_MODELS:
|
||||
return False, f'Unknown model: {model_name}'
|
||||
|
||||
if is_model_available(model_name):
|
||||
return True, f'Model already available: {model_name}'
|
||||
|
||||
metadata = KNOWN_MODELS[model_name]
|
||||
url = metadata['url']
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
try:
|
||||
logger.info(f'Downloading model: {model_name} from {url}')
|
||||
|
||||
import urllib.request
|
||||
os.makedirs(settings.models_dir, exist_ok=True)
|
||||
|
||||
def download_progress(count, block_size, total_size):
|
||||
downloaded = count * block_size
|
||||
percent = min(downloaded * 100 / total_size, 100)
|
||||
logger.debug(f'Download progress: {percent:.1f}%')
|
||||
|
||||
urllib.request.urlretrieve(url, model_path, download_progress)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
size_mb = os.path.getsize(model_path) / (1024 * 1024)
|
||||
logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)')
|
||||
return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)'
|
||||
else:
|
||||
return False, f'Failed to download model: {model_name}'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
|
||||
return False, f'Download failed: {str(e)}'
|
||||
|
||||
|
||||
async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]:
|
||||
"""Download multiple models."""
|
||||
results = {}
|
||||
|
||||
for model_name in model_names:
|
||||
success, message = await download_model(model_name)
|
||||
results[model_name] = (success, message)
|
||||
logger.info(f'Download result for {model_name}: {success} - {message}')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_models_directory_info() -> Dict:
|
||||
"""Get information about the models directory."""
|
||||
models_dir = settings.models_dir
|
||||
|
||||
return {
|
||||
'path': models_dir,
|
||||
'size_mb': file_manager.get_directory_size_mb(models_dir),
|
||||
'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')])
|
||||
if os.path.isdir(models_dir) else 0,
|
||||
}
|
||||
200
app/services/realesrgan_bridge.py
Normal file
200
app/services/realesrgan_bridge.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Real-ESRGAN model management and processing."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
REALESRGAN_AVAILABLE = True
|
||||
except ImportError:
|
||||
REALESRGAN_AVAILABLE = False
|
||||
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
|
||||
|
||||
|
||||
class RealESRGANBridge:
|
||||
"""Bridge to Real-ESRGAN functionality."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Real-ESRGAN bridge."""
|
||||
self.upsampler: Optional[RealESRGANer] = None
|
||||
self.current_model: Optional[str] = None
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize Real-ESRGAN upsampler."""
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN library not available')
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info('Initializing Real-ESRGAN upsampler...')
|
||||
|
||||
# Setup model loader
|
||||
scale = 4
|
||||
model_name = settings.default_model
|
||||
|
||||
# Determine model path
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
|
||||
|
||||
# Load model
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path if os.path.exists(model_path) else None,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
self.initialized = True
|
||||
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def load_model(self, model_name: str) -> bool:
|
||||
"""Load a specific upscaling model."""
|
||||
try:
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN not available')
|
||||
return False
|
||||
|
||||
logger.info(f'Loading model: {model_name}')
|
||||
|
||||
# Extract scale from model name
|
||||
scale = 4
|
||||
if 'x2' in model_name.lower():
|
||||
scale = 2
|
||||
elif 'x3' in model_name.lower():
|
||||
scale = 3
|
||||
elif 'x4' in model_name.lower():
|
||||
scale = 4
|
||||
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f'Model file not found: {model_path}')
|
||||
return False
|
||||
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
logger.info(f'Model loaded: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def upscale(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model_name: Optional[str] = None,
|
||||
outscale: Optional[float] = None,
|
||||
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
|
||||
"""
|
||||
Upscale an image.
|
||||
|
||||
Returns: (success, message, output_size)
|
||||
"""
|
||||
try:
|
||||
if not self.initialized:
|
||||
if not self.initialize():
|
||||
return False, 'Failed to initialize Real-ESRGAN', None
|
||||
|
||||
if model_name and model_name != self.current_model:
|
||||
if not self.load_model(model_name):
|
||||
return False, f'Failed to load model: {model_name}', None
|
||||
|
||||
if not self.upsampler:
|
||||
return False, 'Upsampler not initialized', None
|
||||
|
||||
# Read image
|
||||
logger.info(f'Reading image: {input_path}')
|
||||
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
||||
if input_img is None:
|
||||
return False, f'Failed to read image: {input_path}', None
|
||||
|
||||
input_shape = input_img.shape[:2]
|
||||
logger.info(f'Input image shape: {input_shape}')
|
||||
|
||||
# Upscale
|
||||
logger.info(f'Upscaling with model: {self.current_model}')
|
||||
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
|
||||
|
||||
# Save output
|
||||
cv2.imwrite(str(output_path), output)
|
||||
output_shape = output.shape[:2]
|
||||
logger.info(f'Output image shape: {output_shape}')
|
||||
logger.info(f'Upscaled image saved: {output_path}')
|
||||
|
||||
return True, 'Upscaling completed successfully', tuple(output_shape)
|
||||
except Exception as e:
|
||||
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
||||
return False, f'Upscaling failed: {str(e)}', None
|
||||
|
||||
def get_upscale_factor(self) -> int:
|
||||
"""Get current upscaling factor."""
|
||||
if self.upsampler:
|
||||
return self.upsampler.scale
|
||||
return 4
|
||||
|
||||
def clear_memory(self) -> None:
|
||||
"""Clear GPU memory if available."""
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug('GPU memory cleared')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global instance
|
||||
_bridge: Optional[RealESRGANBridge] = None
|
||||
|
||||
|
||||
def get_bridge() -> RealESRGANBridge:
|
||||
"""Get or create the global Real-ESRGAN bridge."""
|
||||
global _bridge
|
||||
if _bridge is None:
|
||||
_bridge = RealESRGANBridge()
|
||||
return _bridge
|
||||
217
app/services/worker.py
Normal file
217
app/services/worker.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Async job worker queue."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Async job data."""
|
||||
job_id: str
|
||||
status: str # queued, processing, completed, failed
|
||||
input_path: str
|
||||
output_path: str
|
||||
model: str
|
||||
tile_size: Optional[int] = None
|
||||
tile_pad: Optional[int] = None
|
||||
outscale: Optional[float] = None
|
||||
created_at: str = ''
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
processing_time_seconds: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.created_at:
|
||||
self.created_at = datetime.utcnow().isoformat()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
def save_metadata(self) -> None:
|
||||
"""Save job metadata to JSON file."""
|
||||
job_dir = os.path.join(settings.jobs_dir, self.job_id)
|
||||
os.makedirs(job_dir, exist_ok=True)
|
||||
|
||||
metadata_path = os.path.join(job_dir, 'metadata.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load_metadata(cls, job_id: str) -> Optional['Job']:
|
||||
"""Load job metadata from JSON file."""
|
||||
metadata_path = os.path.join(settings.jobs_dir, job_id, 'metadata.json')
|
||||
if not os.path.exists(metadata_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load job metadata: {e}')
|
||||
return None
|
||||
|
||||
|
||||
class WorkerQueue:
|
||||
"""Thread pool worker queue for processing jobs."""
|
||||
|
||||
def __init__(self, worker_func: Callable, num_workers: int = 2):
|
||||
"""
|
||||
Initialize worker queue.
|
||||
|
||||
Args:
|
||||
worker_func: Function to process jobs (job: Job) -> None
|
||||
num_workers: Number of worker threads
|
||||
"""
|
||||
self.queue: Queue = Queue()
|
||||
self.worker_func = worker_func
|
||||
self.num_workers = num_workers
|
||||
self.workers = []
|
||||
self.running = False
|
||||
self.jobs: Dict[str, Job] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start worker threads."""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
for i in range(self.num_workers):
|
||||
worker = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.info(f'Started {self.num_workers} worker threads')
|
||||
|
||||
def stop(self, timeout: int = 10) -> None:
|
||||
"""Stop worker threads gracefully."""
|
||||
self.running = False
|
||||
|
||||
# Signal workers to stop
|
||||
for _ in range(self.num_workers):
|
||||
self.queue.put(None)
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=timeout)
|
||||
|
||||
logger.info('Worker threads stopped')
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model: str,
|
||||
tile_size: Optional[int] = None,
|
||||
tile_pad: Optional[int] = None,
|
||||
outscale: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Submit a job for processing.
|
||||
|
||||
Returns: job_id
|
||||
"""
|
||||
job_id = str(uuid.uuid4())
|
||||
job = Job(
|
||||
job_id=job_id,
|
||||
status='queued',
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
self.queue.put(job)
|
||||
logger.info(f'Job submitted: {job_id}')
|
||||
|
||||
return job_id
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID."""
|
||||
with self.lock:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
def get_all_jobs(self) -> Dict[str, Job]:
|
||||
"""Get all jobs."""
|
||||
with self.lock:
|
||||
return dict(self.jobs)
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker thread main loop."""
|
||||
logger.info(f'Worker thread started')
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
job = self.queue.get(timeout=1)
|
||||
if job is None: # Stop signal
|
||||
break
|
||||
|
||||
self._process_job(job)
|
||||
except Exception:
|
||||
pass # Timeout is normal
|
||||
|
||||
logger.info(f'Worker thread stopped')
|
||||
|
||||
def _process_job(self, job: Job) -> None:
|
||||
"""Process a single job."""
|
||||
try:
|
||||
with self.lock:
|
||||
job.status = 'processing'
|
||||
job.started_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
|
||||
start_time = time.time()
|
||||
self.worker_func(job)
|
||||
job.processing_time_seconds = time.time() - start_time
|
||||
|
||||
with self.lock:
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
logger.info(f'Job completed: {job.job_id} ({job.processing_time_seconds:.2f}s)')
|
||||
except Exception as e:
|
||||
logger.error(f'Job failed: {job.job_id}: {e}', exc_info=True)
|
||||
with self.lock:
|
||||
job.status = 'failed'
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
|
||||
|
||||
# Global instance
|
||||
_worker_queue: Optional[WorkerQueue] = None
|
||||
|
||||
|
||||
def get_worker_queue(worker_func: Callable = None, num_workers: int = 2) -> WorkerQueue:
|
||||
"""Get or create the global worker queue."""
|
||||
global _worker_queue
|
||||
if _worker_queue is None:
|
||||
if worker_func is None:
|
||||
raise ValueError('worker_func required for first initialization')
|
||||
_worker_queue = WorkerQueue(worker_func, num_workers)
|
||||
return _worker_queue
|
||||
Reference in New Issue
Block a user