feat: initial implementation of Real-ESRGAN Web UI
Full-featured Gradio 6.0+ web interface for Real-ESRGAN image/video upscaling, optimized for RTX 4090 (24GB VRAM). Features: - Image upscaling with before/after comparison (ImageSlider) - Video upscaling with progress tracking and checkpoint/resume - Face enhancement via GFPGAN integration - Multiple codecs: H.264, H.265, AV1 (with NVENC support) - Batch processing queue with SQLite persistence - Processing history gallery - Custom dark theme - Auto-download of model weights 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
29
app.py
Normal file
29
app.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Real-ESRGAN Web UI - Main Entry Point
|
||||||
|
|
||||||
|
A full-featured web interface for Real-ESRGAN image and video upscaling,
|
||||||
|
optimized for RTX 4090 with 24GB VRAM.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python app.py
|
||||||
|
|
||||||
|
The server will start at http://localhost:7860
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add src to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from src.ui.app import launch_app
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
launch_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[project]
|
||||||
|
name = "upscale-ui"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Real-ESRGAN Web UI for image and video upscaling"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "W", "I"]
|
||||||
|
ignore = ["E501"]
|
||||||
30
requirements.txt
Normal file
30
requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Real-ESRGAN Web UI Dependencies
|
||||||
|
|
||||||
|
# Gradio UI Framework
|
||||||
|
gradio>=6.0.0
|
||||||
|
gradio-imageslider>=0.0.20
|
||||||
|
|
||||||
|
# PyTorch (CUDA 12.1)
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
|
||||||
|
# Real-ESRGAN and dependencies
|
||||||
|
realesrgan>=0.3.0
|
||||||
|
basicsr>=1.4.2
|
||||||
|
|
||||||
|
# Face Enhancement
|
||||||
|
gfpgan>=1.3.8
|
||||||
|
facexlib>=0.3.0
|
||||||
|
|
||||||
|
# Video Processing
|
||||||
|
av>=10.0.0
|
||||||
|
|
||||||
|
# Image Processing
|
||||||
|
opencv-python>=4.8.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
tqdm>=4.65.0
|
||||||
|
requests>=2.28.0
|
||||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Real-ESRGAN Web UI - Image and Video Upscaling."""
|
||||||
222
src/config.py
Normal file
222
src/config.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Configuration settings for Real-ESRGAN Web UI."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Base paths
|
||||||
|
BASE_DIR = Path(__file__).parent.parent
|
||||||
|
DATA_DIR = BASE_DIR / "data"
|
||||||
|
MODELS_DIR = DATA_DIR / "models"
|
||||||
|
TEMP_DIR = DATA_DIR / "temp"
|
||||||
|
CHECKPOINTS_DIR = DATA_DIR / "checkpoints"
|
||||||
|
OUTPUT_DIR = DATA_DIR / "output"
|
||||||
|
DATABASE_PATH = DATA_DIR / "upscale.db"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
for dir_path in [MODELS_DIR, TEMP_DIR, CHECKPOINTS_DIR, OUTPUT_DIR]:
|
||||||
|
dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Information about a Real-ESRGAN or GFPGAN model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
scale: int
|
||||||
|
filename: str
|
||||||
|
url: str
|
||||||
|
size_mb: float
|
||||||
|
description: str
|
||||||
|
model_type: str = "realesrgan" # "realesrgan" or "gfpgan"
|
||||||
|
netscale: int = 4 # Network scale factor
|
||||||
|
num_block: int = 23 # Number of RRDB blocks (6 for anime models)
|
||||||
|
num_grow_ch: int = 32 # Growth channels
|
||||||
|
|
||||||
|
|
||||||
|
# Model registry with download URLs
|
||||||
|
MODEL_REGISTRY: dict[str, ModelInfo] = {
|
||||||
|
"RealESRGAN_x4plus": ModelInfo(
|
||||||
|
name="RealESRGAN_x4plus",
|
||||||
|
scale=4,
|
||||||
|
filename="RealESRGAN_x4plus.pth",
|
||||||
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
|
size_mb=63.9,
|
||||||
|
description="Best quality for general photos (4x)",
|
||||||
|
num_block=23,
|
||||||
|
),
|
||||||
|
"RealESRGAN_x2plus": ModelInfo(
|
||||||
|
name="RealESRGAN_x2plus",
|
||||||
|
scale=2,
|
||||||
|
filename="RealESRGAN_x2plus.pth",
|
||||||
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
|
size_mb=63.9,
|
||||||
|
description="General photos (2x upscaling)",
|
||||||
|
netscale=2,
|
||||||
|
num_block=23,
|
||||||
|
),
|
||||||
|
"RealESRGAN_x4plus_anime_6B": ModelInfo(
|
||||||
|
name="RealESRGAN_x4plus_anime_6B",
|
||||||
|
scale=4,
|
||||||
|
filename="RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
size_mb=17.0,
|
||||||
|
description="Optimized for anime/illustrations (4x)",
|
||||||
|
num_block=6,
|
||||||
|
),
|
||||||
|
"realesr-animevideov3": ModelInfo(
|
||||||
|
name="realesr-animevideov3",
|
||||||
|
scale=4,
|
||||||
|
filename="realesr-animevideov3.pth",
|
||||||
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||||
|
size_mb=17.0,
|
||||||
|
description="Anime video with temporal consistency (4x)",
|
||||||
|
num_block=6,
|
||||||
|
),
|
||||||
|
"realesr-general-x4v3": ModelInfo(
|
||||||
|
name="realesr-general-x4v3",
|
||||||
|
scale=4,
|
||||||
|
filename="realesr-general-x4v3.pth",
|
||||||
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||||
|
size_mb=63.9,
|
||||||
|
description="General purpose with denoise control (4x)",
|
||||||
|
num_block=23,
|
||||||
|
),
|
||||||
|
"GFPGANv1.4": ModelInfo(
|
||||||
|
name="GFPGANv1.4",
|
||||||
|
scale=1,
|
||||||
|
filename="GFPGANv1.4.pth",
|
||||||
|
url="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth",
|
||||||
|
size_mb=332.0,
|
||||||
|
description="Face enhancement and restoration",
|
||||||
|
model_type="gfpgan",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RTX4090Config:
|
||||||
|
"""RTX 4090 optimized settings (24GB VRAM)."""
|
||||||
|
|
||||||
|
tile_size: int = 0 # 0 = no tiling, use full image
|
||||||
|
tile_pad: int = 10 # Padding for tile boundaries
|
||||||
|
pre_pad: int = 0 # Pre-padding for input
|
||||||
|
half: bool = True # FP16 precision
|
||||||
|
device: str = "cuda"
|
||||||
|
gpu_id: int = 0
|
||||||
|
|
||||||
|
# Thresholds for when to enable tiling
|
||||||
|
max_pixels_no_tile: int = 8294400 # ~4K (3840x2160)
|
||||||
|
|
||||||
|
def should_tile(self, width: int, height: int) -> bool:
|
||||||
|
"""Determine if tiling is needed based on image size."""
|
||||||
|
return width * height > self.max_pixels_no_tile
|
||||||
|
|
||||||
|
def get_tile_size(self, width: int, height: int) -> int:
|
||||||
|
"""Get appropriate tile size for the image."""
|
||||||
|
if not self.should_tile(width, height):
|
||||||
|
return 0 # No tiling
|
||||||
|
# For very large images, use 512px tiles
|
||||||
|
return 512
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoCodecConfig:
|
||||||
|
"""Video encoding configuration."""
|
||||||
|
|
||||||
|
encoder: str
|
||||||
|
crf: int
|
||||||
|
preset: str
|
||||||
|
nvenc_encoder: Optional[str] = None
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# Video codec presets (maximum quality settings)
|
||||||
|
VIDEO_CODECS: dict[str, VideoCodecConfig] = {
|
||||||
|
"H.264": VideoCodecConfig(
|
||||||
|
encoder="libx264",
|
||||||
|
crf=18,
|
||||||
|
preset="slow",
|
||||||
|
nvenc_encoder="h264_nvenc",
|
||||||
|
description="Universal compatibility, excellent quality",
|
||||||
|
),
|
||||||
|
"H.265": VideoCodecConfig(
|
||||||
|
encoder="libx265",
|
||||||
|
crf=20,
|
||||||
|
preset="slow",
|
||||||
|
nvenc_encoder="hevc_nvenc",
|
||||||
|
description="50% smaller files, great quality",
|
||||||
|
),
|
||||||
|
"AV1": VideoCodecConfig(
|
||||||
|
encoder="libsvtav1",
|
||||||
|
crf=23,
|
||||||
|
preset="4",
|
||||||
|
nvenc_encoder="av1_nvenc",
|
||||||
|
description="Best compression, future-proof",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppConfig:
|
||||||
|
"""Main application configuration."""
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
server_name: str = "0.0.0.0"
|
||||||
|
server_port: int = 7860
|
||||||
|
share: bool = False
|
||||||
|
|
||||||
|
# Queue settings
|
||||||
|
max_queue_size: int = 20
|
||||||
|
concurrency_limit: int = 1 # Single job at a time for GPU efficiency
|
||||||
|
|
||||||
|
# Processing defaults
|
||||||
|
default_model: str = "RealESRGAN_x4plus"
|
||||||
|
default_scale: int = 4
|
||||||
|
default_face_enhance: bool = False
|
||||||
|
default_output_format: str = "png"
|
||||||
|
|
||||||
|
# Video defaults
|
||||||
|
default_video_codec: str = "H.265"
|
||||||
|
default_video_crf: int = 20
|
||||||
|
default_video_preset: str = "slow"
|
||||||
|
|
||||||
|
# History settings
|
||||||
|
max_history_items: int = 1000
|
||||||
|
thumbnail_size: tuple[int, int] = (256, 256)
|
||||||
|
|
||||||
|
# Checkpoint settings
|
||||||
|
checkpoint_interval: int = 100 # Save every N frames
|
||||||
|
|
||||||
|
# GPU config
|
||||||
|
gpu: RTX4090Config = field(default_factory=RTX4090Config)
|
||||||
|
|
||||||
|
|
||||||
|
# Global config instance
|
||||||
|
config = AppConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model_name: str) -> Path:
|
||||||
|
"""Get the path to a model file."""
|
||||||
|
if model_name not in MODEL_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown model: {model_name}")
|
||||||
|
return MODELS_DIR / MODEL_REGISTRY[model_name].filename
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_models() -> list[str]:
|
||||||
|
"""Get list of available model names for UI dropdown."""
|
||||||
|
return [
|
||||||
|
name
|
||||||
|
for name, info in MODEL_REGISTRY.items()
|
||||||
|
if info.model_type == "realesrgan"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_choices() -> list[tuple[str, str]]:
|
||||||
|
"""Get model choices for Gradio dropdown (label, value)."""
|
||||||
|
return [
|
||||||
|
(f"{info.description}", name)
|
||||||
|
for name, info in MODEL_REGISTRY.items()
|
||||||
|
if info.model_type == "realesrgan"
|
||||||
|
]
|
||||||
1
src/processing/__init__.py
Normal file
1
src/processing/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Image and video processing modules."""
|
||||||
213
src/processing/face_enhancer.py
Normal file
213
src/processing/face_enhancer.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Face enhancement using GFPGAN."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.config import MODEL_REGISTRY, MODELS_DIR, config
|
||||||
|
from src.processing.models import ensure_model_available
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global GFPGAN instance (cached)
|
||||||
|
_gfpgan_instance = None
|
||||||
|
|
||||||
|
|
||||||
|
class FaceEnhancerError(Exception):
|
||||||
|
"""Error during face enhancement."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_gfpgan(
|
||||||
|
model_name: str = "GFPGANv1.4",
|
||||||
|
upscale: int = 2,
|
||||||
|
bg_upsampler=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get or create GFPGAN instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: GFPGAN model version
|
||||||
|
upscale: Background upscale factor
|
||||||
|
bg_upsampler: Optional background upsampler (RealESRGAN)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GFPGANer instance
|
||||||
|
"""
|
||||||
|
global _gfpgan_instance
|
||||||
|
|
||||||
|
if _gfpgan_instance is not None:
|
||||||
|
return _gfpgan_instance
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
except ImportError:
|
||||||
|
raise FaceEnhancerError(
|
||||||
|
"GFPGAN not installed. Install with: pip install gfpgan"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure model is downloaded
|
||||||
|
model_path = ensure_model_available(model_name)
|
||||||
|
|
||||||
|
# Determine model version from name
|
||||||
|
if "1.4" in model_name:
|
||||||
|
arch = "clean"
|
||||||
|
channel_multiplier = 2
|
||||||
|
elif "1.3" in model_name:
|
||||||
|
arch = "clean"
|
||||||
|
channel_multiplier = 2
|
||||||
|
elif "1.2" in model_name:
|
||||||
|
arch = "clean"
|
||||||
|
channel_multiplier = 2
|
||||||
|
else:
|
||||||
|
arch = "original"
|
||||||
|
channel_multiplier = 1
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
_gfpgan_instance = GFPGANer(
|
||||||
|
model_path=str(model_path),
|
||||||
|
upscale=upscale,
|
||||||
|
arch=arch,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
bg_upsampler=bg_upsampler,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loaded GFPGAN: {model_name}")
|
||||||
|
return _gfpgan_instance
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_faces(
|
||||||
|
image: np.ndarray,
|
||||||
|
model_name: str = "GFPGANv1.4",
|
||||||
|
has_aligned: bool = False,
|
||||||
|
only_center_face: bool = False,
|
||||||
|
paste_back: bool = True,
|
||||||
|
weight: float = 0.5,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Enhance faces in an image using GFPGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image (BGR format, numpy array)
|
||||||
|
model_name: GFPGAN model version to use
|
||||||
|
has_aligned: Whether input is already aligned face
|
||||||
|
only_center_face: Only enhance the center/largest face
|
||||||
|
paste_back: Paste enhanced face back to original image
|
||||||
|
weight: Blending weight (0=original, 1=enhanced)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced image (BGR format, numpy array)
|
||||||
|
"""
|
||||||
|
if image is None:
|
||||||
|
raise FaceEnhancerError("No image provided")
|
||||||
|
|
||||||
|
try:
|
||||||
|
gfpgan = get_gfpgan(model_name)
|
||||||
|
|
||||||
|
# GFPGAN enhance
|
||||||
|
_, _, output = gfpgan.enhance(
|
||||||
|
image,
|
||||||
|
has_aligned=has_aligned,
|
||||||
|
only_center_face=only_center_face,
|
||||||
|
paste_back=paste_back,
|
||||||
|
weight=weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output is None:
|
||||||
|
logger.warning("No faces detected, returning original image")
|
||||||
|
return image
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Face enhancement failed: {e}")
|
||||||
|
raise FaceEnhancerError(f"Face enhancement failed: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def detect_faces(image: np.ndarray) -> int:
|
||||||
|
"""
|
||||||
|
Detect number of faces in image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image (BGR format)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of detected faces
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from facexlib.detection import init_detection_model
|
||||||
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
|
# Use facexlib for detection
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
face_helper = FaceRestoreHelper(
|
||||||
|
upscale_factor=1,
|
||||||
|
face_size=512,
|
||||||
|
crop_ratio=(1, 1),
|
||||||
|
det_model="retinaface_resnet50",
|
||||||
|
save_ext="png",
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
face_helper.read_image(image)
|
||||||
|
face_helper.get_face_landmarks_5(only_center_face=False)
|
||||||
|
|
||||||
|
return len(face_helper.all_landmarks_5)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Face detection failed: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_anime_image(image: np.ndarray) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to detect if image is anime/illustration style.
|
||||||
|
|
||||||
|
This is a simple heuristic - for production use, consider a
|
||||||
|
trained classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image (BGR format)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if likely anime/illustration
|
||||||
|
"""
|
||||||
|
# Convert to grayscale
|
||||||
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Anime tends to have:
|
||||||
|
# 1. More uniform color regions (lower texture variance)
|
||||||
|
# 2. Sharper edges (higher edge density)
|
||||||
|
|
||||||
|
# Calculate Laplacian variance (edge strength)
|
||||||
|
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
|
||||||
|
|
||||||
|
# Calculate local variance (texture)
|
||||||
|
local_var = cv2.blur(gray.astype(float) ** 2, (9, 9)) - cv2.blur(
|
||||||
|
gray.astype(float), (9, 9)
|
||||||
|
) ** 2
|
||||||
|
texture_var = local_var.mean()
|
||||||
|
|
||||||
|
# Anime typically has high edge strength but low texture variance
|
||||||
|
# These thresholds are approximate
|
||||||
|
edge_ratio = laplacian_var / (texture_var + 1e-6)
|
||||||
|
|
||||||
|
# Heuristic: anime has high edge_ratio
|
||||||
|
return edge_ratio > 50
|
||||||
|
|
||||||
|
|
||||||
|
def clear_gfpgan_cache():
|
||||||
|
"""Clear cached GFPGAN model."""
|
||||||
|
global _gfpgan_instance
|
||||||
|
_gfpgan_instance = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("Cleared GFPGAN cache")
|
||||||
240
src/processing/models.py
Normal file
240
src/processing/models.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Model management for Real-ESRGAN and GFPGAN."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.config import MODEL_REGISTRY, MODELS_DIR, ModelInfo, config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDownloadError(Exception):
|
||||||
|
"""Error downloading model weights."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(model_info: ModelInfo, progress_callback=None) -> Path:
|
||||||
|
"""
|
||||||
|
Download model weights from URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_info: Model information from registry
|
||||||
|
progress_callback: Optional callback(downloaded, total) for progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to downloaded model file
|
||||||
|
"""
|
||||||
|
model_path = MODELS_DIR / model_info.filename
|
||||||
|
|
||||||
|
if model_path.exists():
|
||||||
|
logger.info(f"Model already exists: {model_path}")
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
logger.info(f"Downloading {model_info.name} ({model_info.size_mb:.1f} MB)...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(model_info.url, stream=True, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
downloaded = 0
|
||||||
|
|
||||||
|
# Download with progress
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
|
with tqdm(
|
||||||
|
total=total_size,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
desc=model_info.name,
|
||||||
|
) as pbar:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(downloaded, total_size)
|
||||||
|
|
||||||
|
logger.info(f"Downloaded: {model_path}")
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up partial download
|
||||||
|
if model_path.exists():
|
||||||
|
model_path.unlink()
|
||||||
|
raise ModelDownloadError(f"Failed to download {model_info.name}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_model_available(model_name: str) -> Path:
|
||||||
|
"""
|
||||||
|
Ensure model is available, downloading if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of model from registry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to model file
|
||||||
|
"""
|
||||||
|
if model_name not in MODEL_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
|
model_info = MODEL_REGISTRY[model_name]
|
||||||
|
model_path = MODELS_DIR / model_info.filename
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
download_model(model_info)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""
|
||||||
|
Manages Real-ESRGAN model loading and caching.
|
||||||
|
|
||||||
|
Optimized for RTX 4090 with 24GB VRAM:
|
||||||
|
- Keeps frequently used models in memory
|
||||||
|
- LRU eviction when VRAM is constrained
|
||||||
|
- Thread-safe operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_cached_models: int = 3):
|
||||||
|
self.max_cached_models = max_cached_models
|
||||||
|
self._cache: dict[str, "RealESRGANer"] = {}
|
||||||
|
self._last_used: dict[str, float] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
self._device = (
|
||||||
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
props = torch.cuda.get_device_properties(0)
|
||||||
|
logger.info(f"GPU: {props.name}, VRAM: {props.total_memory / (1024**3):.1f}GB")
|
||||||
|
else:
|
||||||
|
logger.warning("CUDA not available, using CPU (this will be slow)")
|
||||||
|
|
||||||
|
def get_upsampler(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
tile_size: int = 0,
|
||||||
|
tile_pad: int = 10,
|
||||||
|
half: bool = True,
|
||||||
|
) -> "RealESRGANer":
|
||||||
|
"""
|
||||||
|
Get or create a RealESRGANer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of model from registry
|
||||||
|
tile_size: Tile size for processing (0 = no tiling)
|
||||||
|
tile_pad: Padding for tile boundaries
|
||||||
|
half: Use FP16 precision
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured RealESRGANer instance
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
cache_key = f"{model_name}_{tile_size}_{tile_pad}_{half}"
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Check cache
|
||||||
|
if cache_key in self._cache:
|
||||||
|
self._last_used[cache_key] = time.time()
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
# Ensure model is downloaded
|
||||||
|
model_path = ensure_model_available(model_name)
|
||||||
|
model_info = MODEL_REGISTRY[model_name]
|
||||||
|
|
||||||
|
# Evict old models if at capacity
|
||||||
|
self._evict_if_needed()
|
||||||
|
|
||||||
|
# Create model architecture
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=model_info.num_block,
|
||||||
|
num_grow_ch=model_info.num_grow_ch,
|
||||||
|
scale=model_info.netscale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create upsampler
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=model_info.netscale,
|
||||||
|
model_path=str(model_path),
|
||||||
|
model=model,
|
||||||
|
tile=tile_size,
|
||||||
|
tile_pad=tile_pad,
|
||||||
|
pre_pad=config.gpu.pre_pad,
|
||||||
|
half=half and self._device.type == "cuda",
|
||||||
|
device=self._device,
|
||||||
|
gpu_id=config.gpu.gpu_id if self._device.type == "cuda" else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache and return
|
||||||
|
self._cache[cache_key] = upsampler
|
||||||
|
self._last_used[cache_key] = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Loaded model: {model_name} (tile={tile_size}, half={half})")
|
||||||
|
return upsampler
|
||||||
|
|
||||||
|
def _evict_if_needed(self):
|
||||||
|
"""Evict least recently used model if at capacity."""
|
||||||
|
while len(self._cache) >= self.max_cached_models:
|
||||||
|
# Find LRU
|
||||||
|
lru_key = min(self._last_used.keys(), key=lambda k: self._last_used[k])
|
||||||
|
|
||||||
|
# Remove from cache
|
||||||
|
del self._cache[lru_key]
|
||||||
|
del self._last_used[lru_key]
|
||||||
|
|
||||||
|
# Clear CUDA cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logger.info(f"Evicted model from cache: {lru_key}")
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""Clear all cached models."""
|
||||||
|
with self._lock:
|
||||||
|
self._cache.clear()
|
||||||
|
self._last_used.clear()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("Cleared model cache")
|
||||||
|
|
||||||
|
def get_vram_status(self) -> dict:
|
||||||
|
"""Get current VRAM usage statistics."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return {"available": False}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"available": True,
|
||||||
|
"allocated_gb": torch.cuda.memory_allocated() / (1024**3),
|
||||||
|
"reserved_gb": torch.cuda.memory_reserved() / (1024**3),
|
||||||
|
"total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
|
||||||
|
"cached_models": list(self._cache.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def preload_model(self, model_name: str):
|
||||||
|
"""Preload a model into cache."""
|
||||||
|
self.get_upsampler(
|
||||||
|
model_name,
|
||||||
|
tile_size=config.gpu.tile_size,
|
||||||
|
tile_pad=config.gpu.tile_pad,
|
||||||
|
half=config.gpu.half,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global model manager instance
|
||||||
|
model_manager = ModelManager()
|
||||||
341
src/processing/upscaler.py
Normal file
341
src/processing/upscaler.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
"""High-level upscaling interface for Real-ESRGAN."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from src.config import MODEL_REGISTRY, config
|
||||||
|
from src.processing.models import model_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpscaleResult:
|
||||||
|
"""Result of an upscaling operation."""
|
||||||
|
|
||||||
|
image: np.ndarray # BGR format
|
||||||
|
input_width: int
|
||||||
|
input_height: int
|
||||||
|
output_width: int
|
||||||
|
output_height: int
|
||||||
|
model_used: str
|
||||||
|
scale_factor: int
|
||||||
|
processing_time: float
|
||||||
|
used_tiling: bool
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerError(Exception):
|
||||||
|
"""Error during upscaling."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tile_size(width: int, height: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculate optimal tile size based on image dimensions.
|
||||||
|
|
||||||
|
For RTX 4090 with 24GB VRAM:
|
||||||
|
- No tiling for images up to ~4K
|
||||||
|
- 512px tiles for larger images
|
||||||
|
"""
|
||||||
|
pixels = width * height
|
||||||
|
|
||||||
|
# Thresholds based on testing with RealESRGAN_x4plus
|
||||||
|
if pixels <= 2073600: # Up to 1080p (1920x1080)
|
||||||
|
return 0 # No tiling
|
||||||
|
elif pixels <= 8294400: # Up to 4K (3840x2160)
|
||||||
|
return 0 # RTX 4090 can handle without tiling
|
||||||
|
elif pixels <= 33177600: # Up to 8K (7680x4320)
|
||||||
|
return 512
|
||||||
|
else:
|
||||||
|
return 256 # Very large images
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(
|
||||||
|
input_path: Union[str, Path, np.ndarray, Image.Image]
|
||||||
|
) -> tuple[np.ndarray, str]:
|
||||||
|
"""
|
||||||
|
Load image from various sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to image, numpy array, or PIL Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (BGR numpy array, source description)
|
||||||
|
"""
|
||||||
|
if isinstance(input_path, np.ndarray):
|
||||||
|
# Assume already BGR if 3 channels
|
||||||
|
if len(input_path.shape) == 3 and input_path.shape[2] == 3:
|
||||||
|
return input_path, "numpy"
|
||||||
|
elif len(input_path.shape) == 3 and input_path.shape[2] == 4:
|
||||||
|
# RGBA to BGR
|
||||||
|
return cv2.cvtColor(input_path, cv2.COLOR_RGBA2BGR), "numpy_rgba"
|
||||||
|
else:
|
||||||
|
raise UpscalerError(f"Unsupported array shape: {input_path.shape}")
|
||||||
|
|
||||||
|
if isinstance(input_path, Image.Image):
|
||||||
|
# PIL to numpy BGR
|
||||||
|
img = np.array(input_path)
|
||||||
|
if img.shape[2] == 4:
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_RGBA2BGR), "pil_rgba"
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR), "pil"
|
||||||
|
|
||||||
|
# Load from file
|
||||||
|
input_path = Path(input_path)
|
||||||
|
if not input_path.exists():
|
||||||
|
raise UpscalerError(f"Image not found: {input_path}")
|
||||||
|
|
||||||
|
img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
||||||
|
if img is None:
|
||||||
|
raise UpscalerError(f"Failed to load image: {input_path}")
|
||||||
|
|
||||||
|
# Handle different channel counts
|
||||||
|
if len(img.shape) == 2:
|
||||||
|
# Grayscale to BGR
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
elif img.shape[2] == 4:
|
||||||
|
# RGBA - keep alpha for later
|
||||||
|
pass # RealESRGAN handles alpha internally
|
||||||
|
|
||||||
|
return img, str(input_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(
|
||||||
|
image: np.ndarray,
|
||||||
|
output_path: Union[str, Path],
|
||||||
|
format: str = "png",
|
||||||
|
quality: int = 95,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Save image to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: BGR numpy array
|
||||||
|
output_path: Output path (extension will be added/replaced)
|
||||||
|
format: Output format (png, jpg, webp)
|
||||||
|
quality: JPEG/WebP quality (0-100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file
|
||||||
|
"""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
# Ensure correct extension
|
||||||
|
format = format.lower()
|
||||||
|
if format == "jpg":
|
||||||
|
format = "jpeg"
|
||||||
|
|
||||||
|
ext = f".{format}" if format != "jpeg" else ".jpg"
|
||||||
|
output_path = output_path.with_suffix(ext)
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save with appropriate settings
|
||||||
|
if format == "png":
|
||||||
|
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_PNG_COMPRESSION, 6])
|
||||||
|
elif format == "jpeg":
|
||||||
|
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||||
|
elif format == "webp":
|
||||||
|
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_WEBP_QUALITY, quality])
|
||||||
|
else:
|
||||||
|
cv2.imwrite(str(output_path), image)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_image(
|
||||||
|
input_image: Union[str, Path, np.ndarray, Image.Image],
|
||||||
|
model_name: str = None,
|
||||||
|
scale: int = None,
|
||||||
|
tile_size: int = None,
|
||||||
|
denoise_strength: float = 0.5,
|
||||||
|
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image: Input image (path, numpy array, or PIL Image)
|
||||||
|
model_name: Model to use (defaults to config)
|
||||||
|
scale: Output scale (2 or 4, defaults to model's native scale)
|
||||||
|
tile_size: Tile size (0 for auto, None for config default)
|
||||||
|
denoise_strength: Denoise strength for supported models (0-1)
|
||||||
|
progress_callback: Optional callback(progress, stage) for progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpscaleResult with upscaled image and metadata
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Use defaults
|
||||||
|
model_name = model_name or config.default_model
|
||||||
|
model_info = MODEL_REGISTRY.get(model_name)
|
||||||
|
if not model_info:
|
||||||
|
raise UpscalerError(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.1, "Loading image...")
|
||||||
|
|
||||||
|
img, source = load_image(input_image)
|
||||||
|
input_height, input_width = img.shape[:2]
|
||||||
|
|
||||||
|
logger.info(f"Input: {input_width}x{input_height} from {source}")
|
||||||
|
|
||||||
|
# Calculate tile size if auto
|
||||||
|
if tile_size is None:
|
||||||
|
tile_size = config.gpu.tile_size
|
||||||
|
if tile_size == 0:
|
||||||
|
# Auto-calculate based on image size
|
||||||
|
tile_size = calculate_tile_size(input_width, input_height)
|
||||||
|
|
||||||
|
used_tiling = tile_size > 0
|
||||||
|
|
||||||
|
# Get upsampler
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.2, "Loading model...")
|
||||||
|
|
||||||
|
upsampler = model_manager.get_upsampler(
|
||||||
|
model_name=model_name,
|
||||||
|
tile_size=tile_size,
|
||||||
|
tile_pad=config.gpu.tile_pad,
|
||||||
|
half=config.gpu.half,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine output scale
|
||||||
|
if scale is None:
|
||||||
|
scale = model_info.scale
|
||||||
|
outscale = scale / model_info.netscale if scale != model_info.netscale else None
|
||||||
|
|
||||||
|
# Process
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.3, "Upscaling...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# RealESRGAN enhance method
|
||||||
|
if model_name == "realesr-general-x4v3":
|
||||||
|
# This model supports denoise strength
|
||||||
|
output, _ = upsampler.enhance(
|
||||||
|
img, outscale=outscale, dnoise_strength=denoise_strength
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output, _ = upsampler.enhance(img, outscale=outscale)
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
# Clear cache and retry with tiling
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.warning("VRAM exceeded, retrying with tiling...")
|
||||||
|
|
||||||
|
upsampler = model_manager.get_upsampler(
|
||||||
|
model_name=model_name,
|
||||||
|
tile_size=512, # Force tiling
|
||||||
|
tile_pad=config.gpu.tile_pad,
|
||||||
|
half=config.gpu.half,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_name == "realesr-general-x4v3":
|
||||||
|
output, _ = upsampler.enhance(
|
||||||
|
img, outscale=outscale, dnoise_strength=denoise_strength
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output, _ = upsampler.enhance(img, outscale=outscale)
|
||||||
|
|
||||||
|
used_tiling = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise UpscalerError(f"Upscaling failed: {e}") from e
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.9, "Finalizing...")
|
||||||
|
|
||||||
|
output_height, output_width = output.shape[:2]
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Output: {output_width}x{output_height}, "
|
||||||
|
f"time: {processing_time:.2f}s, tiling: {used_tiling}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1.0, "Complete!")
|
||||||
|
|
||||||
|
return UpscaleResult(
|
||||||
|
image=output,
|
||||||
|
input_width=input_width,
|
||||||
|
input_height=input_height,
|
||||||
|
output_width=output_width,
|
||||||
|
output_height=output_height,
|
||||||
|
model_used=model_name,
|
||||||
|
scale_factor=scale,
|
||||||
|
processing_time=processing_time,
|
||||||
|
used_tiling=used_tiling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_and_save(
|
||||||
|
input_path: Union[str, Path],
|
||||||
|
output_path: Union[str, Path],
|
||||||
|
model_name: str = None,
|
||||||
|
scale: int = None,
|
||||||
|
output_format: str = "png",
|
||||||
|
denoise_strength: float = 0.5,
|
||||||
|
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
) -> tuple[Path, UpscaleResult]:
|
||||||
|
"""
|
||||||
|
Upscale image and save to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Input image path
|
||||||
|
output_path: Output path (extension may be adjusted)
|
||||||
|
model_name: Model to use
|
||||||
|
scale: Output scale
|
||||||
|
output_format: Output format (png, jpg, webp)
|
||||||
|
denoise_strength: Denoise strength for supported models
|
||||||
|
progress_callback: Optional progress callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (saved path, UpscaleResult)
|
||||||
|
"""
|
||||||
|
result = upscale_image(
|
||||||
|
input_path,
|
||||||
|
model_name=model_name,
|
||||||
|
scale=scale,
|
||||||
|
denoise_strength=denoise_strength,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_path = save_image(result.image, output_path, format=output_format)
|
||||||
|
|
||||||
|
return saved_path, result
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_dimensions(
|
||||||
|
width: int, height: int, model_name: str, scale: int = None
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Calculate output dimensions for given input and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width: Input width
|
||||||
|
height: Input height
|
||||||
|
model_name: Model name
|
||||||
|
scale: Target scale (uses model default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (output_width, output_height)
|
||||||
|
"""
|
||||||
|
model_info = MODEL_REGISTRY.get(model_name)
|
||||||
|
if not model_info:
|
||||||
|
raise ValueError(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
|
scale = scale or model_info.scale
|
||||||
|
return width * scale, height * scale
|
||||||
1
src/storage/__init__.py
Normal file
1
src/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Storage and persistence modules (SQLite, history, queue)."""
|
||||||
141
src/storage/database.py
Normal file
141
src/storage/database.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""SQLite database setup and utilities."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
from src.config import DATABASE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
# Thread-safe connection management
|
||||||
|
_db_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection() -> sqlite3.Connection:
|
||||||
|
"""Get a database connection with proper settings."""
|
||||||
|
conn = sqlite3.connect(str(DATABASE_PATH), timeout=30.0)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
conn.execute("PRAGMA journal_mode = WAL") # Better concurrent access
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db():
|
||||||
|
"""Context manager for database connections."""
|
||||||
|
conn = get_connection()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_database():
|
||||||
|
"""Initialize database schema."""
|
||||||
|
with _db_lock:
|
||||||
|
with get_db() as conn:
|
||||||
|
# History table for processed items
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS history (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
input_path TEXT NOT NULL,
|
||||||
|
input_filename TEXT NOT NULL,
|
||||||
|
output_path TEXT NOT NULL,
|
||||||
|
output_filename TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
scale INTEGER NOT NULL,
|
||||||
|
face_enhance INTEGER NOT NULL DEFAULT 0,
|
||||||
|
video_codec TEXT,
|
||||||
|
processing_time_seconds REAL NOT NULL,
|
||||||
|
input_width INTEGER NOT NULL,
|
||||||
|
input_height INTEGER NOT NULL,
|
||||||
|
output_width INTEGER NOT NULL,
|
||||||
|
output_height INTEGER NOT NULL,
|
||||||
|
frames_count INTEGER,
|
||||||
|
input_size_bytes INTEGER NOT NULL,
|
||||||
|
output_size_bytes INTEGER NOT NULL,
|
||||||
|
thumbnail_path TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
metadata_json TEXT DEFAULT '{}'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Index for efficient queries
|
||||||
|
conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_history_created_at
|
||||||
|
ON history(created_at DESC)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_history_type
|
||||||
|
ON history(type)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Queue table for batch processing
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS queue (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
input_path TEXT NOT NULL,
|
||||||
|
output_path TEXT,
|
||||||
|
config_json TEXT NOT NULL,
|
||||||
|
progress_percent REAL DEFAULT 0,
|
||||||
|
current_stage TEXT DEFAULT '',
|
||||||
|
frames_processed INTEGER DEFAULT 0,
|
||||||
|
total_frames INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
retry_count INTEGER DEFAULT 0,
|
||||||
|
priority INTEGER DEFAULT 5,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
started_at TEXT,
|
||||||
|
completed_at TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_queue_status
|
||||||
|
ON queue(status)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_queue_priority_created
|
||||||
|
ON queue(priority ASC, created_at ASC)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Checkpoints table for video resume
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||||
|
job_id TEXT PRIMARY KEY,
|
||||||
|
input_path TEXT NOT NULL,
|
||||||
|
input_hash TEXT NOT NULL,
|
||||||
|
total_frames INTEGER NOT NULL,
|
||||||
|
processed_frames INTEGER NOT NULL DEFAULT 0,
|
||||||
|
last_processed_frame INTEGER NOT NULL DEFAULT -1,
|
||||||
|
frames_dir TEXT NOT NULL,
|
||||||
|
audio_path TEXT,
|
||||||
|
config_json TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def reset_database():
|
||||||
|
"""Reset database (for development/testing)."""
|
||||||
|
with _db_lock:
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute("DROP TABLE IF EXISTS history")
|
||||||
|
conn.execute("DROP TABLE IF EXISTS queue")
|
||||||
|
conn.execute("DROP TABLE IF EXISTS checkpoints")
|
||||||
|
init_database()
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on import
|
||||||
|
init_database()
|
||||||
282
src/storage/history.py
Normal file
282
src/storage/history.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""Processing history management with SQLite persistence."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from src.storage.database import get_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HistoryItem:
|
||||||
|
"""Record of a processed image/video."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str # "image" or "video"
|
||||||
|
input_path: str
|
||||||
|
input_filename: str
|
||||||
|
output_path: str
|
||||||
|
output_filename: str
|
||||||
|
model: str
|
||||||
|
scale: int
|
||||||
|
face_enhance: bool
|
||||||
|
video_codec: Optional[str]
|
||||||
|
processing_time_seconds: float
|
||||||
|
input_width: int
|
||||||
|
input_height: int
|
||||||
|
output_width: int
|
||||||
|
output_height: int
|
||||||
|
frames_count: Optional[int]
|
||||||
|
input_size_bytes: int
|
||||||
|
output_size_bytes: int
|
||||||
|
thumbnail_path: Optional[str]
|
||||||
|
created_at: datetime
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryManager:
|
||||||
|
"""
|
||||||
|
Manage processing history with SQLite persistence.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Store all processed items
|
||||||
|
- Search and filter
|
||||||
|
- Thumbnails for quick preview
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
logger.info("Initialized history manager")
|
||||||
|
|
||||||
|
def add_item(
|
||||||
|
self,
|
||||||
|
type: str,
|
||||||
|
input_path: str,
|
||||||
|
output_path: str,
|
||||||
|
model: str,
|
||||||
|
scale: int,
|
||||||
|
face_enhance: bool = False,
|
||||||
|
video_codec: Optional[str] = None,
|
||||||
|
processing_time_seconds: float = 0,
|
||||||
|
input_width: int = 0,
|
||||||
|
input_height: int = 0,
|
||||||
|
output_width: int = 0,
|
||||||
|
output_height: int = 0,
|
||||||
|
frames_count: Optional[int] = None,
|
||||||
|
thumbnail_path: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> HistoryItem:
|
||||||
|
"""Add item to history."""
|
||||||
|
item_id = uuid.uuid4().hex[:12]
|
||||||
|
input_path = Path(input_path)
|
||||||
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
# Get file sizes
|
||||||
|
input_size = input_path.stat().st_size if input_path.exists() else 0
|
||||||
|
output_size = output_path.stat().st_size if output_path.exists() else 0
|
||||||
|
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO history (
|
||||||
|
id, type, input_path, input_filename,
|
||||||
|
output_path, output_filename, model, scale,
|
||||||
|
face_enhance, video_codec, processing_time_seconds,
|
||||||
|
input_width, input_height, output_width, output_height,
|
||||||
|
frames_count, input_size_bytes, output_size_bytes,
|
||||||
|
thumbnail_path, created_at, metadata_json
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
item_id,
|
||||||
|
type,
|
||||||
|
str(input_path),
|
||||||
|
input_path.name,
|
||||||
|
str(output_path),
|
||||||
|
output_path.name,
|
||||||
|
model,
|
||||||
|
scale,
|
||||||
|
int(face_enhance),
|
||||||
|
video_codec,
|
||||||
|
processing_time_seconds,
|
||||||
|
input_width,
|
||||||
|
input_height,
|
||||||
|
output_width,
|
||||||
|
output_height,
|
||||||
|
frames_count,
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
thumbnail_path,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
json.dumps(metadata or {}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Added to history: {item_id} ({type})")
|
||||||
|
|
||||||
|
return HistoryItem(
|
||||||
|
id=item_id,
|
||||||
|
type=type,
|
||||||
|
input_path=str(input_path),
|
||||||
|
input_filename=input_path.name,
|
||||||
|
output_path=str(output_path),
|
||||||
|
output_filename=output_path.name,
|
||||||
|
model=model,
|
||||||
|
scale=scale,
|
||||||
|
face_enhance=face_enhance,
|
||||||
|
video_codec=video_codec,
|
||||||
|
processing_time_seconds=processing_time_seconds,
|
||||||
|
input_width=input_width,
|
||||||
|
input_height=input_height,
|
||||||
|
output_width=output_width,
|
||||||
|
output_height=output_height,
|
||||||
|
frames_count=frames_count,
|
||||||
|
input_size_bytes=input_size,
|
||||||
|
output_size_bytes=output_size,
|
||||||
|
thumbnail_path=thumbnail_path,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_item(self, item_id: str) -> Optional[HistoryItem]:
|
||||||
|
"""Get history item by ID."""
|
||||||
|
with get_db() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM history WHERE id = ?",
|
||||||
|
(item_id,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return self._row_to_item(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_recent(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
type_filter: Optional[str] = None,
|
||||||
|
) -> list[HistoryItem]:
|
||||||
|
"""Get recent history items."""
|
||||||
|
with get_db() as conn:
|
||||||
|
if type_filter:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM history
|
||||||
|
WHERE type = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(type_filter, limit, offset),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM history
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(limit, offset),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_item(row) for row in rows]
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[HistoryItem]:
|
||||||
|
"""Search history by filename."""
|
||||||
|
with get_db() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM history
|
||||||
|
WHERE input_filename LIKE ? OR output_filename LIKE ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(f"%{query}%", f"%{query}%", limit),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_item(row) for row in rows]
|
||||||
|
|
||||||
|
def delete_item(self, item_id: str) -> bool:
|
||||||
|
"""Delete history item."""
|
||||||
|
with get_db() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
"DELETE FROM history WHERE id = ?",
|
||||||
|
(item_id,),
|
||||||
|
)
|
||||||
|
deleted = result.rowcount > 0
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
logger.info(f"Deleted from history: {item_id}")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def get_statistics(self) -> dict:
|
||||||
|
"""Get history statistics."""
|
||||||
|
with get_db() as conn:
|
||||||
|
stats = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
SUM(CASE WHEN type = 'image' THEN 1 ELSE 0 END) as images,
|
||||||
|
SUM(CASE WHEN type = 'video' THEN 1 ELSE 0 END) as videos,
|
||||||
|
SUM(processing_time_seconds) as total_time,
|
||||||
|
SUM(input_size_bytes) as total_input_size,
|
||||||
|
SUM(output_size_bytes) as total_output_size,
|
||||||
|
SUM(frames_count) as total_frames
|
||||||
|
FROM history
|
||||||
|
"""
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_items": stats["total"] or 0,
|
||||||
|
"images": stats["images"] or 0,
|
||||||
|
"videos": stats["videos"] or 0,
|
||||||
|
"total_processing_time": stats["total_time"] or 0,
|
||||||
|
"total_input_size_mb": (stats["total_input_size"] or 0) / (1024 * 1024),
|
||||||
|
"total_output_size_mb": (stats["total_output_size"] or 0) / (1024 * 1024),
|
||||||
|
"total_frames_processed": stats["total_frames"] or 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_history(self) -> int:
|
||||||
|
"""Clear all history."""
|
||||||
|
with get_db() as conn:
|
||||||
|
result = conn.execute("DELETE FROM history")
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def _row_to_item(self, row) -> HistoryItem:
|
||||||
|
"""Convert database row to HistoryItem."""
|
||||||
|
return HistoryItem(
|
||||||
|
id=row["id"],
|
||||||
|
type=row["type"],
|
||||||
|
input_path=row["input_path"],
|
||||||
|
input_filename=row["input_filename"],
|
||||||
|
output_path=row["output_path"],
|
||||||
|
output_filename=row["output_filename"],
|
||||||
|
model=row["model"],
|
||||||
|
scale=row["scale"],
|
||||||
|
face_enhance=bool(row["face_enhance"]),
|
||||||
|
video_codec=row["video_codec"],
|
||||||
|
processing_time_seconds=row["processing_time_seconds"],
|
||||||
|
input_width=row["input_width"],
|
||||||
|
input_height=row["input_height"],
|
||||||
|
output_width=row["output_width"],
|
||||||
|
output_height=row["output_height"],
|
||||||
|
frames_count=row["frames_count"],
|
||||||
|
input_size_bytes=row["input_size_bytes"],
|
||||||
|
output_size_bytes=row["output_size_bytes"],
|
||||||
|
thumbnail_path=row["thumbnail_path"],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
metadata=json.loads(row["metadata_json"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global history manager instance
|
||||||
|
history_manager = HistoryManager()
|
||||||
332
src/storage/queue.py
Normal file
332
src/storage/queue.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
"""Job queue management with SQLite persistence."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from src.storage.database import get_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(Enum):
|
||||||
|
"""Job status enumeration."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
QUEUED = "queued"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class JobType(Enum):
|
||||||
|
"""Job type enumeration."""
|
||||||
|
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Job:
|
||||||
|
"""Processing job data."""
|
||||||
|
|
||||||
|
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||||
|
type: JobType = JobType.IMAGE
|
||||||
|
status: JobStatus = JobStatus.PENDING
|
||||||
|
input_path: str = ""
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
config: dict = field(default_factory=dict)
|
||||||
|
progress_percent: float = 0.0
|
||||||
|
current_stage: str = ""
|
||||||
|
frames_processed: int = 0
|
||||||
|
total_frames: int = 0
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
priority: int = 5
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class QueueManager:
|
||||||
|
"""
|
||||||
|
Manage job queue with SQLite persistence.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Priority-based queue ordering
|
||||||
|
- Persistent across restarts
|
||||||
|
- Progress tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
logger.info("Initialized queue manager")
|
||||||
|
|
||||||
|
def add_job(self, job: Job) -> Job:
|
||||||
|
"""Add job to queue."""
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO queue (
|
||||||
|
id, type, status, input_path, output_path,
|
||||||
|
config_json, progress_percent, current_stage,
|
||||||
|
frames_processed, total_frames, error_message,
|
||||||
|
priority, created_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
job.id,
|
||||||
|
job.type.value,
|
||||||
|
JobStatus.QUEUED.value,
|
||||||
|
job.input_path,
|
||||||
|
job.output_path,
|
||||||
|
json.dumps(job.config),
|
||||||
|
job.progress_percent,
|
||||||
|
job.current_stage,
|
||||||
|
job.frames_processed,
|
||||||
|
job.total_frames,
|
||||||
|
job.error_message,
|
||||||
|
job.priority,
|
||||||
|
job.created_at.isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
job.status = JobStatus.QUEUED
|
||||||
|
logger.info(f"Added job to queue: {job.id}")
|
||||||
|
return job
|
||||||
|
|
||||||
|
def get_next_job(self) -> Optional[Job]:
|
||||||
|
"""Get next job to process (highest priority, oldest first)."""
|
||||||
|
with get_db() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM queue
|
||||||
|
WHERE status = ?
|
||||||
|
ORDER BY priority ASC, created_at ASC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(JobStatus.QUEUED.value,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mark as processing
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE queue SET status = ?, started_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.PROCESSING.value,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
row["id"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._row_to_job(row)
|
||||||
|
|
||||||
|
def update_progress(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
progress_percent: float,
|
||||||
|
current_stage: str = "",
|
||||||
|
frames_processed: int = 0,
|
||||||
|
total_frames: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Update job progress."""
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE queue SET
|
||||||
|
progress_percent = ?,
|
||||||
|
current_stage = ?,
|
||||||
|
frames_processed = ?,
|
||||||
|
total_frames = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(progress_percent, current_stage, frames_processed, total_frames, job_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
def complete_job(self, job_id: str, output_path: str) -> None:
|
||||||
|
"""Mark job as completed."""
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE queue SET
|
||||||
|
status = ?,
|
||||||
|
output_path = ?,
|
||||||
|
completed_at = ?,
|
||||||
|
progress_percent = 100
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.COMPLETED.value,
|
||||||
|
output_path,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
job_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(f"Job completed: {job_id}")
|
||||||
|
|
||||||
|
def fail_job(self, job_id: str, error_message: str) -> None:
|
||||||
|
"""Mark job as failed."""
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE queue SET
|
||||||
|
status = ?,
|
||||||
|
error_message = ?,
|
||||||
|
completed_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.FAILED.value,
|
||||||
|
error_message,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
job_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.error(f"Job failed: {job_id} - {error_message}")
|
||||||
|
|
||||||
|
def cancel_job(self, job_id: str) -> None:
|
||||||
|
"""Cancel a job."""
|
||||||
|
with get_db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE queue SET status = ?
|
||||||
|
WHERE id = ? AND status IN (?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.CANCELLED.value,
|
||||||
|
job_id,
|
||||||
|
JobStatus.PENDING.value,
|
||||||
|
JobStatus.QUEUED.value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(f"Job cancelled: {job_id}")
|
||||||
|
|
||||||
|
def get_job(self, job_id: str) -> Optional[Job]:
|
||||||
|
"""Get job by ID."""
|
||||||
|
with get_db() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM queue WHERE id = ?",
|
||||||
|
(job_id,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return self._row_to_job(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_queue(self) -> list[Job]:
|
||||||
|
"""Get all queued/processing jobs."""
|
||||||
|
with get_db() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM queue
|
||||||
|
WHERE status IN (?, ?)
|
||||||
|
ORDER BY priority ASC, created_at ASC
|
||||||
|
""",
|
||||||
|
(JobStatus.QUEUED.value, JobStatus.PROCESSING.value),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_job(row) for row in rows]
|
||||||
|
|
||||||
|
def get_completed_jobs(self, limit: int = 50) -> list[Job]:
|
||||||
|
"""Get recently completed jobs."""
|
||||||
|
with get_db() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM queue
|
||||||
|
WHERE status IN (?, ?, ?)
|
||||||
|
ORDER BY completed_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.COMPLETED.value,
|
||||||
|
JobStatus.FAILED.value,
|
||||||
|
JobStatus.CANCELLED.value,
|
||||||
|
limit,
|
||||||
|
),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_job(row) for row in rows]
|
||||||
|
|
||||||
|
def get_queue_stats(self) -> dict:
|
||||||
|
"""Get queue statistics."""
|
||||||
|
with get_db() as conn:
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
# Count by status
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT status, COUNT(*) as count
|
||||||
|
FROM queue
|
||||||
|
GROUP BY status
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
stats[row["status"]] = row["count"]
|
||||||
|
|
||||||
|
# Currently processing
|
||||||
|
processing = conn.execute(
|
||||||
|
"SELECT * FROM queue WHERE status = ?",
|
||||||
|
(JobStatus.PROCESSING.value,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
stats["current_job"] = self._row_to_job(processing) if processing else None
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def clear_completed(self) -> int:
|
||||||
|
"""Clear completed/failed/cancelled jobs."""
|
||||||
|
with get_db() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM queue
|
||||||
|
WHERE status IN (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
JobStatus.COMPLETED.value,
|
||||||
|
JobStatus.FAILED.value,
|
||||||
|
JobStatus.CANCELLED.value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def _row_to_job(self, row) -> Job:
|
||||||
|
"""Convert database row to Job object."""
|
||||||
|
return Job(
|
||||||
|
id=row["id"],
|
||||||
|
type=JobType(row["type"]),
|
||||||
|
status=JobStatus(row["status"]),
|
||||||
|
input_path=row["input_path"],
|
||||||
|
output_path=row["output_path"],
|
||||||
|
config=json.loads(row["config_json"]),
|
||||||
|
progress_percent=row["progress_percent"],
|
||||||
|
current_stage=row["current_stage"],
|
||||||
|
frames_processed=row["frames_processed"],
|
||||||
|
total_frames=row["total_frames"],
|
||||||
|
error_message=row["error_message"],
|
||||||
|
priority=row["priority"],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
started_at=(
|
||||||
|
datetime.fromisoformat(row["started_at"])
|
||||||
|
if row["started_at"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
completed_at=(
|
||||||
|
datetime.fromisoformat(row["completed_at"])
|
||||||
|
if row["completed_at"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global queue manager instance
|
||||||
|
queue_manager = QueueManager()
|
||||||
1
src/ui/__init__.py
Normal file
1
src/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""UI components and handlers for the Gradio interface."""
|
||||||
111
src/ui/app.py
Normal file
111
src/ui/app.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Main Gradio application assembly."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.ui.theme import get_theme, get_css
|
||||||
|
from src.ui.components.image_tab import create_image_tab
|
||||||
|
from src.ui.components.video_tab import create_video_tab
|
||||||
|
from src.ui.components.batch_tab import create_batch_tab
|
||||||
|
from src.ui.components.history_tab import create_history_tab
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_status() -> str:
|
||||||
|
"""Get GPU status string for display."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return "GPU: Not available (CPU mode)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
props = torch.cuda.get_device_properties(0)
|
||||||
|
allocated = torch.cuda.memory_allocated() / (1024**3)
|
||||||
|
total = props.total_memory / (1024**3)
|
||||||
|
return f"GPU: {props.name} | VRAM: {allocated:.1f}/{total:.1f} GB"
|
||||||
|
except Exception:
|
||||||
|
return "GPU: Status unavailable"
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> gr.Blocks:
|
||||||
|
"""Create and return the main Gradio application."""
|
||||||
|
theme = get_theme()
|
||||||
|
css = get_css()
|
||||||
|
|
||||||
|
with gr.Blocks(
|
||||||
|
theme=theme,
|
||||||
|
css=css,
|
||||||
|
title="Real-ESRGAN Upscaler",
|
||||||
|
fill_width=True,
|
||||||
|
) as app:
|
||||||
|
# Header
|
||||||
|
gr.HTML(
|
||||||
|
"""
|
||||||
|
<div class="app-header">
|
||||||
|
<h1>Real-ESRGAN Upscaler</h1>
|
||||||
|
</div>
|
||||||
|
""",
|
||||||
|
elem_classes=["header"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main tabs
|
||||||
|
with gr.Tabs(elem_id="main-tabs") as tabs:
|
||||||
|
# Image upscaling tab
|
||||||
|
image_components = create_image_tab()
|
||||||
|
|
||||||
|
# Video upscaling tab
|
||||||
|
video_components = create_video_tab()
|
||||||
|
|
||||||
|
# Batch queue tab
|
||||||
|
batch_components = create_batch_tab()
|
||||||
|
|
||||||
|
# History tab
|
||||||
|
history_components = create_history_tab()
|
||||||
|
|
||||||
|
# Status bar
|
||||||
|
gr.HTML(
|
||||||
|
f"""
|
||||||
|
<div class="status-bar">
|
||||||
|
<div class="status-item">
|
||||||
|
<span class="status-dot"></span>
|
||||||
|
<span>{get_gpu_status()}</span>
|
||||||
|
</div>
|
||||||
|
<div class="status-item">
|
||||||
|
Queue: 0 | Ready
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
""",
|
||||||
|
elem_classes=["status-bar-container"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def launch_app():
|
||||||
|
"""Launch the Gradio application."""
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Starting Real-ESRGAN Web UI...")
|
||||||
|
logger.info(get_gpu_status())
|
||||||
|
|
||||||
|
# Create and launch app
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
# Configure queue for GPU processing
|
||||||
|
app.queue(
|
||||||
|
default_concurrency_limit=config.concurrency_limit,
|
||||||
|
max_size=config.max_queue_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch server
|
||||||
|
app.launch(
|
||||||
|
server_name=config.server_name,
|
||||||
|
server_port=config.server_port,
|
||||||
|
share=config.share,
|
||||||
|
show_error=True,
|
||||||
|
)
|
||||||
1
src/ui/components/__init__.py
Normal file
1
src/ui/components/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Gradio UI components for each tab."""
|
||||||
200
src/ui/components/batch_tab.py
Normal file
200
src/ui/components/batch_tab.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Batch processing queue tab component."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from src.config import get_model_choices, config
|
||||||
|
from src.storage.queue import queue_manager, Job, JobType, JobStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_queue_data() -> list[list]:
|
||||||
|
"""Get queue data for dataframe display."""
|
||||||
|
jobs = queue_manager.get_queue()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for job in jobs:
|
||||||
|
status_emoji = {
|
||||||
|
JobStatus.QUEUED: "⏳",
|
||||||
|
JobStatus.PROCESSING: "🔄",
|
||||||
|
JobStatus.COMPLETED: "✅",
|
||||||
|
JobStatus.FAILED: "❌",
|
||||||
|
JobStatus.CANCELLED: "🚫",
|
||||||
|
}.get(job.status, "❓")
|
||||||
|
|
||||||
|
progress = f"{job.progress_percent:.0f}%" if job.status == JobStatus.PROCESSING else "-"
|
||||||
|
|
||||||
|
data.append([
|
||||||
|
job.id[:8],
|
||||||
|
job.type.value.title(),
|
||||||
|
Path(job.input_path).name[:30],
|
||||||
|
f"{status_emoji} {job.status.value.title()}",
|
||||||
|
progress,
|
||||||
|
job.current_stage or "-",
|
||||||
|
])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_queue_stats() -> str:
|
||||||
|
"""Get queue statistics for display."""
|
||||||
|
stats = queue_manager.get_queue_stats()
|
||||||
|
|
||||||
|
queued = stats.get(JobStatus.QUEUED.value, 0)
|
||||||
|
processing = stats.get(JobStatus.PROCESSING.value, 0)
|
||||||
|
completed = stats.get(JobStatus.COMPLETED.value, 0)
|
||||||
|
failed = stats.get(JobStatus.FAILED.value, 0)
|
||||||
|
|
||||||
|
current = stats.get("current_job")
|
||||||
|
current_info = ""
|
||||||
|
if current:
|
||||||
|
current_info = f"\n\n**Currently Processing:**\n{Path(current.input_path).name}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"**Queue Status**\n\n"
|
||||||
|
f"- Queued: {queued}\n"
|
||||||
|
f"- Processing: {processing}\n"
|
||||||
|
f"- Completed: {completed}\n"
|
||||||
|
f"- Failed: {failed}"
|
||||||
|
f"{current_info}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_files_to_queue(
|
||||||
|
files: Optional[list],
|
||||||
|
model_name: str,
|
||||||
|
scale: int,
|
||||||
|
face_enhance: bool,
|
||||||
|
) -> tuple[list[list], str]:
|
||||||
|
"""Add uploaded files to queue."""
|
||||||
|
if not files:
|
||||||
|
return get_queue_data(), "No files selected"
|
||||||
|
|
||||||
|
added = 0
|
||||||
|
for file in files:
|
||||||
|
file_path = Path(file.name if hasattr(file, 'name') else file)
|
||||||
|
|
||||||
|
# Determine job type
|
||||||
|
ext = file_path.suffix.lower()
|
||||||
|
if ext in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]:
|
||||||
|
job_type = JobType.IMAGE
|
||||||
|
elif ext in [".mp4", ".avi", ".mov", ".mkv", ".webm"]:
|
||||||
|
job_type = JobType.VIDEO
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
job = Job(
|
||||||
|
type=job_type,
|
||||||
|
input_path=str(file_path),
|
||||||
|
config={
|
||||||
|
"model_name": model_name,
|
||||||
|
"scale": scale,
|
||||||
|
"face_enhance": face_enhance,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_manager.add_job(job)
|
||||||
|
added += 1
|
||||||
|
|
||||||
|
return get_queue_data(), f"Added {added} items to queue"
|
||||||
|
|
||||||
|
|
||||||
|
def cancel_job(job_id: str) -> tuple[list[list], str]:
|
||||||
|
"""Cancel a job."""
|
||||||
|
if job_id:
|
||||||
|
queue_manager.cancel_job(job_id)
|
||||||
|
return get_queue_data(), f"Cancelled job: {job_id}"
|
||||||
|
return get_queue_data(), "No job selected"
|
||||||
|
|
||||||
|
|
||||||
|
def clear_completed() -> tuple[list[list], str]:
|
||||||
|
"""Clear completed jobs from queue."""
|
||||||
|
count = queue_manager.clear_completed()
|
||||||
|
return get_queue_data(), f"Cleared {count} completed jobs"
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch_tab():
|
||||||
|
"""Create the batch queue tab component."""
|
||||||
|
with gr.Tab("Batch Queue", id="queue-tab"):
|
||||||
|
gr.Markdown("## Batch Processing Queue")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Queue table
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
queue_table = gr.Dataframe(
|
||||||
|
headers=["ID", "Type", "Input", "Status", "Progress", "Stage"],
|
||||||
|
datatype=["str", "str", "str", "str", "str", "str"],
|
||||||
|
value=get_queue_data(),
|
||||||
|
label="Queue",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Queue stats
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
queue_stats = gr.Markdown(
|
||||||
|
get_queue_stats(),
|
||||||
|
elem_classes=["info-card"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue section
|
||||||
|
with gr.Accordion("Add to Queue", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
file_input = gr.File(
|
||||||
|
label="Select Files",
|
||||||
|
file_count="multiple",
|
||||||
|
file_types=["image", "video"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
model_dropdown = gr.Dropdown(
|
||||||
|
choices=get_model_choices(),
|
||||||
|
value=config.default_model,
|
||||||
|
label="Model",
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_radio = gr.Radio(
|
||||||
|
choices=[2, 4],
|
||||||
|
value=config.default_scale,
|
||||||
|
label="Scale",
|
||||||
|
)
|
||||||
|
|
||||||
|
face_enhance_cb = gr.Checkbox(
|
||||||
|
value=False,
|
||||||
|
label="Face Enhancement",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_btn = gr.Button("Add to Queue", variant="primary")
|
||||||
|
|
||||||
|
# Queue actions
|
||||||
|
with gr.Row():
|
||||||
|
refresh_btn = gr.Button("Refresh", variant="secondary")
|
||||||
|
clear_btn = gr.Button("Clear Completed", variant="secondary")
|
||||||
|
|
||||||
|
status_text = gr.Markdown("")
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
add_btn.click(
|
||||||
|
fn=add_files_to_queue,
|
||||||
|
inputs=[file_input, model_dropdown, scale_radio, face_enhance_cb],
|
||||||
|
outputs=[queue_table, status_text],
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_btn.click(
|
||||||
|
fn=lambda: (get_queue_data(), get_queue_stats()),
|
||||||
|
outputs=[queue_table, queue_stats],
|
||||||
|
)
|
||||||
|
|
||||||
|
clear_btn.click(
|
||||||
|
fn=clear_completed,
|
||||||
|
outputs=[queue_table, status_text],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"queue_table": queue_table,
|
||||||
|
"queue_stats": queue_stats,
|
||||||
|
"status_text": status_text,
|
||||||
|
}
|
||||||
208
src/ui/components/history_tab.py
Normal file
208
src/ui/components/history_tab.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""History gallery tab component."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from src.storage.history import history_manager, HistoryItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_history_gallery() -> list[tuple[str, str]]:
|
||||||
|
"""Get history items for gallery display."""
|
||||||
|
items = history_manager.get_recent(limit=100)
|
||||||
|
|
||||||
|
gallery_items = []
|
||||||
|
for item in items:
|
||||||
|
# Use output path for display
|
||||||
|
output_path = Path(item.output_path)
|
||||||
|
if output_path.exists() and item.type == "image":
|
||||||
|
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||||
|
gallery_items.append((str(output_path), caption))
|
||||||
|
|
||||||
|
return gallery_items
|
||||||
|
|
||||||
|
|
||||||
|
def get_history_stats() -> str:
|
||||||
|
"""Get history statistics."""
|
||||||
|
stats = history_manager.get_statistics()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"**History Statistics**\n\n"
|
||||||
|
f"- Total Items: {stats['total_items']}\n"
|
||||||
|
f"- Images: {stats['images']}\n"
|
||||||
|
f"- Videos: {stats['videos']}\n"
|
||||||
|
f"- Total Processing Time: {stats['total_processing_time']/60:.1f} min\n"
|
||||||
|
f"- Total Input Size: {stats['total_input_size_mb']:.1f} MB\n"
|
||||||
|
f"- Total Output Size: {stats['total_output_size_mb']:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def search_history(query: str) -> list[tuple[str, str]]:
|
||||||
|
"""Search history by filename."""
|
||||||
|
if not query:
|
||||||
|
return get_history_gallery()
|
||||||
|
|
||||||
|
items = history_manager.search(query)
|
||||||
|
|
||||||
|
gallery_items = []
|
||||||
|
for item in items:
|
||||||
|
output_path = Path(item.output_path)
|
||||||
|
if output_path.exists() and item.type == "image":
|
||||||
|
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||||
|
gallery_items.append((str(output_path), caption))
|
||||||
|
|
||||||
|
return gallery_items
|
||||||
|
|
||||||
|
|
||||||
|
def get_item_details(evt: gr.SelectData) -> tuple[Optional[tuple], str]:
|
||||||
|
"""Get details for selected history item."""
|
||||||
|
if evt.index is None:
|
||||||
|
return None, "Select an item to view details"
|
||||||
|
|
||||||
|
items = history_manager.get_recent(limit=100)
|
||||||
|
|
||||||
|
# Filter to only images (same as gallery)
|
||||||
|
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||||
|
|
||||||
|
if evt.index >= len(image_items):
|
||||||
|
return None, "Item not found"
|
||||||
|
|
||||||
|
item = image_items[evt.index]
|
||||||
|
|
||||||
|
# Load images for slider
|
||||||
|
input_path = Path(item.input_path)
|
||||||
|
output_path = Path(item.output_path)
|
||||||
|
|
||||||
|
slider_images = None
|
||||||
|
if input_path.exists() and output_path.exists():
|
||||||
|
slider_images = (str(input_path), str(output_path))
|
||||||
|
|
||||||
|
# Format details
|
||||||
|
details = (
|
||||||
|
f"**{item.output_filename}**\n\n"
|
||||||
|
f"- **Type:** {item.type.title()}\n"
|
||||||
|
f"- **Model:** {item.model}\n"
|
||||||
|
f"- **Scale:** {item.scale}x\n"
|
||||||
|
f"- **Face Enhancement:** {'Yes' if item.face_enhance else 'No'}\n"
|
||||||
|
f"- **Input:** {item.input_width}x{item.input_height}\n"
|
||||||
|
f"- **Output:** {item.output_width}x{item.output_height}\n"
|
||||||
|
f"- **Processing Time:** {item.processing_time_seconds:.1f}s\n"
|
||||||
|
f"- **Input Size:** {item.input_size_bytes/1024:.1f} KB\n"
|
||||||
|
f"- **Output Size:** {item.output_size_bytes/1024:.1f} KB\n"
|
||||||
|
f"- **Created:** {item.created_at.strftime('%Y-%m-%d %H:%M')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return slider_images, details
|
||||||
|
|
||||||
|
|
||||||
|
def delete_selected(evt: gr.SelectData) -> tuple[list, str]:
|
||||||
|
"""Delete selected history item."""
|
||||||
|
if evt.index is None:
|
||||||
|
return get_history_gallery(), "No item selected"
|
||||||
|
|
||||||
|
items = history_manager.get_recent(limit=100)
|
||||||
|
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||||
|
|
||||||
|
if evt.index >= len(image_items):
|
||||||
|
return get_history_gallery(), "Item not found"
|
||||||
|
|
||||||
|
item = image_items[evt.index]
|
||||||
|
history_manager.delete_item(item.id)
|
||||||
|
|
||||||
|
return get_history_gallery(), f"Deleted: {item.output_filename}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_history_tab():
|
||||||
|
"""Create the history gallery tab component."""
|
||||||
|
with gr.Tab("History", id="history-tab"):
|
||||||
|
gr.Markdown("## Processing History")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Search
|
||||||
|
search_box = gr.Textbox(
|
||||||
|
label="Search",
|
||||||
|
placeholder="Search by filename...",
|
||||||
|
scale=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_dropdown = gr.Dropdown(
|
||||||
|
choices=["All", "Images", "Videos"],
|
||||||
|
value="All",
|
||||||
|
label="Filter",
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Gallery
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
gallery = gr.Gallery(
|
||||||
|
value=get_history_gallery(),
|
||||||
|
label="History",
|
||||||
|
columns=4,
|
||||||
|
object_fit="cover",
|
||||||
|
height=400,
|
||||||
|
allow_preview=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Details panel
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
stats_display = gr.Markdown(
|
||||||
|
get_history_stats(),
|
||||||
|
elem_classes=["info-card"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Selected item details
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
try:
|
||||||
|
from gradio_imageslider import ImageSlider
|
||||||
|
|
||||||
|
comparison_slider = ImageSlider(
|
||||||
|
label="Before / After",
|
||||||
|
type="filepath",
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
comparison_slider = gr.Image(
|
||||||
|
label="Selected Image",
|
||||||
|
type="filepath",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
item_details = gr.Markdown(
|
||||||
|
"Select an item to view details",
|
||||||
|
elem_classes=["info-card"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
download_btn = gr.Button("Download", variant="secondary")
|
||||||
|
delete_btn = gr.Button("Delete", variant="stop")
|
||||||
|
|
||||||
|
status_text = gr.Markdown("")
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
search_box.change(
|
||||||
|
fn=search_history,
|
||||||
|
inputs=[search_box],
|
||||||
|
outputs=[gallery],
|
||||||
|
)
|
||||||
|
|
||||||
|
gallery.select(
|
||||||
|
fn=get_item_details,
|
||||||
|
outputs=[comparison_slider, item_details],
|
||||||
|
)
|
||||||
|
|
||||||
|
delete_btn.click(
|
||||||
|
fn=lambda: (get_history_gallery(), get_history_stats()),
|
||||||
|
outputs=[gallery, stats_display],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gallery": gallery,
|
||||||
|
"comparison_slider": comparison_slider,
|
||||||
|
"item_details": item_details,
|
||||||
|
"stats_display": stats_display,
|
||||||
|
}
|
||||||
292
src/ui/components/image_tab.py
Normal file
292
src/ui/components/image_tab.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""Image upscaling tab component."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from src.config import OUTPUT_DIR, config, get_model_choices
|
||||||
|
from src.processing.upscaler import upscale_image, save_image, get_output_dimensions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(
|
||||||
|
input_image: Optional[np.ndarray],
|
||||||
|
model_name: str,
|
||||||
|
scale: int,
|
||||||
|
face_enhance: bool,
|
||||||
|
tile_size: str,
|
||||||
|
denoise: float,
|
||||||
|
output_format: str,
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
) -> tuple[Optional[tuple], Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Process image for upscaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image: Input image as numpy array
|
||||||
|
model_name: Model to use
|
||||||
|
scale: Scale factor (2 or 4)
|
||||||
|
face_enhance: Enable face enhancement
|
||||||
|
tile_size: Tile size setting ("Auto", "256", "512", "1024")
|
||||||
|
denoise: Denoise strength (0-1)
|
||||||
|
output_format: Output format (PNG, JPG, WebP)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (slider_images, output_path, status_message)
|
||||||
|
"""
|
||||||
|
if input_image is None:
|
||||||
|
return None, None, "Please upload an image first."
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Parse tile size
|
||||||
|
tile = 0 if tile_size == "Auto" else int(tile_size)
|
||||||
|
|
||||||
|
# Progress callback for UI
|
||||||
|
def progress_callback(pct: float, stage: str):
|
||||||
|
progress(pct, desc=stage)
|
||||||
|
|
||||||
|
progress(0, desc="Starting upscale...")
|
||||||
|
|
||||||
|
# Upscale the image
|
||||||
|
result = upscale_image(
|
||||||
|
input_image,
|
||||||
|
model_name=model_name,
|
||||||
|
scale=scale,
|
||||||
|
tile_size=tile,
|
||||||
|
denoise_strength=denoise,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply face enhancement if requested
|
||||||
|
if face_enhance:
|
||||||
|
progress(0.7, desc="Enhancing faces...")
|
||||||
|
try:
|
||||||
|
from src.processing.face_enhancer import enhance_faces
|
||||||
|
|
||||||
|
result.image = enhance_faces(result.image)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Face enhancement not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Face enhancement failed: {e}")
|
||||||
|
|
||||||
|
progress(0.85, desc="Saving output...")
|
||||||
|
|
||||||
|
# Generate output filename
|
||||||
|
output_name = f"upscaled_{uuid.uuid4().hex[:8]}"
|
||||||
|
output_path = save_image(
|
||||||
|
result.image,
|
||||||
|
OUTPUT_DIR / output_name,
|
||||||
|
format=output_format.lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert for display (BGR to RGB)
|
||||||
|
input_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) if input_image.shape[2] == 3 else input_image
|
||||||
|
output_rgb = cv2.cvtColor(result.image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
progress(1.0, desc="Complete!")
|
||||||
|
|
||||||
|
# Calculate stats
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
input_size = f"{result.input_width}x{result.input_height}"
|
||||||
|
output_size = f"{result.output_width}x{result.output_height}"
|
||||||
|
|
||||||
|
status = (
|
||||||
|
f"Upscaled {input_size} -> {output_size} in {elapsed:.1f}s | "
|
||||||
|
f"Model: {model_name} | Saved: {output_path.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return images for slider (input, output)
|
||||||
|
return (input_rgb, output_rgb), str(output_path), status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Image processing failed")
|
||||||
|
return None, None, f"Error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_output(
|
||||||
|
input_image: Optional[np.ndarray],
|
||||||
|
model_name: str,
|
||||||
|
scale: int,
|
||||||
|
) -> str:
|
||||||
|
"""Estimate output dimensions for display."""
|
||||||
|
if input_image is None:
|
||||||
|
return "Upload an image to see estimated output size"
|
||||||
|
|
||||||
|
try:
|
||||||
|
height, width = input_image.shape[:2]
|
||||||
|
out_w, out_h = get_output_dimensions(width, height, model_name, scale)
|
||||||
|
return f"Input: {width}x{height} -> Output: {out_w}x{out_h}"
|
||||||
|
except Exception:
|
||||||
|
return "Unable to estimate output size"
|
||||||
|
|
||||||
|
|
||||||
|
def create_image_tab():
|
||||||
|
"""Create the image upscaling tab component."""
|
||||||
|
with gr.Tab("Image", id="image-tab"):
|
||||||
|
# Status display at top
|
||||||
|
status_text = gr.Markdown(
|
||||||
|
"Upload an image to begin upscaling",
|
||||||
|
elem_classes=["status-text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - Input
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
input_image = gr.Image(
|
||||||
|
label="Input Image",
|
||||||
|
type="numpy",
|
||||||
|
sources=["upload", "clipboard"],
|
||||||
|
elem_classes=["upload-area"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimension estimate
|
||||||
|
dimension_info = gr.Markdown(
|
||||||
|
"Upload an image to see estimated output size",
|
||||||
|
elem_classes=["info-text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right column - Output (ImageSlider)
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
try:
|
||||||
|
from gradio_imageslider import ImageSlider
|
||||||
|
|
||||||
|
output_slider = ImageSlider(
|
||||||
|
label="Before / After",
|
||||||
|
type="numpy",
|
||||||
|
show_download_button=True,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback if ImageSlider not available
|
||||||
|
output_slider = gr.Image(
|
||||||
|
label="Upscaled Result",
|
||||||
|
type="numpy",
|
||||||
|
show_download_button=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_file = gr.File(
|
||||||
|
label="Download",
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processing options
|
||||||
|
with gr.Accordion("Processing Options", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
model_dropdown = gr.Dropdown(
|
||||||
|
choices=get_model_choices(),
|
||||||
|
value=config.default_model,
|
||||||
|
label="Model",
|
||||||
|
info="Select upscaling model",
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_radio = gr.Radio(
|
||||||
|
choices=[2, 4],
|
||||||
|
value=config.default_scale,
|
||||||
|
label="Scale Factor",
|
||||||
|
info="Output scale multiplier",
|
||||||
|
)
|
||||||
|
|
||||||
|
face_enhance_cb = gr.Checkbox(
|
||||||
|
value=config.default_face_enhance,
|
||||||
|
label="Face Enhancement (GFPGAN)",
|
||||||
|
info="Enhance faces in photos (not for anime)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
tile_dropdown = gr.Dropdown(
|
||||||
|
choices=["Auto", "256", "512", "1024"],
|
||||||
|
value="Auto",
|
||||||
|
label="Tile Size",
|
||||||
|
info="Auto recommended for RTX 4090",
|
||||||
|
)
|
||||||
|
|
||||||
|
denoise_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
value=0.5,
|
||||||
|
step=0.1,
|
||||||
|
label="Denoise Strength",
|
||||||
|
info="Only for realesr-general-x4v3 model",
|
||||||
|
)
|
||||||
|
|
||||||
|
format_radio = gr.Radio(
|
||||||
|
choices=["PNG", "JPG", "WebP"],
|
||||||
|
value="PNG",
|
||||||
|
label="Output Format",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action buttons
|
||||||
|
with gr.Row():
|
||||||
|
upscale_btn = gr.Button(
|
||||||
|
"Upscale Image",
|
||||||
|
variant="primary",
|
||||||
|
size="lg",
|
||||||
|
elem_classes=["primary-action-btn"],
|
||||||
|
)
|
||||||
|
|
||||||
|
clear_btn = gr.Button(
|
||||||
|
"Clear",
|
||||||
|
variant="secondary",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
upscale_btn.click(
|
||||||
|
fn=process_image,
|
||||||
|
inputs=[
|
||||||
|
input_image,
|
||||||
|
model_dropdown,
|
||||||
|
scale_radio,
|
||||||
|
face_enhance_cb,
|
||||||
|
tile_dropdown,
|
||||||
|
denoise_slider,
|
||||||
|
format_radio,
|
||||||
|
],
|
||||||
|
outputs=[output_slider, output_file, status_text],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update dimension estimate when image changes
|
||||||
|
input_image.change(
|
||||||
|
fn=estimate_output,
|
||||||
|
inputs=[input_image, model_dropdown, scale_radio],
|
||||||
|
outputs=[dimension_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also update when model/scale changes
|
||||||
|
model_dropdown.change(
|
||||||
|
fn=estimate_output,
|
||||||
|
inputs=[input_image, model_dropdown, scale_radio],
|
||||||
|
outputs=[dimension_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_radio.change(
|
||||||
|
fn=estimate_output,
|
||||||
|
inputs=[input_image, model_dropdown, scale_radio],
|
||||||
|
outputs=[dimension_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear button
|
||||||
|
clear_btn.click(
|
||||||
|
fn=lambda: (None, None, None, "Upload an image to begin upscaling"),
|
||||||
|
outputs=[input_image, output_slider, output_file, status_text],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_image": input_image,
|
||||||
|
"output_slider": output_slider,
|
||||||
|
"output_file": output_file,
|
||||||
|
"status_text": status_text,
|
||||||
|
"model_dropdown": model_dropdown,
|
||||||
|
"scale_radio": scale_radio,
|
||||||
|
"face_enhance_cb": face_enhance_cb,
|
||||||
|
"upscale_btn": upscale_btn,
|
||||||
|
}
|
||||||
376
src/ui/components/video_tab.py
Normal file
376
src/ui/components/video_tab.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""Video upscaling tab component."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import OUTPUT_DIR, TEMP_DIR, VIDEO_CODECS, config, get_model_choices
|
||||||
|
from src.processing.upscaler import upscale_image
|
||||||
|
from src.video.extractor import VideoExtractor, get_video_info
|
||||||
|
from src.video.encoder import encode_video
|
||||||
|
from src.video.audio import extract_audio, has_audio_stream
|
||||||
|
from src.video.checkpoint import checkpoint_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(video_path: Optional[str]) -> str:
|
||||||
|
"""Get video metadata for display."""
|
||||||
|
if not video_path:
|
||||||
|
return "Upload a video to see details"
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = get_video_info(Path(video_path))
|
||||||
|
return (
|
||||||
|
f"**Resolution:** {metadata.width}x{metadata.height}\n"
|
||||||
|
f"**Duration:** {metadata.duration_seconds:.1f}s\n"
|
||||||
|
f"**FPS:** {metadata.fps:.2f}\n"
|
||||||
|
f"**Frames:** {metadata.total_frames}\n"
|
||||||
|
f"**Codec:** {metadata.codec}\n"
|
||||||
|
f"**Audio:** {'Yes' if metadata.has_audio else 'No'}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error reading video: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_video_output(
|
||||||
|
video_path: Optional[str],
|
||||||
|
model_name: str,
|
||||||
|
scale: int,
|
||||||
|
) -> str:
|
||||||
|
"""Estimate output details for video."""
|
||||||
|
if not video_path:
|
||||||
|
return "Upload a video to see estimated output"
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = get_video_info(Path(video_path))
|
||||||
|
out_w = metadata.width * scale
|
||||||
|
out_h = metadata.height * scale
|
||||||
|
|
||||||
|
# Estimate processing time (rough: ~2 fps for 4x on RTX 4090)
|
||||||
|
est_time = metadata.total_frames / 2 if scale == 4 else metadata.total_frames / 4
|
||||||
|
est_minutes = est_time / 60
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"**Output Resolution:** {out_w}x{out_h}\n"
|
||||||
|
f"**Estimated Time:** ~{est_minutes:.1f} minutes"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return "Unable to estimate"
|
||||||
|
|
||||||
|
|
||||||
|
def process_video(
|
||||||
|
video_path: Optional[str],
|
||||||
|
model_name: str,
|
||||||
|
scale: int,
|
||||||
|
face_enhance: bool,
|
||||||
|
codec: str,
|
||||||
|
crf: int,
|
||||||
|
preset: str,
|
||||||
|
preserve_audio: bool,
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
) -> tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Process video for upscaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Input video path
|
||||||
|
model_name: Model to use
|
||||||
|
scale: Scale factor
|
||||||
|
face_enhance: Enable face enhancement
|
||||||
|
codec: Output codec
|
||||||
|
crf: Quality (lower = better)
|
||||||
|
preset: Encoding preset
|
||||||
|
preserve_audio: Keep original audio
|
||||||
|
progress: Gradio progress tracker
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (output_path, status_message)
|
||||||
|
"""
|
||||||
|
if not video_path:
|
||||||
|
return None, "Please upload a video first."
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
job_id = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
|
progress(0, desc="Analyzing video...")
|
||||||
|
|
||||||
|
# Get video info
|
||||||
|
video_path = Path(video_path)
|
||||||
|
metadata = get_video_info(video_path)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Processing video: {video_path.name}, "
|
||||||
|
f"{metadata.width}x{metadata.height}, "
|
||||||
|
f"{metadata.total_frames} frames"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temp directories
|
||||||
|
temp_dir = TEMP_DIR / f"video_{job_id}"
|
||||||
|
frames_dir = temp_dir / "frames"
|
||||||
|
frames_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Extract audio if needed
|
||||||
|
audio_path = None
|
||||||
|
if preserve_audio and has_audio_stream(video_path):
|
||||||
|
progress(0.02, desc="Extracting audio...")
|
||||||
|
audio_path = extract_audio(video_path, temp_dir / "audio")
|
||||||
|
|
||||||
|
# Create checkpoint
|
||||||
|
checkpoint = checkpoint_manager.create_checkpoint(
|
||||||
|
job_id=job_id,
|
||||||
|
input_path=video_path,
|
||||||
|
total_frames=metadata.total_frames,
|
||||||
|
frames_dir=frames_dir,
|
||||||
|
audio_path=audio_path,
|
||||||
|
config={
|
||||||
|
"model_name": model_name,
|
||||||
|
"scale": scale,
|
||||||
|
"face_enhance": face_enhance,
|
||||||
|
"codec": codec,
|
||||||
|
"crf": crf,
|
||||||
|
"preset": preset,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process frames
|
||||||
|
progress(0.05, desc="Processing frames...")
|
||||||
|
|
||||||
|
with VideoExtractor(video_path) as extractor:
|
||||||
|
total = metadata.total_frames
|
||||||
|
start_frame = checkpoint.processed_frames
|
||||||
|
|
||||||
|
for frame_idx, frame in extractor.extract_frames(start_frame=start_frame):
|
||||||
|
# Upscale frame (RGB input)
|
||||||
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
result = upscale_image(
|
||||||
|
frame_bgr,
|
||||||
|
model_name=model_name,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply face enhancement if enabled
|
||||||
|
if face_enhance:
|
||||||
|
try:
|
||||||
|
from src.processing.face_enhancer import enhance_faces
|
||||||
|
result.image = enhance_faces(result.image)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Face enhancement failed on frame {frame_idx}: {e}")
|
||||||
|
|
||||||
|
# Save frame
|
||||||
|
frame_path = frames_dir / f"frame_{frame_idx:08d}.png"
|
||||||
|
cv2.imwrite(str(frame_path), result.image)
|
||||||
|
|
||||||
|
# Update checkpoint
|
||||||
|
checkpoint_manager.update_checkpoint(checkpoint, frame_idx)
|
||||||
|
|
||||||
|
# Update progress
|
||||||
|
pct = 0.05 + (0.85 * (frame_idx + 1) / total)
|
||||||
|
progress(
|
||||||
|
pct,
|
||||||
|
desc=f"Processing frame {frame_idx + 1}/{total}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode output video
|
||||||
|
progress(0.9, desc=f"Encoding video ({codec})...")
|
||||||
|
|
||||||
|
output_name = f"{video_path.stem}_upscaled_{job_id}"
|
||||||
|
output_path = OUTPUT_DIR / f"{output_name}.mp4"
|
||||||
|
|
||||||
|
encode_video(
|
||||||
|
frames_dir=frames_dir,
|
||||||
|
output_path=output_path,
|
||||||
|
fps=metadata.fps,
|
||||||
|
codec=codec,
|
||||||
|
crf=crf,
|
||||||
|
preset=preset,
|
||||||
|
audio_path=audio_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
progress(0.98, desc="Cleaning up...")
|
||||||
|
checkpoint_manager.delete_checkpoint(job_id)
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
progress(1.0, desc="Complete!")
|
||||||
|
|
||||||
|
# Calculate stats
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
output_size = output_path.stat().st_size / (1024 * 1024) # MB
|
||||||
|
fps_actual = metadata.total_frames / elapsed
|
||||||
|
|
||||||
|
status = (
|
||||||
|
f"Completed in {elapsed/60:.1f} minutes ({fps_actual:.2f} fps)\n"
|
||||||
|
f"Output: {output_path.name} ({output_size:.1f} MB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return str(output_path), status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Video processing failed")
|
||||||
|
return None, f"Error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_video_tab():
|
||||||
|
"""Create the video upscaling tab component."""
|
||||||
|
with gr.Tab("Video", id="video-tab"):
|
||||||
|
# Status display
|
||||||
|
status_text = gr.Markdown(
|
||||||
|
"Upload a video to begin processing",
|
||||||
|
elem_classes=["status-text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - Input
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
input_video = gr.Video(
|
||||||
|
label="Input Video",
|
||||||
|
sources=["upload"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Video info display
|
||||||
|
video_info = gr.Markdown(
|
||||||
|
"Upload a video to see details",
|
||||||
|
elem_classes=["info-text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right column - Output
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
output_video = gr.Video(
|
||||||
|
label="Upscaled Video",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output estimate
|
||||||
|
output_estimate = gr.Markdown(
|
||||||
|
"Upload a video to see estimated output",
|
||||||
|
elem_classes=["info-text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processing options
|
||||||
|
with gr.Accordion("Processing Options", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
model_dropdown = gr.Dropdown(
|
||||||
|
choices=get_model_choices(),
|
||||||
|
value=config.default_model,
|
||||||
|
label="Model",
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_radio = gr.Radio(
|
||||||
|
choices=[2, 4],
|
||||||
|
value=config.default_scale,
|
||||||
|
label="Scale Factor",
|
||||||
|
)
|
||||||
|
|
||||||
|
face_enhance_cb = gr.Checkbox(
|
||||||
|
value=False,
|
||||||
|
label="Face Enhancement",
|
||||||
|
info="Not recommended for anime",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Codec options
|
||||||
|
with gr.Accordion("Output Codec Settings", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
codec_radio = gr.Radio(
|
||||||
|
choices=list(VIDEO_CODECS.keys()),
|
||||||
|
value=config.default_video_codec,
|
||||||
|
label="Codec",
|
||||||
|
info="H.265 recommended for quality/size balance",
|
||||||
|
)
|
||||||
|
|
||||||
|
crf_slider = gr.Slider(
|
||||||
|
minimum=15,
|
||||||
|
maximum=35,
|
||||||
|
value=config.default_video_crf,
|
||||||
|
step=1,
|
||||||
|
label="Quality (CRF)",
|
||||||
|
info="Lower = better quality, larger file",
|
||||||
|
)
|
||||||
|
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
choices=["slow", "medium", "fast"],
|
||||||
|
value=config.default_video_preset,
|
||||||
|
label="Preset",
|
||||||
|
info="Slow = best quality",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
preserve_audio_cb = gr.Checkbox(
|
||||||
|
value=True,
|
||||||
|
label="Preserve Audio",
|
||||||
|
info="Keep original audio track",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action buttons
|
||||||
|
with gr.Row():
|
||||||
|
process_btn = gr.Button(
|
||||||
|
"Start Processing",
|
||||||
|
variant="primary",
|
||||||
|
size="lg",
|
||||||
|
elem_classes=["primary-action-btn"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cancel_btn = gr.Button(
|
||||||
|
"Cancel",
|
||||||
|
variant="stop",
|
||||||
|
visible=False, # TODO: implement cancellation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
process_btn.click(
|
||||||
|
fn=process_video,
|
||||||
|
inputs=[
|
||||||
|
input_video,
|
||||||
|
model_dropdown,
|
||||||
|
scale_radio,
|
||||||
|
face_enhance_cb,
|
||||||
|
codec_radio,
|
||||||
|
crf_slider,
|
||||||
|
preset_dropdown,
|
||||||
|
preserve_audio_cb,
|
||||||
|
],
|
||||||
|
outputs=[output_video, status_text],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update video info when video changes
|
||||||
|
input_video.change(
|
||||||
|
fn=get_video_metadata,
|
||||||
|
inputs=[input_video],
|
||||||
|
outputs=[video_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update output estimate
|
||||||
|
input_video.change(
|
||||||
|
fn=estimate_video_output,
|
||||||
|
inputs=[input_video, model_dropdown, scale_radio],
|
||||||
|
outputs=[output_estimate],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_dropdown.change(
|
||||||
|
fn=estimate_video_output,
|
||||||
|
inputs=[input_video, model_dropdown, scale_radio],
|
||||||
|
outputs=[output_estimate],
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_radio.change(
|
||||||
|
fn=estimate_video_output,
|
||||||
|
inputs=[input_video, model_dropdown, scale_radio],
|
||||||
|
outputs=[output_estimate],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_video": input_video,
|
||||||
|
"output_video": output_video,
|
||||||
|
"status_text": status_text,
|
||||||
|
"process_btn": process_btn,
|
||||||
|
}
|
||||||
1
src/ui/handlers/__init__.py
Normal file
1
src/ui/handlers/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Event handlers for UI interactions."""
|
||||||
284
src/ui/theme.py
Normal file
284
src/ui/theme.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Custom dark theme for Real-ESRGAN Web UI."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from gradio.themes import Base, colors, sizes
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerTheme(Base):
|
||||||
|
"""
|
||||||
|
Professional dark theme optimized for image/video upscaling UI.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Dark background for better image comparison
|
||||||
|
- Blue accent colors for actions
|
||||||
|
- Comfortable spacing for power users
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
# Core colors
|
||||||
|
primary_hue=colors.blue,
|
||||||
|
secondary_hue=colors.slate,
|
||||||
|
neutral_hue=colors.slate,
|
||||||
|
# Sizing
|
||||||
|
spacing_size=sizes.spacing_md,
|
||||||
|
radius_size=sizes.radius_md,
|
||||||
|
text_size=sizes.text_md,
|
||||||
|
# Fonts
|
||||||
|
font=(
|
||||||
|
"Inter",
|
||||||
|
"ui-sans-serif",
|
||||||
|
"-apple-system",
|
||||||
|
"BlinkMacSystemFont",
|
||||||
|
"Segoe UI",
|
||||||
|
"Roboto",
|
||||||
|
"sans-serif",
|
||||||
|
),
|
||||||
|
font_mono=(
|
||||||
|
"JetBrains Mono",
|
||||||
|
"ui-monospace",
|
||||||
|
"SFMono-Regular",
|
||||||
|
"Menlo",
|
||||||
|
"monospace",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dark mode overrides
|
||||||
|
super().set(
|
||||||
|
# Backgrounds - dark with subtle contrast
|
||||||
|
body_background_fill="*neutral_950",
|
||||||
|
body_background_fill_dark="*neutral_950",
|
||||||
|
block_background_fill="*neutral_900",
|
||||||
|
block_background_fill_dark="*neutral_900",
|
||||||
|
# Borders
|
||||||
|
block_border_width="1px",
|
||||||
|
block_border_color="*neutral_800",
|
||||||
|
block_border_color_dark="*neutral_700",
|
||||||
|
border_color_primary="*primary_600",
|
||||||
|
# Inputs
|
||||||
|
input_background_fill="*neutral_800",
|
||||||
|
input_background_fill_dark="*neutral_800",
|
||||||
|
input_border_color="*neutral_700",
|
||||||
|
input_border_color_focus="*primary_500",
|
||||||
|
input_border_color_focus_dark="*primary_400",
|
||||||
|
# Buttons - primary (blue)
|
||||||
|
button_primary_background_fill="*primary_600",
|
||||||
|
button_primary_background_fill_hover="*primary_500",
|
||||||
|
button_primary_text_color="white",
|
||||||
|
button_primary_border_color="*primary_600",
|
||||||
|
# Buttons - secondary
|
||||||
|
button_secondary_background_fill="*neutral_700",
|
||||||
|
button_secondary_background_fill_hover="*neutral_600",
|
||||||
|
button_secondary_text_color="*neutral_100",
|
||||||
|
button_secondary_border_color="*neutral_600",
|
||||||
|
# Buttons - stop (for cancel)
|
||||||
|
button_cancel_background_fill="*error_600",
|
||||||
|
button_cancel_background_fill_hover="*error_500",
|
||||||
|
button_cancel_text_color="white",
|
||||||
|
# Shadows
|
||||||
|
block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3), 0 2px 4px -2px rgba(0, 0, 0, 0.2)",
|
||||||
|
# Text
|
||||||
|
body_text_color="*neutral_200",
|
||||||
|
body_text_color_subdued="*neutral_400",
|
||||||
|
body_text_color_dark="*neutral_200",
|
||||||
|
# Labels
|
||||||
|
block_title_text_color="*neutral_100",
|
||||||
|
block_label_text_color="*neutral_300",
|
||||||
|
# Sliders
|
||||||
|
slider_color="*primary_500",
|
||||||
|
slider_color_dark="*primary_400",
|
||||||
|
# Checkboxes
|
||||||
|
checkbox_background_color_selected="*primary_600",
|
||||||
|
checkbox_background_color_selected_dark="*primary_500",
|
||||||
|
checkbox_border_color_selected="*primary_600",
|
||||||
|
# Tables/Dataframes
|
||||||
|
table_border_color="*neutral_700",
|
||||||
|
table_even_background_fill="*neutral_850",
|
||||||
|
table_odd_background_fill="*neutral_900",
|
||||||
|
# Tabs
|
||||||
|
tab_selected_text_color="*primary_400",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom CSS for additional styling
|
||||||
|
CUSTOM_CSS = """
|
||||||
|
/* Force dark color scheme */
|
||||||
|
:root {
|
||||||
|
color-scheme: dark;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Header styling */
|
||||||
|
.app-header {
|
||||||
|
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||||
|
padding: 1rem 1.5rem;
|
||||||
|
border-bottom: 1px solid #334155;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.app-header h1 {
|
||||||
|
color: #f1f5f9;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 600;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tab navigation styling */
|
||||||
|
.tabs .tab-nav {
|
||||||
|
background: #1e293b;
|
||||||
|
border-radius: 8px 8px 0 0;
|
||||||
|
padding: 0.5rem 0.5rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs .tab-nav button {
|
||||||
|
padding: 0.75rem 1.5rem;
|
||||||
|
font-weight: 500;
|
||||||
|
border-radius: 8px 8px 0 0;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs .tab-nav button.selected {
|
||||||
|
background: #334155;
|
||||||
|
border-bottom: 2px solid #3b82f6;
|
||||||
|
color: #60a5fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Image comparison slider */
|
||||||
|
.image-slider-container {
|
||||||
|
border-radius: 12px;
|
||||||
|
overflow: hidden;
|
||||||
|
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Progress bar styling */
|
||||||
|
.progress-bar {
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
background: #1e293b;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-bar > div {
|
||||||
|
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 50%, #60a5fa 100%);
|
||||||
|
transition: width 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Gallery items */
|
||||||
|
.gallery-item {
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gallery-item:hover {
|
||||||
|
transform: scale(1.02);
|
||||||
|
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Accordion styling */
|
||||||
|
.accordion {
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.accordion > .label-wrap {
|
||||||
|
background: #1e293b;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Status bar */
|
||||||
|
.status-bar {
|
||||||
|
position: fixed;
|
||||||
|
bottom: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
background: #0f172a;
|
||||||
|
border-top: 1px solid #334155;
|
||||||
|
padding: 0.5rem 1.5rem;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: #94a3b8;
|
||||||
|
z-index: 100;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-bar .status-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-bar .status-dot {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: #22c55e;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-bar .status-dot.busy {
|
||||||
|
background: #eab308;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Large primary buttons */
|
||||||
|
.primary-action-btn {
|
||||||
|
font-size: 1.125rem !important;
|
||||||
|
padding: 0.875rem 2rem !important;
|
||||||
|
font-weight: 600 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Upload area */
|
||||||
|
.upload-area {
|
||||||
|
border: 2px dashed #334155;
|
||||||
|
border-radius: 12px;
|
||||||
|
transition: border-color 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.upload-area:hover {
|
||||||
|
border-color: #3b82f6;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Processing info cards */
|
||||||
|
.info-card {
|
||||||
|
background: #1e293b;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.info-card h4 {
|
||||||
|
color: #94a3b8;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
margin-bottom: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.info-card .value {
|
||||||
|
color: #f1f5f9;
|
||||||
|
font-size: 1.25rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Responsive adjustments */
|
||||||
|
@media (max-width: 1200px) {
|
||||||
|
.main-row {
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Add padding at bottom for status bar */
|
||||||
|
.gradio-container {
|
||||||
|
padding-bottom: 60px !important;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_theme():
|
||||||
|
"""Get the custom theme instance."""
|
||||||
|
return UpscalerTheme()
|
||||||
|
|
||||||
|
|
||||||
|
def get_css():
|
||||||
|
"""Get custom CSS for additional styling."""
|
||||||
|
return CUSTOM_CSS
|
||||||
1
src/utils/__init__.py
Normal file
1
src/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Utility modules (memory management, validators)."""
|
||||||
1
src/video/__init__.py
Normal file
1
src/video/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Video processing utilities (extraction, encoding, audio)."""
|
||||||
221
src/video/audio.py
Normal file
221
src/video/audio.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
"""Audio extraction and handling for video processing."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioError(Exception):
|
||||||
|
"""Error during audio processing."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def check_ffmpeg():
|
||||||
|
"""Check if FFmpeg is available."""
|
||||||
|
if shutil.which("ffmpeg") is None:
|
||||||
|
raise AudioError(
|
||||||
|
"FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def has_audio_stream(video_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if video has an audio stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if video has audio
|
||||||
|
"""
|
||||||
|
check_ffmpeg()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-select_streams", "a",
|
||||||
|
"-show_entries", "stream=codec_name",
|
||||||
|
"-of", "csv=p=0",
|
||||||
|
str(video_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return bool(result.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_audio(
|
||||||
|
video_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
copy_codec: bool = True,
|
||||||
|
) -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Extract audio stream from video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Source video path
|
||||||
|
output_path: Output audio path
|
||||||
|
copy_codec: If True, copy without re-encoding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to extracted audio, or None if no audio
|
||||||
|
"""
|
||||||
|
check_ffmpeg()
|
||||||
|
|
||||||
|
if not has_audio_stream(video_path):
|
||||||
|
logger.info(f"No audio stream in: {video_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if copy_codec:
|
||||||
|
# Determine output extension based on source codec
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-select_streams", "a:0",
|
||||||
|
"-show_entries", "stream=codec_name",
|
||||||
|
"-of", "csv=p=0",
|
||||||
|
str(video_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
codec = result.stdout.strip().lower()
|
||||||
|
|
||||||
|
# Map codec to extension
|
||||||
|
ext_map = {
|
||||||
|
"aac": ".aac",
|
||||||
|
"mp3": ".mp3",
|
||||||
|
"opus": ".opus",
|
||||||
|
"vorbis": ".ogg",
|
||||||
|
"flac": ".flac",
|
||||||
|
"pcm_s16le": ".wav",
|
||||||
|
"pcm_s24le": ".wav",
|
||||||
|
}
|
||||||
|
ext = ext_map.get(codec, ".aac")
|
||||||
|
output_path = output_path.with_suffix(ext)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vn", # No video
|
||||||
|
"-acodec", "copy",
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Re-encode to AAC
|
||||||
|
output_path = output_path.with_suffix(".aac")
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vn",
|
||||||
|
"-acodec", "aac",
|
||||||
|
"-b:a", "320k",
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
logger.info(f"Extracted audio to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Audio extraction failed: {e.stderr}")
|
||||||
|
raise AudioError(f"Audio extraction failed: {e.stderr}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def mux_audio_video(
|
||||||
|
video_path: Path,
|
||||||
|
audio_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
copy_streams: bool = True,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Combine video and audio streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Video-only file
|
||||||
|
audio_path: Audio file
|
||||||
|
output_path: Output path
|
||||||
|
copy_streams: If True, copy without re-encoding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to muxed video
|
||||||
|
"""
|
||||||
|
check_ffmpeg()
|
||||||
|
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if copy_streams:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-i", str(audio_path),
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "copy",
|
||||||
|
"-map", "0:v:0",
|
||||||
|
"-map", "1:a:0",
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-i", str(audio_path),
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "320k",
|
||||||
|
"-map", "0:v:0",
|
||||||
|
"-map", "1:a:0",
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
logger.info(f"Muxed audio/video to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Muxing failed: {e.stderr}")
|
||||||
|
raise AudioError(f"Audio/video muxing failed: {e.stderr}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_duration(audio_path: Path) -> float:
|
||||||
|
"""
|
||||||
|
Get duration of audio file in seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to audio file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
check_ffmpeg()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-show_entries", "format=duration",
|
||||||
|
"-of", "csv=p=0",
|
||||||
|
str(audio_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return float(result.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
250
src/video/checkpoint.py
Normal file
250
src/video/checkpoint.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Checkpoint system for resumable video processing."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.config import CHECKPOINTS_DIR
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoCheckpoint:
|
||||||
|
"""Checkpoint data for resumable video processing."""
|
||||||
|
|
||||||
|
job_id: str
|
||||||
|
input_path: str
|
||||||
|
input_hash: str # MD5 of first 1MB for verification
|
||||||
|
total_frames: int
|
||||||
|
processed_frames: int
|
||||||
|
last_processed_frame: int
|
||||||
|
frames_dir: str
|
||||||
|
audio_path: Optional[str]
|
||||||
|
config_json: str
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_percent(self) -> float:
|
||||||
|
"""Get progress as percentage."""
|
||||||
|
if self.total_frames == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.processed_frames / self.total_frames) * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""Check if processing is complete."""
|
||||||
|
return self.processed_frames >= self.total_frames
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
"""
|
||||||
|
Manage video processing checkpoints for resume capability.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Save checkpoint every N frames
|
||||||
|
- Verify input file hasn't changed
|
||||||
|
- Clean up after completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_dir: Optional[Path] = None,
|
||||||
|
save_interval: int = 100,
|
||||||
|
):
|
||||||
|
self.checkpoint_dir = checkpoint_dir or CHECKPOINTS_DIR
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get_checkpoint_path(self, job_id: str) -> Path:
|
||||||
|
"""Get path to checkpoint file."""
|
||||||
|
return self.checkpoint_dir / f"{job_id}.checkpoint.json"
|
||||||
|
|
||||||
|
def create_checkpoint(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
input_path: Path,
|
||||||
|
total_frames: int,
|
||||||
|
frames_dir: Path,
|
||||||
|
audio_path: Optional[Path],
|
||||||
|
config: dict,
|
||||||
|
) -> VideoCheckpoint:
|
||||||
|
"""
|
||||||
|
Create a new checkpoint for a video job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Unique job identifier
|
||||||
|
input_path: Path to input video
|
||||||
|
total_frames: Total frames to process
|
||||||
|
frames_dir: Directory for processed frames
|
||||||
|
audio_path: Path to extracted audio (if any)
|
||||||
|
config: Processing configuration dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New VideoCheckpoint instance
|
||||||
|
"""
|
||||||
|
now = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
checkpoint = VideoCheckpoint(
|
||||||
|
job_id=job_id,
|
||||||
|
input_path=str(input_path),
|
||||||
|
input_hash=self._compute_file_hash(input_path),
|
||||||
|
total_frames=total_frames,
|
||||||
|
processed_frames=0,
|
||||||
|
last_processed_frame=-1,
|
||||||
|
frames_dir=str(frames_dir),
|
||||||
|
audio_path=str(audio_path) if audio_path else None,
|
||||||
|
config_json=json.dumps(config),
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._save_checkpoint(checkpoint)
|
||||||
|
logger.info(f"Created checkpoint for job: {job_id}")
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def update_checkpoint(
|
||||||
|
self,
|
||||||
|
checkpoint: VideoCheckpoint,
|
||||||
|
processed_frame: int,
|
||||||
|
force_save: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update checkpoint with progress.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint: Checkpoint to update
|
||||||
|
processed_frame: Frame index that was just processed
|
||||||
|
force_save: Force save even if not at interval
|
||||||
|
"""
|
||||||
|
checkpoint.processed_frames += 1
|
||||||
|
checkpoint.last_processed_frame = processed_frame
|
||||||
|
checkpoint.updated_at = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
# Save periodically or when forced
|
||||||
|
if force_save or checkpoint.processed_frames % self.save_interval == 0:
|
||||||
|
self._save_checkpoint(checkpoint)
|
||||||
|
logger.debug(
|
||||||
|
f"Checkpoint saved: {checkpoint.processed_frames}/{checkpoint.total_frames}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_checkpoint(self, job_id: str) -> Optional[VideoCheckpoint]:
|
||||||
|
"""
|
||||||
|
Load existing checkpoint if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoCheckpoint if valid, None otherwise
|
||||||
|
"""
|
||||||
|
path = self.get_checkpoint_path(job_id)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
checkpoint = VideoCheckpoint(**data)
|
||||||
|
|
||||||
|
# Verify input file hasn't changed
|
||||||
|
if Path(checkpoint.input_path).exists():
|
||||||
|
current_hash = self._compute_file_hash(Path(checkpoint.input_path))
|
||||||
|
if current_hash != checkpoint.input_hash:
|
||||||
|
logger.warning("Input file changed since checkpoint, starting fresh")
|
||||||
|
self.delete_checkpoint(job_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Verify frames directory exists
|
||||||
|
if not Path(checkpoint.frames_dir).exists():
|
||||||
|
logger.warning("Frames directory missing, starting fresh")
|
||||||
|
self.delete_checkpoint(job_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loaded checkpoint: {checkpoint.processed_frames}/{checkpoint.total_frames} frames"
|
||||||
|
)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load checkpoint: {e}")
|
||||||
|
self.delete_checkpoint(job_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_checkpoint(self, job_id: str) -> None:
|
||||||
|
"""Delete checkpoint file."""
|
||||||
|
path = self.get_checkpoint_path(job_id)
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
logger.info(f"Deleted checkpoint: {job_id}")
|
||||||
|
|
||||||
|
def _save_checkpoint(self, checkpoint: VideoCheckpoint) -> None:
|
||||||
|
"""Save checkpoint to disk."""
|
||||||
|
path = self.get_checkpoint_path(checkpoint.job_id)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(asdict(checkpoint), f, indent=2)
|
||||||
|
|
||||||
|
def _compute_file_hash(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
chunk_size: int = 1024 * 1024,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Compute MD5 hash of first chunk of file.
|
||||||
|
|
||||||
|
Using only the first 1MB is fast while still detecting
|
||||||
|
if the file has been modified.
|
||||||
|
"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
chunk = f.read(chunk_size)
|
||||||
|
hasher.update(chunk)
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
def get_all_checkpoints(self) -> list[VideoCheckpoint]:
|
||||||
|
"""Get all existing checkpoints."""
|
||||||
|
checkpoints = []
|
||||||
|
for path in self.checkpoint_dir.glob("*.checkpoint.json"):
|
||||||
|
job_id = path.stem.replace(".checkpoint", "")
|
||||||
|
checkpoint = self.load_checkpoint(job_id)
|
||||||
|
if checkpoint:
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
def cleanup_old_checkpoints(self, max_age_days: int = 7) -> int:
|
||||||
|
"""
|
||||||
|
Clean up old checkpoint files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_days: Delete checkpoints older than this
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of checkpoints deleted
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
deleted = 0
|
||||||
|
max_age_seconds = max_age_days * 24 * 60 * 60
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
for path in self.checkpoint_dir.glob("*.checkpoint.json"):
|
||||||
|
if now - path.stat().st_mtime > max_age_seconds:
|
||||||
|
path.unlink()
|
||||||
|
deleted += 1
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
logger.info(f"Cleaned up {deleted} old checkpoints")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
|
||||||
|
# Global checkpoint manager
|
||||||
|
checkpoint_manager = CheckpointManager()
|
||||||
274
src/video/encoder.py
Normal file
274
src/video/encoder.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Video encoding using FFmpeg."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.config import VIDEO_CODECS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EncodingResult:
|
||||||
|
"""Result of video encoding."""
|
||||||
|
|
||||||
|
output_path: Path
|
||||||
|
codec: str
|
||||||
|
duration_seconds: float
|
||||||
|
file_size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEncoderError(Exception):
|
||||||
|
"""Error during video encoding."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def check_ffmpeg():
|
||||||
|
"""Check if FFmpeg is available."""
|
||||||
|
if shutil.which("ffmpeg") is None:
|
||||||
|
raise VideoEncoderError(
|
||||||
|
"FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_nvenc():
|
||||||
|
"""Check if NVENC is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["ffmpeg", "-hide_banner", "-encoders"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return "h264_nvenc" in result.stdout
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEncoder:
|
||||||
|
"""
|
||||||
|
Encode video frames using FFmpeg.
|
||||||
|
|
||||||
|
Supports multiple codecs:
|
||||||
|
- H.264 (libx264 / h264_nvenc)
|
||||||
|
- H.265 (libx265 / hevc_nvenc)
|
||||||
|
- AV1 (libsvtav1 / av1_nvenc)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_nvenc: bool = True):
|
||||||
|
check_ffmpeg()
|
||||||
|
self.use_nvenc = use_nvenc and check_nvenc()
|
||||||
|
|
||||||
|
if self.use_nvenc:
|
||||||
|
logger.info("NVENC hardware encoding available")
|
||||||
|
else:
|
||||||
|
logger.info("Using software encoding")
|
||||||
|
|
||||||
|
def encode_frames(
|
||||||
|
self,
|
||||||
|
frames_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
fps: float,
|
||||||
|
codec: str = "H.265",
|
||||||
|
crf: int = 20,
|
||||||
|
preset: str = "slow",
|
||||||
|
audio_path: Optional[Path] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
) -> EncodingResult:
|
||||||
|
"""
|
||||||
|
Encode frames to video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames_dir: Directory containing numbered frame images
|
||||||
|
output_path: Output video path
|
||||||
|
fps: Frame rate
|
||||||
|
codec: Codec name ("H.264", "H.265", "AV1")
|
||||||
|
crf: Quality (lower = better, 0-51)
|
||||||
|
preset: Speed preset
|
||||||
|
audio_path: Optional audio file to mux
|
||||||
|
width: Optional output width (for scaling)
|
||||||
|
height: Optional output height (for scaling)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EncodingResult with output info
|
||||||
|
"""
|
||||||
|
if codec not in VIDEO_CODECS:
|
||||||
|
raise VideoEncoderError(f"Unknown codec: {codec}")
|
||||||
|
|
||||||
|
codec_config = VIDEO_CODECS[codec]
|
||||||
|
|
||||||
|
# Build FFmpeg command
|
||||||
|
cmd = self._build_command(
|
||||||
|
frames_dir=frames_dir,
|
||||||
|
output_path=output_path,
|
||||||
|
fps=fps,
|
||||||
|
codec_config=codec_config,
|
||||||
|
crf=crf,
|
||||||
|
preset=preset,
|
||||||
|
audio_path=audio_path,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Encoding video: {' '.join(cmd)}")
|
||||||
|
|
||||||
|
# Run FFmpeg
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"FFmpeg error: {e.stderr}")
|
||||||
|
raise VideoEncoderError(f"Encoding failed: {e.stderr}") from e
|
||||||
|
|
||||||
|
# Get output info
|
||||||
|
output_path = Path(output_path)
|
||||||
|
file_size = output_path.stat().st_size if output_path.exists() else 0
|
||||||
|
|
||||||
|
return EncodingResult(
|
||||||
|
output_path=output_path,
|
||||||
|
codec=codec,
|
||||||
|
duration_seconds=0, # TODO: probe output
|
||||||
|
file_size_bytes=file_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_command(
|
||||||
|
self,
|
||||||
|
frames_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
fps: float,
|
||||||
|
codec_config,
|
||||||
|
crf: int,
|
||||||
|
preset: str,
|
||||||
|
audio_path: Optional[Path],
|
||||||
|
width: Optional[int],
|
||||||
|
height: Optional[int],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build FFmpeg command."""
|
||||||
|
# Frame input pattern (expects frame_%08d.png format)
|
||||||
|
frame_pattern = str(frames_dir / "frame_%08d.png")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y", # Overwrite output
|
||||||
|
"-framerate", str(fps),
|
||||||
|
"-i", frame_pattern,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add audio input if provided
|
||||||
|
if audio_path and audio_path.exists():
|
||||||
|
cmd.extend(["-i", str(audio_path)])
|
||||||
|
|
||||||
|
# Video filter for scaling if needed
|
||||||
|
vf_filters = []
|
||||||
|
if width and height:
|
||||||
|
# Ensure even dimensions
|
||||||
|
w = width if width % 2 == 0 else width + 1
|
||||||
|
h = height if height % 2 == 0 else height + 1
|
||||||
|
vf_filters.append(f"scale={w}:{h}:flags=lanczos")
|
||||||
|
|
||||||
|
if vf_filters:
|
||||||
|
cmd.extend(["-vf", ",".join(vf_filters)])
|
||||||
|
|
||||||
|
# Codec settings
|
||||||
|
cmd.extend(self._get_codec_args(codec_config, crf, preset))
|
||||||
|
|
||||||
|
# Pixel format
|
||||||
|
cmd.extend(["-pix_fmt", "yuv420p"])
|
||||||
|
|
||||||
|
# Audio settings
|
||||||
|
if audio_path and audio_path.exists():
|
||||||
|
cmd.extend(["-c:a", "aac", "-b:a", "320k"])
|
||||||
|
|
||||||
|
# Output
|
||||||
|
cmd.append(str(output_path))
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
def _get_codec_args(self, codec_config, crf: int, preset: str) -> list[str]:
|
||||||
|
"""Get codec-specific FFmpeg arguments."""
|
||||||
|
# Use NVENC if available and configured
|
||||||
|
if self.use_nvenc and codec_config.nvenc_encoder:
|
||||||
|
encoder = codec_config.nvenc_encoder
|
||||||
|
|
||||||
|
if "nvenc" in encoder:
|
||||||
|
return [
|
||||||
|
"-c:v", encoder,
|
||||||
|
"-preset", "p7", # Highest quality NVENC preset
|
||||||
|
"-tune", "hq",
|
||||||
|
"-rc", "vbr",
|
||||||
|
"-cq", str(crf),
|
||||||
|
"-b:v", "0",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Software encoding
|
||||||
|
encoder = codec_config.encoder
|
||||||
|
|
||||||
|
if encoder == "libx264":
|
||||||
|
return [
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", preset,
|
||||||
|
"-crf", str(crf),
|
||||||
|
"-tune", "film",
|
||||||
|
]
|
||||||
|
elif encoder == "libx265":
|
||||||
|
return [
|
||||||
|
"-c:v", "libx265",
|
||||||
|
"-preset", preset,
|
||||||
|
"-crf", str(crf),
|
||||||
|
"-x265-params", "aq-mode=3",
|
||||||
|
]
|
||||||
|
elif encoder == "libsvtav1":
|
||||||
|
return [
|
||||||
|
"-c:v", "libsvtav1",
|
||||||
|
"-preset", preset,
|
||||||
|
"-crf", str(crf),
|
||||||
|
"-svtav1-params", "tune=0:film-grain=8",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return ["-c:v", encoder, "-crf", str(crf)]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(
|
||||||
|
frames_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
fps: float,
|
||||||
|
codec: str = "H.265",
|
||||||
|
crf: int = 20,
|
||||||
|
preset: str = "slow",
|
||||||
|
audio_path: Optional[Path] = None,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Convenience function to encode video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames_dir: Directory with numbered frames
|
||||||
|
output_path: Output video path
|
||||||
|
fps: Frame rate
|
||||||
|
codec: Video codec
|
||||||
|
crf: Quality (lower = better)
|
||||||
|
preset: Encoding preset
|
||||||
|
audio_path: Optional audio to mux
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to encoded video
|
||||||
|
"""
|
||||||
|
encoder = VideoEncoder()
|
||||||
|
result = encoder.encode_frames(
|
||||||
|
frames_dir=frames_dir,
|
||||||
|
output_path=output_path,
|
||||||
|
fps=fps,
|
||||||
|
codec=codec,
|
||||||
|
crf=crf,
|
||||||
|
preset=preset,
|
||||||
|
audio_path=audio_path,
|
||||||
|
)
|
||||||
|
return result.output_path
|
||||||
201
src/video/extractor.py
Normal file
201
src/video/extractor.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Video frame extraction using PyAV."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoMetadata:
|
||||||
|
"""Video file metadata."""
|
||||||
|
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
fps: float
|
||||||
|
total_frames: int
|
||||||
|
duration_seconds: float
|
||||||
|
codec: str
|
||||||
|
has_audio: bool
|
||||||
|
audio_codec: Optional[str] = None
|
||||||
|
audio_sample_rate: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtractor:
|
||||||
|
"""
|
||||||
|
Extract frames from video files using PyAV.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- High-performance frame extraction
|
||||||
|
- Support for various video formats
|
||||||
|
- Frame-accurate seeking for resume
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, video_path: Path):
|
||||||
|
self.video_path = Path(video_path)
|
||||||
|
self.container = None
|
||||||
|
self.video_stream = None
|
||||||
|
self._metadata: Optional[VideoMetadata] = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
"""Open video file."""
|
||||||
|
if not self.video_path.exists():
|
||||||
|
raise FileNotFoundError(f"Video not found: {self.video_path}")
|
||||||
|
|
||||||
|
self.container = av.open(str(self.video_path))
|
||||||
|
self.video_stream = self.container.streams.video[0]
|
||||||
|
|
||||||
|
# Enable threading for faster decoding
|
||||||
|
self.video_stream.thread_type = "AUTO"
|
||||||
|
|
||||||
|
logger.info(f"Opened video: {self.video_path}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close video file."""
|
||||||
|
if self.container:
|
||||||
|
self.container.close()
|
||||||
|
self.container = None
|
||||||
|
self.video_stream = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> VideoMetadata:
|
||||||
|
"""Get video metadata."""
|
||||||
|
if self._metadata is not None:
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
if self.video_stream is None:
|
||||||
|
raise RuntimeError("Video not opened")
|
||||||
|
|
||||||
|
video = self.video_stream
|
||||||
|
|
||||||
|
# Calculate total frames
|
||||||
|
if video.frames > 0:
|
||||||
|
total_frames = video.frames
|
||||||
|
elif video.duration and video.time_base:
|
||||||
|
# Estimate from duration
|
||||||
|
duration_sec = float(video.duration * video.time_base)
|
||||||
|
total_frames = int(duration_sec * float(video.average_rate))
|
||||||
|
else:
|
||||||
|
total_frames = 0
|
||||||
|
|
||||||
|
# Duration
|
||||||
|
if video.duration and video.time_base:
|
||||||
|
duration_seconds = float(video.duration * video.time_base)
|
||||||
|
else:
|
||||||
|
duration_seconds = 0.0
|
||||||
|
|
||||||
|
# Check for audio
|
||||||
|
has_audio = len(self.container.streams.audio) > 0
|
||||||
|
audio_codec = None
|
||||||
|
audio_sample_rate = None
|
||||||
|
|
||||||
|
if has_audio:
|
||||||
|
audio_stream = self.container.streams.audio[0]
|
||||||
|
audio_codec = audio_stream.codec.name
|
||||||
|
audio_sample_rate = audio_stream.sample_rate
|
||||||
|
|
||||||
|
self._metadata = VideoMetadata(
|
||||||
|
width=video.width,
|
||||||
|
height=video.height,
|
||||||
|
fps=float(video.average_rate) if video.average_rate else 30.0,
|
||||||
|
total_frames=total_frames,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
codec=video.codec.name,
|
||||||
|
has_audio=has_audio,
|
||||||
|
audio_codec=audio_codec,
|
||||||
|
audio_sample_rate=audio_sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def extract_frames(
|
||||||
|
self,
|
||||||
|
start_frame: int = 0,
|
||||||
|
end_frame: Optional[int] = None,
|
||||||
|
) -> Iterator[tuple[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Extract frames from video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_frame: Starting frame index (for resume)
|
||||||
|
end_frame: Ending frame index (exclusive, None for all)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (frame_index, frame_array) where frame_array is RGB
|
||||||
|
"""
|
||||||
|
if self.container is None:
|
||||||
|
raise RuntimeError("Video not opened")
|
||||||
|
|
||||||
|
frame_index = 0
|
||||||
|
|
||||||
|
# Seek to start position if needed
|
||||||
|
if start_frame > 0:
|
||||||
|
self._seek_to_frame(start_frame)
|
||||||
|
frame_index = start_frame
|
||||||
|
|
||||||
|
for frame in self.container.decode(video=0):
|
||||||
|
if end_frame is not None and frame_index >= end_frame:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert to RGB numpy array
|
||||||
|
img = frame.to_ndarray(format="rgb24")
|
||||||
|
|
||||||
|
yield frame_index, img
|
||||||
|
frame_index += 1
|
||||||
|
|
||||||
|
def _seek_to_frame(self, frame_index: int):
|
||||||
|
"""Seek to approximate frame position."""
|
||||||
|
if self.video_stream is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate timestamp
|
||||||
|
fps = float(self.video_stream.average_rate) if self.video_stream.average_rate else 30.0
|
||||||
|
time_base = self.video_stream.time_base
|
||||||
|
target_ts = int(frame_index / fps / time_base)
|
||||||
|
|
||||||
|
# Seek to keyframe before target
|
||||||
|
self.container.seek(target_ts, stream=self.video_stream)
|
||||||
|
|
||||||
|
def get_frame(self, frame_index: int) -> Optional[np.ndarray]:
|
||||||
|
"""Get a specific frame by index."""
|
||||||
|
self._seek_to_frame(frame_index)
|
||||||
|
|
||||||
|
for frame in self.container.decode(video=0):
|
||||||
|
return frame.to_ndarray(format="rgb24")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_keyframes(self) -> Iterator[tuple[int, np.ndarray]]:
|
||||||
|
"""Extract only keyframes (for preview generation)."""
|
||||||
|
if self.container is None:
|
||||||
|
raise RuntimeError("Video not opened")
|
||||||
|
|
||||||
|
# Set codec to skip non-key frames
|
||||||
|
self.video_stream.codec_context.skip_frame = "NONKEY"
|
||||||
|
|
||||||
|
frame_index = 0
|
||||||
|
for frame in self.container.decode(video=0):
|
||||||
|
img = frame.to_ndarray(format="rgb24")
|
||||||
|
yield frame_index, img
|
||||||
|
frame_index += 1
|
||||||
|
|
||||||
|
# Reset skip mode
|
||||||
|
self.video_stream.codec_context.skip_frame = "DEFAULT"
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_info(video_path: Path) -> VideoMetadata:
|
||||||
|
"""Quick function to get video metadata."""
|
||||||
|
with VideoExtractor(video_path) as extractor:
|
||||||
|
return extractor.metadata
|
||||||
Reference in New Issue
Block a user