Files
realesrgan-api/app/routers/upscale.py

266 lines
7.9 KiB
Python
Raw Normal View History

2026-02-16 19:56:25 +01:00
"""Upscaling endpoints."""
import json
import logging
from time import time
from typing import List, Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse
from app.schemas.upscale import UpscaleOptions
from app.services import file_manager, realesrgan_bridge, worker
from app.services.model_manager import is_model_available
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/v1', tags=['upscaling'])
@router.post('/upscale')
async def upscale_sync(
image: UploadFile = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
outscale: Optional[float] = Form(None),
):
"""
Synchronous image upscaling.
Upscales an image using Real-ESRGAN and returns the result directly.
Suitable for small to medium images.
"""
request_dir = file_manager.create_request_dir()
try:
# Validate model
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}. Download it first using /api/v1/models/download'
)
# Save upload
input_path = await file_manager.save_upload(image, request_dir)
output_path = file_manager.generate_output_path(input_path)
# Process
start_time = time()
bridge = realesrgan_bridge.get_bridge()
success, message, output_size = bridge.upscale(
input_path=input_path,
output_path=output_path,
model_name=model,
outscale=outscale,
)
if not success:
raise HTTPException(status_code=500, detail=message)
processing_time = time() - start_time
logger.info(f'Sync upscaling completed in {processing_time:.2f}s')
return FileResponse(
path=output_path,
media_type='application/octet-stream',
filename=image.filename,
headers={'X-Processing-Time': f'{processing_time:.2f}'},
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Upscaling failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
finally:
file_manager.cleanup_directory(request_dir)
@router.post('/upscale-batch')
async def upscale_batch(
images: List[UploadFile] = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
):
"""
Batch upscaling via async jobs.
Submit multiple images for upscaling. Returns job IDs for monitoring.
"""
if not images:
raise HTTPException(status_code=400, detail='No images provided')
if len(images) > 100:
raise HTTPException(status_code=400, detail='Maximum 100 images per request')
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}'
)
request_dir = file_manager.create_request_dir()
job_ids = []
try:
# Save all images
input_paths = await file_manager.save_uploads(images, request_dir)
wq = worker.get_worker_queue()
# Submit jobs
for input_path in input_paths:
output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}')
job_id = wq.submit_job(
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
)
job_ids.append(job_id)
logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs')
return {
'success': True,
'job_ids': job_ids,
'total': len(job_ids),
'message': f'Batch processing started for {len(job_ids)} images',
}
except Exception as e:
logger.error(f'Batch submission failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/jobs')
async def create_job(
image: UploadFile = File(...),
model: str = Form('RealESRGAN_x4plus'),
tile_size: Optional[int] = Form(None),
tile_pad: Optional[int] = Form(None),
outscale: Optional[float] = Form(None),
):
"""
Create an async upscaling job.
Submit a single image for asynchronous upscaling.
Use /api/v1/jobs/{job_id} to check status and download result.
"""
if not is_model_available(model):
raise HTTPException(
status_code=400,
detail=f'Model not available: {model}'
)
request_dir = file_manager.create_request_dir()
try:
# Save upload
input_path = await file_manager.save_upload(image, request_dir)
output_path = file_manager.generate_output_path(input_path)
# Submit job
wq = worker.get_worker_queue()
job_id = wq.submit_job(
input_path=input_path,
output_path=output_path,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
outscale=outscale,
)
return {
'success': True,
'job_id': job_id,
'status_url': f'/api/v1/jobs/{job_id}',
'result_url': f'/api/v1/jobs/{job_id}/result',
}
except Exception as e:
logger.error(f'Job creation failed: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/jobs/{job_id}')
async def get_job_status(job_id: str):
"""Get status of an upscaling job."""
wq = worker.get_worker_queue()
job = wq.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
return {
'job_id': job.job_id,
'status': job.status,
'model': job.model,
'created_at': job.created_at,
'started_at': job.started_at,
'completed_at': job.completed_at,
'processing_time_seconds': job.processing_time_seconds,
'error': job.error,
}
@router.get('/jobs/{job_id}/result')
async def get_job_result(job_id: str):
"""Download result of a completed upscaling job."""
wq = worker.get_worker_queue()
job = wq.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
if job.status == 'queued' or job.status == 'processing':
raise HTTPException(
status_code=202,
detail=f'Job is still processing: {job.status}'
)
if job.status == 'failed':
raise HTTPException(status_code=500, detail=f'Job failed: {job.error}')
if job.status != 'completed':
raise HTTPException(status_code=400, detail=f'Job status: {job.status}')
if not job.output_path or not __import__('os').path.exists(job.output_path):
raise HTTPException(status_code=404, detail='Result file not found')
return FileResponse(
path=job.output_path,
media_type='application/octet-stream',
filename=f'upscaled_{job_id}.png',
)
@router.get('/jobs')
async def list_jobs(
status: Optional[str] = None,
limit: int = 100,
):
"""List all jobs, optionally filtered by status."""
wq = worker.get_worker_queue()
all_jobs = wq.get_all_jobs()
jobs = []
for job in all_jobs.values():
if status and job.status != status:
continue
jobs.append({
'job_id': job.job_id,
'status': job.status,
'model': job.model,
'created_at': job.created_at,
'processing_time_seconds': job.processing_time_seconds,
})
# Sort by creation time (newest first) and limit
jobs.sort(key=lambda x: x['created_at'], reverse=True)
jobs = jobs[:limit]
return {
'total': len(all_jobs),
'returned': len(jobs),
'jobs': jobs,
}