260 lines
7.9 KiB
Python
260 lines
7.9 KiB
Python
"""
|
|
Real-ESRGAN API Python Client
|
|
|
|
Example usage:
|
|
client = RealESRGANClient('http://localhost:8000')
|
|
|
|
# Synchronous upscaling
|
|
result = await client.upscale_sync('input.jpg', 'RealESRGAN_x4plus', 'output.jpg')
|
|
|
|
# Asynchronous job
|
|
job_id = await client.create_job('input.jpg', 'RealESRGAN_x4plus')
|
|
while True:
|
|
status = await client.get_job_status(job_id)
|
|
if status['status'] == 'completed':
|
|
await client.download_result(job_id, 'output.jpg')
|
|
break
|
|
await asyncio.sleep(5)
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RealESRGANClient:
|
|
"""Python client for Real-ESRGAN API."""
|
|
|
|
def __init__(self, base_url: str = 'http://localhost:8000', timeout: float = 300):
|
|
"""
|
|
Initialize client.
|
|
|
|
Args:
|
|
base_url: API base URL
|
|
timeout: Request timeout in seconds
|
|
"""
|
|
self.base_url = base_url.rstrip('/')
|
|
self.timeout = timeout
|
|
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=timeout)
|
|
|
|
async def close(self):
|
|
"""Close HTTP client."""
|
|
await self.client.aclose()
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self.close()
|
|
|
|
async def health_check(self) -> dict:
|
|
"""Check API health."""
|
|
response = await self.client.get('/api/v1/health')
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def get_system_info(self) -> dict:
|
|
"""Get system information."""
|
|
response = await self.client.get('/api/v1/system')
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def list_models(self) -> dict:
|
|
"""List available models."""
|
|
response = await self.client.get('/api/v1/models')
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def download_models(self, model_names: list[str]) -> dict:
|
|
"""Download models."""
|
|
response = await self.client.post(
|
|
'/api/v1/models/download',
|
|
json={'models': model_names},
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def upscale_sync(
|
|
self,
|
|
input_path: str,
|
|
model: str = 'RealESRGAN_x4plus',
|
|
output_path: Optional[str] = None,
|
|
tile_size: Optional[int] = None,
|
|
) -> dict:
|
|
"""
|
|
Synchronous upscaling (streaming response).
|
|
|
|
Args:
|
|
input_path: Path to input image
|
|
model: Model name to use
|
|
output_path: Where to save output (if None, returns dict)
|
|
tile_size: Optional tile size override
|
|
|
|
Returns:
|
|
Dictionary with success, processing_time, etc.
|
|
"""
|
|
input_file = Path(input_path)
|
|
if not input_file.exists():
|
|
raise FileNotFoundError(f'Input file not found: {input_path}')
|
|
|
|
files = {'image': input_file.open('rb')}
|
|
data = {'model': model}
|
|
if tile_size is not None:
|
|
data['tile_size'] = tile_size
|
|
|
|
try:
|
|
response = await self.client.post(
|
|
'/api/v1/upscale',
|
|
files=files,
|
|
data=data,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
if output_path:
|
|
Path(output_path).write_bytes(response.content)
|
|
return {
|
|
'success': True,
|
|
'output_path': output_path,
|
|
'processing_time': float(response.headers.get('X-Processing-Time', 0)),
|
|
}
|
|
else:
|
|
return {
|
|
'success': True,
|
|
'content': response.content,
|
|
'processing_time': float(response.headers.get('X-Processing-Time', 0)),
|
|
}
|
|
finally:
|
|
files['image'].close()
|
|
|
|
async def create_job(
|
|
self,
|
|
input_path: str,
|
|
model: str = 'RealESRGAN_x4plus',
|
|
tile_size: Optional[int] = None,
|
|
outscale: Optional[float] = None,
|
|
) -> str:
|
|
"""
|
|
Create asynchronous upscaling job.
|
|
|
|
Args:
|
|
input_path: Path to input image
|
|
model: Model name to use
|
|
tile_size: Optional tile size
|
|
outscale: Optional output scale
|
|
|
|
Returns:
|
|
Job ID
|
|
"""
|
|
input_file = Path(input_path)
|
|
if not input_file.exists():
|
|
raise FileNotFoundError(f'Input file not found: {input_path}')
|
|
|
|
files = {'image': input_file.open('rb')}
|
|
data = {'model': model}
|
|
if tile_size is not None:
|
|
data['tile_size'] = tile_size
|
|
if outscale is not None:
|
|
data['outscale'] = outscale
|
|
|
|
try:
|
|
response = await self.client.post(
|
|
'/api/v1/jobs',
|
|
files=files,
|
|
data=data,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()['job_id']
|
|
finally:
|
|
files['image'].close()
|
|
|
|
async def get_job_status(self, job_id: str) -> dict:
|
|
"""Get job status."""
|
|
response = await self.client.get(f'/api/v1/jobs/{job_id}')
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def download_result(self, job_id: str, output_path: str) -> bool:
|
|
"""Download job result."""
|
|
response = await self.client.get(f'/api/v1/jobs/{job_id}/result')
|
|
response.raise_for_status()
|
|
Path(output_path).write_bytes(response.content)
|
|
return True
|
|
|
|
async def wait_for_job(
|
|
self,
|
|
job_id: str,
|
|
poll_interval: float = 5,
|
|
max_wait: Optional[float] = None,
|
|
) -> dict:
|
|
"""
|
|
Wait for job to complete.
|
|
|
|
Args:
|
|
job_id: Job ID to wait for
|
|
poll_interval: Seconds between status checks
|
|
max_wait: Maximum seconds to wait (None = infinite)
|
|
|
|
Returns:
|
|
Final job status
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
while True:
|
|
status = await self.get_job_status(job_id)
|
|
|
|
if status['status'] in ('completed', 'failed'):
|
|
return status
|
|
|
|
if max_wait and (time.time() - start_time) > max_wait:
|
|
raise TimeoutError(f'Job {job_id} did not complete within {max_wait}s')
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
async def list_jobs(self, status: Optional[str] = None, limit: int = 100) -> dict:
|
|
"""List jobs."""
|
|
params = {'limit': limit}
|
|
if status:
|
|
params['status'] = status
|
|
|
|
response = await self.client.get('/api/v1/jobs', params=params)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def cleanup_jobs(self, hours: int = 24) -> dict:
|
|
"""Clean up old jobs."""
|
|
response = await self.client.post(f'/api/v1/cleanup?hours={hours}')
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
async def main():
|
|
"""Example usage."""
|
|
async with RealESRGANClient() as client:
|
|
# Check health
|
|
health = await client.health_check()
|
|
print(f'API Status: {health["status"]}')
|
|
|
|
# List available models
|
|
models = await client.list_models()
|
|
print(f'Available Models: {[m["name"] for m in models["available_models"]]}')
|
|
|
|
# Example: Synchronous upscaling
|
|
# result = await client.upscale_sync('input.jpg', output_path='output.jpg')
|
|
# print(f'Upscaled in {result["processing_time"]:.2f}s')
|
|
|
|
# Example: Asynchronous job
|
|
# job_id = await client.create_job('input.jpg')
|
|
# final_status = await client.wait_for_job(job_id)
|
|
# await client.download_result(job_id, 'output.jpg')
|
|
# print(f'Job completed: {final_status}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(level=logging.INFO)
|
|
asyncio.run(main())
|