Files
realesrgan-api/app/services/realesrgan_bridge.py
2026-02-16 19:56:25 +01:00

201 lines
6.5 KiB
Python

"""Real-ESRGAN model management and processing."""
import logging
import os
from typing import Optional, Tuple
import cv2
import numpy as np
from app.config import settings
logger = logging.getLogger(__name__)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
REALESRGAN_AVAILABLE = True
except ImportError:
REALESRGAN_AVAILABLE = False
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
class RealESRGANBridge:
"""Bridge to Real-ESRGAN functionality."""
def __init__(self):
"""Initialize the Real-ESRGAN bridge."""
self.upsampler: Optional[RealESRGANer] = None
self.current_model: Optional[str] = None
self.initialized = False
def initialize(self) -> bool:
"""Initialize Real-ESRGAN upsampler."""
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN library not available')
return False
try:
logger.info('Initializing Real-ESRGAN upsampler...')
# Setup model loader
scale = 4
model_name = settings.default_model
# Determine model path
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
# Load model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path if os.path.exists(model_path) else None,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
self.initialized = True
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
return False
def load_model(self, model_name: str) -> bool:
"""Load a specific upscaling model."""
try:
if not REALESRGAN_AVAILABLE:
logger.error('Real-ESRGAN not available')
return False
logger.info(f'Loading model: {model_name}')
# Extract scale from model name
scale = 4
if 'x2' in model_name.lower():
scale = 2
elif 'x3' in model_name.lower():
scale = 3
elif 'x4' in model_name.lower():
scale = 4
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
if not os.path.exists(model_path):
logger.error(f'Model file not found: {model_path}')
return False
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale
)
self.upsampler = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
tile=settings.tile_size,
tile_pad=settings.tile_pad,
pre_pad=0,
half=('cuda' in settings.get_execution_providers()),
)
self.current_model = model_name
logger.info(f'Model loaded: {model_name}')
return True
except Exception as e:
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
return False
def upscale(
self,
input_path: str,
output_path: str,
model_name: Optional[str] = None,
outscale: Optional[float] = None,
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
"""
Upscale an image.
Returns: (success, message, output_size)
"""
try:
if not self.initialized:
if not self.initialize():
return False, 'Failed to initialize Real-ESRGAN', None
if model_name and model_name != self.current_model:
if not self.load_model(model_name):
return False, f'Failed to load model: {model_name}', None
if not self.upsampler:
return False, 'Upsampler not initialized', None
# Read image
logger.info(f'Reading image: {input_path}')
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
if input_img is None:
return False, f'Failed to read image: {input_path}', None
input_shape = input_img.shape[:2]
logger.info(f'Input image shape: {input_shape}')
# Upscale
logger.info(f'Upscaling with model: {self.current_model}')
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
# Save output
cv2.imwrite(str(output_path), output)
output_shape = output.shape[:2]
logger.info(f'Output image shape: {output_shape}')
logger.info(f'Upscaled image saved: {output_path}')
return True, 'Upscaling completed successfully', tuple(output_shape)
except Exception as e:
logger.error(f'Upscaling failed: {e}', exc_info=True)
return False, f'Upscaling failed: {str(e)}', None
def get_upscale_factor(self) -> int:
"""Get current upscaling factor."""
if self.upsampler:
return self.upsampler.scale
return 4
def clear_memory(self) -> None:
"""Clear GPU memory if available."""
try:
import torch
torch.cuda.empty_cache()
logger.debug('GPU memory cleared')
except Exception:
pass
# Global instance
_bridge: Optional[RealESRGANBridge] = None
def get_bridge() -> RealESRGANBridge:
"""Get or create the global Real-ESRGAN bridge."""
global _bridge
if _bridge is None:
_bridge = RealESRGANBridge()
return _bridge