Initial Real-ESRGAN API project setup
This commit is contained in:
200
app/services/realesrgan_bridge.py
Normal file
200
app/services/realesrgan_bridge.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Real-ESRGAN model management and processing."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
REALESRGAN_AVAILABLE = True
|
||||
except ImportError:
|
||||
REALESRGAN_AVAILABLE = False
|
||||
logger.warning('Real-ESRGAN not available. Install via: pip install realesrgan')
|
||||
|
||||
|
||||
class RealESRGANBridge:
|
||||
"""Bridge to Real-ESRGAN functionality."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Real-ESRGAN bridge."""
|
||||
self.upsampler: Optional[RealESRGANer] = None
|
||||
self.current_model: Optional[str] = None
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize Real-ESRGAN upsampler."""
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN library not available')
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info('Initializing Real-ESRGAN upsampler...')
|
||||
|
||||
# Setup model loader
|
||||
scale = 4
|
||||
model_name = settings.default_model
|
||||
|
||||
# Determine model path
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f'Model not found at {model_path}, will attempt to auto-download')
|
||||
|
||||
# Load model
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path if os.path.exists(model_path) else None,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
self.initialized = True
|
||||
logger.info(f'Real-ESRGAN initialized with model: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Real-ESRGAN: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def load_model(self, model_name: str) -> bool:
|
||||
"""Load a specific upscaling model."""
|
||||
try:
|
||||
if not REALESRGAN_AVAILABLE:
|
||||
logger.error('Real-ESRGAN not available')
|
||||
return False
|
||||
|
||||
logger.info(f'Loading model: {model_name}')
|
||||
|
||||
# Extract scale from model name
|
||||
scale = 4
|
||||
if 'x2' in model_name.lower():
|
||||
scale = 2
|
||||
elif 'x3' in model_name.lower():
|
||||
scale = 3
|
||||
elif 'x4' in model_name.lower():
|
||||
scale = 4
|
||||
|
||||
model_path = os.path.join(settings.models_dir, f'{model_name}.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f'Model file not found: {model_path}')
|
||||
return False
|
||||
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model=model,
|
||||
tile=settings.tile_size,
|
||||
tile_pad=settings.tile_pad,
|
||||
pre_pad=0,
|
||||
half=('cuda' in settings.get_execution_providers()),
|
||||
)
|
||||
|
||||
self.current_model = model_name
|
||||
logger.info(f'Model loaded: {model_name}')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load model {model_name}: {e}', exc_info=True)
|
||||
return False
|
||||
|
||||
def upscale(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model_name: Optional[str] = None,
|
||||
outscale: Optional[float] = None,
|
||||
) -> Tuple[bool, str, Optional[Tuple[int, int]]]:
|
||||
"""
|
||||
Upscale an image.
|
||||
|
||||
Returns: (success, message, output_size)
|
||||
"""
|
||||
try:
|
||||
if not self.initialized:
|
||||
if not self.initialize():
|
||||
return False, 'Failed to initialize Real-ESRGAN', None
|
||||
|
||||
if model_name and model_name != self.current_model:
|
||||
if not self.load_model(model_name):
|
||||
return False, f'Failed to load model: {model_name}', None
|
||||
|
||||
if not self.upsampler:
|
||||
return False, 'Upsampler not initialized', None
|
||||
|
||||
# Read image
|
||||
logger.info(f'Reading image: {input_path}')
|
||||
input_img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
||||
if input_img is None:
|
||||
return False, f'Failed to read image: {input_path}', None
|
||||
|
||||
input_shape = input_img.shape[:2]
|
||||
logger.info(f'Input image shape: {input_shape}')
|
||||
|
||||
# Upscale
|
||||
logger.info(f'Upscaling with model: {self.current_model}')
|
||||
output, _ = self.upsampler.enhance(input_img, outscale=outscale or 4)
|
||||
|
||||
# Save output
|
||||
cv2.imwrite(str(output_path), output)
|
||||
output_shape = output.shape[:2]
|
||||
logger.info(f'Output image shape: {output_shape}')
|
||||
logger.info(f'Upscaled image saved: {output_path}')
|
||||
|
||||
return True, 'Upscaling completed successfully', tuple(output_shape)
|
||||
except Exception as e:
|
||||
logger.error(f'Upscaling failed: {e}', exc_info=True)
|
||||
return False, f'Upscaling failed: {str(e)}', None
|
||||
|
||||
def get_upscale_factor(self) -> int:
|
||||
"""Get current upscaling factor."""
|
||||
if self.upsampler:
|
||||
return self.upsampler.scale
|
||||
return 4
|
||||
|
||||
def clear_memory(self) -> None:
|
||||
"""Clear GPU memory if available."""
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug('GPU memory cleared')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global instance
|
||||
_bridge: Optional[RealESRGANBridge] = None
|
||||
|
||||
|
||||
def get_bridge() -> RealESRGANBridge:
|
||||
"""Get or create the global Real-ESRGAN bridge."""
|
||||
global _bridge
|
||||
if _bridge is None:
|
||||
_bridge = RealESRGANBridge()
|
||||
return _bridge
|
||||
Reference in New Issue
Block a user