Initial Real-ESRGAN API project setup
This commit is contained in:
1
app/routers/__init__.py
Normal file
1
app/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routers."""
|
||||
146
app/routers/health.py
Normal file
146
app/routers/health.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Health check and system information endpoints."""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.health import HealthResponse, RequestStats, SystemInfo
|
||||
from app.services import file_manager, worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['system'])
|
||||
|
||||
# Track uptime
|
||||
_start_time = time.time()
|
||||
|
||||
# Request statistics
|
||||
_stats = {
|
||||
'total_requests': 0,
|
||||
'successful_requests': 0,
|
||||
'failed_requests': 0,
|
||||
'total_processing_time': 0.0,
|
||||
'total_images_processed': 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get('/health')
|
||||
async def health_check() -> HealthResponse:
|
||||
"""API health check."""
|
||||
uptime = time.time() - _start_time
|
||||
return HealthResponse(
|
||||
status='ok',
|
||||
version='1.0.0',
|
||||
uptime_seconds=uptime,
|
||||
message='Real-ESRGAN API is running',
|
||||
)
|
||||
|
||||
|
||||
@router.get('/health/ready')
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe."""
|
||||
from app.services import realesrgan_bridge
|
||||
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
|
||||
if not bridge.initialized:
|
||||
raise HTTPException(status_code=503, detail='Not ready')
|
||||
|
||||
return {'ready': True}
|
||||
|
||||
|
||||
@router.get('/health/live')
|
||||
async def liveness_check():
|
||||
"""Kubernetes liveness probe."""
|
||||
return {'alive': True}
|
||||
|
||||
|
||||
@router.get('/system')
|
||||
async def get_system_info() -> SystemInfo:
|
||||
"""Get comprehensive system information."""
|
||||
try:
|
||||
# Uptime
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
# CPU and memory
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Disk
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_percent = disk.percent
|
||||
|
||||
# GPU
|
||||
gpu_available = False
|
||||
gpu_memory_mb = None
|
||||
gpu_memory_used_mb = None
|
||||
|
||||
try:
|
||||
import torch
|
||||
gpu_available = torch.cuda.is_available()
|
||||
if gpu_available:
|
||||
gpu_memory_mb = int(torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
gpu_memory_used_mb = int(torch.cuda.memory_allocated(0) / (1024 * 1024))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Models directory size
|
||||
models_size = file_manager.get_directory_size_mb(settings.models_dir)
|
||||
|
||||
# Jobs queue
|
||||
wq = worker.get_worker_queue()
|
||||
queue_length = wq.queue.qsize()
|
||||
|
||||
return SystemInfo(
|
||||
status='ok',
|
||||
version='1.0.0',
|
||||
uptime_seconds=uptime,
|
||||
cpu_usage_percent=cpu_percent,
|
||||
memory_usage_percent=memory_percent,
|
||||
disk_usage_percent=disk_percent,
|
||||
gpu_available=gpu_available,
|
||||
gpu_memory_mb=gpu_memory_mb,
|
||||
gpu_memory_used_mb=gpu_memory_used_mb,
|
||||
execution_providers=settings.get_execution_providers(),
|
||||
models_dir_size_mb=models_size,
|
||||
jobs_queue_length=queue_length,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get system info: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/stats')
|
||||
async def get_stats() -> RequestStats:
|
||||
"""Get request statistics."""
|
||||
avg_time = 0.0
|
||||
if _stats['successful_requests'] > 0:
|
||||
avg_time = _stats['total_processing_time'] / _stats['successful_requests']
|
||||
|
||||
return RequestStats(
|
||||
total_requests=_stats['total_requests'],
|
||||
successful_requests=_stats['successful_requests'],
|
||||
failed_requests=_stats['failed_requests'],
|
||||
average_processing_time_seconds=avg_time,
|
||||
total_images_processed=_stats['total_images_processed'],
|
||||
)
|
||||
|
||||
|
||||
@router.post('/cleanup')
|
||||
async def cleanup_old_jobs(hours: int = 24):
|
||||
"""Clean up old job directories."""
|
||||
try:
|
||||
cleaned = file_manager.cleanup_old_jobs(hours)
|
||||
return {
|
||||
'success': True,
|
||||
'cleaned_jobs': cleaned,
|
||||
'message': f'Cleaned up {cleaned} job directories older than {hours} hours',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Cleanup failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
109
app/routers/models.py
Normal file
109
app/routers/models.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Model management endpoints."""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.models import ModelDownloadRequest, ModelDownloadResponse, ModelListResponse
|
||||
from app.services import model_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['models'])
|
||||
|
||||
|
||||
@router.get('/models')
|
||||
async def list_models() -> ModelListResponse:
|
||||
"""List all available models."""
|
||||
try:
|
||||
available = model_manager.get_available_models()
|
||||
local_count = sum(1 for m in available if m['available'])
|
||||
|
||||
return ModelListResponse(
|
||||
available_models=available,
|
||||
total_models=len(available),
|
||||
local_models=local_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to list models: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post('/models/download')
|
||||
async def download_models(request: ModelDownloadRequest) -> ModelDownloadResponse:
|
||||
"""Download one or more models."""
|
||||
if not request.models:
|
||||
raise HTTPException(status_code=400, detail='No models specified')
|
||||
|
||||
try:
|
||||
logger.info(f'Downloading models: {request.models}')
|
||||
results = await model_manager.download_models(request.models)
|
||||
|
||||
downloaded = []
|
||||
failed = []
|
||||
errors = {}
|
||||
|
||||
for model_name, (success, message) in results.items():
|
||||
if success:
|
||||
downloaded.append(model_name)
|
||||
else:
|
||||
failed.append(model_name)
|
||||
errors[model_name] = message
|
||||
|
||||
return ModelDownloadResponse(
|
||||
success=len(failed) == 0,
|
||||
message=f'Downloaded {len(downloaded)} model(s)',
|
||||
downloaded=downloaded,
|
||||
failed=failed,
|
||||
errors=errors,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Model download failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/models/{model_name}')
|
||||
async def get_model_info(model_name: str):
|
||||
"""Get information about a specific model."""
|
||||
models = model_manager.get_available_models()
|
||||
|
||||
for model in models:
|
||||
if model['name'] == model_name:
|
||||
return model
|
||||
|
||||
raise HTTPException(status_code=404, detail=f'Model not found: {model_name}')
|
||||
|
||||
|
||||
@router.post('/models/{model_name}/download')
|
||||
async def download_model(model_name: str):
|
||||
"""Download a specific model."""
|
||||
try:
|
||||
success, message = await model_manager.download_model(model_name)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=message)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': message,
|
||||
'model': model_name,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/models-info')
|
||||
async def get_models_directory_info():
|
||||
"""Get information about the models directory."""
|
||||
try:
|
||||
info = model_manager.get_models_directory_info()
|
||||
return {
|
||||
'models_directory': info['path'],
|
||||
'total_size_mb': round(info['size_mb'], 2),
|
||||
'model_count': info['model_count'],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get models directory info: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
265
app/routers/upscale.py
Normal file
265
app/routers/upscale.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Upscaling endpoints."""
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.schemas.upscale import UpscaleOptions
|
||||
from app.services import file_manager, realesrgan_bridge, worker
|
||||
from app.services.model_manager import is_model_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/api/v1', tags=['upscaling'])
|
||||
|
||||
|
||||
@router.post('/upscale')
|
||||
async def upscale_sync(
|
||||
image: UploadFile = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
outscale: Optional[float] = Form(None),
|
||||
):
|
||||
"""
|
||||
Synchronous image upscaling.
|
||||
|
||||
Upscales an image using Real-ESRGAN and returns the result directly.
|
||||
Suitable for small to medium images.
|
||||
"""
|
||||
request_dir = file_manager.create_request_dir()
|
||||
|
||||
try:
|
||||
# Validate model
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}. Download it first using /api/v1/models/download'
|
||||
)
|
||||
|
||||
# Save upload
|
||||
input_path = await file_manager.save_upload(image, request_dir)
|
||||
output_path = file_manager.generate_output_path(input_path)
|
||||
|
||||
# Process
|
||||
start_time = time()
|
||||
bridge = realesrgan_bridge.get_bridge()
|
||||
success, message, output_size = bridge.upscale(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model_name=model,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=message)
|
||||
|
||||
processing_time = time() - start_time
|
||||
logger.info(f'Sync upscaling completed in {processing_time:.2f}s')
|
||||
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
media_type='application/octet-stream',
|
||||
filename=image.filename,
|
||||
headers={'X-Processing-Time': f'{processing_time:.2f}'},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
file_manager.cleanup_directory(request_dir)
|
||||
|
||||
|
||||
@router.post('/upscale-batch')
|
||||
async def upscale_batch(
|
||||
images: List[UploadFile] = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
):
|
||||
"""
|
||||
Batch upscaling via async jobs.
|
||||
|
||||
Submit multiple images for upscaling. Returns job IDs for monitoring.
|
||||
"""
|
||||
if not images:
|
||||
raise HTTPException(status_code=400, detail='No images provided')
|
||||
|
||||
if len(images) > 100:
|
||||
raise HTTPException(status_code=400, detail='Maximum 100 images per request')
|
||||
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}'
|
||||
)
|
||||
|
||||
request_dir = file_manager.create_request_dir()
|
||||
job_ids = []
|
||||
|
||||
try:
|
||||
# Save all images
|
||||
input_paths = await file_manager.save_uploads(images, request_dir)
|
||||
wq = worker.get_worker_queue()
|
||||
|
||||
# Submit jobs
|
||||
for input_path in input_paths:
|
||||
output_path = file_manager.generate_output_path(input_path, f'_upscaled_{model}')
|
||||
job_id = wq.submit_job(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
)
|
||||
job_ids.append(job_id)
|
||||
|
||||
logger.info(f'Submitted batch of {len(job_ids)} upscaling jobs')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'job_ids': job_ids,
|
||||
'total': len(job_ids),
|
||||
'message': f'Batch processing started for {len(job_ids)} images',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Batch submission failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post('/jobs')
|
||||
async def create_job(
|
||||
image: UploadFile = File(...),
|
||||
model: str = Form('RealESRGAN_x4plus'),
|
||||
tile_size: Optional[int] = Form(None),
|
||||
tile_pad: Optional[int] = Form(None),
|
||||
outscale: Optional[float] = Form(None),
|
||||
):
|
||||
"""
|
||||
Create an async upscaling job.
|
||||
|
||||
Submit a single image for asynchronous upscaling.
|
||||
Use /api/v1/jobs/{job_id} to check status and download result.
|
||||
"""
|
||||
if not is_model_available(model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Model not available: {model}'
|
||||
)
|
||||
|
||||
request_dir = file_manager.create_request_dir()
|
||||
|
||||
try:
|
||||
# Save upload
|
||||
input_path = await file_manager.save_upload(image, request_dir)
|
||||
output_path = file_manager.generate_output_path(input_path)
|
||||
|
||||
# Submit job
|
||||
wq = worker.get_worker_queue()
|
||||
job_id = wq.submit_job(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'job_id': job_id,
|
||||
'status_url': f'/api/v1/jobs/{job_id}',
|
||||
'result_url': f'/api/v1/jobs/{job_id}/result',
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f'Job creation failed: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/jobs/{job_id}')
|
||||
async def get_job_status(job_id: str):
|
||||
"""Get status of an upscaling job."""
|
||||
wq = worker.get_worker_queue()
|
||||
job = wq.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
||||
|
||||
return {
|
||||
'job_id': job.job_id,
|
||||
'status': job.status,
|
||||
'model': job.model,
|
||||
'created_at': job.created_at,
|
||||
'started_at': job.started_at,
|
||||
'completed_at': job.completed_at,
|
||||
'processing_time_seconds': job.processing_time_seconds,
|
||||
'error': job.error,
|
||||
}
|
||||
|
||||
|
||||
@router.get('/jobs/{job_id}/result')
|
||||
async def get_job_result(job_id: str):
|
||||
"""Download result of a completed upscaling job."""
|
||||
wq = worker.get_worker_queue()
|
||||
job = wq.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f'Job not found: {job_id}')
|
||||
|
||||
if job.status == 'queued' or job.status == 'processing':
|
||||
raise HTTPException(
|
||||
status_code=202,
|
||||
detail=f'Job is still processing: {job.status}'
|
||||
)
|
||||
|
||||
if job.status == 'failed':
|
||||
raise HTTPException(status_code=500, detail=f'Job failed: {job.error}')
|
||||
|
||||
if job.status != 'completed':
|
||||
raise HTTPException(status_code=400, detail=f'Job status: {job.status}')
|
||||
|
||||
if not job.output_path or not __import__('os').path.exists(job.output_path):
|
||||
raise HTTPException(status_code=404, detail='Result file not found')
|
||||
|
||||
return FileResponse(
|
||||
path=job.output_path,
|
||||
media_type='application/octet-stream',
|
||||
filename=f'upscaled_{job_id}.png',
|
||||
)
|
||||
|
||||
|
||||
@router.get('/jobs')
|
||||
async def list_jobs(
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
):
|
||||
"""List all jobs, optionally filtered by status."""
|
||||
wq = worker.get_worker_queue()
|
||||
all_jobs = wq.get_all_jobs()
|
||||
|
||||
jobs = []
|
||||
for job in all_jobs.values():
|
||||
if status and job.status != status:
|
||||
continue
|
||||
jobs.append({
|
||||
'job_id': job.job_id,
|
||||
'status': job.status,
|
||||
'model': job.model,
|
||||
'created_at': job.created_at,
|
||||
'processing_time_seconds': job.processing_time_seconds,
|
||||
})
|
||||
|
||||
# Sort by creation time (newest first) and limit
|
||||
jobs.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
jobs = jobs[:limit]
|
||||
|
||||
return {
|
||||
'total': len(all_jobs),
|
||||
'returned': len(jobs),
|
||||
'jobs': jobs,
|
||||
}
|
||||
Reference in New Issue
Block a user