266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
"""Upscaling endpoints."""
|
|
import json
|
|
import logging
|
|
from time import time
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
|
from fastapi.responses import FileResponse
|
|
|
|
from app.schemas.upscale import UpscaleOptions
|
|
from app.services import file_manager, realesrgan_bridge, worker
|
|
from app.services.model_manager import is_model_available
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix='/api/v1', tags=['upscaling'])
|
|
|
|
|
|
@router.post('/upscale')
|
|
async def upscale_sync(
|
|
image: UploadFile = File(...),
|
|
model: str = Form('RealESRGAN_x4plus'),
|
|
tile_size: Optional[int] = Form(None),
|
|
tile_pad: Optional[int] = Form(None),
|
|
outscale: Optional[float] = Form(None),
|
|
):
|
|
"""
|
|
Synchronous image upscaling.
|
|
|
|
Upscales an image using Real-ESRGAN and returns the result directly.
|
|
Suitable for small to medium images.
|
|
"""
|
|
request_dir = file_manager.create_request_dir()
|
|
|
|
try:
|
|
# Validate model
|
|
if not is_model_available(model):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f'Model not available: {model}. Download it first using /api/v1/models/download'
|
|
)
|
|
|
|
# Save upload
|
|
input_path = await file_manager.save_upload(image, request_dir)
|
|
output_path = file_manager.generate_output_path(input_path)
|
|
|
|
# Process
|
|
start_time = time()
|
|
bridge = realesrgan_bridge.get_bridge()
|
|
success, message, output_size = bridge.upscale(
|
|
input_path=input_path,
|
|
output_path=output_path,
|
|
model_name=model,
|
|
outscale=outscale,
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail=message)
|
|
|
|
processing_time = time() - start_time
|
|
logger.info(f'Sync upscaling completed in {processing_time:.2f}s')
|
|
|
|
return FileResponse(
|
|
path=output_path,
|
|
media_type='application/octet-stream',
|
|
filename=image.filename,
|
|
headers={'X-Processing-Time': f'{processing_time:.2f}'},
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
finally:
|
|
file_manager.cleanup_directory(request_dir)
|
|
|
|
|
|
@router.post('/upscale-batch')
|
|
async def upscale_batch(
|
|
images: List[UploadFile] = File(...),
|
|
model: str = Form('RealESRGAN_x4plus'),
|
|
tile_size: Optional[int] = Form(None),
|
|
tile_pad: Optional[int] = Form(None),
|
|
):
|
|
"""
|
|
Batch upscaling via async jobs.
|
|
|
|
Submit multiple images for upscaling. Returns job IDs for monitoring.
|
|
"""
|
|
if not images:
|
|
raise HTTPException(status_code=400, detail='No images provided')
|
|
|
|
if len(images) > 100:
|
|
raise HTTPException(status_code=400, detail='Maximum 100 images per request')
|
|
|
|
if not is_model_available(model):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f'Model not available: {model}'
|
|
)
|
|
|
|
request_dir = file_manager.create_request_dir()
|
|
job_ids = []
|
|
|
|
try:
|
|
# Save all images
|
|
input_paths = await file_manager.save_uploads(images, request_dir)
|
|
wq = worker.get_worker_queue()
|
|
|
|
# Submit jobs
|
|
for input_path in input_paths:
|
|
output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}')
|
|
job_id = wq.submit_job(
|
|
input_path=input_path,
|
|
output_path=output_path,
|
|
model=model,
|
|
tile_size=tile_size,
|
|
tile_pad=tile_pad,
|
|
)
|
|
job_ids.append(job_id)
|
|
|
|
logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs')
|
|
|
|
return {
|
|
'success': True,
|
|
'job_ids': job_ids,
|
|
'total': len(job_ids),
|
|
'message': f'Batch processing started for {len(job_ids)} images',
|
|
}
|
|
except Exception as e:
|
|
logger.error(f'Batch submission failed: {e}', exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post('/jobs')
|
|
async def create_job(
|
|
image: UploadFile = File(...),
|
|
model: str = Form('RealESRGAN_x4plus'),
|
|
tile_size: Optional[int] = Form(None),
|
|
tile_pad: Optional[int] = Form(None),
|
|
outscale: Optional[float] = Form(None),
|
|
):
|
|
"""
|
|
Create an async upscaling job.
|
|
|
|
Submit a single image for asynchronous upscaling.
|
|
Use /api/v1/jobs/{job_id} to check status and download result.
|
|
"""
|
|
if not is_model_available(model):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f'Model not available: {model}'
|
|
)
|
|
|
|
request_dir = file_manager.create_request_dir()
|
|
|
|
try:
|
|
# Save upload
|
|
input_path = await file_manager.save_upload(image, request_dir)
|
|
output_path = file_manager.generate_output_path(input_path)
|
|
|
|
# Submit job
|
|
wq = worker.get_worker_queue()
|
|
job_id = wq.submit_job(
|
|
input_path=input_path,
|
|
output_path=output_path,
|
|
model=model,
|
|
tile_size=tile_size,
|
|
tile_pad=tile_pad,
|
|
outscale=outscale,
|
|
)
|
|
|
|
return {
|
|
'success': True,
|
|
'job_id': job_id,
|
|
'status_url': f'/api/v1/jobs/{job_id}',
|
|
'result_url': f'/api/v1/jobs/{job_id}/result',
|
|
}
|
|
except Exception as e:
|
|
logger.error(f'Job creation failed: {e}', exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get('/jobs/{job_id}')
|
|
async def get_job_status(job_id: str):
|
|
"""Get status of an upscaling job."""
|
|
wq = worker.get_worker_queue()
|
|
job = wq.get_job(job_id)
|
|
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
|
|
|
return {
|
|
'job_id': job.job_id,
|
|
'status': job.status,
|
|
'model': job.model,
|
|
'created_at': job.created_at,
|
|
'started_at': job.started_at,
|
|
'completed_at': job.completed_at,
|
|
'processing_time_seconds': job.processing_time_seconds,
|
|
'error': job.error,
|
|
}
|
|
|
|
|
|
@router.get('/jobs/{job_id}/result')
|
|
async def get_job_result(job_id: str):
|
|
"""Download result of a completed upscaling job."""
|
|
wq = worker.get_worker_queue()
|
|
job = wq.get_job(job_id)
|
|
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
|
|
|
if job.status == 'queued' or job.status == 'processing':
|
|
raise HTTPException(
|
|
status_code=202,
|
|
detail=f'Job is still processing: {job.status}'
|
|
)
|
|
|
|
if job.status == 'failed':
|
|
raise HTTPException(status_code=500, detail=f'Job failed: {job.error}')
|
|
|
|
if job.status != 'completed':
|
|
raise HTTPException(status_code=400, detail=f'Job status: {job.status}')
|
|
|
|
if not job.output_path or not __import__('os').path.exists(job.output_path):
|
|
raise HTTPException(status_code=404, detail='Result file not found')
|
|
|
|
return FileResponse(
|
|
path=job.output_path,
|
|
media_type='application/octet-stream',
|
|
filename=f'upscaled_{job_id}.png',
|
|
)
|
|
|
|
|
|
@router.get('/jobs')
|
|
async def list_jobs(
|
|
status: Optional[str] = None,
|
|
limit: int = 100,
|
|
):
|
|
"""List all jobs, optionally filtered by status."""
|
|
wq = worker.get_worker_queue()
|
|
all_jobs = wq.get_all_jobs()
|
|
|
|
jobs = []
|
|
for job in all_jobs.values():
|
|
if status and job.status != status:
|
|
continue
|
|
jobs.append({
|
|
'job_id': job.job_id,
|
|
'status': job.status,
|
|
'model': job.model,
|
|
'created_at': job.created_at,
|
|
'processing_time_seconds': job.processing_time_seconds,
|
|
})
|
|
|
|
# Sort by creation time (newest first) and limit
|
|
jobs.sort(key=lambda x: x['created_at'], reverse=True)
|
|
jobs = jobs[:limit]
|
|
|
|
return {
|
|
'total': len(all_jobs),
|
|
'returned': len(jobs),
|
|
'jobs': jobs,
|
|
}
|