"""Upscaling endpoints.""" import json import logging from time import time from typing import List, Optional from fastapi import APIRouter, File, Form, HTTPException, UploadFile from fastapi.responses import FileResponse from app.schemas.upscale import UpscaleOptions from app.services import file_manager, realesrgan_bridge, worker from app.services.model_manager import is_model_available logger = logging.getLogger(__name__) router = APIRouter(prefix='/api/v1', tags=['upscaling']) @router.post('/upscale') async def upscale_sync( image: UploadFile = File(...), model: str = Form('RealESRGAN_x4plus'), tile_size: Optional[int] = Form(None), tile_pad: Optional[int] = Form(None), outscale: Optional[float] = Form(None), ): """ Synchronous image upscaling. Upscales an image using Real-ESRGAN and returns the result directly. Suitable for small to medium images. """ request_dir = file_manager.create_request_dir() try: # Validate model if not is_model_available(model): raise HTTPException( status_code=400, detail=f'Model not available: {model}. Download it first using /api/v1/models/download' ) # Save upload input_path = await file_manager.save_upload(image, request_dir) output_path = file_manager.generate_output_path(input_path) # Process start_time = time() bridge = realesrgan_bridge.get_bridge() success, message, output_size = bridge.upscale( input_path=input_path, output_path=output_path, model_name=model, outscale=outscale, ) if not success: raise HTTPException(status_code=500, detail=message) processing_time = time() - start_time logger.info(f'Sync upscaling completed in {processing_time:.2f}s') return FileResponse( path=output_path, media_type='application/octet-stream', filename=image.filename, headers={'X-Processing-Time': f'{processing_time:.2f}'}, ) except HTTPException: raise except Exception as e: logger.error(f'Upscaling failed: {e}', exc_info=True) raise HTTPException(status_code=500, detail=str(e)) finally: file_manager.cleanup_directory(request_dir) @router.post('/upscale-batch') async def upscale_batch( images: List[UploadFile] = File(...), model: str = Form('RealESRGAN_x4plus'), tile_size: Optional[int] = Form(None), tile_pad: Optional[int] = Form(None), ): """ Batch upscaling via async jobs. Submit multiple images for upscaling. Returns job IDs for monitoring. """ if not images: raise HTTPException(status_code=400, detail='No images provided') if len(images) > 100: raise HTTPException(status_code=400, detail='Maximum 100 images per request') if not is_model_available(model): raise HTTPException( status_code=400, detail=f'Model not available: {model}' ) request_dir = file_manager.create_request_dir() job_ids = [] try: # Save all images input_paths = await file_manager.save_uploads(images, request_dir) wq = worker.get_worker_queue() # Submit jobs for input_path in input_paths: output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}') job_id = wq.submit_job( input_path=input_path, output_path=output_path, model=model, tile_size=tile_size, tile_pad=tile_pad, ) job_ids.append(job_id) logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs') return { 'success': True, 'job_ids': job_ids, 'total': len(job_ids), 'message': f'Batch processing started for {len(job_ids)} images', } except Exception as e: logger.error(f'Batch submission failed: {e}', exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post('/jobs') async def create_job( image: UploadFile = File(...), model: str = Form('RealESRGAN_x4plus'), tile_size: Optional[int] = Form(None), tile_pad: Optional[int] = Form(None), outscale: Optional[float] = Form(None), ): """ Create an async upscaling job. Submit a single image for asynchronous upscaling. Use /api/v1/jobs/{job_id} to check status and download result. """ if not is_model_available(model): raise HTTPException( status_code=400, detail=f'Model not available: {model}' ) request_dir = file_manager.create_request_dir() try: # Save upload input_path = await file_manager.save_upload(image, request_dir) output_path = file_manager.generate_output_path(input_path) # Submit job wq = worker.get_worker_queue() job_id = wq.submit_job( input_path=input_path, output_path=output_path, model=model, tile_size=tile_size, tile_pad=tile_pad, outscale=outscale, ) return { 'success': True, 'job_id': job_id, 'status_url': f'/api/v1/jobs/{job_id}', 'result_url': f'/api/v1/jobs/{job_id}/result', } except Exception as e: logger.error(f'Job creation failed: {e}', exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get('/jobs/{job_id}') async def get_job_status(job_id: str): """Get status of an upscaling job.""" wq = worker.get_worker_queue() job = wq.get_job(job_id) if not job: raise HTTPException(status_code=404, detail=f'Job not found: {job_id}') return { 'job_id': job.job_id, 'status': job.status, 'model': job.model, 'created_at': job.created_at, 'started_at': job.started_at, 'completed_at': job.completed_at, 'processing_time_seconds': job.processing_time_seconds, 'error': job.error, } @router.get('/jobs/{job_id}/result') async def get_job_result(job_id: str): """Download result of a completed upscaling job.""" wq = worker.get_worker_queue() job = wq.get_job(job_id) if not job: raise HTTPException(status_code=404, detail=f'Job not found: {job_id}') if job.status == 'queued' or job.status == 'processing': raise HTTPException( status_code=202, detail=f'Job is still processing: {job.status}' ) if job.status == 'failed': raise HTTPException(status_code=500, detail=f'Job failed: {job.error}') if job.status != 'completed': raise HTTPException(status_code=400, detail=f'Job status: {job.status}') if not job.output_path or not __import__('os').path.exists(job.output_path): raise HTTPException(status_code=404, detail='Result file not found') return FileResponse( path=job.output_path, media_type='application/octet-stream', filename=f'upscaled_{job_id}.png', ) @router.get('/jobs') async def list_jobs( status: Optional[str] = None, limit: int = 100, ): """List all jobs, optionally filtered by status.""" wq = worker.get_worker_queue() all_jobs = wq.get_all_jobs() jobs = [] for job in all_jobs.values(): if status and job.status != status: continue jobs.append({ 'job_id': job.job_id, 'status': job.status, 'model': job.model, 'created_at': job.created_at, 'processing_time_seconds': job.processing_time_seconds, }) # Sort by creation time (newest first) and limit jobs.sort(key=lambda x: x['created_at'], reverse=True) jobs = jobs[:limit] return { 'total': len(all_jobs), 'returned': len(jobs), 'jobs': jobs, }