All checks were successful
Build and Push Docker Image / build (push) Successful in 8m35s
212 lines
7.2 KiB
Python
212 lines
7.2 KiB
Python
"""Real-ESRGAN model management and processing."""
|
|
import logging
|
|
import os
|
|
from typing import Optional, Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from app.config import settings
|
|
from app.services import model_manager
|
|
import urllib.request
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
REALESRGAN_AVAILABLE = True
|
|
except ImportError:
|
|
REALESRGAN_AVAILABLE = False
|
|
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
|
|
|
|
|
|
class RealESRGANBridge:
|
|
"""Bridge to Real-ESRGAN functionality."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the Real-ESRGAN bridge."""
|
|
self.upsampler: Optional[RealESRGANer] = None
|
|
self.current_model: Optional[str] = None
|
|
self.initialized = False
|
|
|
|
def initialize(self) -> bool:
|
|
"""Initialize Real-ESRGAN upsampler."""
|
|
if not REALESRGAN_AVAILABLE:
|
|
logger.error('Real-ESRGAN library not available')
|
|
return False
|
|
|
|
try:
|
|
logger.info('Initializing Real-ESRGAN upsampler...')
|
|
|
|
# Setup model loader
|
|
scale = 4
|
|
model_name = settings.default_model
|
|
|
|
# Determine model path
|
|
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
|
if not os.path.exists(model_path):
|
|
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
|
|
if settings.auto_model_download:
|
|
# Try to locate known model URL and download it
|
|
try:
|
|
meta = model_manager.KNOWN_MODELS.get(model_name)
|
|
if meta and meta.get('url'):
|
|
os.makedirs(settings.models_dir, exist_ok=True)
|
|
url = meta['url']
|
|
logger.info(f'Downloading model {model_name} from {url}')
|
|
urllib.request.urlretrieve(url, model_path)
|
|
logger.info(f'Model downloaded to {model_path}')
|
|
else:
|
|
logger.error(f'No download URL known for model: {model_name}')
|
|
except Exception as e:
|
|
logger.error(f'Automatic model download failed: {e}', exc_info=True)
|
|
|
|
if not os.path.exists(model_path):
|
|
logger.error(f'Model file still not found: {model_path} - aborting initialization')
|
|
return False
|
|
|
|
# Load model
|
|
model = RRDBNet(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_block=23,
|
|
num_grow_ch=32,
|
|
scale=scale
|
|
)
|
|
|
|
self.upsampler = RealESRGANer(
|
|
scale=scale,
|
|
model_path=model_path,
|
|
model=model,
|
|
tile=settings.tile_size,
|
|
tile_pad=settings.tile_pad,
|
|
pre_pad=0,
|
|
half=False,
|
|
)
|
|
|
|
self.current_model = model_name
|
|
self.initialized = True
|
|
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
|
|
return False
|
|
|
|
def load_model(self, model_name: str) -> bool:
|
|
"""Load a specific upscaling model."""
|
|
try:
|
|
if not REALESRGAN_AVAILABLE:
|
|
logger.error('Real-ESRGAN not available')
|
|
return False
|
|
|
|
logger.info(f'Loading model: {model_name}')
|
|
|
|
# Extract scale from model name
|
|
scale = 4
|
|
if 'x2' in model_name.lower():
|
|
scale = 2
|
|
elif 'x3' in model_name.lower():
|
|
scale = 3
|
|
elif 'x4' in model_name.lower():
|
|
scale = 4
|
|
|
|
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
|
|
|
if not os.path.exists(model_path):
|
|
logger.error(f'Model file not found: {model_path}')
|
|
return False
|
|
|
|
model = RRDBNet(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_block=23,
|
|
num_grow_ch=32,
|
|
scale=scale
|
|
)
|
|
|
|
self.upsampler = RealESRGANer(
|
|
scale=scale,
|
|
model_path=model_path,
|
|
model=model,
|
|
tile=settings.tile_size,
|
|
tile_pad=settings.tile_pad,
|
|
pre_pad=0,
|
|
half=False,
|
|
)
|
|
|
|
self.current_model = model_name
|
|
logger.info(f'Model loaded: {model_name}')
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
|
|
return False
|
|
|
|
def upscale(
|
|
self,
|
|
input_path: str,
|
|
output_path: str,
|
|
model_name: Optional[str] = None,
|
|
outscale: Optional[float] = None,
|
|
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
|
|
"""
|
|
Upscale an image.
|
|
|
|
Returns: (success, message, output_size)
|
|
"""
|
|
try:
|
|
if not self.initialized:
|
|
if not self.initialize():
|
|
return False, 'Failed to initialize Real-ESRGAN', None
|
|
|
|
if model_name and model_name != self.current_model:
|
|
if not self.load_model(model_name):
|
|
return False, f'Failed to load model: {model_name}', None
|
|
|
|
if not self.upsampler:
|
|
return False, 'Upsampler not initialized', None
|
|
|
|
# Read image
|
|
logger.info(f'Reading image: {input_path}')
|
|
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
|
if input_img is None:
|
|
return False, f'Failed to read image: {input_path}', None
|
|
|
|
input_shape = input_img.shape[:2]
|
|
logger.info(f'Input image shape: {input_shape}')
|
|
|
|
# Upscale
|
|
logger.info(f'Upscaling with model: {self.current_model}')
|
|
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
|
|
|
|
# Save output
|
|
cv2.imwrite(str(output_path), output)
|
|
output_shape = output.shape[:2]
|
|
logger.info(f'Output image shape: {output_shape}')
|
|
logger.info(f'Upscaled image saved: {output_path}')
|
|
|
|
return True, 'Upscaling completed successfully', tuple(output_shape)
|
|
except Exception as e:
|
|
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
|
return False, f'Upscaling failed: {str(e)}', None
|
|
|
|
def get_upscale_factor(self) -> int:
|
|
"""Get current upscaling factor."""
|
|
if self.upsampler:
|
|
return self.upsampler.scale
|
|
return 4
|
|
|
|
|
|
# Global instance
|
|
_bridge: Optional[RealESRGANBridge] = None
|
|
|
|
|
|
def get_bridge() -> RealESRGANBridge:
|
|
"""Get or create the global Real-ESRGAN bridge."""
|
|
global _bridge
|
|
if _bridge is None:
|
|
_bridge = RealESRGANBridge()
|
|
return _bridge
|