"""Real-ESRGAN model management and processing.""" import logging import os from typing import Optional, Tuple import cv2 import numpy as np from app.config import settings from app.services import model_manager import urllib.request logger = logging.getLogger(__name__) try: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer REALESRGAN_AVAILABLE = True except ImportError: REALESRGAN_AVAILABLE = False logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan') class RealESRGANBridge: """Bridge to Real-ESRGAN functionality.""" def __init__(self): """Initialize the Real-ESRGAN bridge.""" self.upsampler: Optional[RealESRGANer] = None self.current_model: Optional[str] = None self.initialized = False def initialize(self) -> bool: """Initialize Real-ESRGAN upsampler.""" if not REALESRGAN_AVAILABLE: logger.error('Real-ESRGAN library not available') return False try: logger.info('Initializing Real-ESRGAN upsampler...') # Setup model loader scale = 4 model_name = settings.default_model # Determine model path model_path = os.path.join(settings.models_dir, f'{model_name}.pth') if not os.path.exists(model_path): logger.warning(f'Model not found at {model_path}, will attempt to auto-download') if settings.auto_model_download: # Try to locate known model URL and download it try: meta = model_manager.KNOWN_MODELS.get(model_name) if meta and meta.get('url'): os.makedirs(settings.models_dir, exist_ok=True) url = meta['url'] logger.info(f'Downloading model {model_name} from {url}') urllib.request.urlretrieve(url, model_path) logger.info(f'Model downloaded to {model_path}') else: logger.error(f'No download URL known for model: {model_name}') except Exception as e: logger.error(f'Automatic model download failed: {e}', exc_info=True) if not os.path.exists(model_path): logger.error(f'Model file still not found: {model_path} - aborting initialization') return False # Load model model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale ) self.upsampler = RealESRGANer( scale=scale, model_path=model_path, model=model, tile=settings.tile_size, tile_pad=settings.tile_pad, pre_pad=0, half=False, ) self.current_model = model_name self.initialized = True logger.info(f'Real-ESRGAN initialized with model: {model_name}') return True except Exception as e: logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True) return False def load_model(self, model_name: str) -> bool: """Load a specific upscaling model.""" try: if not REALESRGAN_AVAILABLE: logger.error('Real-ESRGAN not available') return False logger.info(f'Loading model: {model_name}') # Extract scale from model name scale = 4 if 'x2' in model_name.lower(): scale = 2 elif 'x3' in model_name.lower(): scale = 3 elif 'x4' in model_name.lower(): scale = 4 model_path = os.path.join(settings.models_dir, f'{model_name}.pth') if not os.path.exists(model_path): logger.error(f'Model file not found: {model_path}') return False model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale ) self.upsampler = RealESRGANer( scale=scale, model_path=model_path, model=model, tile=settings.tile_size, tile_pad=settings.tile_pad, pre_pad=0, half=False, ) self.current_model = model_name logger.info(f'Model loaded: {model_name}') return True except Exception as e: logger.error(f'Failed to load model {model_name}: {e}', exc_info=True) return False def upscale( self, input_path: str, output_path: str, model_name: Optional[str] = None, outscale: Optional[float] = None, ) -> Tuple[bool, str, Optional[Tuple[int, int]]]: """ Upscale an image. Returns: (success, message, output_size) """ try: if not self.initialized: if not self.initialize(): return False, 'Failed to initialize Real-ESRGAN', None if model_name and model_name != self.current_model: if not self.load_model(model_name): return False, f'Failed to load model: {model_name}', None if not self.upsampler: return False, 'Upsampler not initialized', None # Read image logger.info(f'Reading image: {input_path}') input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED) if input_img is None: return False, f'Failed to read image: {input_path}', None input_shape = input_img.shape[:2] logger.info(f'Input image shape: {input_shape}') # Upscale logger.info(f'Upscaling with model: {self.current_model}') output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4) # Save output cv2.imwrite(str(output_path), output) output_shape = output.shape[:2] logger.info(f'Output image shape: {output_shape}') logger.info(f'Upscaled image saved: {output_path}') return True, 'Upscaling completed successfully', tuple(output_shape) except Exception as e: logger.error(f'Upscaling failed: {e}', exc_info=True) return False, f'Upscaling failed: {str(e)}', None def get_upscale_factor(self) -> int: """Get current upscaling factor.""" if self.upsampler: return self.upsampler.scale return 4 # Global instance _bridge: Optional[RealESRGANBridge] = None def get_bridge() -> RealESRGANBridge: """Get or create the global Real-ESRGAN bridge.""" global _bridge if _bridge is None: _bridge = RealESRGANBridge() return _bridge