192 lines
5.8 KiB
Python
192 lines
5.8 KiB
Python
"""
|
|
Test suite for Real-ESRGAN API.
|
|
|
|
Run with: pytest tests/
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from PIL import Image
|
|
|
|
from app.main import app
|
|
from app.services import file_manager, worker
|
|
|
|
# Test client
|
|
client = TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
"""Create temporary directory for tests."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_image(temp_dir):
|
|
"""Create a test image."""
|
|
img = Image.new('RGB', (256, 256), color=(73, 109, 137))
|
|
path = temp_dir / 'test.jpg'
|
|
img.save(path)
|
|
return path
|
|
|
|
|
|
class TestHealth:
|
|
def test_health_check(self):
|
|
"""Test health check endpoint."""
|
|
response = client.get('/api/v1/health')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data['status'] == 'ok'
|
|
assert data['version'] == '1.0.0'
|
|
assert 'uptime_seconds' in data
|
|
|
|
def test_liveness(self):
|
|
"""Test liveness probe."""
|
|
response = client.get('/api/v1/health/live')
|
|
assert response.status_code == 200
|
|
assert response.json()['alive'] is True
|
|
|
|
|
|
class TestModels:
|
|
def test_list_models(self):
|
|
"""Test listing models."""
|
|
response = client.get('/api/v1/models')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert 'available_models' in data
|
|
assert 'total_models' in data
|
|
assert 'local_models' in data
|
|
assert isinstance(data['available_models'], list)
|
|
|
|
def test_models_info(self):
|
|
"""Test models directory info."""
|
|
response = client.get('/api/v1/models-info')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert 'models_directory' in data
|
|
assert 'total_size_mb' in data
|
|
assert 'model_count' in data
|
|
|
|
|
|
class TestSystem:
|
|
def test_system_info(self):
|
|
"""Test system information."""
|
|
response = client.get('/api/v1/system')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data['status'] == 'ok'
|
|
assert 'cpu_usage_percent' in data
|
|
assert 'memory_usage_percent' in data
|
|
assert 'disk_usage_percent' in data
|
|
assert 'execution_providers' in data
|
|
|
|
def test_stats(self):
|
|
"""Test statistics endpoint."""
|
|
response = client.get('/api/v1/stats')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert 'total_requests' in data
|
|
assert 'successful_requests' in data
|
|
assert 'failed_requests' in data
|
|
|
|
|
|
class TestFileManager:
|
|
def test_ensure_directories(self):
|
|
"""Test directory creation."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Mock settings
|
|
with patch('app.services.file_manager.settings.upload_dir', f'{tmpdir}/uploads'):
|
|
with patch('app.services.file_manager.settings.output_dir', f'{tmpdir}/outputs'):
|
|
file_manager.ensure_directories()
|
|
assert Path(f'{tmpdir}/uploads').exists()
|
|
assert Path(f'{tmpdir}/outputs').exists()
|
|
|
|
def test_generate_output_path(self):
|
|
"""Test output path generation."""
|
|
input_path = '/tmp/test.jpg'
|
|
output_path = file_manager.generate_output_path(input_path)
|
|
assert output_path.endswith('.jpg')
|
|
assert 'upscaled' in output_path
|
|
|
|
def test_cleanup_directory(self, temp_dir):
|
|
"""Test directory cleanup."""
|
|
test_dir = temp_dir / 'test_cleanup'
|
|
test_dir.mkdir()
|
|
(test_dir / 'file.txt').write_text('test')
|
|
|
|
assert test_dir.exists()
|
|
file_manager.cleanup_directory(str(test_dir))
|
|
assert not test_dir.exists()
|
|
|
|
|
|
class TestWorker:
|
|
def test_job_creation(self):
|
|
"""Test job creation."""
|
|
job = worker.Job(
|
|
job_id='test-id',
|
|
status='queued',
|
|
input_path='/tmp/input.jpg',
|
|
output_path='/tmp/output.jpg',
|
|
model='RealESRGAN_x4plus',
|
|
)
|
|
|
|
assert job.job_id == 'test-id'
|
|
assert job.status == 'queued'
|
|
assert 'created_at' in job.to_dict()
|
|
|
|
def test_job_metadata_save(self, temp_dir):
|
|
"""Test job metadata persistence."""
|
|
with patch('app.services.worker.settings.jobs_dir', str(temp_dir)):
|
|
job = worker.Job(
|
|
job_id='test-id',
|
|
status='queued',
|
|
input_path='/tmp/input.jpg',
|
|
output_path='/tmp/output.jpg',
|
|
model='RealESRGAN_x4plus',
|
|
)
|
|
|
|
job.save_metadata()
|
|
metadata_file = temp_dir / 'test-id' / 'metadata.json'
|
|
assert metadata_file.exists()
|
|
|
|
data = json.loads(metadata_file.read_text())
|
|
assert data['job_id'] == 'test-id'
|
|
assert data['status'] == 'queued'
|
|
|
|
|
|
class TestJobEndpoints:
|
|
def test_list_jobs(self):
|
|
"""Test listing jobs endpoint."""
|
|
response = client.get('/api/v1/jobs')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert 'total' in data
|
|
assert 'jobs' in data
|
|
assert 'returned' in data
|
|
|
|
def test_job_not_found(self):
|
|
"""Test requesting non-existent job."""
|
|
response = client.get('/api/v1/jobs/nonexistent-id')
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestRootEndpoint:
|
|
def test_root(self):
|
|
"""Test root endpoint."""
|
|
response = client.get('/')
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert 'name' in data
|
|
assert 'version' in data
|
|
assert 'endpoints' in data
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|