Initial Real-ESRGAN API project setup
This commit is contained in:
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API services."""
|
||||
126
app/services/file_manager.py
Normal file
126
app/services/file_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""File management utilities."""
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_directories() -> None:
|
||||
"""Ensure all required directories exist."""
|
||||
for path in (settings.upload_dir, settings.output_dir, settings.models_dir,
|
||||
settings.temp_dir, settings.jobs_dir):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
logger.info(f'Directory ensured: {path}')
|
||||
|
||||
|
||||
def create_request_dir() -> str:
|
||||
"""Create a unique request directory."""
|
||||
request_id = str(uuid.uuid4())
|
||||
request_dir = os.path.join(settings.upload_dir, request_id)
|
||||
os.makedirs(request_dir, exist_ok=True)
|
||||
return request_dir
|
||||
|
||||
|
||||
async def save_upload(file: UploadFile, directory: str) -> str:
|
||||
"""Save uploaded file to directory."""
|
||||
ext = os.path.splitext(file.filename or '')[1] or '.jpg'
|
||||
filename = f'{uuid.uuid4()}{ext}'
|
||||
filepath = os.path.join(directory, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
while chunk := await file.read(1024 * 1024):
|
||||
f.write(chunk)
|
||||
|
||||
logger.debug(f'File saved: {filepath}')
|
||||
return filepath
|
||||
|
||||
|
||||
async def save_uploads(files: List[UploadFile], directory: str) -> List[str]:
|
||||
"""Save multiple uploaded files to directory."""
|
||||
paths = []
|
||||
for file in files:
|
||||
path = await save_upload(file, directory)
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
|
||||
def generate_output_path(input_path: str, suffix: str = '_upscaled') -> str:
|
||||
"""Generate output path for processed image."""
|
||||
base, ext = os.path.splitext(input_path)
|
||||
name = os.path.basename(base)
|
||||
filename = f'{name}{suffix}{ext}'
|
||||
return os.path.join(settings.output_dir, filename)
|
||||
|
||||
|
||||
def cleanup_directory(directory: str) -> None:
|
||||
"""Remove directory and all contents."""
|
||||
if os.path.isdir(directory):
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
logger.debug(f'Cleaned up directory: {directory}')
|
||||
|
||||
|
||||
def cleanup_file(filepath: str) -> None:
|
||||
"""Remove a file."""
|
||||
if os.path.isfile(filepath):
|
||||
os.remove(filepath)
|
||||
logger.debug(f'Cleaned up file: {filepath}')
|
||||
|
||||
|
||||
def get_directory_size_mb(directory: str) -> float:
|
||||
"""Get total size of directory in MB."""
|
||||
total = 0
|
||||
for dirpath, dirnames, filenames in os.walk(directory):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
if os.path.exists(fp):
|
||||
total += os.path.getsize(fp)
|
||||
return total / (1024 * 1024)
|
||||
|
||||
|
||||
def list_model_files() -> List[Tuple[str, str, int]]:
|
||||
"""Return list of (name, path, size_bytes) for all .pth/.onnx files in models dir."""
|
||||
models = []
|
||||
models_dir = settings.models_dir
|
||||
if not os.path.isdir(models_dir):
|
||||
return models
|
||||
|
||||
for name in sorted(os.listdir(models_dir)):
|
||||
if name.endswith(('.pth', '.onnx', '.pt', '.safetensors')):
|
||||
path = os.path.join(models_dir, name)
|
||||
try:
|
||||
size = os.path.getsize(path)
|
||||
models.append((name, path, size))
|
||||
except OSError:
|
||||
logger.warning(f'Could not get size of model: {path}')
|
||||
return models
|
||||
|
||||
|
||||
def cleanup_old_jobs(hours: int = 24) -> int:
|
||||
"""Clean up old job directories (older than specified hours)."""
|
||||
import time
|
||||
cutoff_time = time.time() - (hours * 3600)
|
||||
cleaned = 0
|
||||
|
||||
if not os.path.isdir(settings.jobs_dir):
|
||||
return cleaned
|
||||
|
||||
for item in os.listdir(settings.jobs_dir):
|
||||
item_path = os.path.join(settings.jobs_dir, item)
|
||||
if os.path.isdir(item_path):
|
||||
try:
|
||||
if os.path.getmtime(item_path) < cutoff_time:
|
||||
cleanup_directory(item_path)
|
||||
cleaned += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f'Cleaned up {cleaned} old job directories')
|
||||
return cleaned
|
||||
154
app/services/model_manager.py
Normal file
154
app/services/model_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model management utilities."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.services import file_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known models for Easy Real-ESRGAN
|
||||
KNOWN_MODELS = {
|
||||
'RealESRGAN_x2plus': {
|
||||
'scale': 2,
|
||||
'description': '2x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
'size_mb': 66.7,
|
||||
},
|
||||
'RealESRGAN_x3plus': {
|
||||
'scale': 3,
|
||||
'description': '3x upscaling',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x3plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (general purpose)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth',
|
||||
'size_mb': 101.7,
|
||||
},
|
||||
'RealESRGAN_x4plus_anime_6B': {
|
||||
'scale': 4,
|
||||
'description': '4x upscaling (anime/art)',
|
||||
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2/RealESRGAN_x4plus_anime_6B.pth',
|
||||
'size_mb': 18.9,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> List[Dict]:
|
||||
"""Get list of available models with details."""
|
||||
models = []
|
||||
|
||||
for model_name, metadata in KNOWN_MODELS.items():
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
available = os.path.exists(model_path)
|
||||
|
||||
size_mb = metadata['size_mb']
|
||||
if available:
|
||||
try:
|
||||
actual_size = os.path.getsize(model_path) / (1024 * 1024)
|
||||
size_mb = actual_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'scale': metadata['scale'],
|
||||
'description': metadata['description'],
|
||||
'available': available,
|
||||
'size_mb': size_mb,
|
||||
'size_bytes': int(size_mb * 1024 * 1024),
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_model_available(model_name: str) -> bool:
|
||||
"""Check if a model is available locally."""
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
return os.path.exists(model_path)
|
||||
|
||||
|
||||
def get_model_scale(model_name: str) -> Optional[int]:
|
||||
"""Get upscaling factor for a model."""
|
||||
if model_name in KNOWN_MODELS:
|
||||
return KNOWN_MODELS[model_name]['scale']
|
||||
|
||||
# Try to infer from model name
|
||||
if 'x2' in model_name.lower():
|
||||
return 2
|
||||
elif 'x3' in model_name.lower():
|
||||
return 3
|
||||
elif 'x4' in model_name.lower():
|
||||
return 4
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def download_model(model_name: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download a model.
|
||||
|
||||
Returns: (success, message)
|
||||
"""
|
||||
if model_name not in KNOWN_MODELS:
|
||||
return False, f'Unknown model: {model_name}'
|
||||
|
||||
if is_model_available(model_name):
|
||||
return True, f'Model already available: {model_name}'
|
||||
|
||||
metadata = KNOWN_MODELS[model_name]
|
||||
url = metadata['url']
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
try:
|
||||
logger.info(f'Downloading model: {model_name} from {url}')
|
||||
|
||||
import urllib.request
|
||||
os.makedirs(settings.models_dir, exist_ok=True)
|
||||
|
||||
def download_progress(count, block_size, total_size):
|
||||
downloaded = count * block_size
|
||||
percent = min(downloaded * 100 / total_size, 100)
|
||||
logger.debug(f'Download progress: {percent:.1f}%')
|
||||
|
||||
urllib.request.urlretrieve(url, model_path, download_progress)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
size_mb = os.path.getsize(model_path) / (1024 * 1024)
|
||||
logger.info(f'Model downloaded: {model_name} ({size_mb:.1f} MB)')
|
||||
return True, f'Model downloaded: {model_name} ({size_mb:.1f} MB)'
|
||||
else:
|
||||
return False, f'Failed to download model: {model_name}'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to download model {model_name}: {e}', exc_info=True)
|
||||
return False, f'Download failed: {str(e)}'
|
||||
|
||||
|
||||
async def download_models(model_names: List[str]) -> Dict[str, tuple[bool, str]]:
|
||||
"""Download multiple models."""
|
||||
results = {}
|
||||
|
||||
for model_name in model_names:
|
||||
success, message = await download_model(model_name)
|
||||
results[model_name] = (success, message)
|
||||
logger.info(f'Download result for {model_name}: {success} - {message}')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_models_directory_info() -> Dict:
|
||||
"""Get information about the models directory."""
|
||||
models_dir = settings.models_dir
|
||||
|
||||
return {
|
||||
'path': models_dir,
|
||||
'size_mb': file_manager.get_directory_size_mb(models_dir),
|
||||
'model_count': len([f for f in os.listdir(models_dir) if f.endswith('.pth')])
|
||||
if os.path.isdir(models_dir) else 0,
|
||||
}
|
||||
200
app/services/realesrgan_bridge.py
Normal file
200
app/services/realesrgan_bridge.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Real-ESRGAN model management and processing."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
REALESRGAN_AVAILABLE = True
|
||||
except ImportError:
|
||||
REALESRGAN_AVAILABLE = False
|
||||
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
|
||||
|
||||
|
||||
class RealESRGANBridge:
|
||||
"""Bridge to Real-ESRGAN functionality."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Real-ESRGAN bridge."""
|
||||
self.upsampler: Optional[RealESRGANer] = None
|
||||
self.current_model: Optional[str] = None
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize Real-ESRGAN upsampler."""
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN library not available')
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info('Initializing Real-ESRGAN upsampler...')
|
||||
|
||||
# Setup model loader
|
||||
scale = 4
|
||||
model_name = settings.default_model
|
||||
|
||||
# Determine model path
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
|
||||
|
||||
# Load model
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path if os.path.exists(model_path) else None,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
self.initialized = True
|
||||
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def load_model(self, model_name: str) -> bool:
|
||||
"""Load a specific upscaling model."""
|
||||
try:
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN not available')
|
||||
return False
|
||||
|
||||
logger.info(f'Loading model: {model_name}')
|
||||
|
||||
# Extract scale from model name
|
||||
scale = 4
|
||||
if 'x2' in model_name.lower():
|
||||
scale = 2
|
||||
elif 'x3' in model_name.lower():
|
||||
scale = 3
|
||||
elif 'x4' in model_name.lower():
|
||||
scale = 4
|
||||
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f'Model file not found: {model_path}')
|
||||
return False
|
||||
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
logger.info(f'Model loaded: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def upscale(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model_name: Optional[str] = None,
|
||||
outscale: Optional[float] = None,
|
||||
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
|
||||
"""
|
||||
Upscale an image.
|
||||
|
||||
Returns: (success, message, output_size)
|
||||
"""
|
||||
try:
|
||||
if not self.initialized:
|
||||
if not self.initialize():
|
||||
return False, 'Failed to initialize Real-ESRGAN', None
|
||||
|
||||
if model_name and model_name != self.current_model:
|
||||
if not self.load_model(model_name):
|
||||
return False, f'Failed to load model: {model_name}', None
|
||||
|
||||
if not self.upsampler:
|
||||
return False, 'Upsampler not initialized', None
|
||||
|
||||
# Read image
|
||||
logger.info(f'Reading image: {input_path}')
|
||||
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
||||
if input_img is None:
|
||||
return False, f'Failed to read image: {input_path}', None
|
||||
|
||||
input_shape = input_img.shape[:2]
|
||||
logger.info(f'Input image shape: {input_shape}')
|
||||
|
||||
# Upscale
|
||||
logger.info(f'Upscaling with model: {self.current_model}')
|
||||
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
|
||||
|
||||
# Save output
|
||||
cv2.imwrite(str(output_path), output)
|
||||
output_shape = output.shape[:2]
|
||||
logger.info(f'Output image shape: {output_shape}')
|
||||
logger.info(f'Upscaled image saved: {output_path}')
|
||||
|
||||
return True, 'Upscaling completed successfully', tuple(output_shape)
|
||||
except Exception as e:
|
||||
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
||||
return False, f'Upscaling failed: {str(e)}', None
|
||||
|
||||
def get_upscale_factor(self) -> int:
|
||||
"""Get current upscaling factor."""
|
||||
if self.upsampler:
|
||||
return self.upsampler.scale
|
||||
return 4
|
||||
|
||||
def clear_memory(self) -> None:
|
||||
"""Clear GPU memory if available."""
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug('GPU memory cleared')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global instance
|
||||
_bridge: Optional[RealESRGANBridge] = None
|
||||
|
||||
|
||||
def get_bridge() -> RealESRGANBridge:
|
||||
"""Get or create the global Real-ESRGAN bridge."""
|
||||
global _bridge
|
||||
if _bridge is None:
|
||||
_bridge = RealESRGANBridge()
|
||||
return _bridge
|
||||
217
app/services/worker.py
Normal file
217
app/services/worker.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Async job worker queue."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Async job data."""
|
||||
job_id: str
|
||||
status: str # queued, processing, completed, failed
|
||||
input_path: str
|
||||
output_path: str
|
||||
model: str
|
||||
tile_size: Optional[int] = None
|
||||
tile_pad: Optional[int] = None
|
||||
outscale: Optional[float] = None
|
||||
created_at: str = ''
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
processing_time_seconds: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.created_at:
|
||||
self.created_at = datetime.utcnow().isoformat()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
def save_metadata(self) -> None:
|
||||
"""Save job metadata to JSON file."""
|
||||
job_dir = os.path.join(settings.jobs_dir, self.job_id)
|
||||
os.makedirs(job_dir, exist_ok=True)
|
||||
|
||||
metadata_path = os.path.join(job_dir, 'metadata.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load_metadata(cls, job_id: str) -> Optional['Job']:
|
||||
"""Load job metadata from JSON file."""
|
||||
metadata_path = os.path.join(settings.jobs_dir, job_id, 'metadata.json')
|
||||
if not os.path.exists(metadata_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load job metadata: {e}')
|
||||
return None
|
||||
|
||||
|
||||
class WorkerQueue:
|
||||
"""Thread pool worker queue for processing jobs."""
|
||||
|
||||
def __init__(self, worker_func: Callable, num_workers: int = 2):
|
||||
"""
|
||||
Initialize worker queue.
|
||||
|
||||
Args:
|
||||
worker_func: Function to process jobs (job: Job) -> None
|
||||
num_workers: Number of worker threads
|
||||
"""
|
||||
self.queue: Queue = Queue()
|
||||
self.worker_func = worker_func
|
||||
self.num_workers = num_workers
|
||||
self.workers = []
|
||||
self.running = False
|
||||
self.jobs: Dict[str, Job] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start worker threads."""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
for i in range(self.num_workers):
|
||||
worker = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.info(f'Started {self.num_workers} worker threads')
|
||||
|
||||
def stop(self, timeout: int = 10) -> None:
|
||||
"""Stop worker threads gracefully."""
|
||||
self.running = False
|
||||
|
||||
# Signal workers to stop
|
||||
for _ in range(self.num_workers):
|
||||
self.queue.put(None)
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=timeout)
|
||||
|
||||
logger.info('Worker threads stopped')
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model: str,
|
||||
tile_size: Optional[int] = None,
|
||||
tile_pad: Optional[int] = None,
|
||||
outscale: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Submit a job for processing.
|
||||
|
||||
Returns: job_id
|
||||
"""
|
||||
job_id = str(uuid.uuid4())
|
||||
job = Job(
|
||||
job_id=job_id,
|
||||
status='queued',
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
model=model,
|
||||
tile_size=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
outscale=outscale,
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
self.queue.put(job)
|
||||
logger.info(f'Job submitted: {job_id}')
|
||||
|
||||
return job_id
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID."""
|
||||
with self.lock:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
def get_all_jobs(self) -> Dict[str, Job]:
|
||||
"""Get all jobs."""
|
||||
with self.lock:
|
||||
return dict(self.jobs)
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""Worker thread main loop."""
|
||||
logger.info(f'Worker thread started')
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
job = self.queue.get(timeout=1)
|
||||
if job is None: # Stop signal
|
||||
break
|
||||
|
||||
self._process_job(job)
|
||||
except Exception:
|
||||
pass # Timeout is normal
|
||||
|
||||
logger.info(f'Worker thread stopped')
|
||||
|
||||
def _process_job(self, job: Job) -> None:
|
||||
"""Process a single job."""
|
||||
try:
|
||||
with self.lock:
|
||||
job.status = 'processing'
|
||||
job.started_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
|
||||
start_time = time.time()
|
||||
self.worker_func(job)
|
||||
job.processing_time_seconds = time.time() - start_time
|
||||
|
||||
with self.lock:
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
logger.info(f'Job completed: {job.job_id} ({job.processing_time_seconds:.2f}s)')
|
||||
except Exception as e:
|
||||
logger.error(f'Job failed: {job.job_id}: {e}', exc_info=True)
|
||||
with self.lock:
|
||||
job.status = 'failed'
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.utcnow().isoformat()
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
job.save_metadata()
|
||||
|
||||
|
||||
# Global instance
|
||||
_worker_queue: Optional[WorkerQueue] = None
|
||||
|
||||
|
||||
def get_worker_queue(worker_func: Callable = None, num_workers: int = 2) -> WorkerQueue:
|
||||
"""Get or create the global worker queue."""
|
||||
global _worker_queue
|
||||
if _worker_queue is None:
|
||||
if worker_func is None:
|
||||
raise ValueError('worker_func required for first initialization')
|
||||
_worker_queue = WorkerQueue(worker_func, num_workers)
|
||||
return _worker_queue
|
||||
Reference in New Issue
Block a user