From a6d20cf087634b9b8020f0433f7eaca8d11b25da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Kr=C3=BCger?= Date: Thu, 27 Nov 2025 11:56:59 +0100 Subject: [PATCH] feat: initial implementation of Real-ESRGAN Web UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Full-featured Gradio 6.0+ web interface for Real-ESRGAN image/video upscaling, optimized for RTX 4090 (24GB VRAM). Features: - Image upscaling with before/after comparison (ImageSlider) - Video upscaling with progress tracking and checkpoint/resume - Face enhancement via GFPGAN integration - Multiple codecs: H.264, H.265, AV1 (with NVENC support) - Batch processing queue with SQLite persistence - Processing history gallery - Custom dark theme - Auto-download of model weights 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- app.py | 29 +++ pyproject.toml | 18 ++ requirements.txt | 30 +++ src/__init__.py | 1 + src/config.py | 222 ++++++++++++++++++ src/processing/__init__.py | 1 + src/processing/face_enhancer.py | 213 +++++++++++++++++ src/processing/models.py | 240 ++++++++++++++++++++ src/processing/upscaler.py | 341 ++++++++++++++++++++++++++++ src/storage/__init__.py | 1 + src/storage/database.py | 141 ++++++++++++ src/storage/history.py | 282 +++++++++++++++++++++++ src/storage/queue.py | 332 +++++++++++++++++++++++++++ src/ui/__init__.py | 1 + src/ui/app.py | 111 +++++++++ src/ui/components/__init__.py | 1 + src/ui/components/batch_tab.py | 200 ++++++++++++++++ src/ui/components/history_tab.py | 208 +++++++++++++++++ src/ui/components/image_tab.py | 292 ++++++++++++++++++++++++ src/ui/components/video_tab.py | 376 +++++++++++++++++++++++++++++++ src/ui/handlers/__init__.py | 1 + src/ui/theme.py | 284 +++++++++++++++++++++++ src/utils/__init__.py | 1 + src/video/__init__.py | 1 + src/video/audio.py | 221 ++++++++++++++++++ src/video/checkpoint.py | 250 ++++++++++++++++++++ src/video/encoder.py | 274 ++++++++++++++++++++++ src/video/extractor.py | 201 +++++++++++++++++ 28 files changed, 4273 insertions(+) create mode 100644 app.py create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/config.py create mode 100644 src/processing/__init__.py create mode 100644 src/processing/face_enhancer.py create mode 100644 src/processing/models.py create mode 100644 src/processing/upscaler.py create mode 100644 src/storage/__init__.py create mode 100644 src/storage/database.py create mode 100644 src/storage/history.py create mode 100644 src/storage/queue.py create mode 100644 src/ui/__init__.py create mode 100644 src/ui/app.py create mode 100644 src/ui/components/__init__.py create mode 100644 src/ui/components/batch_tab.py create mode 100644 src/ui/components/history_tab.py create mode 100644 src/ui/components/image_tab.py create mode 100644 src/ui/components/video_tab.py create mode 100644 src/ui/handlers/__init__.py create mode 100644 src/ui/theme.py create mode 100644 src/utils/__init__.py create mode 100644 src/video/__init__.py create mode 100644 src/video/audio.py create mode 100644 src/video/checkpoint.py create mode 100644 src/video/encoder.py create mode 100644 src/video/extractor.py diff --git a/app.py b/app.py new file mode 100644 index 0000000..3984f4b --- /dev/null +++ b/app.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +""" +Real-ESRGAN Web UI - Main Entry Point + +A full-featured web interface for Real-ESRGAN image and video upscaling, +optimized for RTX 4090 with 24GB VRAM. + +Usage: + python app.py + +The server will start at http://localhost:7860 +""" + +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent)) + +from src.ui.app import launch_app + + +def main(): + """Main entry point.""" + launch_app() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5ffa3e2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "upscale-ui" +version = "0.1.0" +description = "Real-ESRGAN Web UI for image and video upscaling" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} + +[tool.setuptools] +packages = ["src"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3e131d0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +# Real-ESRGAN Web UI Dependencies + +# Gradio UI Framework +gradio>=6.0.0 +gradio-imageslider>=0.0.20 + +# PyTorch (CUDA 12.1) +--extra-index-url https://download.pytorch.org/whl/cu121 +torch>=2.0.0 +torchvision>=0.15.0 + +# Real-ESRGAN and dependencies +realesrgan>=0.3.0 +basicsr>=1.4.2 + +# Face Enhancement +gfpgan>=1.3.8 +facexlib>=0.3.0 + +# Video Processing +av>=10.0.0 + +# Image Processing +opencv-python>=4.8.0 +Pillow>=10.0.0 +numpy>=1.24.0 + +# Utilities +tqdm>=4.65.0 +requests>=2.28.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..0071bd5 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""Real-ESRGAN Web UI - Image and Video Upscaling.""" diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..ca63ab4 --- /dev/null +++ b/src/config.py @@ -0,0 +1,222 @@ +"""Configuration settings for Real-ESRGAN Web UI.""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + + +# Base paths +BASE_DIR = Path(__file__).parent.parent +DATA_DIR = BASE_DIR / "data" +MODELS_DIR = DATA_DIR / "models" +TEMP_DIR = DATA_DIR / "temp" +CHECKPOINTS_DIR = DATA_DIR / "checkpoints" +OUTPUT_DIR = DATA_DIR / "output" +DATABASE_PATH = DATA_DIR / "upscale.db" + +# Ensure directories exist +for dir_path in [MODELS_DIR, TEMP_DIR, CHECKPOINTS_DIR, OUTPUT_DIR]: + dir_path.mkdir(parents=True, exist_ok=True) + + +@dataclass +class ModelInfo: + """Information about a Real-ESRGAN or GFPGAN model.""" + + name: str + scale: int + filename: str + url: str + size_mb: float + description: str + model_type: str = "realesrgan" # "realesrgan" or "gfpgan" + netscale: int = 4 # Network scale factor + num_block: int = 23 # Number of RRDB blocks (6 for anime models) + num_grow_ch: int = 32 # Growth channels + + +# Model registry with download URLs +MODEL_REGISTRY: dict[str, ModelInfo] = { + "RealESRGAN_x4plus": ModelInfo( + name="RealESRGAN_x4plus", + scale=4, + filename="RealESRGAN_x4plus.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + size_mb=63.9, + description="Best quality for general photos (4x)", + num_block=23, + ), + "RealESRGAN_x2plus": ModelInfo( + name="RealESRGAN_x2plus", + scale=2, + filename="RealESRGAN_x2plus.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + size_mb=63.9, + description="General photos (2x upscaling)", + netscale=2, + num_block=23, + ), + "RealESRGAN_x4plus_anime_6B": ModelInfo( + name="RealESRGAN_x4plus_anime_6B", + scale=4, + filename="RealESRGAN_x4plus_anime_6B.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + size_mb=17.0, + description="Optimized for anime/illustrations (4x)", + num_block=6, + ), + "realesr-animevideov3": ModelInfo( + name="realesr-animevideov3", + scale=4, + filename="realesr-animevideov3.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + size_mb=17.0, + description="Anime video with temporal consistency (4x)", + num_block=6, + ), + "realesr-general-x4v3": ModelInfo( + name="realesr-general-x4v3", + scale=4, + filename="realesr-general-x4v3.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + size_mb=63.9, + description="General purpose with denoise control (4x)", + num_block=23, + ), + "GFPGANv1.4": ModelInfo( + name="GFPGANv1.4", + scale=1, + filename="GFPGANv1.4.pth", + url="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth", + size_mb=332.0, + description="Face enhancement and restoration", + model_type="gfpgan", + ), +} + + +@dataclass +class RTX4090Config: + """RTX 4090 optimized settings (24GB VRAM).""" + + tile_size: int = 0 # 0 = no tiling, use full image + tile_pad: int = 10 # Padding for tile boundaries + pre_pad: int = 0 # Pre-padding for input + half: bool = True # FP16 precision + device: str = "cuda" + gpu_id: int = 0 + + # Thresholds for when to enable tiling + max_pixels_no_tile: int = 8294400 # ~4K (3840x2160) + + def should_tile(self, width: int, height: int) -> bool: + """Determine if tiling is needed based on image size.""" + return width * height > self.max_pixels_no_tile + + def get_tile_size(self, width: int, height: int) -> int: + """Get appropriate tile size for the image.""" + if not self.should_tile(width, height): + return 0 # No tiling + # For very large images, use 512px tiles + return 512 + + +@dataclass +class VideoCodecConfig: + """Video encoding configuration.""" + + encoder: str + crf: int + preset: str + nvenc_encoder: Optional[str] = None + description: str = "" + + +# Video codec presets (maximum quality settings) +VIDEO_CODECS: dict[str, VideoCodecConfig] = { + "H.264": VideoCodecConfig( + encoder="libx264", + crf=18, + preset="slow", + nvenc_encoder="h264_nvenc", + description="Universal compatibility, excellent quality", + ), + "H.265": VideoCodecConfig( + encoder="libx265", + crf=20, + preset="slow", + nvenc_encoder="hevc_nvenc", + description="50% smaller files, great quality", + ), + "AV1": VideoCodecConfig( + encoder="libsvtav1", + crf=23, + preset="4", + nvenc_encoder="av1_nvenc", + description="Best compression, future-proof", + ), +} + + +@dataclass +class AppConfig: + """Main application configuration.""" + + # Server settings + server_name: str = "0.0.0.0" + server_port: int = 7860 + share: bool = False + + # Queue settings + max_queue_size: int = 20 + concurrency_limit: int = 1 # Single job at a time for GPU efficiency + + # Processing defaults + default_model: str = "RealESRGAN_x4plus" + default_scale: int = 4 + default_face_enhance: bool = False + default_output_format: str = "png" + + # Video defaults + default_video_codec: str = "H.265" + default_video_crf: int = 20 + default_video_preset: str = "slow" + + # History settings + max_history_items: int = 1000 + thumbnail_size: tuple[int, int] = (256, 256) + + # Checkpoint settings + checkpoint_interval: int = 100 # Save every N frames + + # GPU config + gpu: RTX4090Config = field(default_factory=RTX4090Config) + + +# Global config instance +config = AppConfig() + + +def get_model_path(model_name: str) -> Path: + """Get the path to a model file.""" + if model_name not in MODEL_REGISTRY: + raise ValueError(f"Unknown model: {model_name}") + return MODELS_DIR / MODEL_REGISTRY[model_name].filename + + +def get_available_models() -> list[str]: + """Get list of available model names for UI dropdown.""" + return [ + name + for name, info in MODEL_REGISTRY.items() + if info.model_type == "realesrgan" + ] + + +def get_model_choices() -> list[tuple[str, str]]: + """Get model choices for Gradio dropdown (label, value).""" + return [ + (f"{info.description}", name) + for name, info in MODEL_REGISTRY.items() + if info.model_type == "realesrgan" + ] diff --git a/src/processing/__init__.py b/src/processing/__init__.py new file mode 100644 index 0000000..f107c8c --- /dev/null +++ b/src/processing/__init__.py @@ -0,0 +1 @@ +"""Image and video processing modules.""" diff --git a/src/processing/face_enhancer.py b/src/processing/face_enhancer.py new file mode 100644 index 0000000..ce1aa02 --- /dev/null +++ b/src/processing/face_enhancer.py @@ -0,0 +1,213 @@ +"""Face enhancement using GFPGAN.""" + +import logging +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import torch + +from src.config import MODEL_REGISTRY, MODELS_DIR, config +from src.processing.models import ensure_model_available + +logger = logging.getLogger(__name__) + +# Global GFPGAN instance (cached) +_gfpgan_instance = None + + +class FaceEnhancerError(Exception): + """Error during face enhancement.""" + + pass + + +def get_gfpgan( + model_name: str = "GFPGANv1.4", + upscale: int = 2, + bg_upsampler=None, +): + """ + Get or create GFPGAN instance. + + Args: + model_name: GFPGAN model version + upscale: Background upscale factor + bg_upsampler: Optional background upsampler (RealESRGAN) + + Returns: + GFPGANer instance + """ + global _gfpgan_instance + + if _gfpgan_instance is not None: + return _gfpgan_instance + + try: + from gfpgan import GFPGANer + except ImportError: + raise FaceEnhancerError( + "GFPGAN not installed. Install with: pip install gfpgan" + ) + + # Ensure model is downloaded + model_path = ensure_model_available(model_name) + + # Determine model version from name + if "1.4" in model_name: + arch = "clean" + channel_multiplier = 2 + elif "1.3" in model_name: + arch = "clean" + channel_multiplier = 2 + elif "1.2" in model_name: + arch = "clean" + channel_multiplier = 2 + else: + arch = "original" + channel_multiplier = 1 + + device = "cuda" if torch.cuda.is_available() else "cpu" + + _gfpgan_instance = GFPGANer( + model_path=str(model_path), + upscale=upscale, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=bg_upsampler, + device=device, + ) + + logger.info(f"Loaded GFPGAN: {model_name}") + return _gfpgan_instance + + +def enhance_faces( + image: np.ndarray, + model_name: str = "GFPGANv1.4", + has_aligned: bool = False, + only_center_face: bool = False, + paste_back: bool = True, + weight: float = 0.5, +) -> np.ndarray: + """ + Enhance faces in an image using GFPGAN. + + Args: + image: Input image (BGR format, numpy array) + model_name: GFPGAN model version to use + has_aligned: Whether input is already aligned face + only_center_face: Only enhance the center/largest face + paste_back: Paste enhanced face back to original image + weight: Blending weight (0=original, 1=enhanced) + + Returns: + Enhanced image (BGR format, numpy array) + """ + if image is None: + raise FaceEnhancerError("No image provided") + + try: + gfpgan = get_gfpgan(model_name) + + # GFPGAN enhance + _, _, output = gfpgan.enhance( + image, + has_aligned=has_aligned, + only_center_face=only_center_face, + paste_back=paste_back, + weight=weight, + ) + + if output is None: + logger.warning("No faces detected, returning original image") + return image + + return output + + except Exception as e: + logger.error(f"Face enhancement failed: {e}") + raise FaceEnhancerError(f"Face enhancement failed: {e}") from e + + +def detect_faces(image: np.ndarray) -> int: + """ + Detect number of faces in image. + + Args: + image: Input image (BGR format) + + Returns: + Number of detected faces + """ + try: + from facexlib.detection import init_detection_model + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + + # Use facexlib for detection + device = "cuda" if torch.cuda.is_available() else "cpu" + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=device, + ) + + face_helper.read_image(image) + face_helper.get_face_landmarks_5(only_center_face=False) + + return len(face_helper.all_landmarks_5) + + except Exception as e: + logger.warning(f"Face detection failed: {e}") + return 0 + + +def is_anime_image(image: np.ndarray) -> bool: + """ + Attempt to detect if image is anime/illustration style. + + This is a simple heuristic - for production use, consider a + trained classifier. + + Args: + image: Input image (BGR format) + + Returns: + True if likely anime/illustration + """ + # Convert to grayscale + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Anime tends to have: + # 1. More uniform color regions (lower texture variance) + # 2. Sharper edges (higher edge density) + + # Calculate Laplacian variance (edge strength) + laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var() + + # Calculate local variance (texture) + local_var = cv2.blur(gray.astype(float) ** 2, (9, 9)) - cv2.blur( + gray.astype(float), (9, 9) + ) ** 2 + texture_var = local_var.mean() + + # Anime typically has high edge strength but low texture variance + # These thresholds are approximate + edge_ratio = laplacian_var / (texture_var + 1e-6) + + # Heuristic: anime has high edge_ratio + return edge_ratio > 50 + + +def clear_gfpgan_cache(): + """Clear cached GFPGAN model.""" + global _gfpgan_instance + _gfpgan_instance = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Cleared GFPGAN cache") diff --git a/src/processing/models.py b/src/processing/models.py new file mode 100644 index 0000000..3c94938 --- /dev/null +++ b/src/processing/models.py @@ -0,0 +1,240 @@ +"""Model management for Real-ESRGAN and GFPGAN.""" + +import logging +import time +from pathlib import Path +from threading import Lock +from typing import Optional + +import requests +import torch +from tqdm import tqdm + +from src.config import MODEL_REGISTRY, MODELS_DIR, ModelInfo, config + +logger = logging.getLogger(__name__) + + +class ModelDownloadError(Exception): + """Error downloading model weights.""" + + pass + + +def download_model(model_info: ModelInfo, progress_callback=None) -> Path: + """ + Download model weights from URL. + + Args: + model_info: Model information from registry + progress_callback: Optional callback(downloaded, total) for progress + + Returns: + Path to downloaded model file + """ + model_path = MODELS_DIR / model_info.filename + + if model_path.exists(): + logger.info(f"Model already exists: {model_path}") + return model_path + + logger.info(f"Downloading {model_info.name} ({model_info.size_mb:.1f} MB)...") + + try: + response = requests.get(model_info.url, stream=True, timeout=30) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 + + # Download with progress + with open(model_path, "wb") as f: + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=model_info.name, + ) as pbar: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded += len(chunk) + pbar.update(len(chunk)) + if progress_callback: + progress_callback(downloaded, total_size) + + logger.info(f"Downloaded: {model_path}") + return model_path + + except Exception as e: + # Clean up partial download + if model_path.exists(): + model_path.unlink() + raise ModelDownloadError(f"Failed to download {model_info.name}: {e}") from e + + +def ensure_model_available(model_name: str) -> Path: + """ + Ensure model is available, downloading if necessary. + + Args: + model_name: Name of model from registry + + Returns: + Path to model file + """ + if model_name not in MODEL_REGISTRY: + raise ValueError(f"Unknown model: {model_name}") + + model_info = MODEL_REGISTRY[model_name] + model_path = MODELS_DIR / model_info.filename + + if not model_path.exists(): + download_model(model_info) + + return model_path + + +class ModelManager: + """ + Manages Real-ESRGAN model loading and caching. + + Optimized for RTX 4090 with 24GB VRAM: + - Keeps frequently used models in memory + - LRU eviction when VRAM is constrained + - Thread-safe operations + """ + + def __init__(self, max_cached_models: int = 3): + self.max_cached_models = max_cached_models + self._cache: dict[str, "RealESRGANer"] = {} + self._last_used: dict[str, float] = {} + self._lock = Lock() + self._device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + + if torch.cuda.is_available(): + props = torch.cuda.get_device_properties(0) + logger.info(f"GPU: {props.name}, VRAM: {props.total_memory / (1024**3):.1f}GB") + else: + logger.warning("CUDA not available, using CPU (this will be slow)") + + def get_upsampler( + self, + model_name: str, + tile_size: int = 0, + tile_pad: int = 10, + half: bool = True, + ) -> "RealESRGANer": + """ + Get or create a RealESRGANer instance. + + Args: + model_name: Name of model from registry + tile_size: Tile size for processing (0 = no tiling) + tile_pad: Padding for tile boundaries + half: Use FP16 precision + + Returns: + Configured RealESRGANer instance + """ + # Import here to avoid circular imports + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + + cache_key = f"{model_name}_{tile_size}_{tile_pad}_{half}" + + with self._lock: + # Check cache + if cache_key in self._cache: + self._last_used[cache_key] = time.time() + return self._cache[cache_key] + + # Ensure model is downloaded + model_path = ensure_model_available(model_name) + model_info = MODEL_REGISTRY[model_name] + + # Evict old models if at capacity + self._evict_if_needed() + + # Create model architecture + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=model_info.num_block, + num_grow_ch=model_info.num_grow_ch, + scale=model_info.netscale, + ) + + # Create upsampler + upsampler = RealESRGANer( + scale=model_info.netscale, + model_path=str(model_path), + model=model, + tile=tile_size, + tile_pad=tile_pad, + pre_pad=config.gpu.pre_pad, + half=half and self._device.type == "cuda", + device=self._device, + gpu_id=config.gpu.gpu_id if self._device.type == "cuda" else None, + ) + + # Cache and return + self._cache[cache_key] = upsampler + self._last_used[cache_key] = time.time() + + logger.info(f"Loaded model: {model_name} (tile={tile_size}, half={half})") + return upsampler + + def _evict_if_needed(self): + """Evict least recently used model if at capacity.""" + while len(self._cache) >= self.max_cached_models: + # Find LRU + lru_key = min(self._last_used.keys(), key=lambda k: self._last_used[k]) + + # Remove from cache + del self._cache[lru_key] + del self._last_used[lru_key] + + # Clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"Evicted model from cache: {lru_key}") + + def clear_cache(self): + """Clear all cached models.""" + with self._lock: + self._cache.clear() + self._last_used.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Cleared model cache") + + def get_vram_status(self) -> dict: + """Get current VRAM usage statistics.""" + if not torch.cuda.is_available(): + return {"available": False} + + return { + "available": True, + "allocated_gb": torch.cuda.memory_allocated() / (1024**3), + "reserved_gb": torch.cuda.memory_reserved() / (1024**3), + "total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3), + "cached_models": list(self._cache.keys()), + } + + def preload_model(self, model_name: str): + """Preload a model into cache.""" + self.get_upsampler( + model_name, + tile_size=config.gpu.tile_size, + tile_pad=config.gpu.tile_pad, + half=config.gpu.half, + ) + + +# Global model manager instance +model_manager = ModelManager() diff --git a/src/processing/upscaler.py b/src/processing/upscaler.py new file mode 100644 index 0000000..2ea4a87 --- /dev/null +++ b/src/processing/upscaler.py @@ -0,0 +1,341 @@ +"""High-level upscaling interface for Real-ESRGAN.""" + +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional, Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from src.config import MODEL_REGISTRY, config +from src.processing.models import model_manager + +logger = logging.getLogger(__name__) + + +@dataclass +class UpscaleResult: + """Result of an upscaling operation.""" + + image: np.ndarray # BGR format + input_width: int + input_height: int + output_width: int + output_height: int + model_used: str + scale_factor: int + processing_time: float + used_tiling: bool + + +class UpscalerError(Exception): + """Error during upscaling.""" + + pass + + +def calculate_tile_size(width: int, height: int) -> int: + """ + Calculate optimal tile size based on image dimensions. + + For RTX 4090 with 24GB VRAM: + - No tiling for images up to ~4K + - 512px tiles for larger images + """ + pixels = width * height + + # Thresholds based on testing with RealESRGAN_x4plus + if pixels <= 2073600: # Up to 1080p (1920x1080) + return 0 # No tiling + elif pixels <= 8294400: # Up to 4K (3840x2160) + return 0 # RTX 4090 can handle without tiling + elif pixels <= 33177600: # Up to 8K (7680x4320) + return 512 + else: + return 256 # Very large images + + +def load_image( + input_path: Union[str, Path, np.ndarray, Image.Image] +) -> tuple[np.ndarray, str]: + """ + Load image from various sources. + + Args: + input_path: Path to image, numpy array, or PIL Image + + Returns: + Tuple of (BGR numpy array, source description) + """ + if isinstance(input_path, np.ndarray): + # Assume already BGR if 3 channels + if len(input_path.shape) == 3 and input_path.shape[2] == 3: + return input_path, "numpy" + elif len(input_path.shape) == 3 and input_path.shape[2] == 4: + # RGBA to BGR + return cv2.cvtColor(input_path, cv2.COLOR_RGBA2BGR), "numpy_rgba" + else: + raise UpscalerError(f"Unsupported array shape: {input_path.shape}") + + if isinstance(input_path, Image.Image): + # PIL to numpy BGR + img = np.array(input_path) + if img.shape[2] == 4: + return cv2.cvtColor(img, cv2.COLOR_RGBA2BGR), "pil_rgba" + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR), "pil" + + # Load from file + input_path = Path(input_path) + if not input_path.exists(): + raise UpscalerError(f"Image not found: {input_path}") + + img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED) + if img is None: + raise UpscalerError(f"Failed to load image: {input_path}") + + # Handle different channel counts + if len(img.shape) == 2: + # Grayscale to BGR + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: + # RGBA - keep alpha for later + pass # RealESRGAN handles alpha internally + + return img, str(input_path) + + +def save_image( + image: np.ndarray, + output_path: Union[str, Path], + format: str = "png", + quality: int = 95, +) -> Path: + """ + Save image to file. + + Args: + image: BGR numpy array + output_path: Output path (extension will be added/replaced) + format: Output format (png, jpg, webp) + quality: JPEG/WebP quality (0-100) + + Returns: + Path to saved file + """ + output_path = Path(output_path) + + # Ensure correct extension + format = format.lower() + if format == "jpg": + format = "jpeg" + + ext = f".{format}" if format != "jpeg" else ".jpg" + output_path = output_path.with_suffix(ext) + + # Ensure output directory exists + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Save with appropriate settings + if format == "png": + cv2.imwrite(str(output_path), image, [cv2.IMWRITE_PNG_COMPRESSION, 6]) + elif format == "jpeg": + cv2.imwrite(str(output_path), image, [cv2.IMWRITE_JPEG_QUALITY, quality]) + elif format == "webp": + cv2.imwrite(str(output_path), image, [cv2.IMWRITE_WEBP_QUALITY, quality]) + else: + cv2.imwrite(str(output_path), image) + + return output_path + + +def upscale_image( + input_image: Union[str, Path, np.ndarray, Image.Image], + model_name: str = None, + scale: int = None, + tile_size: int = None, + denoise_strength: float = 0.5, + progress_callback: Optional[Callable[[float, str], None]] = None, +) -> UpscaleResult: + """ + Upscale a single image. + + Args: + input_image: Input image (path, numpy array, or PIL Image) + model_name: Model to use (defaults to config) + scale: Output scale (2 or 4, defaults to model's native scale) + tile_size: Tile size (0 for auto, None for config default) + denoise_strength: Denoise strength for supported models (0-1) + progress_callback: Optional callback(progress, stage) for progress + + Returns: + UpscaleResult with upscaled image and metadata + """ + start_time = time.time() + + # Use defaults + model_name = model_name or config.default_model + model_info = MODEL_REGISTRY.get(model_name) + if not model_info: + raise UpscalerError(f"Unknown model: {model_name}") + + # Load image + if progress_callback: + progress_callback(0.1, "Loading image...") + + img, source = load_image(input_image) + input_height, input_width = img.shape[:2] + + logger.info(f"Input: {input_width}x{input_height} from {source}") + + # Calculate tile size if auto + if tile_size is None: + tile_size = config.gpu.tile_size + if tile_size == 0: + # Auto-calculate based on image size + tile_size = calculate_tile_size(input_width, input_height) + + used_tiling = tile_size > 0 + + # Get upsampler + if progress_callback: + progress_callback(0.2, "Loading model...") + + upsampler = model_manager.get_upsampler( + model_name=model_name, + tile_size=tile_size, + tile_pad=config.gpu.tile_pad, + half=config.gpu.half, + ) + + # Determine output scale + if scale is None: + scale = model_info.scale + outscale = scale / model_info.netscale if scale != model_info.netscale else None + + # Process + if progress_callback: + progress_callback(0.3, "Upscaling...") + + try: + # RealESRGAN enhance method + if model_name == "realesr-general-x4v3": + # This model supports denoise strength + output, _ = upsampler.enhance( + img, outscale=outscale, dnoise_strength=denoise_strength + ) + else: + output, _ = upsampler.enhance(img, outscale=outscale) + + except torch.cuda.OutOfMemoryError: + # Clear cache and retry with tiling + torch.cuda.empty_cache() + logger.warning("VRAM exceeded, retrying with tiling...") + + upsampler = model_manager.get_upsampler( + model_name=model_name, + tile_size=512, # Force tiling + tile_pad=config.gpu.tile_pad, + half=config.gpu.half, + ) + + if model_name == "realesr-general-x4v3": + output, _ = upsampler.enhance( + img, outscale=outscale, dnoise_strength=denoise_strength + ) + else: + output, _ = upsampler.enhance(img, outscale=outscale) + + used_tiling = True + + except Exception as e: + raise UpscalerError(f"Upscaling failed: {e}") from e + + if progress_callback: + progress_callback(0.9, "Finalizing...") + + output_height, output_width = output.shape[:2] + processing_time = time.time() - start_time + + logger.info( + f"Output: {output_width}x{output_height}, " + f"time: {processing_time:.2f}s, tiling: {used_tiling}" + ) + + if progress_callback: + progress_callback(1.0, "Complete!") + + return UpscaleResult( + image=output, + input_width=input_width, + input_height=input_height, + output_width=output_width, + output_height=output_height, + model_used=model_name, + scale_factor=scale, + processing_time=processing_time, + used_tiling=used_tiling, + ) + + +def upscale_and_save( + input_path: Union[str, Path], + output_path: Union[str, Path], + model_name: str = None, + scale: int = None, + output_format: str = "png", + denoise_strength: float = 0.5, + progress_callback: Optional[Callable[[float, str], None]] = None, +) -> tuple[Path, UpscaleResult]: + """ + Upscale image and save to file. + + Args: + input_path: Input image path + output_path: Output path (extension may be adjusted) + model_name: Model to use + scale: Output scale + output_format: Output format (png, jpg, webp) + denoise_strength: Denoise strength for supported models + progress_callback: Optional progress callback + + Returns: + Tuple of (saved path, UpscaleResult) + """ + result = upscale_image( + input_path, + model_name=model_name, + scale=scale, + denoise_strength=denoise_strength, + progress_callback=progress_callback, + ) + + saved_path = save_image(result.image, output_path, format=output_format) + + return saved_path, result + + +def get_output_dimensions( + width: int, height: int, model_name: str, scale: int = None +) -> tuple[int, int]: + """ + Calculate output dimensions for given input and model. + + Args: + width: Input width + height: Input height + model_name: Model name + scale: Target scale (uses model default if None) + + Returns: + Tuple of (output_width, output_height) + """ + model_info = MODEL_REGISTRY.get(model_name) + if not model_info: + raise ValueError(f"Unknown model: {model_name}") + + scale = scale or model_info.scale + return width * scale, height * scale diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000..058cdaf --- /dev/null +++ b/src/storage/__init__.py @@ -0,0 +1 @@ +"""Storage and persistence modules (SQLite, history, queue).""" diff --git a/src/storage/database.py b/src/storage/database.py new file mode 100644 index 0000000..90da137 --- /dev/null +++ b/src/storage/database.py @@ -0,0 +1,141 @@ +"""SQLite database setup and utilities.""" + +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from threading import Lock + +from src.config import DATABASE_PATH + + +# Thread-safe connection management +_db_lock = Lock() + + +def get_connection() -> sqlite3.Connection: + """Get a database connection with proper settings.""" + conn = sqlite3.connect(str(DATABASE_PATH), timeout=30.0) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + conn.execute("PRAGMA journal_mode = WAL") # Better concurrent access + return conn + + +@contextmanager +def get_db(): + """Context manager for database connections.""" + conn = get_connection() + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def init_database(): + """Initialize database schema.""" + with _db_lock: + with get_db() as conn: + # History table for processed items + conn.execute(""" + CREATE TABLE IF NOT EXISTS history ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + input_path TEXT NOT NULL, + input_filename TEXT NOT NULL, + output_path TEXT NOT NULL, + output_filename TEXT NOT NULL, + model TEXT NOT NULL, + scale INTEGER NOT NULL, + face_enhance INTEGER NOT NULL DEFAULT 0, + video_codec TEXT, + processing_time_seconds REAL NOT NULL, + input_width INTEGER NOT NULL, + input_height INTEGER NOT NULL, + output_width INTEGER NOT NULL, + output_height INTEGER NOT NULL, + frames_count INTEGER, + input_size_bytes INTEGER NOT NULL, + output_size_bytes INTEGER NOT NULL, + thumbnail_path TEXT, + created_at TEXT NOT NULL, + metadata_json TEXT DEFAULT '{}' + ) + """) + + # Index for efficient queries + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_history_created_at + ON history(created_at DESC) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_history_type + ON history(type) + """) + + # Queue table for batch processing + conn.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + input_path TEXT NOT NULL, + output_path TEXT, + config_json TEXT NOT NULL, + progress_percent REAL DEFAULT 0, + current_stage TEXT DEFAULT '', + frames_processed INTEGER DEFAULT 0, + total_frames INTEGER DEFAULT 0, + error_message TEXT, + retry_count INTEGER DEFAULT 0, + priority INTEGER DEFAULT 5, + created_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT + ) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_queue_status + ON queue(status) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_queue_priority_created + ON queue(priority ASC, created_at ASC) + """) + + # Checkpoints table for video resume + conn.execute(""" + CREATE TABLE IF NOT EXISTS checkpoints ( + job_id TEXT PRIMARY KEY, + input_path TEXT NOT NULL, + input_hash TEXT NOT NULL, + total_frames INTEGER NOT NULL, + processed_frames INTEGER NOT NULL DEFAULT 0, + last_processed_frame INTEGER NOT NULL DEFAULT -1, + frames_dir TEXT NOT NULL, + audio_path TEXT, + config_json TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """) + + +def reset_database(): + """Reset database (for development/testing).""" + with _db_lock: + with get_db() as conn: + conn.execute("DROP TABLE IF EXISTS history") + conn.execute("DROP TABLE IF EXISTS queue") + conn.execute("DROP TABLE IF EXISTS checkpoints") + init_database() + + +# Initialize on import +init_database() diff --git a/src/storage/history.py b/src/storage/history.py new file mode 100644 index 0000000..864e61b --- /dev/null +++ b/src/storage/history.py @@ -0,0 +1,282 @@ +"""Processing history management with SQLite persistence.""" + +import json +import logging +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Optional +import uuid + +from src.storage.database import get_db + +logger = logging.getLogger(__name__) + + +@dataclass +class HistoryItem: + """Record of a processed image/video.""" + + id: str + type: str # "image" or "video" + input_path: str + input_filename: str + output_path: str + output_filename: str + model: str + scale: int + face_enhance: bool + video_codec: Optional[str] + processing_time_seconds: float + input_width: int + input_height: int + output_width: int + output_height: int + frames_count: Optional[int] + input_size_bytes: int + output_size_bytes: int + thumbnail_path: Optional[str] + created_at: datetime + metadata: dict + + +class HistoryManager: + """ + Manage processing history with SQLite persistence. + + Features: + - Store all processed items + - Search and filter + - Thumbnails for quick preview + """ + + def __init__(self): + logger.info("Initialized history manager") + + def add_item( + self, + type: str, + input_path: str, + output_path: str, + model: str, + scale: int, + face_enhance: bool = False, + video_codec: Optional[str] = None, + processing_time_seconds: float = 0, + input_width: int = 0, + input_height: int = 0, + output_width: int = 0, + output_height: int = 0, + frames_count: Optional[int] = None, + thumbnail_path: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> HistoryItem: + """Add item to history.""" + item_id = uuid.uuid4().hex[:12] + input_path = Path(input_path) + output_path = Path(output_path) + + # Get file sizes + input_size = input_path.stat().st_size if input_path.exists() else 0 + output_size = output_path.stat().st_size if output_path.exists() else 0 + + with get_db() as conn: + conn.execute( + """ + INSERT INTO history ( + id, type, input_path, input_filename, + output_path, output_filename, model, scale, + face_enhance, video_codec, processing_time_seconds, + input_width, input_height, output_width, output_height, + frames_count, input_size_bytes, output_size_bytes, + thumbnail_path, created_at, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + item_id, + type, + str(input_path), + input_path.name, + str(output_path), + output_path.name, + model, + scale, + int(face_enhance), + video_codec, + processing_time_seconds, + input_width, + input_height, + output_width, + output_height, + frames_count, + input_size, + output_size, + thumbnail_path, + datetime.utcnow().isoformat(), + json.dumps(metadata or {}), + ), + ) + + logger.info(f"Added to history: {item_id} ({type})") + + return HistoryItem( + id=item_id, + type=type, + input_path=str(input_path), + input_filename=input_path.name, + output_path=str(output_path), + output_filename=output_path.name, + model=model, + scale=scale, + face_enhance=face_enhance, + video_codec=video_codec, + processing_time_seconds=processing_time_seconds, + input_width=input_width, + input_height=input_height, + output_width=output_width, + output_height=output_height, + frames_count=frames_count, + input_size_bytes=input_size, + output_size_bytes=output_size, + thumbnail_path=thumbnail_path, + created_at=datetime.utcnow(), + metadata=metadata or {}, + ) + + def get_item(self, item_id: str) -> Optional[HistoryItem]: + """Get history item by ID.""" + with get_db() as conn: + row = conn.execute( + "SELECT * FROM history WHERE id = ?", + (item_id,), + ).fetchone() + + if row: + return self._row_to_item(row) + return None + + def get_recent( + self, + limit: int = 50, + offset: int = 0, + type_filter: Optional[str] = None, + ) -> list[HistoryItem]: + """Get recent history items.""" + with get_db() as conn: + if type_filter: + rows = conn.execute( + """ + SELECT * FROM history + WHERE type = ? + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, + (type_filter, limit, offset), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT * FROM history + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ).fetchall() + + return [self._row_to_item(row) for row in rows] + + def search( + self, + query: str, + limit: int = 50, + ) -> list[HistoryItem]: + """Search history by filename.""" + with get_db() as conn: + rows = conn.execute( + """ + SELECT * FROM history + WHERE input_filename LIKE ? OR output_filename LIKE ? + ORDER BY created_at DESC + LIMIT ? + """, + (f"%{query}%", f"%{query}%", limit), + ).fetchall() + + return [self._row_to_item(row) for row in rows] + + def delete_item(self, item_id: str) -> bool: + """Delete history item.""" + with get_db() as conn: + result = conn.execute( + "DELETE FROM history WHERE id = ?", + (item_id,), + ) + deleted = result.rowcount > 0 + + if deleted: + logger.info(f"Deleted from history: {item_id}") + + return deleted + + def get_statistics(self) -> dict: + """Get history statistics.""" + with get_db() as conn: + stats = conn.execute( + """ + SELECT + COUNT(*) as total, + SUM(CASE WHEN type = 'image' THEN 1 ELSE 0 END) as images, + SUM(CASE WHEN type = 'video' THEN 1 ELSE 0 END) as videos, + SUM(processing_time_seconds) as total_time, + SUM(input_size_bytes) as total_input_size, + SUM(output_size_bytes) as total_output_size, + SUM(frames_count) as total_frames + FROM history + """ + ).fetchone() + + return { + "total_items": stats["total"] or 0, + "images": stats["images"] or 0, + "videos": stats["videos"] or 0, + "total_processing_time": stats["total_time"] or 0, + "total_input_size_mb": (stats["total_input_size"] or 0) / (1024 * 1024), + "total_output_size_mb": (stats["total_output_size"] or 0) / (1024 * 1024), + "total_frames_processed": stats["total_frames"] or 0, + } + + def clear_history(self) -> int: + """Clear all history.""" + with get_db() as conn: + result = conn.execute("DELETE FROM history") + return result.rowcount + + def _row_to_item(self, row) -> HistoryItem: + """Convert database row to HistoryItem.""" + return HistoryItem( + id=row["id"], + type=row["type"], + input_path=row["input_path"], + input_filename=row["input_filename"], + output_path=row["output_path"], + output_filename=row["output_filename"], + model=row["model"], + scale=row["scale"], + face_enhance=bool(row["face_enhance"]), + video_codec=row["video_codec"], + processing_time_seconds=row["processing_time_seconds"], + input_width=row["input_width"], + input_height=row["input_height"], + output_width=row["output_width"], + output_height=row["output_height"], + frames_count=row["frames_count"], + input_size_bytes=row["input_size_bytes"], + output_size_bytes=row["output_size_bytes"], + thumbnail_path=row["thumbnail_path"], + created_at=datetime.fromisoformat(row["created_at"]), + metadata=json.loads(row["metadata_json"]), + ) + + +# Global history manager instance +history_manager = HistoryManager() diff --git a/src/storage/queue.py b/src/storage/queue.py new file mode 100644 index 0000000..8c5c1d1 --- /dev/null +++ b/src/storage/queue.py @@ -0,0 +1,332 @@ +"""Job queue management with SQLite persistence.""" + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Optional +import uuid + +from src.storage.database import get_db + +logger = logging.getLogger(__name__) + + +class JobStatus(Enum): + """Job status enumeration.""" + + PENDING = "pending" + QUEUED = "queued" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobType(Enum): + """Job type enumeration.""" + + IMAGE = "image" + VIDEO = "video" + + +@dataclass +class Job: + """Processing job data.""" + + id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) + type: JobType = JobType.IMAGE + status: JobStatus = JobStatus.PENDING + input_path: str = "" + output_path: Optional[str] = None + config: dict = field(default_factory=dict) + progress_percent: float = 0.0 + current_stage: str = "" + frames_processed: int = 0 + total_frames: int = 0 + error_message: Optional[str] = None + priority: int = 5 + created_at: datetime = field(default_factory=datetime.utcnow) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class QueueManager: + """ + Manage job queue with SQLite persistence. + + Features: + - Priority-based queue ordering + - Persistent across restarts + - Progress tracking + """ + + def __init__(self): + logger.info("Initialized queue manager") + + def add_job(self, job: Job) -> Job: + """Add job to queue.""" + with get_db() as conn: + conn.execute( + """ + INSERT INTO queue ( + id, type, status, input_path, output_path, + config_json, progress_percent, current_stage, + frames_processed, total_frames, error_message, + priority, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + job.id, + job.type.value, + JobStatus.QUEUED.value, + job.input_path, + job.output_path, + json.dumps(job.config), + job.progress_percent, + job.current_stage, + job.frames_processed, + job.total_frames, + job.error_message, + job.priority, + job.created_at.isoformat(), + ), + ) + + job.status = JobStatus.QUEUED + logger.info(f"Added job to queue: {job.id}") + return job + + def get_next_job(self) -> Optional[Job]: + """Get next job to process (highest priority, oldest first).""" + with get_db() as conn: + row = conn.execute( + """ + SELECT * FROM queue + WHERE status = ? + ORDER BY priority ASC, created_at ASC + LIMIT 1 + """, + (JobStatus.QUEUED.value,), + ).fetchone() + + if not row: + return None + + # Mark as processing + conn.execute( + """ + UPDATE queue SET status = ?, started_at = ? + WHERE id = ? + """, + ( + JobStatus.PROCESSING.value, + datetime.utcnow().isoformat(), + row["id"], + ), + ) + + return self._row_to_job(row) + + def update_progress( + self, + job_id: str, + progress_percent: float, + current_stage: str = "", + frames_processed: int = 0, + total_frames: int = 0, + ) -> None: + """Update job progress.""" + with get_db() as conn: + conn.execute( + """ + UPDATE queue SET + progress_percent = ?, + current_stage = ?, + frames_processed = ?, + total_frames = ? + WHERE id = ? + """, + (progress_percent, current_stage, frames_processed, total_frames, job_id), + ) + + def complete_job(self, job_id: str, output_path: str) -> None: + """Mark job as completed.""" + with get_db() as conn: + conn.execute( + """ + UPDATE queue SET + status = ?, + output_path = ?, + completed_at = ?, + progress_percent = 100 + WHERE id = ? + """, + ( + JobStatus.COMPLETED.value, + output_path, + datetime.utcnow().isoformat(), + job_id, + ), + ) + logger.info(f"Job completed: {job_id}") + + def fail_job(self, job_id: str, error_message: str) -> None: + """Mark job as failed.""" + with get_db() as conn: + conn.execute( + """ + UPDATE queue SET + status = ?, + error_message = ?, + completed_at = ? + WHERE id = ? + """, + ( + JobStatus.FAILED.value, + error_message, + datetime.utcnow().isoformat(), + job_id, + ), + ) + logger.error(f"Job failed: {job_id} - {error_message}") + + def cancel_job(self, job_id: str) -> None: + """Cancel a job.""" + with get_db() as conn: + conn.execute( + """ + UPDATE queue SET status = ? + WHERE id = ? AND status IN (?, ?) + """, + ( + JobStatus.CANCELLED.value, + job_id, + JobStatus.PENDING.value, + JobStatus.QUEUED.value, + ), + ) + logger.info(f"Job cancelled: {job_id}") + + def get_job(self, job_id: str) -> Optional[Job]: + """Get job by ID.""" + with get_db() as conn: + row = conn.execute( + "SELECT * FROM queue WHERE id = ?", + (job_id,), + ).fetchone() + + if row: + return self._row_to_job(row) + return None + + def get_queue(self) -> list[Job]: + """Get all queued/processing jobs.""" + with get_db() as conn: + rows = conn.execute( + """ + SELECT * FROM queue + WHERE status IN (?, ?) + ORDER BY priority ASC, created_at ASC + """, + (JobStatus.QUEUED.value, JobStatus.PROCESSING.value), + ).fetchall() + + return [self._row_to_job(row) for row in rows] + + def get_completed_jobs(self, limit: int = 50) -> list[Job]: + """Get recently completed jobs.""" + with get_db() as conn: + rows = conn.execute( + """ + SELECT * FROM queue + WHERE status IN (?, ?, ?) + ORDER BY completed_at DESC + LIMIT ? + """, + ( + JobStatus.COMPLETED.value, + JobStatus.FAILED.value, + JobStatus.CANCELLED.value, + limit, + ), + ).fetchall() + + return [self._row_to_job(row) for row in rows] + + def get_queue_stats(self) -> dict: + """Get queue statistics.""" + with get_db() as conn: + stats = {} + + # Count by status + rows = conn.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """ + ).fetchall() + + for row in rows: + stats[row["status"]] = row["count"] + + # Currently processing + processing = conn.execute( + "SELECT * FROM queue WHERE status = ?", + (JobStatus.PROCESSING.value,), + ).fetchone() + + stats["current_job"] = self._row_to_job(processing) if processing else None + + return stats + + def clear_completed(self) -> int: + """Clear completed/failed/cancelled jobs.""" + with get_db() as conn: + result = conn.execute( + """ + DELETE FROM queue + WHERE status IN (?, ?, ?) + """, + ( + JobStatus.COMPLETED.value, + JobStatus.FAILED.value, + JobStatus.CANCELLED.value, + ), + ) + return result.rowcount + + def _row_to_job(self, row) -> Job: + """Convert database row to Job object.""" + return Job( + id=row["id"], + type=JobType(row["type"]), + status=JobStatus(row["status"]), + input_path=row["input_path"], + output_path=row["output_path"], + config=json.loads(row["config_json"]), + progress_percent=row["progress_percent"], + current_stage=row["current_stage"], + frames_processed=row["frames_processed"], + total_frames=row["total_frames"], + error_message=row["error_message"], + priority=row["priority"], + created_at=datetime.fromisoformat(row["created_at"]), + started_at=( + datetime.fromisoformat(row["started_at"]) + if row["started_at"] + else None + ), + completed_at=( + datetime.fromisoformat(row["completed_at"]) + if row["completed_at"] + else None + ), + ) + + +# Global queue manager instance +queue_manager = QueueManager() diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000..009cd09 --- /dev/null +++ b/src/ui/__init__.py @@ -0,0 +1 @@ +"""UI components and handlers for the Gradio interface.""" diff --git a/src/ui/app.py b/src/ui/app.py new file mode 100644 index 0000000..d5c0477 --- /dev/null +++ b/src/ui/app.py @@ -0,0 +1,111 @@ +"""Main Gradio application assembly.""" + +import logging +import torch +import gradio as gr + +from src.config import config +from src.ui.theme import get_theme, get_css +from src.ui.components.image_tab import create_image_tab +from src.ui.components.video_tab import create_video_tab +from src.ui.components.batch_tab import create_batch_tab +from src.ui.components.history_tab import create_history_tab + +logger = logging.getLogger(__name__) + + +def get_gpu_status() -> str: + """Get GPU status string for display.""" + if not torch.cuda.is_available(): + return "GPU: Not available (CPU mode)" + + try: + props = torch.cuda.get_device_properties(0) + allocated = torch.cuda.memory_allocated() / (1024**3) + total = props.total_memory / (1024**3) + return f"GPU: {props.name} | VRAM: {allocated:.1f}/{total:.1f} GB" + except Exception: + return "GPU: Status unavailable" + + +def create_app() -> gr.Blocks: + """Create and return the main Gradio application.""" + theme = get_theme() + css = get_css() + + with gr.Blocks( + theme=theme, + css=css, + title="Real-ESRGAN Upscaler", + fill_width=True, + ) as app: + # Header + gr.HTML( + """ +
+

Real-ESRGAN Upscaler

+
+ """, + elem_classes=["header"], + ) + + # Main tabs + with gr.Tabs(elem_id="main-tabs") as tabs: + # Image upscaling tab + image_components = create_image_tab() + + # Video upscaling tab + video_components = create_video_tab() + + # Batch queue tab + batch_components = create_batch_tab() + + # History tab + history_components = create_history_tab() + + # Status bar + gr.HTML( + f""" +
+
+ + {get_gpu_status()} +
+
+ Queue: 0 | Ready +
+
+ """, + elem_classes=["status-bar-container"], + ) + + return app + + +def launch_app(): + """Launch the Gradio application.""" + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + logger.info("Starting Real-ESRGAN Web UI...") + logger.info(get_gpu_status()) + + # Create and launch app + app = create_app() + + # Configure queue for GPU processing + app.queue( + default_concurrency_limit=config.concurrency_limit, + max_size=config.max_queue_size, + ) + + # Launch server + app.launch( + server_name=config.server_name, + server_port=config.server_port, + share=config.share, + show_error=True, + ) diff --git a/src/ui/components/__init__.py b/src/ui/components/__init__.py new file mode 100644 index 0000000..db9356d --- /dev/null +++ b/src/ui/components/__init__.py @@ -0,0 +1 @@ +"""Gradio UI components for each tab.""" diff --git a/src/ui/components/batch_tab.py b/src/ui/components/batch_tab.py new file mode 100644 index 0000000..783300d --- /dev/null +++ b/src/ui/components/batch_tab.py @@ -0,0 +1,200 @@ +"""Batch processing queue tab component.""" + +import logging +from pathlib import Path +from typing import Optional + +import gradio as gr + +from src.config import get_model_choices, config +from src.storage.queue import queue_manager, Job, JobType, JobStatus + +logger = logging.getLogger(__name__) + + +def get_queue_data() -> list[list]: + """Get queue data for dataframe display.""" + jobs = queue_manager.get_queue() + + data = [] + for job in jobs: + status_emoji = { + JobStatus.QUEUED: "⏳", + JobStatus.PROCESSING: "🔄", + JobStatus.COMPLETED: "✅", + JobStatus.FAILED: "❌", + JobStatus.CANCELLED: "🚫", + }.get(job.status, "❓") + + progress = f"{job.progress_percent:.0f}%" if job.status == JobStatus.PROCESSING else "-" + + data.append([ + job.id[:8], + job.type.value.title(), + Path(job.input_path).name[:30], + f"{status_emoji} {job.status.value.title()}", + progress, + job.current_stage or "-", + ]) + + return data + + +def get_queue_stats() -> str: + """Get queue statistics for display.""" + stats = queue_manager.get_queue_stats() + + queued = stats.get(JobStatus.QUEUED.value, 0) + processing = stats.get(JobStatus.PROCESSING.value, 0) + completed = stats.get(JobStatus.COMPLETED.value, 0) + failed = stats.get(JobStatus.FAILED.value, 0) + + current = stats.get("current_job") + current_info = "" + if current: + current_info = f"\n\n**Currently Processing:**\n{Path(current.input_path).name}" + + return ( + f"**Queue Status**\n\n" + f"- Queued: {queued}\n" + f"- Processing: {processing}\n" + f"- Completed: {completed}\n" + f"- Failed: {failed}" + f"{current_info}" + ) + + +def add_files_to_queue( + files: Optional[list], + model_name: str, + scale: int, + face_enhance: bool, +) -> tuple[list[list], str]: + """Add uploaded files to queue.""" + if not files: + return get_queue_data(), "No files selected" + + added = 0 + for file in files: + file_path = Path(file.name if hasattr(file, 'name') else file) + + # Determine job type + ext = file_path.suffix.lower() + if ext in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]: + job_type = JobType.IMAGE + elif ext in [".mp4", ".avi", ".mov", ".mkv", ".webm"]: + job_type = JobType.VIDEO + else: + continue + + job = Job( + type=job_type, + input_path=str(file_path), + config={ + "model_name": model_name, + "scale": scale, + "face_enhance": face_enhance, + }, + ) + + queue_manager.add_job(job) + added += 1 + + return get_queue_data(), f"Added {added} items to queue" + + +def cancel_job(job_id: str) -> tuple[list[list], str]: + """Cancel a job.""" + if job_id: + queue_manager.cancel_job(job_id) + return get_queue_data(), f"Cancelled job: {job_id}" + return get_queue_data(), "No job selected" + + +def clear_completed() -> tuple[list[list], str]: + """Clear completed jobs from queue.""" + count = queue_manager.clear_completed() + return get_queue_data(), f"Cleared {count} completed jobs" + + +def create_batch_tab(): + """Create the batch queue tab component.""" + with gr.Tab("Batch Queue", id="queue-tab"): + gr.Markdown("## Batch Processing Queue") + + with gr.Row(): + # Queue table + with gr.Column(scale=2): + queue_table = gr.Dataframe( + headers=["ID", "Type", "Input", "Status", "Progress", "Stage"], + datatype=["str", "str", "str", "str", "str", "str"], + value=get_queue_data(), + label="Queue", + interactive=False, + ) + + # Queue stats + with gr.Column(scale=1): + queue_stats = gr.Markdown( + get_queue_stats(), + elem_classes=["info-card"], + ) + + # Add to queue section + with gr.Accordion("Add to Queue", open=True): + with gr.Row(): + file_input = gr.File( + label="Select Files", + file_count="multiple", + file_types=["image", "video"], + ) + + with gr.Row(): + model_dropdown = gr.Dropdown( + choices=get_model_choices(), + value=config.default_model, + label="Model", + ) + + scale_radio = gr.Radio( + choices=[2, 4], + value=config.default_scale, + label="Scale", + ) + + face_enhance_cb = gr.Checkbox( + value=False, + label="Face Enhancement", + ) + + add_btn = gr.Button("Add to Queue", variant="primary") + + # Queue actions + with gr.Row(): + refresh_btn = gr.Button("Refresh", variant="secondary") + clear_btn = gr.Button("Clear Completed", variant="secondary") + + status_text = gr.Markdown("") + + # Event handlers + add_btn.click( + fn=add_files_to_queue, + inputs=[file_input, model_dropdown, scale_radio, face_enhance_cb], + outputs=[queue_table, status_text], + ) + + refresh_btn.click( + fn=lambda: (get_queue_data(), get_queue_stats()), + outputs=[queue_table, queue_stats], + ) + + clear_btn.click( + fn=clear_completed, + outputs=[queue_table, status_text], + ) + + return { + "queue_table": queue_table, + "queue_stats": queue_stats, + "status_text": status_text, + } diff --git a/src/ui/components/history_tab.py b/src/ui/components/history_tab.py new file mode 100644 index 0000000..7e04a82 --- /dev/null +++ b/src/ui/components/history_tab.py @@ -0,0 +1,208 @@ +"""History gallery tab component.""" + +import logging +from pathlib import Path +from typing import Optional + +import gradio as gr + +from src.storage.history import history_manager, HistoryItem + +logger = logging.getLogger(__name__) + + +def get_history_gallery() -> list[tuple[str, str]]: + """Get history items for gallery display.""" + items = history_manager.get_recent(limit=100) + + gallery_items = [] + for item in items: + # Use output path for display + output_path = Path(item.output_path) + if output_path.exists() and item.type == "image": + caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}" + gallery_items.append((str(output_path), caption)) + + return gallery_items + + +def get_history_stats() -> str: + """Get history statistics.""" + stats = history_manager.get_statistics() + + return ( + f"**History Statistics**\n\n" + f"- Total Items: {stats['total_items']}\n" + f"- Images: {stats['images']}\n" + f"- Videos: {stats['videos']}\n" + f"- Total Processing Time: {stats['total_processing_time']/60:.1f} min\n" + f"- Total Input Size: {stats['total_input_size_mb']:.1f} MB\n" + f"- Total Output Size: {stats['total_output_size_mb']:.1f} MB" + ) + + +def search_history(query: str) -> list[tuple[str, str]]: + """Search history by filename.""" + if not query: + return get_history_gallery() + + items = history_manager.search(query) + + gallery_items = [] + for item in items: + output_path = Path(item.output_path) + if output_path.exists() and item.type == "image": + caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}" + gallery_items.append((str(output_path), caption)) + + return gallery_items + + +def get_item_details(evt: gr.SelectData) -> tuple[Optional[tuple], str]: + """Get details for selected history item.""" + if evt.index is None: + return None, "Select an item to view details" + + items = history_manager.get_recent(limit=100) + + # Filter to only images (same as gallery) + image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()] + + if evt.index >= len(image_items): + return None, "Item not found" + + item = image_items[evt.index] + + # Load images for slider + input_path = Path(item.input_path) + output_path = Path(item.output_path) + + slider_images = None + if input_path.exists() and output_path.exists(): + slider_images = (str(input_path), str(output_path)) + + # Format details + details = ( + f"**{item.output_filename}**\n\n" + f"- **Type:** {item.type.title()}\n" + f"- **Model:** {item.model}\n" + f"- **Scale:** {item.scale}x\n" + f"- **Face Enhancement:** {'Yes' if item.face_enhance else 'No'}\n" + f"- **Input:** {item.input_width}x{item.input_height}\n" + f"- **Output:** {item.output_width}x{item.output_height}\n" + f"- **Processing Time:** {item.processing_time_seconds:.1f}s\n" + f"- **Input Size:** {item.input_size_bytes/1024:.1f} KB\n" + f"- **Output Size:** {item.output_size_bytes/1024:.1f} KB\n" + f"- **Created:** {item.created_at.strftime('%Y-%m-%d %H:%M')}" + ) + + return slider_images, details + + +def delete_selected(evt: gr.SelectData) -> tuple[list, str]: + """Delete selected history item.""" + if evt.index is None: + return get_history_gallery(), "No item selected" + + items = history_manager.get_recent(limit=100) + image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()] + + if evt.index >= len(image_items): + return get_history_gallery(), "Item not found" + + item = image_items[evt.index] + history_manager.delete_item(item.id) + + return get_history_gallery(), f"Deleted: {item.output_filename}" + + +def create_history_tab(): + """Create the history gallery tab component.""" + with gr.Tab("History", id="history-tab"): + gr.Markdown("## Processing History") + + with gr.Row(): + # Search + search_box = gr.Textbox( + label="Search", + placeholder="Search by filename...", + scale=2, + ) + + filter_dropdown = gr.Dropdown( + choices=["All", "Images", "Videos"], + value="All", + label="Filter", + scale=1, + ) + + with gr.Row(): + # Gallery + with gr.Column(scale=2): + gallery = gr.Gallery( + value=get_history_gallery(), + label="History", + columns=4, + object_fit="cover", + height=400, + allow_preview=True, + ) + + # Details panel + with gr.Column(scale=1): + stats_display = gr.Markdown( + get_history_stats(), + elem_classes=["info-card"], + ) + + # Selected item details + with gr.Row(): + with gr.Column(scale=2): + try: + from gradio_imageslider import ImageSlider + + comparison_slider = ImageSlider( + label="Before / After", + type="filepath", + ) + except ImportError: + comparison_slider = gr.Image( + label="Selected Image", + type="filepath", + ) + + with gr.Column(scale=1): + item_details = gr.Markdown( + "Select an item to view details", + elem_classes=["info-card"], + ) + + with gr.Row(): + download_btn = gr.Button("Download", variant="secondary") + delete_btn = gr.Button("Delete", variant="stop") + + status_text = gr.Markdown("") + + # Event handlers + search_box.change( + fn=search_history, + inputs=[search_box], + outputs=[gallery], + ) + + gallery.select( + fn=get_item_details, + outputs=[comparison_slider, item_details], + ) + + delete_btn.click( + fn=lambda: (get_history_gallery(), get_history_stats()), + outputs=[gallery, stats_display], + ) + + return { + "gallery": gallery, + "comparison_slider": comparison_slider, + "item_details": item_details, + "stats_display": stats_display, + } diff --git a/src/ui/components/image_tab.py b/src/ui/components/image_tab.py new file mode 100644 index 0000000..190f40d --- /dev/null +++ b/src/ui/components/image_tab.py @@ -0,0 +1,292 @@ +"""Image upscaling tab component.""" + +import logging +import tempfile +import time +import uuid +from pathlib import Path +from typing import Optional + +import cv2 +import gradio as gr +import numpy as np +from PIL import Image + +from src.config import OUTPUT_DIR, config, get_model_choices +from src.processing.upscaler import upscale_image, save_image, get_output_dimensions + +logger = logging.getLogger(__name__) + + +def process_image( + input_image: Optional[np.ndarray], + model_name: str, + scale: int, + face_enhance: bool, + tile_size: str, + denoise: float, + output_format: str, + progress=gr.Progress(track_tqdm=True), +) -> tuple[Optional[tuple], Optional[str], str]: + """ + Process image for upscaling. + + Args: + input_image: Input image as numpy array + model_name: Model to use + scale: Scale factor (2 or 4) + face_enhance: Enable face enhancement + tile_size: Tile size setting ("Auto", "256", "512", "1024") + denoise: Denoise strength (0-1) + output_format: Output format (PNG, JPG, WebP) + + Returns: + Tuple of (slider_images, output_path, status_message) + """ + if input_image is None: + return None, None, "Please upload an image first." + + try: + start_time = time.time() + + # Parse tile size + tile = 0 if tile_size == "Auto" else int(tile_size) + + # Progress callback for UI + def progress_callback(pct: float, stage: str): + progress(pct, desc=stage) + + progress(0, desc="Starting upscale...") + + # Upscale the image + result = upscale_image( + input_image, + model_name=model_name, + scale=scale, + tile_size=tile, + denoise_strength=denoise, + progress_callback=progress_callback, + ) + + # Apply face enhancement if requested + if face_enhance: + progress(0.7, desc="Enhancing faces...") + try: + from src.processing.face_enhancer import enhance_faces + + result.image = enhance_faces(result.image) + except ImportError: + logger.warning("Face enhancement not available") + except Exception as e: + logger.warning(f"Face enhancement failed: {e}") + + progress(0.85, desc="Saving output...") + + # Generate output filename + output_name = f"upscaled_{uuid.uuid4().hex[:8]}" + output_path = save_image( + result.image, + OUTPUT_DIR / output_name, + format=output_format.lower(), + ) + + # Convert for display (BGR to RGB) + input_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) if input_image.shape[2] == 3 else input_image + output_rgb = cv2.cvtColor(result.image, cv2.COLOR_BGR2RGB) + + progress(1.0, desc="Complete!") + + # Calculate stats + elapsed = time.time() - start_time + input_size = f"{result.input_width}x{result.input_height}" + output_size = f"{result.output_width}x{result.output_height}" + + status = ( + f"Upscaled {input_size} -> {output_size} in {elapsed:.1f}s | " + f"Model: {model_name} | Saved: {output_path.name}" + ) + + # Return images for slider (input, output) + return (input_rgb, output_rgb), str(output_path), status + + except Exception as e: + logger.exception("Image processing failed") + return None, None, f"Error: {str(e)}" + + +def estimate_output( + input_image: Optional[np.ndarray], + model_name: str, + scale: int, +) -> str: + """Estimate output dimensions for display.""" + if input_image is None: + return "Upload an image to see estimated output size" + + try: + height, width = input_image.shape[:2] + out_w, out_h = get_output_dimensions(width, height, model_name, scale) + return f"Input: {width}x{height} -> Output: {out_w}x{out_h}" + except Exception: + return "Unable to estimate output size" + + +def create_image_tab(): + """Create the image upscaling tab component.""" + with gr.Tab("Image", id="image-tab"): + # Status display at top + status_text = gr.Markdown( + "Upload an image to begin upscaling", + elem_classes=["status-text"], + ) + + with gr.Row(): + # Left column - Input + with gr.Column(scale=1): + input_image = gr.Image( + label="Input Image", + type="numpy", + sources=["upload", "clipboard"], + elem_classes=["upload-area"], + ) + + # Dimension estimate + dimension_info = gr.Markdown( + "Upload an image to see estimated output size", + elem_classes=["info-text"], + ) + + # Right column - Output (ImageSlider) + with gr.Column(scale=1): + try: + from gradio_imageslider import ImageSlider + + output_slider = ImageSlider( + label="Before / After", + type="numpy", + show_download_button=True, + ) + except ImportError: + # Fallback if ImageSlider not available + output_slider = gr.Image( + label="Upscaled Result", + type="numpy", + show_download_button=True, + ) + + output_file = gr.File( + label="Download", + visible=False, + ) + + # Processing options + with gr.Accordion("Processing Options", open=True): + with gr.Row(): + model_dropdown = gr.Dropdown( + choices=get_model_choices(), + value=config.default_model, + label="Model", + info="Select upscaling model", + ) + + scale_radio = gr.Radio( + choices=[2, 4], + value=config.default_scale, + label="Scale Factor", + info="Output scale multiplier", + ) + + face_enhance_cb = gr.Checkbox( + value=config.default_face_enhance, + label="Face Enhancement (GFPGAN)", + info="Enhance faces in photos (not for anime)", + ) + + with gr.Row(): + tile_dropdown = gr.Dropdown( + choices=["Auto", "256", "512", "1024"], + value="Auto", + label="Tile Size", + info="Auto recommended for RTX 4090", + ) + + denoise_slider = gr.Slider( + minimum=0, + maximum=1, + value=0.5, + step=0.1, + label="Denoise Strength", + info="Only for realesr-general-x4v3 model", + ) + + format_radio = gr.Radio( + choices=["PNG", "JPG", "WebP"], + value="PNG", + label="Output Format", + ) + + # Action buttons + with gr.Row(): + upscale_btn = gr.Button( + "Upscale Image", + variant="primary", + size="lg", + elem_classes=["primary-action-btn"], + ) + + clear_btn = gr.Button( + "Clear", + variant="secondary", + ) + + # Event handlers + upscale_btn.click( + fn=process_image, + inputs=[ + input_image, + model_dropdown, + scale_radio, + face_enhance_cb, + tile_dropdown, + denoise_slider, + format_radio, + ], + outputs=[output_slider, output_file, status_text], + ) + + # Update dimension estimate when image changes + input_image.change( + fn=estimate_output, + inputs=[input_image, model_dropdown, scale_radio], + outputs=[dimension_info], + ) + + # Also update when model/scale changes + model_dropdown.change( + fn=estimate_output, + inputs=[input_image, model_dropdown, scale_radio], + outputs=[dimension_info], + ) + + scale_radio.change( + fn=estimate_output, + inputs=[input_image, model_dropdown, scale_radio], + outputs=[dimension_info], + ) + + # Clear button + clear_btn.click( + fn=lambda: (None, None, None, "Upload an image to begin upscaling"), + outputs=[input_image, output_slider, output_file, status_text], + ) + + return { + "input_image": input_image, + "output_slider": output_slider, + "output_file": output_file, + "status_text": status_text, + "model_dropdown": model_dropdown, + "scale_radio": scale_radio, + "face_enhance_cb": face_enhance_cb, + "upscale_btn": upscale_btn, + } diff --git a/src/ui/components/video_tab.py b/src/ui/components/video_tab.py new file mode 100644 index 0000000..a465341 --- /dev/null +++ b/src/ui/components/video_tab.py @@ -0,0 +1,376 @@ +"""Video upscaling tab component.""" + +import logging +import shutil +import tempfile +import time +import uuid +from pathlib import Path +from typing import Optional + +import cv2 +import gradio as gr +import numpy as np + +from src.config import OUTPUT_DIR, TEMP_DIR, VIDEO_CODECS, config, get_model_choices +from src.processing.upscaler import upscale_image +from src.video.extractor import VideoExtractor, get_video_info +from src.video.encoder import encode_video +from src.video.audio import extract_audio, has_audio_stream +from src.video.checkpoint import checkpoint_manager + +logger = logging.getLogger(__name__) + + +def get_video_metadata(video_path: Optional[str]) -> str: + """Get video metadata for display.""" + if not video_path: + return "Upload a video to see details" + + try: + metadata = get_video_info(Path(video_path)) + return ( + f"**Resolution:** {metadata.width}x{metadata.height}\n" + f"**Duration:** {metadata.duration_seconds:.1f}s\n" + f"**FPS:** {metadata.fps:.2f}\n" + f"**Frames:** {metadata.total_frames}\n" + f"**Codec:** {metadata.codec}\n" + f"**Audio:** {'Yes' if metadata.has_audio else 'No'}" + ) + except Exception as e: + return f"Error reading video: {e}" + + +def estimate_video_output( + video_path: Optional[str], + model_name: str, + scale: int, +) -> str: + """Estimate output details for video.""" + if not video_path: + return "Upload a video to see estimated output" + + try: + metadata = get_video_info(Path(video_path)) + out_w = metadata.width * scale + out_h = metadata.height * scale + + # Estimate processing time (rough: ~2 fps for 4x on RTX 4090) + est_time = metadata.total_frames / 2 if scale == 4 else metadata.total_frames / 4 + est_minutes = est_time / 60 + + return ( + f"**Output Resolution:** {out_w}x{out_h}\n" + f"**Estimated Time:** ~{est_minutes:.1f} minutes" + ) + except Exception: + return "Unable to estimate" + + +def process_video( + video_path: Optional[str], + model_name: str, + scale: int, + face_enhance: bool, + codec: str, + crf: int, + preset: str, + preserve_audio: bool, + progress=gr.Progress(track_tqdm=True), +) -> tuple[Optional[str], str]: + """ + Process video for upscaling. + + Args: + video_path: Input video path + model_name: Model to use + scale: Scale factor + face_enhance: Enable face enhancement + codec: Output codec + crf: Quality (lower = better) + preset: Encoding preset + preserve_audio: Keep original audio + progress: Gradio progress tracker + + Returns: + Tuple of (output_path, status_message) + """ + if not video_path: + return None, "Please upload a video first." + + try: + start_time = time.time() + job_id = uuid.uuid4().hex[:8] + + progress(0, desc="Analyzing video...") + + # Get video info + video_path = Path(video_path) + metadata = get_video_info(video_path) + + logger.info( + f"Processing video: {video_path.name}, " + f"{metadata.width}x{metadata.height}, " + f"{metadata.total_frames} frames" + ) + + # Create temp directories + temp_dir = TEMP_DIR / f"video_{job_id}" + frames_dir = temp_dir / "frames" + frames_dir.mkdir(parents=True, exist_ok=True) + + # Extract audio if needed + audio_path = None + if preserve_audio and has_audio_stream(video_path): + progress(0.02, desc="Extracting audio...") + audio_path = extract_audio(video_path, temp_dir / "audio") + + # Create checkpoint + checkpoint = checkpoint_manager.create_checkpoint( + job_id=job_id, + input_path=video_path, + total_frames=metadata.total_frames, + frames_dir=frames_dir, + audio_path=audio_path, + config={ + "model_name": model_name, + "scale": scale, + "face_enhance": face_enhance, + "codec": codec, + "crf": crf, + "preset": preset, + }, + ) + + # Process frames + progress(0.05, desc="Processing frames...") + + with VideoExtractor(video_path) as extractor: + total = metadata.total_frames + start_frame = checkpoint.processed_frames + + for frame_idx, frame in extractor.extract_frames(start_frame=start_frame): + # Upscale frame (RGB input) + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + result = upscale_image( + frame_bgr, + model_name=model_name, + scale=scale, + ) + + # Apply face enhancement if enabled + if face_enhance: + try: + from src.processing.face_enhancer import enhance_faces + result.image = enhance_faces(result.image) + except Exception as e: + logger.warning(f"Face enhancement failed on frame {frame_idx}: {e}") + + # Save frame + frame_path = frames_dir / f"frame_{frame_idx:08d}.png" + cv2.imwrite(str(frame_path), result.image) + + # Update checkpoint + checkpoint_manager.update_checkpoint(checkpoint, frame_idx) + + # Update progress + pct = 0.05 + (0.85 * (frame_idx + 1) / total) + progress( + pct, + desc=f"Processing frame {frame_idx + 1}/{total}", + ) + + # Encode output video + progress(0.9, desc=f"Encoding video ({codec})...") + + output_name = f"{video_path.stem}_upscaled_{job_id}" + output_path = OUTPUT_DIR / f"{output_name}.mp4" + + encode_video( + frames_dir=frames_dir, + output_path=output_path, + fps=metadata.fps, + codec=codec, + crf=crf, + preset=preset, + audio_path=audio_path, + ) + + # Cleanup + progress(0.98, desc="Cleaning up...") + checkpoint_manager.delete_checkpoint(job_id) + shutil.rmtree(temp_dir, ignore_errors=True) + + progress(1.0, desc="Complete!") + + # Calculate stats + elapsed = time.time() - start_time + output_size = output_path.stat().st_size / (1024 * 1024) # MB + fps_actual = metadata.total_frames / elapsed + + status = ( + f"Completed in {elapsed/60:.1f} minutes ({fps_actual:.2f} fps)\n" + f"Output: {output_path.name} ({output_size:.1f} MB)" + ) + + return str(output_path), status + + except Exception as e: + logger.exception("Video processing failed") + return None, f"Error: {str(e)}" + + +def create_video_tab(): + """Create the video upscaling tab component.""" + with gr.Tab("Video", id="video-tab"): + # Status display + status_text = gr.Markdown( + "Upload a video to begin processing", + elem_classes=["status-text"], + ) + + with gr.Row(): + # Left column - Input + with gr.Column(scale=1): + input_video = gr.Video( + label="Input Video", + sources=["upload"], + ) + + # Video info display + video_info = gr.Markdown( + "Upload a video to see details", + elem_classes=["info-text"], + ) + + # Right column - Output + with gr.Column(scale=1): + output_video = gr.Video( + label="Upscaled Video", + interactive=False, + ) + + # Output estimate + output_estimate = gr.Markdown( + "Upload a video to see estimated output", + elem_classes=["info-text"], + ) + + # Processing options + with gr.Accordion("Processing Options", open=True): + with gr.Row(): + model_dropdown = gr.Dropdown( + choices=get_model_choices(), + value=config.default_model, + label="Model", + ) + + scale_radio = gr.Radio( + choices=[2, 4], + value=config.default_scale, + label="Scale Factor", + ) + + face_enhance_cb = gr.Checkbox( + value=False, + label="Face Enhancement", + info="Not recommended for anime", + ) + + # Codec options + with gr.Accordion("Output Codec Settings", open=True): + with gr.Row(): + codec_radio = gr.Radio( + choices=list(VIDEO_CODECS.keys()), + value=config.default_video_codec, + label="Codec", + info="H.265 recommended for quality/size balance", + ) + + crf_slider = gr.Slider( + minimum=15, + maximum=35, + value=config.default_video_crf, + step=1, + label="Quality (CRF)", + info="Lower = better quality, larger file", + ) + + preset_dropdown = gr.Dropdown( + choices=["slow", "medium", "fast"], + value=config.default_video_preset, + label="Preset", + info="Slow = best quality", + ) + + with gr.Row(): + preserve_audio_cb = gr.Checkbox( + value=True, + label="Preserve Audio", + info="Keep original audio track", + ) + + # Action buttons + with gr.Row(): + process_btn = gr.Button( + "Start Processing", + variant="primary", + size="lg", + elem_classes=["primary-action-btn"], + ) + + cancel_btn = gr.Button( + "Cancel", + variant="stop", + visible=False, # TODO: implement cancellation + ) + + # Event handlers + process_btn.click( + fn=process_video, + inputs=[ + input_video, + model_dropdown, + scale_radio, + face_enhance_cb, + codec_radio, + crf_slider, + preset_dropdown, + preserve_audio_cb, + ], + outputs=[output_video, status_text], + ) + + # Update video info when video changes + input_video.change( + fn=get_video_metadata, + inputs=[input_video], + outputs=[video_info], + ) + + # Update output estimate + input_video.change( + fn=estimate_video_output, + inputs=[input_video, model_dropdown, scale_radio], + outputs=[output_estimate], + ) + + model_dropdown.change( + fn=estimate_video_output, + inputs=[input_video, model_dropdown, scale_radio], + outputs=[output_estimate], + ) + + scale_radio.change( + fn=estimate_video_output, + inputs=[input_video, model_dropdown, scale_radio], + outputs=[output_estimate], + ) + + return { + "input_video": input_video, + "output_video": output_video, + "status_text": status_text, + "process_btn": process_btn, + } diff --git a/src/ui/handlers/__init__.py b/src/ui/handlers/__init__.py new file mode 100644 index 0000000..c355e96 --- /dev/null +++ b/src/ui/handlers/__init__.py @@ -0,0 +1 @@ +"""Event handlers for UI interactions.""" diff --git a/src/ui/theme.py b/src/ui/theme.py new file mode 100644 index 0000000..7a684c3 --- /dev/null +++ b/src/ui/theme.py @@ -0,0 +1,284 @@ +"""Custom dark theme for Real-ESRGAN Web UI.""" + +import gradio as gr +from gradio.themes import Base, colors, sizes + + +class UpscalerTheme(Base): + """ + Professional dark theme optimized for image/video upscaling UI. + + Features: + - Dark background for better image comparison + - Blue accent colors for actions + - Comfortable spacing for power users + """ + + def __init__(self): + super().__init__( + # Core colors + primary_hue=colors.blue, + secondary_hue=colors.slate, + neutral_hue=colors.slate, + # Sizing + spacing_size=sizes.spacing_md, + radius_size=sizes.radius_md, + text_size=sizes.text_md, + # Fonts + font=( + "Inter", + "ui-sans-serif", + "-apple-system", + "BlinkMacSystemFont", + "Segoe UI", + "Roboto", + "sans-serif", + ), + font_mono=( + "JetBrains Mono", + "ui-monospace", + "SFMono-Regular", + "Menlo", + "monospace", + ), + ) + + # Dark mode overrides + super().set( + # Backgrounds - dark with subtle contrast + body_background_fill="*neutral_950", + body_background_fill_dark="*neutral_950", + block_background_fill="*neutral_900", + block_background_fill_dark="*neutral_900", + # Borders + block_border_width="1px", + block_border_color="*neutral_800", + block_border_color_dark="*neutral_700", + border_color_primary="*primary_600", + # Inputs + input_background_fill="*neutral_800", + input_background_fill_dark="*neutral_800", + input_border_color="*neutral_700", + input_border_color_focus="*primary_500", + input_border_color_focus_dark="*primary_400", + # Buttons - primary (blue) + button_primary_background_fill="*primary_600", + button_primary_background_fill_hover="*primary_500", + button_primary_text_color="white", + button_primary_border_color="*primary_600", + # Buttons - secondary + button_secondary_background_fill="*neutral_700", + button_secondary_background_fill_hover="*neutral_600", + button_secondary_text_color="*neutral_100", + button_secondary_border_color="*neutral_600", + # Buttons - stop (for cancel) + button_cancel_background_fill="*error_600", + button_cancel_background_fill_hover="*error_500", + button_cancel_text_color="white", + # Shadows + block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3), 0 2px 4px -2px rgba(0, 0, 0, 0.2)", + # Text + body_text_color="*neutral_200", + body_text_color_subdued="*neutral_400", + body_text_color_dark="*neutral_200", + # Labels + block_title_text_color="*neutral_100", + block_label_text_color="*neutral_300", + # Sliders + slider_color="*primary_500", + slider_color_dark="*primary_400", + # Checkboxes + checkbox_background_color_selected="*primary_600", + checkbox_background_color_selected_dark="*primary_500", + checkbox_border_color_selected="*primary_600", + # Tables/Dataframes + table_border_color="*neutral_700", + table_even_background_fill="*neutral_850", + table_odd_background_fill="*neutral_900", + # Tabs + tab_selected_text_color="*primary_400", + ) + + +# Custom CSS for additional styling +CUSTOM_CSS = """ +/* Force dark color scheme */ +:root { + color-scheme: dark; +} + +/* Header styling */ +.app-header { + background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%); + padding: 1rem 1.5rem; + border-bottom: 1px solid #334155; + margin-bottom: 1rem; +} + +.app-header h1 { + color: #f1f5f9; + font-size: 1.5rem; + font-weight: 600; + margin: 0; +} + +/* Tab navigation styling */ +.tabs .tab-nav { + background: #1e293b; + border-radius: 8px 8px 0 0; + padding: 0.5rem 0.5rem 0; +} + +.tabs .tab-nav button { + padding: 0.75rem 1.5rem; + font-weight: 500; + border-radius: 8px 8px 0 0; + transition: all 0.2s ease; +} + +.tabs .tab-nav button.selected { + background: #334155; + border-bottom: 2px solid #3b82f6; + color: #60a5fa; +} + +/* Image comparison slider */ +.image-slider-container { + border-radius: 12px; + overflow: hidden; + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4); +} + +/* Progress bar styling */ +.progress-bar { + height: 8px; + border-radius: 4px; + background: #1e293b; + overflow: hidden; +} + +.progress-bar > div { + background: linear-gradient(90deg, #2563eb 0%, #3b82f6 50%, #60a5fa 100%); + transition: width 0.3s ease; +} + +/* Gallery items */ +.gallery-item { + border-radius: 8px; + overflow: hidden; + transition: transform 0.2s ease, box-shadow 0.2s ease; +} + +.gallery-item:hover { + transform: scale(1.02); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.5); +} + +/* Accordion styling */ +.accordion { + border: 1px solid #334155; + border-radius: 8px; + overflow: hidden; +} + +.accordion > .label-wrap { + background: #1e293b; + padding: 0.75rem 1rem; +} + +/* Status bar */ +.status-bar { + position: fixed; + bottom: 0; + left: 0; + right: 0; + background: #0f172a; + border-top: 1px solid #334155; + padding: 0.5rem 1.5rem; + font-size: 0.875rem; + color: #94a3b8; + z-index: 100; + display: flex; + justify-content: space-between; + align-items: center; +} + +.status-bar .status-item { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.status-bar .status-dot { + width: 8px; + height: 8px; + border-radius: 50%; + background: #22c55e; +} + +.status-bar .status-dot.busy { + background: #eab308; +} + +/* Large primary buttons */ +.primary-action-btn { + font-size: 1.125rem !important; + padding: 0.875rem 2rem !important; + font-weight: 600 !important; +} + +/* Upload area */ +.upload-area { + border: 2px dashed #334155; + border-radius: 12px; + transition: border-color 0.2s ease; +} + +.upload-area:hover { + border-color: #3b82f6; +} + +/* Processing info cards */ +.info-card { + background: #1e293b; + border: 1px solid #334155; + border-radius: 8px; + padding: 1rem; +} + +.info-card h4 { + color: #94a3b8; + font-size: 0.75rem; + text-transform: uppercase; + letter-spacing: 0.05em; + margin-bottom: 0.25rem; +} + +.info-card .value { + color: #f1f5f9; + font-size: 1.25rem; + font-weight: 600; +} + +/* Responsive adjustments */ +@media (max-width: 1200px) { + .main-row { + flex-direction: column; + } +} + +/* Add padding at bottom for status bar */ +.gradio-container { + padding-bottom: 60px !important; +} +""" + + +def get_theme(): + """Get the custom theme instance.""" + return UpscalerTheme() + + +def get_css(): + """Get custom CSS for additional styling.""" + return CUSTOM_CSS diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e1ad603 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +"""Utility modules (memory management, validators).""" diff --git a/src/video/__init__.py b/src/video/__init__.py new file mode 100644 index 0000000..d077dcd --- /dev/null +++ b/src/video/__init__.py @@ -0,0 +1 @@ +"""Video processing utilities (extraction, encoding, audio).""" diff --git a/src/video/audio.py b/src/video/audio.py new file mode 100644 index 0000000..39ef5d9 --- /dev/null +++ b/src/video/audio.py @@ -0,0 +1,221 @@ +"""Audio extraction and handling for video processing.""" + +import logging +import subprocess +import shutil +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +class AudioError(Exception): + """Error during audio processing.""" + + pass + + +def check_ffmpeg(): + """Check if FFmpeg is available.""" + if shutil.which("ffmpeg") is None: + raise AudioError( + "FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg" + ) + + +def has_audio_stream(video_path: Path) -> bool: + """ + Check if video has an audio stream. + + Args: + video_path: Path to video file + + Returns: + True if video has audio + """ + check_ffmpeg() + + try: + result = subprocess.run( + [ + "ffprobe", + "-v", "error", + "-select_streams", "a", + "-show_entries", "stream=codec_name", + "-of", "csv=p=0", + str(video_path), + ], + capture_output=True, + text=True, + ) + return bool(result.stdout.strip()) + except Exception: + return False + + +def extract_audio( + video_path: Path, + output_path: Path, + copy_codec: bool = True, +) -> Optional[Path]: + """ + Extract audio stream from video. + + Args: + video_path: Source video path + output_path: Output audio path + copy_codec: If True, copy without re-encoding + + Returns: + Path to extracted audio, or None if no audio + """ + check_ffmpeg() + + if not has_audio_stream(video_path): + logger.info(f"No audio stream in: {video_path}") + return None + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if copy_codec: + # Determine output extension based on source codec + result = subprocess.run( + [ + "ffprobe", + "-v", "error", + "-select_streams", "a:0", + "-show_entries", "stream=codec_name", + "-of", "csv=p=0", + str(video_path), + ], + capture_output=True, + text=True, + ) + + codec = result.stdout.strip().lower() + + # Map codec to extension + ext_map = { + "aac": ".aac", + "mp3": ".mp3", + "opus": ".opus", + "vorbis": ".ogg", + "flac": ".flac", + "pcm_s16le": ".wav", + "pcm_s24le": ".wav", + } + ext = ext_map.get(codec, ".aac") + output_path = output_path.with_suffix(ext) + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-vn", # No video + "-acodec", "copy", + str(output_path), + ] + else: + # Re-encode to AAC + output_path = output_path.with_suffix(".aac") + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-vn", + "-acodec", "aac", + "-b:a", "320k", + str(output_path), + ] + + try: + subprocess.run(cmd, capture_output=True, check=True) + logger.info(f"Extracted audio to: {output_path}") + return output_path + except subprocess.CalledProcessError as e: + logger.error(f"Audio extraction failed: {e.stderr}") + raise AudioError(f"Audio extraction failed: {e.stderr}") from e + + +def mux_audio_video( + video_path: Path, + audio_path: Path, + output_path: Path, + copy_streams: bool = True, +) -> Path: + """ + Combine video and audio streams. + + Args: + video_path: Video-only file + audio_path: Audio file + output_path: Output path + copy_streams: If True, copy without re-encoding + + Returns: + Path to muxed video + """ + check_ffmpeg() + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if copy_streams: + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "copy", + "-map", "0:v:0", + "-map", "1:a:0", + str(output_path), + ] + else: + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "320k", + "-map", "0:v:0", + "-map", "1:a:0", + str(output_path), + ] + + try: + subprocess.run(cmd, capture_output=True, check=True) + logger.info(f"Muxed audio/video to: {output_path}") + return output_path + except subprocess.CalledProcessError as e: + logger.error(f"Muxing failed: {e.stderr}") + raise AudioError(f"Audio/video muxing failed: {e.stderr}") from e + + +def get_audio_duration(audio_path: Path) -> float: + """ + Get duration of audio file in seconds. + + Args: + audio_path: Path to audio file + + Returns: + Duration in seconds + """ + check_ffmpeg() + + try: + result = subprocess.run( + [ + "ffprobe", + "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(audio_path), + ], + capture_output=True, + text=True, + ) + return float(result.stdout.strip()) + except Exception: + return 0.0 diff --git a/src/video/checkpoint.py b/src/video/checkpoint.py new file mode 100644 index 0000000..41d8244 --- /dev/null +++ b/src/video/checkpoint.py @@ -0,0 +1,250 @@ +"""Checkpoint system for resumable video processing.""" + +import hashlib +import json +import logging +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Optional + +from src.config import CHECKPOINTS_DIR + +logger = logging.getLogger(__name__) + + +@dataclass +class VideoCheckpoint: + """Checkpoint data for resumable video processing.""" + + job_id: str + input_path: str + input_hash: str # MD5 of first 1MB for verification + total_frames: int + processed_frames: int + last_processed_frame: int + frames_dir: str + audio_path: Optional[str] + config_json: str + created_at: str + updated_at: str + + @property + def progress_percent(self) -> float: + """Get progress as percentage.""" + if self.total_frames == 0: + return 0.0 + return (self.processed_frames / self.total_frames) * 100 + + @property + def is_complete(self) -> bool: + """Check if processing is complete.""" + return self.processed_frames >= self.total_frames + + +class CheckpointManager: + """ + Manage video processing checkpoints for resume capability. + + Features: + - Save checkpoint every N frames + - Verify input file hasn't changed + - Clean up after completion + """ + + def __init__( + self, + checkpoint_dir: Optional[Path] = None, + save_interval: int = 100, + ): + self.checkpoint_dir = checkpoint_dir or CHECKPOINTS_DIR + self.save_interval = save_interval + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def get_checkpoint_path(self, job_id: str) -> Path: + """Get path to checkpoint file.""" + return self.checkpoint_dir / f"{job_id}.checkpoint.json" + + def create_checkpoint( + self, + job_id: str, + input_path: Path, + total_frames: int, + frames_dir: Path, + audio_path: Optional[Path], + config: dict, + ) -> VideoCheckpoint: + """ + Create a new checkpoint for a video job. + + Args: + job_id: Unique job identifier + input_path: Path to input video + total_frames: Total frames to process + frames_dir: Directory for processed frames + audio_path: Path to extracted audio (if any) + config: Processing configuration dict + + Returns: + New VideoCheckpoint instance + """ + now = datetime.utcnow().isoformat() + + checkpoint = VideoCheckpoint( + job_id=job_id, + input_path=str(input_path), + input_hash=self._compute_file_hash(input_path), + total_frames=total_frames, + processed_frames=0, + last_processed_frame=-1, + frames_dir=str(frames_dir), + audio_path=str(audio_path) if audio_path else None, + config_json=json.dumps(config), + created_at=now, + updated_at=now, + ) + + self._save_checkpoint(checkpoint) + logger.info(f"Created checkpoint for job: {job_id}") + + return checkpoint + + def update_checkpoint( + self, + checkpoint: VideoCheckpoint, + processed_frame: int, + force_save: bool = False, + ) -> None: + """ + Update checkpoint with progress. + + Args: + checkpoint: Checkpoint to update + processed_frame: Frame index that was just processed + force_save: Force save even if not at interval + """ + checkpoint.processed_frames += 1 + checkpoint.last_processed_frame = processed_frame + checkpoint.updated_at = datetime.utcnow().isoformat() + + # Save periodically or when forced + if force_save or checkpoint.processed_frames % self.save_interval == 0: + self._save_checkpoint(checkpoint) + logger.debug( + f"Checkpoint saved: {checkpoint.processed_frames}/{checkpoint.total_frames}" + ) + + def load_checkpoint(self, job_id: str) -> Optional[VideoCheckpoint]: + """ + Load existing checkpoint if available. + + Args: + job_id: Job identifier + + Returns: + VideoCheckpoint if valid, None otherwise + """ + path = self.get_checkpoint_path(job_id) + + if not path.exists(): + return None + + try: + with open(path, "r") as f: + data = json.load(f) + + checkpoint = VideoCheckpoint(**data) + + # Verify input file hasn't changed + if Path(checkpoint.input_path).exists(): + current_hash = self._compute_file_hash(Path(checkpoint.input_path)) + if current_hash != checkpoint.input_hash: + logger.warning("Input file changed since checkpoint, starting fresh") + self.delete_checkpoint(job_id) + return None + + # Verify frames directory exists + if not Path(checkpoint.frames_dir).exists(): + logger.warning("Frames directory missing, starting fresh") + self.delete_checkpoint(job_id) + return None + + logger.info( + f"Loaded checkpoint: {checkpoint.processed_frames}/{checkpoint.total_frames} frames" + ) + return checkpoint + + except Exception as e: + logger.error(f"Failed to load checkpoint: {e}") + self.delete_checkpoint(job_id) + return None + + def delete_checkpoint(self, job_id: str) -> None: + """Delete checkpoint file.""" + path = self.get_checkpoint_path(job_id) + if path.exists(): + path.unlink() + logger.info(f"Deleted checkpoint: {job_id}") + + def _save_checkpoint(self, checkpoint: VideoCheckpoint) -> None: + """Save checkpoint to disk.""" + path = self.get_checkpoint_path(checkpoint.job_id) + with open(path, "w") as f: + json.dump(asdict(checkpoint), f, indent=2) + + def _compute_file_hash( + self, + path: Path, + chunk_size: int = 1024 * 1024, + ) -> str: + """ + Compute MD5 hash of first chunk of file. + + Using only the first 1MB is fast while still detecting + if the file has been modified. + """ + hasher = hashlib.md5() + with open(path, "rb") as f: + chunk = f.read(chunk_size) + hasher.update(chunk) + return hasher.hexdigest() + + def get_all_checkpoints(self) -> list[VideoCheckpoint]: + """Get all existing checkpoints.""" + checkpoints = [] + for path in self.checkpoint_dir.glob("*.checkpoint.json"): + job_id = path.stem.replace(".checkpoint", "") + checkpoint = self.load_checkpoint(job_id) + if checkpoint: + checkpoints.append(checkpoint) + return checkpoints + + def cleanup_old_checkpoints(self, max_age_days: int = 7) -> int: + """ + Clean up old checkpoint files. + + Args: + max_age_days: Delete checkpoints older than this + + Returns: + Number of checkpoints deleted + """ + import time + + deleted = 0 + max_age_seconds = max_age_days * 24 * 60 * 60 + now = time.time() + + for path in self.checkpoint_dir.glob("*.checkpoint.json"): + if now - path.stat().st_mtime > max_age_seconds: + path.unlink() + deleted += 1 + + if deleted: + logger.info(f"Cleaned up {deleted} old checkpoints") + + return deleted + + +# Global checkpoint manager +checkpoint_manager = CheckpointManager() diff --git a/src/video/encoder.py b/src/video/encoder.py new file mode 100644 index 0000000..03a5eaa --- /dev/null +++ b/src/video/encoder.py @@ -0,0 +1,274 @@ +"""Video encoding using FFmpeg.""" + +import logging +import subprocess +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from src.config import VIDEO_CODECS + +logger = logging.getLogger(__name__) + + +@dataclass +class EncodingResult: + """Result of video encoding.""" + + output_path: Path + codec: str + duration_seconds: float + file_size_bytes: int + + +class VideoEncoderError(Exception): + """Error during video encoding.""" + + pass + + +def check_ffmpeg(): + """Check if FFmpeg is available.""" + if shutil.which("ffmpeg") is None: + raise VideoEncoderError( + "FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg" + ) + + +def check_nvenc(): + """Check if NVENC is available.""" + try: + result = subprocess.run( + ["ffmpeg", "-hide_banner", "-encoders"], + capture_output=True, + text=True, + ) + return "h264_nvenc" in result.stdout + except Exception: + return False + + +class VideoEncoder: + """ + Encode video frames using FFmpeg. + + Supports multiple codecs: + - H.264 (libx264 / h264_nvenc) + - H.265 (libx265 / hevc_nvenc) + - AV1 (libsvtav1 / av1_nvenc) + """ + + def __init__(self, use_nvenc: bool = True): + check_ffmpeg() + self.use_nvenc = use_nvenc and check_nvenc() + + if self.use_nvenc: + logger.info("NVENC hardware encoding available") + else: + logger.info("Using software encoding") + + def encode_frames( + self, + frames_dir: Path, + output_path: Path, + fps: float, + codec: str = "H.265", + crf: int = 20, + preset: str = "slow", + audio_path: Optional[Path] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> EncodingResult: + """ + Encode frames to video. + + Args: + frames_dir: Directory containing numbered frame images + output_path: Output video path + fps: Frame rate + codec: Codec name ("H.264", "H.265", "AV1") + crf: Quality (lower = better, 0-51) + preset: Speed preset + audio_path: Optional audio file to mux + width: Optional output width (for scaling) + height: Optional output height (for scaling) + + Returns: + EncodingResult with output info + """ + if codec not in VIDEO_CODECS: + raise VideoEncoderError(f"Unknown codec: {codec}") + + codec_config = VIDEO_CODECS[codec] + + # Build FFmpeg command + cmd = self._build_command( + frames_dir=frames_dir, + output_path=output_path, + fps=fps, + codec_config=codec_config, + crf=crf, + preset=preset, + audio_path=audio_path, + width=width, + height=height, + ) + + logger.info(f"Encoding video: {' '.join(cmd)}") + + # Run FFmpeg + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg error: {e.stderr}") + raise VideoEncoderError(f"Encoding failed: {e.stderr}") from e + + # Get output info + output_path = Path(output_path) + file_size = output_path.stat().st_size if output_path.exists() else 0 + + return EncodingResult( + output_path=output_path, + codec=codec, + duration_seconds=0, # TODO: probe output + file_size_bytes=file_size, + ) + + def _build_command( + self, + frames_dir: Path, + output_path: Path, + fps: float, + codec_config, + crf: int, + preset: str, + audio_path: Optional[Path], + width: Optional[int], + height: Optional[int], + ) -> list[str]: + """Build FFmpeg command.""" + # Frame input pattern (expects frame_%08d.png format) + frame_pattern = str(frames_dir / "frame_%08d.png") + + cmd = [ + "ffmpeg", + "-y", # Overwrite output + "-framerate", str(fps), + "-i", frame_pattern, + ] + + # Add audio input if provided + if audio_path and audio_path.exists(): + cmd.extend(["-i", str(audio_path)]) + + # Video filter for scaling if needed + vf_filters = [] + if width and height: + # Ensure even dimensions + w = width if width % 2 == 0 else width + 1 + h = height if height % 2 == 0 else height + 1 + vf_filters.append(f"scale={w}:{h}:flags=lanczos") + + if vf_filters: + cmd.extend(["-vf", ",".join(vf_filters)]) + + # Codec settings + cmd.extend(self._get_codec_args(codec_config, crf, preset)) + + # Pixel format + cmd.extend(["-pix_fmt", "yuv420p"]) + + # Audio settings + if audio_path and audio_path.exists(): + cmd.extend(["-c:a", "aac", "-b:a", "320k"]) + + # Output + cmd.append(str(output_path)) + + return cmd + + def _get_codec_args(self, codec_config, crf: int, preset: str) -> list[str]: + """Get codec-specific FFmpeg arguments.""" + # Use NVENC if available and configured + if self.use_nvenc and codec_config.nvenc_encoder: + encoder = codec_config.nvenc_encoder + + if "nvenc" in encoder: + return [ + "-c:v", encoder, + "-preset", "p7", # Highest quality NVENC preset + "-tune", "hq", + "-rc", "vbr", + "-cq", str(crf), + "-b:v", "0", + ] + + # Software encoding + encoder = codec_config.encoder + + if encoder == "libx264": + return [ + "-c:v", "libx264", + "-preset", preset, + "-crf", str(crf), + "-tune", "film", + ] + elif encoder == "libx265": + return [ + "-c:v", "libx265", + "-preset", preset, + "-crf", str(crf), + "-x265-params", "aq-mode=3", + ] + elif encoder == "libsvtav1": + return [ + "-c:v", "libsvtav1", + "-preset", preset, + "-crf", str(crf), + "-svtav1-params", "tune=0:film-grain=8", + ] + else: + return ["-c:v", encoder, "-crf", str(crf)] + + +def encode_video( + frames_dir: Path, + output_path: Path, + fps: float, + codec: str = "H.265", + crf: int = 20, + preset: str = "slow", + audio_path: Optional[Path] = None, +) -> Path: + """ + Convenience function to encode video. + + Args: + frames_dir: Directory with numbered frames + output_path: Output video path + fps: Frame rate + codec: Video codec + crf: Quality (lower = better) + preset: Encoding preset + audio_path: Optional audio to mux + + Returns: + Path to encoded video + """ + encoder = VideoEncoder() + result = encoder.encode_frames( + frames_dir=frames_dir, + output_path=output_path, + fps=fps, + codec=codec, + crf=crf, + preset=preset, + audio_path=audio_path, + ) + return result.output_path diff --git a/src/video/extractor.py b/src/video/extractor.py new file mode 100644 index 0000000..3148abc --- /dev/null +++ b/src/video/extractor.py @@ -0,0 +1,201 @@ +"""Video frame extraction using PyAV.""" + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator, Optional + +import av +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class VideoMetadata: + """Video file metadata.""" + + width: int + height: int + fps: float + total_frames: int + duration_seconds: float + codec: str + has_audio: bool + audio_codec: Optional[str] = None + audio_sample_rate: Optional[int] = None + + +class VideoExtractor: + """ + Extract frames from video files using PyAV. + + Features: + - High-performance frame extraction + - Support for various video formats + - Frame-accurate seeking for resume + """ + + def __init__(self, video_path: Path): + self.video_path = Path(video_path) + self.container = None + self.video_stream = None + self._metadata: Optional[VideoMetadata] = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def open(self): + """Open video file.""" + if not self.video_path.exists(): + raise FileNotFoundError(f"Video not found: {self.video_path}") + + self.container = av.open(str(self.video_path)) + self.video_stream = self.container.streams.video[0] + + # Enable threading for faster decoding + self.video_stream.thread_type = "AUTO" + + logger.info(f"Opened video: {self.video_path}") + + def close(self): + """Close video file.""" + if self.container: + self.container.close() + self.container = None + self.video_stream = None + + @property + def metadata(self) -> VideoMetadata: + """Get video metadata.""" + if self._metadata is not None: + return self._metadata + + if self.video_stream is None: + raise RuntimeError("Video not opened") + + video = self.video_stream + + # Calculate total frames + if video.frames > 0: + total_frames = video.frames + elif video.duration and video.time_base: + # Estimate from duration + duration_sec = float(video.duration * video.time_base) + total_frames = int(duration_sec * float(video.average_rate)) + else: + total_frames = 0 + + # Duration + if video.duration and video.time_base: + duration_seconds = float(video.duration * video.time_base) + else: + duration_seconds = 0.0 + + # Check for audio + has_audio = len(self.container.streams.audio) > 0 + audio_codec = None + audio_sample_rate = None + + if has_audio: + audio_stream = self.container.streams.audio[0] + audio_codec = audio_stream.codec.name + audio_sample_rate = audio_stream.sample_rate + + self._metadata = VideoMetadata( + width=video.width, + height=video.height, + fps=float(video.average_rate) if video.average_rate else 30.0, + total_frames=total_frames, + duration_seconds=duration_seconds, + codec=video.codec.name, + has_audio=has_audio, + audio_codec=audio_codec, + audio_sample_rate=audio_sample_rate, + ) + + return self._metadata + + def extract_frames( + self, + start_frame: int = 0, + end_frame: Optional[int] = None, + ) -> Iterator[tuple[int, np.ndarray]]: + """ + Extract frames from video. + + Args: + start_frame: Starting frame index (for resume) + end_frame: Ending frame index (exclusive, None for all) + + Yields: + Tuple of (frame_index, frame_array) where frame_array is RGB + """ + if self.container is None: + raise RuntimeError("Video not opened") + + frame_index = 0 + + # Seek to start position if needed + if start_frame > 0: + self._seek_to_frame(start_frame) + frame_index = start_frame + + for frame in self.container.decode(video=0): + if end_frame is not None and frame_index >= end_frame: + break + + # Convert to RGB numpy array + img = frame.to_ndarray(format="rgb24") + + yield frame_index, img + frame_index += 1 + + def _seek_to_frame(self, frame_index: int): + """Seek to approximate frame position.""" + if self.video_stream is None: + return + + # Calculate timestamp + fps = float(self.video_stream.average_rate) if self.video_stream.average_rate else 30.0 + time_base = self.video_stream.time_base + target_ts = int(frame_index / fps / time_base) + + # Seek to keyframe before target + self.container.seek(target_ts, stream=self.video_stream) + + def get_frame(self, frame_index: int) -> Optional[np.ndarray]: + """Get a specific frame by index.""" + self._seek_to_frame(frame_index) + + for frame in self.container.decode(video=0): + return frame.to_ndarray(format="rgb24") + + return None + + def extract_keyframes(self) -> Iterator[tuple[int, np.ndarray]]: + """Extract only keyframes (for preview generation).""" + if self.container is None: + raise RuntimeError("Video not opened") + + # Set codec to skip non-key frames + self.video_stream.codec_context.skip_frame = "NONKEY" + + frame_index = 0 + for frame in self.container.decode(video=0): + img = frame.to_ndarray(format="rgb24") + yield frame_index, img + frame_index += 1 + + # Reset skip mode + self.video_stream.codec_context.skip_frame = "DEFAULT" + + +def get_video_info(video_path: Path) -> VideoMetadata: + """Quick function to get video metadata.""" + with VideoExtractor(video_path) as extractor: + return extractor.metadata