feat: initial implementation of Real-ESRGAN Web UI
Full-featured Gradio 6.0+ web interface for Real-ESRGAN image/video upscaling, optimized for RTX 4090 (24GB VRAM). Features: - Image upscaling with before/after comparison (ImageSlider) - Video upscaling with progress tracking and checkpoint/resume - Face enhancement via GFPGAN integration - Multiple codecs: H.264, H.265, AV1 (with NVENC support) - Batch processing queue with SQLite persistence - Processing history gallery - Custom dark theme - Auto-download of model weights 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
29
app.py
Normal file
29
app.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Real-ESRGAN Web UI - Main Entry Point
|
||||
|
||||
A full-featured web interface for Real-ESRGAN image and video upscaling,
|
||||
optimized for RTX 4090 with 24GB VRAM.
|
||||
|
||||
Usage:
|
||||
python app.py
|
||||
|
||||
The server will start at http://localhost:7860
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.ui.app import launch_app
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
launch_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[project]
|
||||
name = "upscale-ui"
|
||||
version = "0.1.0"
|
||||
description = "Real-ESRGAN Web UI for image and video upscaling"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "MIT"}
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I"]
|
||||
ignore = ["E501"]
|
||||
30
requirements.txt
Normal file
30
requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
# Real-ESRGAN Web UI Dependencies
|
||||
|
||||
# Gradio UI Framework
|
||||
gradio>=6.0.0
|
||||
gradio-imageslider>=0.0.20
|
||||
|
||||
# PyTorch (CUDA 12.1)
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
|
||||
# Real-ESRGAN and dependencies
|
||||
realesrgan>=0.3.0
|
||||
basicsr>=1.4.2
|
||||
|
||||
# Face Enhancement
|
||||
gfpgan>=1.3.8
|
||||
facexlib>=0.3.0
|
||||
|
||||
# Video Processing
|
||||
av>=10.0.0
|
||||
|
||||
# Image Processing
|
||||
opencv-python>=4.8.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# Utilities
|
||||
tqdm>=4.65.0
|
||||
requests>=2.28.0
|
||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Real-ESRGAN Web UI - Image and Video Upscaling."""
|
||||
222
src/config.py
Normal file
222
src/config.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Configuration settings for Real-ESRGAN Web UI."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Base paths
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
DATA_DIR = BASE_DIR / "data"
|
||||
MODELS_DIR = DATA_DIR / "models"
|
||||
TEMP_DIR = DATA_DIR / "temp"
|
||||
CHECKPOINTS_DIR = DATA_DIR / "checkpoints"
|
||||
OUTPUT_DIR = DATA_DIR / "output"
|
||||
DATABASE_PATH = DATA_DIR / "upscale.db"
|
||||
|
||||
# Ensure directories exist
|
||||
for dir_path in [MODELS_DIR, TEMP_DIR, CHECKPOINTS_DIR, OUTPUT_DIR]:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a Real-ESRGAN or GFPGAN model."""
|
||||
|
||||
name: str
|
||||
scale: int
|
||||
filename: str
|
||||
url: str
|
||||
size_mb: float
|
||||
description: str
|
||||
model_type: str = "realesrgan" # "realesrgan" or "gfpgan"
|
||||
netscale: int = 4 # Network scale factor
|
||||
num_block: int = 23 # Number of RRDB blocks (6 for anime models)
|
||||
num_grow_ch: int = 32 # Growth channels
|
||||
|
||||
|
||||
# Model registry with download URLs
|
||||
MODEL_REGISTRY: dict[str, ModelInfo] = {
|
||||
"RealESRGAN_x4plus": ModelInfo(
|
||||
name="RealESRGAN_x4plus",
|
||||
scale=4,
|
||||
filename="RealESRGAN_x4plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
size_mb=63.9,
|
||||
description="Best quality for general photos (4x)",
|
||||
num_block=23,
|
||||
),
|
||||
"RealESRGAN_x2plus": ModelInfo(
|
||||
name="RealESRGAN_x2plus",
|
||||
scale=2,
|
||||
filename="RealESRGAN_x2plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
size_mb=63.9,
|
||||
description="General photos (2x upscaling)",
|
||||
netscale=2,
|
||||
num_block=23,
|
||||
),
|
||||
"RealESRGAN_x4plus_anime_6B": ModelInfo(
|
||||
name="RealESRGAN_x4plus_anime_6B",
|
||||
scale=4,
|
||||
filename="RealESRGAN_x4plus_anime_6B.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
size_mb=17.0,
|
||||
description="Optimized for anime/illustrations (4x)",
|
||||
num_block=6,
|
||||
),
|
||||
"realesr-animevideov3": ModelInfo(
|
||||
name="realesr-animevideov3",
|
||||
scale=4,
|
||||
filename="realesr-animevideov3.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
size_mb=17.0,
|
||||
description="Anime video with temporal consistency (4x)",
|
||||
num_block=6,
|
||||
),
|
||||
"realesr-general-x4v3": ModelInfo(
|
||||
name="realesr-general-x4v3",
|
||||
scale=4,
|
||||
filename="realesr-general-x4v3.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
size_mb=63.9,
|
||||
description="General purpose with denoise control (4x)",
|
||||
num_block=23,
|
||||
),
|
||||
"GFPGANv1.4": ModelInfo(
|
||||
name="GFPGANv1.4",
|
||||
scale=1,
|
||||
filename="GFPGANv1.4.pth",
|
||||
url="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth",
|
||||
size_mb=332.0,
|
||||
description="Face enhancement and restoration",
|
||||
model_type="gfpgan",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTX4090Config:
|
||||
"""RTX 4090 optimized settings (24GB VRAM)."""
|
||||
|
||||
tile_size: int = 0 # 0 = no tiling, use full image
|
||||
tile_pad: int = 10 # Padding for tile boundaries
|
||||
pre_pad: int = 0 # Pre-padding for input
|
||||
half: bool = True # FP16 precision
|
||||
device: str = "cuda"
|
||||
gpu_id: int = 0
|
||||
|
||||
# Thresholds for when to enable tiling
|
||||
max_pixels_no_tile: int = 8294400 # ~4K (3840x2160)
|
||||
|
||||
def should_tile(self, width: int, height: int) -> bool:
|
||||
"""Determine if tiling is needed based on image size."""
|
||||
return width * height > self.max_pixels_no_tile
|
||||
|
||||
def get_tile_size(self, width: int, height: int) -> int:
|
||||
"""Get appropriate tile size for the image."""
|
||||
if not self.should_tile(width, height):
|
||||
return 0 # No tiling
|
||||
# For very large images, use 512px tiles
|
||||
return 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoCodecConfig:
|
||||
"""Video encoding configuration."""
|
||||
|
||||
encoder: str
|
||||
crf: int
|
||||
preset: str
|
||||
nvenc_encoder: Optional[str] = None
|
||||
description: str = ""
|
||||
|
||||
|
||||
# Video codec presets (maximum quality settings)
|
||||
VIDEO_CODECS: dict[str, VideoCodecConfig] = {
|
||||
"H.264": VideoCodecConfig(
|
||||
encoder="libx264",
|
||||
crf=18,
|
||||
preset="slow",
|
||||
nvenc_encoder="h264_nvenc",
|
||||
description="Universal compatibility, excellent quality",
|
||||
),
|
||||
"H.265": VideoCodecConfig(
|
||||
encoder="libx265",
|
||||
crf=20,
|
||||
preset="slow",
|
||||
nvenc_encoder="hevc_nvenc",
|
||||
description="50% smaller files, great quality",
|
||||
),
|
||||
"AV1": VideoCodecConfig(
|
||||
encoder="libsvtav1",
|
||||
crf=23,
|
||||
preset="4",
|
||||
nvenc_encoder="av1_nvenc",
|
||||
description="Best compression, future-proof",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
|
||||
# Server settings
|
||||
server_name: str = "0.0.0.0"
|
||||
server_port: int = 7860
|
||||
share: bool = False
|
||||
|
||||
# Queue settings
|
||||
max_queue_size: int = 20
|
||||
concurrency_limit: int = 1 # Single job at a time for GPU efficiency
|
||||
|
||||
# Processing defaults
|
||||
default_model: str = "RealESRGAN_x4plus"
|
||||
default_scale: int = 4
|
||||
default_face_enhance: bool = False
|
||||
default_output_format: str = "png"
|
||||
|
||||
# Video defaults
|
||||
default_video_codec: str = "H.265"
|
||||
default_video_crf: int = 20
|
||||
default_video_preset: str = "slow"
|
||||
|
||||
# History settings
|
||||
max_history_items: int = 1000
|
||||
thumbnail_size: tuple[int, int] = (256, 256)
|
||||
|
||||
# Checkpoint settings
|
||||
checkpoint_interval: int = 100 # Save every N frames
|
||||
|
||||
# GPU config
|
||||
gpu: RTX4090Config = field(default_factory=RTX4090Config)
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = AppConfig()
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> Path:
|
||||
"""Get the path to a model file."""
|
||||
if model_name not in MODEL_REGISTRY:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
return MODELS_DIR / MODEL_REGISTRY[model_name].filename
|
||||
|
||||
|
||||
def get_available_models() -> list[str]:
|
||||
"""Get list of available model names for UI dropdown."""
|
||||
return [
|
||||
name
|
||||
for name, info in MODEL_REGISTRY.items()
|
||||
if info.model_type == "realesrgan"
|
||||
]
|
||||
|
||||
|
||||
def get_model_choices() -> list[tuple[str, str]]:
|
||||
"""Get model choices for Gradio dropdown (label, value)."""
|
||||
return [
|
||||
(f"{info.description}", name)
|
||||
for name, info in MODEL_REGISTRY.items()
|
||||
if info.model_type == "realesrgan"
|
||||
]
|
||||
1
src/processing/__init__.py
Normal file
1
src/processing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Image and video processing modules."""
|
||||
213
src/processing/face_enhancer.py
Normal file
213
src/processing/face_enhancer.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Face enhancement using GFPGAN."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.config import MODEL_REGISTRY, MODELS_DIR, config
|
||||
from src.processing.models import ensure_model_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global GFPGAN instance (cached)
|
||||
_gfpgan_instance = None
|
||||
|
||||
|
||||
class FaceEnhancerError(Exception):
|
||||
"""Error during face enhancement."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_gfpgan(
|
||||
model_name: str = "GFPGANv1.4",
|
||||
upscale: int = 2,
|
||||
bg_upsampler=None,
|
||||
):
|
||||
"""
|
||||
Get or create GFPGAN instance.
|
||||
|
||||
Args:
|
||||
model_name: GFPGAN model version
|
||||
upscale: Background upscale factor
|
||||
bg_upsampler: Optional background upsampler (RealESRGAN)
|
||||
|
||||
Returns:
|
||||
GFPGANer instance
|
||||
"""
|
||||
global _gfpgan_instance
|
||||
|
||||
if _gfpgan_instance is not None:
|
||||
return _gfpgan_instance
|
||||
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
except ImportError:
|
||||
raise FaceEnhancerError(
|
||||
"GFPGAN not installed. Install with: pip install gfpgan"
|
||||
)
|
||||
|
||||
# Ensure model is downloaded
|
||||
model_path = ensure_model_available(model_name)
|
||||
|
||||
# Determine model version from name
|
||||
if "1.4" in model_name:
|
||||
arch = "clean"
|
||||
channel_multiplier = 2
|
||||
elif "1.3" in model_name:
|
||||
arch = "clean"
|
||||
channel_multiplier = 2
|
||||
elif "1.2" in model_name:
|
||||
arch = "clean"
|
||||
channel_multiplier = 2
|
||||
else:
|
||||
arch = "original"
|
||||
channel_multiplier = 1
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
_gfpgan_instance = GFPGANer(
|
||||
model_path=str(model_path),
|
||||
upscale=upscale,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=bg_upsampler,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logger.info(f"Loaded GFPGAN: {model_name}")
|
||||
return _gfpgan_instance
|
||||
|
||||
|
||||
def enhance_faces(
|
||||
image: np.ndarray,
|
||||
model_name: str = "GFPGANv1.4",
|
||||
has_aligned: bool = False,
|
||||
only_center_face: bool = False,
|
||||
paste_back: bool = True,
|
||||
weight: float = 0.5,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Enhance faces in an image using GFPGAN.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format, numpy array)
|
||||
model_name: GFPGAN model version to use
|
||||
has_aligned: Whether input is already aligned face
|
||||
only_center_face: Only enhance the center/largest face
|
||||
paste_back: Paste enhanced face back to original image
|
||||
weight: Blending weight (0=original, 1=enhanced)
|
||||
|
||||
Returns:
|
||||
Enhanced image (BGR format, numpy array)
|
||||
"""
|
||||
if image is None:
|
||||
raise FaceEnhancerError("No image provided")
|
||||
|
||||
try:
|
||||
gfpgan = get_gfpgan(model_name)
|
||||
|
||||
# GFPGAN enhance
|
||||
_, _, output = gfpgan.enhance(
|
||||
image,
|
||||
has_aligned=has_aligned,
|
||||
only_center_face=only_center_face,
|
||||
paste_back=paste_back,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
logger.warning("No faces detected, returning original image")
|
||||
return image
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Face enhancement failed: {e}")
|
||||
raise FaceEnhancerError(f"Face enhancement failed: {e}") from e
|
||||
|
||||
|
||||
def detect_faces(image: np.ndarray) -> int:
|
||||
"""
|
||||
Detect number of faces in image.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format)
|
||||
|
||||
Returns:
|
||||
Number of detected faces
|
||||
"""
|
||||
try:
|
||||
from facexlib.detection import init_detection_model
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
# Use facexlib for detection
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model="retinaface_resnet50",
|
||||
save_ext="png",
|
||||
device=device,
|
||||
)
|
||||
|
||||
face_helper.read_image(image)
|
||||
face_helper.get_face_landmarks_5(only_center_face=False)
|
||||
|
||||
return len(face_helper.all_landmarks_5)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Face detection failed: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def is_anime_image(image: np.ndarray) -> bool:
|
||||
"""
|
||||
Attempt to detect if image is anime/illustration style.
|
||||
|
||||
This is a simple heuristic - for production use, consider a
|
||||
trained classifier.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format)
|
||||
|
||||
Returns:
|
||||
True if likely anime/illustration
|
||||
"""
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Anime tends to have:
|
||||
# 1. More uniform color regions (lower texture variance)
|
||||
# 2. Sharper edges (higher edge density)
|
||||
|
||||
# Calculate Laplacian variance (edge strength)
|
||||
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
|
||||
|
||||
# Calculate local variance (texture)
|
||||
local_var = cv2.blur(gray.astype(float) ** 2, (9, 9)) - cv2.blur(
|
||||
gray.astype(float), (9, 9)
|
||||
) ** 2
|
||||
texture_var = local_var.mean()
|
||||
|
||||
# Anime typically has high edge strength but low texture variance
|
||||
# These thresholds are approximate
|
||||
edge_ratio = laplacian_var / (texture_var + 1e-6)
|
||||
|
||||
# Heuristic: anime has high edge_ratio
|
||||
return edge_ratio > 50
|
||||
|
||||
|
||||
def clear_gfpgan_cache():
|
||||
"""Clear cached GFPGAN model."""
|
||||
global _gfpgan_instance
|
||||
_gfpgan_instance = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cleared GFPGAN cache")
|
||||
240
src/processing/models.py
Normal file
240
src/processing/models.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Model management for Real-ESRGAN and GFPGAN."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.config import MODEL_REGISTRY, MODELS_DIR, ModelInfo, config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelDownloadError(Exception):
|
||||
"""Error downloading model weights."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def download_model(model_info: ModelInfo, progress_callback=None) -> Path:
|
||||
"""
|
||||
Download model weights from URL.
|
||||
|
||||
Args:
|
||||
model_info: Model information from registry
|
||||
progress_callback: Optional callback(downloaded, total) for progress
|
||||
|
||||
Returns:
|
||||
Path to downloaded model file
|
||||
"""
|
||||
model_path = MODELS_DIR / model_info.filename
|
||||
|
||||
if model_path.exists():
|
||||
logger.info(f"Model already exists: {model_path}")
|
||||
return model_path
|
||||
|
||||
logger.info(f"Downloading {model_info.name} ({model_info.size_mb:.1f} MB)...")
|
||||
|
||||
try:
|
||||
response = requests.get(model_info.url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
downloaded = 0
|
||||
|
||||
# Download with progress
|
||||
with open(model_path, "wb") as f:
|
||||
with tqdm(
|
||||
total=total_size,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
desc=model_info.name,
|
||||
) as pbar:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
pbar.update(len(chunk))
|
||||
if progress_callback:
|
||||
progress_callback(downloaded, total_size)
|
||||
|
||||
logger.info(f"Downloaded: {model_path}")
|
||||
return model_path
|
||||
|
||||
except Exception as e:
|
||||
# Clean up partial download
|
||||
if model_path.exists():
|
||||
model_path.unlink()
|
||||
raise ModelDownloadError(f"Failed to download {model_info.name}: {e}") from e
|
||||
|
||||
|
||||
def ensure_model_available(model_name: str) -> Path:
|
||||
"""
|
||||
Ensure model is available, downloading if necessary.
|
||||
|
||||
Args:
|
||||
model_name: Name of model from registry
|
||||
|
||||
Returns:
|
||||
Path to model file
|
||||
"""
|
||||
if model_name not in MODEL_REGISTRY:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
|
||||
model_info = MODEL_REGISTRY[model_name]
|
||||
model_path = MODELS_DIR / model_info.filename
|
||||
|
||||
if not model_path.exists():
|
||||
download_model(model_info)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
Manages Real-ESRGAN model loading and caching.
|
||||
|
||||
Optimized for RTX 4090 with 24GB VRAM:
|
||||
- Keeps frequently used models in memory
|
||||
- LRU eviction when VRAM is constrained
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
def __init__(self, max_cached_models: int = 3):
|
||||
self.max_cached_models = max_cached_models
|
||||
self._cache: dict[str, "RealESRGANer"] = {}
|
||||
self._last_used: dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
self._device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
logger.info(f"GPU: {props.name}, VRAM: {props.total_memory / (1024**3):.1f}GB")
|
||||
else:
|
||||
logger.warning("CUDA not available, using CPU (this will be slow)")
|
||||
|
||||
def get_upsampler(
|
||||
self,
|
||||
model_name: str,
|
||||
tile_size: int = 0,
|
||||
tile_pad: int = 10,
|
||||
half: bool = True,
|
||||
) -> "RealESRGANer":
|
||||
"""
|
||||
Get or create a RealESRGANer instance.
|
||||
|
||||
Args:
|
||||
model_name: Name of model from registry
|
||||
tile_size: Tile size for processing (0 = no tiling)
|
||||
tile_pad: Padding for tile boundaries
|
||||
half: Use FP16 precision
|
||||
|
||||
Returns:
|
||||
Configured RealESRGANer instance
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
cache_key = f"{model_name}_{tile_size}_{tile_pad}_{half}"
|
||||
|
||||
with self._lock:
|
||||
# Check cache
|
||||
if cache_key in self._cache:
|
||||
self._last_used[cache_key] = time.time()
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Ensure model is downloaded
|
||||
model_path = ensure_model_available(model_name)
|
||||
model_info = MODEL_REGISTRY[model_name]
|
||||
|
||||
# Evict old models if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
# Create model architecture
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=model_info.num_block,
|
||||
num_grow_ch=model_info.num_grow_ch,
|
||||
scale=model_info.netscale,
|
||||
)
|
||||
|
||||
# Create upsampler
|
||||
upsampler = RealESRGANer(
|
||||
scale=model_info.netscale,
|
||||
model_path=str(model_path),
|
||||
model=model,
|
||||
tile=tile_size,
|
||||
tile_pad=tile_pad,
|
||||
pre_pad=config.gpu.pre_pad,
|
||||
half=half and self._device.type == "cuda",
|
||||
device=self._device,
|
||||
gpu_id=config.gpu.gpu_id if self._device.type == "cuda" else None,
|
||||
)
|
||||
|
||||
# Cache and return
|
||||
self._cache[cache_key] = upsampler
|
||||
self._last_used[cache_key] = time.time()
|
||||
|
||||
logger.info(f"Loaded model: {model_name} (tile={tile_size}, half={half})")
|
||||
return upsampler
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used model if at capacity."""
|
||||
while len(self._cache) >= self.max_cached_models:
|
||||
# Find LRU
|
||||
lru_key = min(self._last_used.keys(), key=lambda k: self._last_used[k])
|
||||
|
||||
# Remove from cache
|
||||
del self._cache[lru_key]
|
||||
del self._last_used[lru_key]
|
||||
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info(f"Evicted model from cache: {lru_key}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached models."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._last_used.clear()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cleared model cache")
|
||||
|
||||
def get_vram_status(self) -> dict:
|
||||
"""Get current VRAM usage statistics."""
|
||||
if not torch.cuda.is_available():
|
||||
return {"available": False}
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"allocated_gb": torch.cuda.memory_allocated() / (1024**3),
|
||||
"reserved_gb": torch.cuda.memory_reserved() / (1024**3),
|
||||
"total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
|
||||
"cached_models": list(self._cache.keys()),
|
||||
}
|
||||
|
||||
def preload_model(self, model_name: str):
|
||||
"""Preload a model into cache."""
|
||||
self.get_upsampler(
|
||||
model_name,
|
||||
tile_size=config.gpu.tile_size,
|
||||
tile_pad=config.gpu.tile_pad,
|
||||
half=config.gpu.half,
|
||||
)
|
||||
|
||||
|
||||
# Global model manager instance
|
||||
model_manager = ModelManager()
|
||||
341
src/processing/upscaler.py
Normal file
341
src/processing/upscaler.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""High-level upscaling interface for Real-ESRGAN."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from src.config import MODEL_REGISTRY, config
|
||||
from src.processing.models import model_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpscaleResult:
|
||||
"""Result of an upscaling operation."""
|
||||
|
||||
image: np.ndarray # BGR format
|
||||
input_width: int
|
||||
input_height: int
|
||||
output_width: int
|
||||
output_height: int
|
||||
model_used: str
|
||||
scale_factor: int
|
||||
processing_time: float
|
||||
used_tiling: bool
|
||||
|
||||
|
||||
class UpscalerError(Exception):
|
||||
"""Error during upscaling."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def calculate_tile_size(width: int, height: int) -> int:
|
||||
"""
|
||||
Calculate optimal tile size based on image dimensions.
|
||||
|
||||
For RTX 4090 with 24GB VRAM:
|
||||
- No tiling for images up to ~4K
|
||||
- 512px tiles for larger images
|
||||
"""
|
||||
pixels = width * height
|
||||
|
||||
# Thresholds based on testing with RealESRGAN_x4plus
|
||||
if pixels <= 2073600: # Up to 1080p (1920x1080)
|
||||
return 0 # No tiling
|
||||
elif pixels <= 8294400: # Up to 4K (3840x2160)
|
||||
return 0 # RTX 4090 can handle without tiling
|
||||
elif pixels <= 33177600: # Up to 8K (7680x4320)
|
||||
return 512
|
||||
else:
|
||||
return 256 # Very large images
|
||||
|
||||
|
||||
def load_image(
|
||||
input_path: Union[str, Path, np.ndarray, Image.Image]
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""
|
||||
Load image from various sources.
|
||||
|
||||
Args:
|
||||
input_path: Path to image, numpy array, or PIL Image
|
||||
|
||||
Returns:
|
||||
Tuple of (BGR numpy array, source description)
|
||||
"""
|
||||
if isinstance(input_path, np.ndarray):
|
||||
# Assume already BGR if 3 channels
|
||||
if len(input_path.shape) == 3 and input_path.shape[2] == 3:
|
||||
return input_path, "numpy"
|
||||
elif len(input_path.shape) == 3 and input_path.shape[2] == 4:
|
||||
# RGBA to BGR
|
||||
return cv2.cvtColor(input_path, cv2.COLOR_RGBA2BGR), "numpy_rgba"
|
||||
else:
|
||||
raise UpscalerError(f"Unsupported array shape: {input_path.shape}")
|
||||
|
||||
if isinstance(input_path, Image.Image):
|
||||
# PIL to numpy BGR
|
||||
img = np.array(input_path)
|
||||
if img.shape[2] == 4:
|
||||
return cv2.cvtColor(img, cv2.COLOR_RGBA2BGR), "pil_rgba"
|
||||
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR), "pil"
|
||||
|
||||
# Load from file
|
||||
input_path = Path(input_path)
|
||||
if not input_path.exists():
|
||||
raise UpscalerError(f"Image not found: {input_path}")
|
||||
|
||||
img = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
|
||||
if img is None:
|
||||
raise UpscalerError(f"Failed to load image: {input_path}")
|
||||
|
||||
# Handle different channel counts
|
||||
if len(img.shape) == 2:
|
||||
# Grayscale to BGR
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4:
|
||||
# RGBA - keep alpha for later
|
||||
pass # RealESRGAN handles alpha internally
|
||||
|
||||
return img, str(input_path)
|
||||
|
||||
|
||||
def save_image(
|
||||
image: np.ndarray,
|
||||
output_path: Union[str, Path],
|
||||
format: str = "png",
|
||||
quality: int = 95,
|
||||
) -> Path:
|
||||
"""
|
||||
Save image to file.
|
||||
|
||||
Args:
|
||||
image: BGR numpy array
|
||||
output_path: Output path (extension will be added/replaced)
|
||||
format: Output format (png, jpg, webp)
|
||||
quality: JPEG/WebP quality (0-100)
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
|
||||
# Ensure correct extension
|
||||
format = format.lower()
|
||||
if format == "jpg":
|
||||
format = "jpeg"
|
||||
|
||||
ext = f".{format}" if format != "jpeg" else ".jpg"
|
||||
output_path = output_path.with_suffix(ext)
|
||||
|
||||
# Ensure output directory exists
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save with appropriate settings
|
||||
if format == "png":
|
||||
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_PNG_COMPRESSION, 6])
|
||||
elif format == "jpeg":
|
||||
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
elif format == "webp":
|
||||
cv2.imwrite(str(output_path), image, [cv2.IMWRITE_WEBP_QUALITY, quality])
|
||||
else:
|
||||
cv2.imwrite(str(output_path), image)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def upscale_image(
|
||||
input_image: Union[str, Path, np.ndarray, Image.Image],
|
||||
model_name: str = None,
|
||||
scale: int = None,
|
||||
tile_size: int = None,
|
||||
denoise_strength: float = 0.5,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> UpscaleResult:
|
||||
"""
|
||||
Upscale a single image.
|
||||
|
||||
Args:
|
||||
input_image: Input image (path, numpy array, or PIL Image)
|
||||
model_name: Model to use (defaults to config)
|
||||
scale: Output scale (2 or 4, defaults to model's native scale)
|
||||
tile_size: Tile size (0 for auto, None for config default)
|
||||
denoise_strength: Denoise strength for supported models (0-1)
|
||||
progress_callback: Optional callback(progress, stage) for progress
|
||||
|
||||
Returns:
|
||||
UpscaleResult with upscaled image and metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use defaults
|
||||
model_name = model_name or config.default_model
|
||||
model_info = MODEL_REGISTRY.get(model_name)
|
||||
if not model_info:
|
||||
raise UpscalerError(f"Unknown model: {model_name}")
|
||||
|
||||
# Load image
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "Loading image...")
|
||||
|
||||
img, source = load_image(input_image)
|
||||
input_height, input_width = img.shape[:2]
|
||||
|
||||
logger.info(f"Input: {input_width}x{input_height} from {source}")
|
||||
|
||||
# Calculate tile size if auto
|
||||
if tile_size is None:
|
||||
tile_size = config.gpu.tile_size
|
||||
if tile_size == 0:
|
||||
# Auto-calculate based on image size
|
||||
tile_size = calculate_tile_size(input_width, input_height)
|
||||
|
||||
used_tiling = tile_size > 0
|
||||
|
||||
# Get upsampler
|
||||
if progress_callback:
|
||||
progress_callback(0.2, "Loading model...")
|
||||
|
||||
upsampler = model_manager.get_upsampler(
|
||||
model_name=model_name,
|
||||
tile_size=tile_size,
|
||||
tile_pad=config.gpu.tile_pad,
|
||||
half=config.gpu.half,
|
||||
)
|
||||
|
||||
# Determine output scale
|
||||
if scale is None:
|
||||
scale = model_info.scale
|
||||
outscale = scale / model_info.netscale if scale != model_info.netscale else None
|
||||
|
||||
# Process
|
||||
if progress_callback:
|
||||
progress_callback(0.3, "Upscaling...")
|
||||
|
||||
try:
|
||||
# RealESRGAN enhance method
|
||||
if model_name == "realesr-general-x4v3":
|
||||
# This model supports denoise strength
|
||||
output, _ = upsampler.enhance(
|
||||
img, outscale=outscale, dnoise_strength=denoise_strength
|
||||
)
|
||||
else:
|
||||
output, _ = upsampler.enhance(img, outscale=outscale)
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
# Clear cache and retry with tiling
|
||||
torch.cuda.empty_cache()
|
||||
logger.warning("VRAM exceeded, retrying with tiling...")
|
||||
|
||||
upsampler = model_manager.get_upsampler(
|
||||
model_name=model_name,
|
||||
tile_size=512, # Force tiling
|
||||
tile_pad=config.gpu.tile_pad,
|
||||
half=config.gpu.half,
|
||||
)
|
||||
|
||||
if model_name == "realesr-general-x4v3":
|
||||
output, _ = upsampler.enhance(
|
||||
img, outscale=outscale, dnoise_strength=denoise_strength
|
||||
)
|
||||
else:
|
||||
output, _ = upsampler.enhance(img, outscale=outscale)
|
||||
|
||||
used_tiling = True
|
||||
|
||||
except Exception as e:
|
||||
raise UpscalerError(f"Upscaling failed: {e}") from e
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.9, "Finalizing...")
|
||||
|
||||
output_height, output_width = output.shape[:2]
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Output: {output_width}x{output_height}, "
|
||||
f"time: {processing_time:.2f}s, tiling: {used_tiling}"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, "Complete!")
|
||||
|
||||
return UpscaleResult(
|
||||
image=output,
|
||||
input_width=input_width,
|
||||
input_height=input_height,
|
||||
output_width=output_width,
|
||||
output_height=output_height,
|
||||
model_used=model_name,
|
||||
scale_factor=scale,
|
||||
processing_time=processing_time,
|
||||
used_tiling=used_tiling,
|
||||
)
|
||||
|
||||
|
||||
def upscale_and_save(
|
||||
input_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
scale: int = None,
|
||||
output_format: str = "png",
|
||||
denoise_strength: float = 0.5,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> tuple[Path, UpscaleResult]:
|
||||
"""
|
||||
Upscale image and save to file.
|
||||
|
||||
Args:
|
||||
input_path: Input image path
|
||||
output_path: Output path (extension may be adjusted)
|
||||
model_name: Model to use
|
||||
scale: Output scale
|
||||
output_format: Output format (png, jpg, webp)
|
||||
denoise_strength: Denoise strength for supported models
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
Tuple of (saved path, UpscaleResult)
|
||||
"""
|
||||
result = upscale_image(
|
||||
input_path,
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
denoise_strength=denoise_strength,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
saved_path = save_image(result.image, output_path, format=output_format)
|
||||
|
||||
return saved_path, result
|
||||
|
||||
|
||||
def get_output_dimensions(
|
||||
width: int, height: int, model_name: str, scale: int = None
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate output dimensions for given input and model.
|
||||
|
||||
Args:
|
||||
width: Input width
|
||||
height: Input height
|
||||
model_name: Model name
|
||||
scale: Target scale (uses model default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (output_width, output_height)
|
||||
"""
|
||||
model_info = MODEL_REGISTRY.get(model_name)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
|
||||
scale = scale or model_info.scale
|
||||
return width * scale, height * scale
|
||||
1
src/storage/__init__.py
Normal file
1
src/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Storage and persistence modules (SQLite, history, queue)."""
|
||||
141
src/storage/database.py
Normal file
141
src/storage/database.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""SQLite database setup and utilities."""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
from src.config import DATABASE_PATH
|
||||
|
||||
|
||||
# Thread-safe connection management
|
||||
_db_lock = Lock()
|
||||
|
||||
|
||||
def get_connection() -> sqlite3.Connection:
|
||||
"""Get a database connection with proper settings."""
|
||||
conn = sqlite3.connect(str(DATABASE_PATH), timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.execute("PRAGMA journal_mode = WAL") # Better concurrent access
|
||||
return conn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
"""Context manager for database connections."""
|
||||
conn = get_connection()
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize database schema."""
|
||||
with _db_lock:
|
||||
with get_db() as conn:
|
||||
# History table for processed items
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
input_path TEXT NOT NULL,
|
||||
input_filename TEXT NOT NULL,
|
||||
output_path TEXT NOT NULL,
|
||||
output_filename TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
scale INTEGER NOT NULL,
|
||||
face_enhance INTEGER NOT NULL DEFAULT 0,
|
||||
video_codec TEXT,
|
||||
processing_time_seconds REAL NOT NULL,
|
||||
input_width INTEGER NOT NULL,
|
||||
input_height INTEGER NOT NULL,
|
||||
output_width INTEGER NOT NULL,
|
||||
output_height INTEGER NOT NULL,
|
||||
frames_count INTEGER,
|
||||
input_size_bytes INTEGER NOT NULL,
|
||||
output_size_bytes INTEGER NOT NULL,
|
||||
thumbnail_path TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
metadata_json TEXT DEFAULT '{}'
|
||||
)
|
||||
""")
|
||||
|
||||
# Index for efficient queries
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_history_created_at
|
||||
ON history(created_at DESC)
|
||||
""")
|
||||
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_history_type
|
||||
ON history(type)
|
||||
""")
|
||||
|
||||
# Queue table for batch processing
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS queue (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
input_path TEXT NOT NULL,
|
||||
output_path TEXT,
|
||||
config_json TEXT NOT NULL,
|
||||
progress_percent REAL DEFAULT 0,
|
||||
current_stage TEXT DEFAULT '',
|
||||
frames_processed INTEGER DEFAULT 0,
|
||||
total_frames INTEGER DEFAULT 0,
|
||||
error_message TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
priority INTEGER DEFAULT 5,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_queue_status
|
||||
ON queue(status)
|
||||
""")
|
||||
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_queue_priority_created
|
||||
ON queue(priority ASC, created_at ASC)
|
||||
""")
|
||||
|
||||
# Checkpoints table for video resume
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
input_path TEXT NOT NULL,
|
||||
input_hash TEXT NOT NULL,
|
||||
total_frames INTEGER NOT NULL,
|
||||
processed_frames INTEGER NOT NULL DEFAULT 0,
|
||||
last_processed_frame INTEGER NOT NULL DEFAULT -1,
|
||||
frames_dir TEXT NOT NULL,
|
||||
audio_path TEXT,
|
||||
config_json TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def reset_database():
|
||||
"""Reset database (for development/testing)."""
|
||||
with _db_lock:
|
||||
with get_db() as conn:
|
||||
conn.execute("DROP TABLE IF EXISTS history")
|
||||
conn.execute("DROP TABLE IF EXISTS queue")
|
||||
conn.execute("DROP TABLE IF EXISTS checkpoints")
|
||||
init_database()
|
||||
|
||||
|
||||
# Initialize on import
|
||||
init_database()
|
||||
282
src/storage/history.py
Normal file
282
src/storage/history.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Processing history management with SQLite persistence."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from src.storage.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryItem:
|
||||
"""Record of a processed image/video."""
|
||||
|
||||
id: str
|
||||
type: str # "image" or "video"
|
||||
input_path: str
|
||||
input_filename: str
|
||||
output_path: str
|
||||
output_filename: str
|
||||
model: str
|
||||
scale: int
|
||||
face_enhance: bool
|
||||
video_codec: Optional[str]
|
||||
processing_time_seconds: float
|
||||
input_width: int
|
||||
input_height: int
|
||||
output_width: int
|
||||
output_height: int
|
||||
frames_count: Optional[int]
|
||||
input_size_bytes: int
|
||||
output_size_bytes: int
|
||||
thumbnail_path: Optional[str]
|
||||
created_at: datetime
|
||||
metadata: dict
|
||||
|
||||
|
||||
class HistoryManager:
|
||||
"""
|
||||
Manage processing history with SQLite persistence.
|
||||
|
||||
Features:
|
||||
- Store all processed items
|
||||
- Search and filter
|
||||
- Thumbnails for quick preview
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("Initialized history manager")
|
||||
|
||||
def add_item(
|
||||
self,
|
||||
type: str,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model: str,
|
||||
scale: int,
|
||||
face_enhance: bool = False,
|
||||
video_codec: Optional[str] = None,
|
||||
processing_time_seconds: float = 0,
|
||||
input_width: int = 0,
|
||||
input_height: int = 0,
|
||||
output_width: int = 0,
|
||||
output_height: int = 0,
|
||||
frames_count: Optional[int] = None,
|
||||
thumbnail_path: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> HistoryItem:
|
||||
"""Add item to history."""
|
||||
item_id = uuid.uuid4().hex[:12]
|
||||
input_path = Path(input_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
# Get file sizes
|
||||
input_size = input_path.stat().st_size if input_path.exists() else 0
|
||||
output_size = output_path.stat().st_size if output_path.exists() else 0
|
||||
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO history (
|
||||
id, type, input_path, input_filename,
|
||||
output_path, output_filename, model, scale,
|
||||
face_enhance, video_codec, processing_time_seconds,
|
||||
input_width, input_height, output_width, output_height,
|
||||
frames_count, input_size_bytes, output_size_bytes,
|
||||
thumbnail_path, created_at, metadata_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
item_id,
|
||||
type,
|
||||
str(input_path),
|
||||
input_path.name,
|
||||
str(output_path),
|
||||
output_path.name,
|
||||
model,
|
||||
scale,
|
||||
int(face_enhance),
|
||||
video_codec,
|
||||
processing_time_seconds,
|
||||
input_width,
|
||||
input_height,
|
||||
output_width,
|
||||
output_height,
|
||||
frames_count,
|
||||
input_size,
|
||||
output_size,
|
||||
thumbnail_path,
|
||||
datetime.utcnow().isoformat(),
|
||||
json.dumps(metadata or {}),
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Added to history: {item_id} ({type})")
|
||||
|
||||
return HistoryItem(
|
||||
id=item_id,
|
||||
type=type,
|
||||
input_path=str(input_path),
|
||||
input_filename=input_path.name,
|
||||
output_path=str(output_path),
|
||||
output_filename=output_path.name,
|
||||
model=model,
|
||||
scale=scale,
|
||||
face_enhance=face_enhance,
|
||||
video_codec=video_codec,
|
||||
processing_time_seconds=processing_time_seconds,
|
||||
input_width=input_width,
|
||||
input_height=input_height,
|
||||
output_width=output_width,
|
||||
output_height=output_height,
|
||||
frames_count=frames_count,
|
||||
input_size_bytes=input_size,
|
||||
output_size_bytes=output_size,
|
||||
thumbnail_path=thumbnail_path,
|
||||
created_at=datetime.utcnow(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def get_item(self, item_id: str) -> Optional[HistoryItem]:
|
||||
"""Get history item by ID."""
|
||||
with get_db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM history WHERE id = ?",
|
||||
(item_id,),
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return self._row_to_item(row)
|
||||
return None
|
||||
|
||||
def get_recent(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
type_filter: Optional[str] = None,
|
||||
) -> list[HistoryItem]:
|
||||
"""Get recent history items."""
|
||||
with get_db() as conn:
|
||||
if type_filter:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM history
|
||||
WHERE type = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(type_filter, limit, offset),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
).fetchall()
|
||||
|
||||
return [self._row_to_item(row) for row in rows]
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 50,
|
||||
) -> list[HistoryItem]:
|
||||
"""Search history by filename."""
|
||||
with get_db() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM history
|
||||
WHERE input_filename LIKE ? OR output_filename LIKE ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(f"%{query}%", f"%{query}%", limit),
|
||||
).fetchall()
|
||||
|
||||
return [self._row_to_item(row) for row in rows]
|
||||
|
||||
def delete_item(self, item_id: str) -> bool:
|
||||
"""Delete history item."""
|
||||
with get_db() as conn:
|
||||
result = conn.execute(
|
||||
"DELETE FROM history WHERE id = ?",
|
||||
(item_id,),
|
||||
)
|
||||
deleted = result.rowcount > 0
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Deleted from history: {item_id}")
|
||||
|
||||
return deleted
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
"""Get history statistics."""
|
||||
with get_db() as conn:
|
||||
stats = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN type = 'image' THEN 1 ELSE 0 END) as images,
|
||||
SUM(CASE WHEN type = 'video' THEN 1 ELSE 0 END) as videos,
|
||||
SUM(processing_time_seconds) as total_time,
|
||||
SUM(input_size_bytes) as total_input_size,
|
||||
SUM(output_size_bytes) as total_output_size,
|
||||
SUM(frames_count) as total_frames
|
||||
FROM history
|
||||
"""
|
||||
).fetchone()
|
||||
|
||||
return {
|
||||
"total_items": stats["total"] or 0,
|
||||
"images": stats["images"] or 0,
|
||||
"videos": stats["videos"] or 0,
|
||||
"total_processing_time": stats["total_time"] or 0,
|
||||
"total_input_size_mb": (stats["total_input_size"] or 0) / (1024 * 1024),
|
||||
"total_output_size_mb": (stats["total_output_size"] or 0) / (1024 * 1024),
|
||||
"total_frames_processed": stats["total_frames"] or 0,
|
||||
}
|
||||
|
||||
def clear_history(self) -> int:
|
||||
"""Clear all history."""
|
||||
with get_db() as conn:
|
||||
result = conn.execute("DELETE FROM history")
|
||||
return result.rowcount
|
||||
|
||||
def _row_to_item(self, row) -> HistoryItem:
|
||||
"""Convert database row to HistoryItem."""
|
||||
return HistoryItem(
|
||||
id=row["id"],
|
||||
type=row["type"],
|
||||
input_path=row["input_path"],
|
||||
input_filename=row["input_filename"],
|
||||
output_path=row["output_path"],
|
||||
output_filename=row["output_filename"],
|
||||
model=row["model"],
|
||||
scale=row["scale"],
|
||||
face_enhance=bool(row["face_enhance"]),
|
||||
video_codec=row["video_codec"],
|
||||
processing_time_seconds=row["processing_time_seconds"],
|
||||
input_width=row["input_width"],
|
||||
input_height=row["input_height"],
|
||||
output_width=row["output_width"],
|
||||
output_height=row["output_height"],
|
||||
frames_count=row["frames_count"],
|
||||
input_size_bytes=row["input_size_bytes"],
|
||||
output_size_bytes=row["output_size_bytes"],
|
||||
thumbnail_path=row["thumbnail_path"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
metadata=json.loads(row["metadata_json"]),
|
||||
)
|
||||
|
||||
|
||||
# Global history manager instance
|
||||
history_manager = HistoryManager()
|
||||
332
src/storage/queue.py
Normal file
332
src/storage/queue.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Job queue management with SQLite persistence."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from src.storage.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Job status enumeration."""
|
||||
|
||||
PENDING = "pending"
|
||||
QUEUED = "queued"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobType(Enum):
|
||||
"""Job type enumeration."""
|
||||
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Processing job data."""
|
||||
|
||||
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
type: JobType = JobType.IMAGE
|
||||
status: JobStatus = JobStatus.PENDING
|
||||
input_path: str = ""
|
||||
output_path: Optional[str] = None
|
||||
config: dict = field(default_factory=dict)
|
||||
progress_percent: float = 0.0
|
||||
current_stage: str = ""
|
||||
frames_processed: int = 0
|
||||
total_frames: int = 0
|
||||
error_message: Optional[str] = None
|
||||
priority: int = 5
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""
|
||||
Manage job queue with SQLite persistence.
|
||||
|
||||
Features:
|
||||
- Priority-based queue ordering
|
||||
- Persistent across restarts
|
||||
- Progress tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("Initialized queue manager")
|
||||
|
||||
def add_job(self, job: Job) -> Job:
|
||||
"""Add job to queue."""
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO queue (
|
||||
id, type, status, input_path, output_path,
|
||||
config_json, progress_percent, current_stage,
|
||||
frames_processed, total_frames, error_message,
|
||||
priority, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
job.id,
|
||||
job.type.value,
|
||||
JobStatus.QUEUED.value,
|
||||
job.input_path,
|
||||
job.output_path,
|
||||
json.dumps(job.config),
|
||||
job.progress_percent,
|
||||
job.current_stage,
|
||||
job.frames_processed,
|
||||
job.total_frames,
|
||||
job.error_message,
|
||||
job.priority,
|
||||
job.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
job.status = JobStatus.QUEUED
|
||||
logger.info(f"Added job to queue: {job.id}")
|
||||
return job
|
||||
|
||||
def get_next_job(self) -> Optional[Job]:
|
||||
"""Get next job to process (highest priority, oldest first)."""
|
||||
with get_db() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM queue
|
||||
WHERE status = ?
|
||||
ORDER BY priority ASC, created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(JobStatus.QUEUED.value,),
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Mark as processing
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE queue SET status = ?, started_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
JobStatus.PROCESSING.value,
|
||||
datetime.utcnow().isoformat(),
|
||||
row["id"],
|
||||
),
|
||||
)
|
||||
|
||||
return self._row_to_job(row)
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress_percent: float,
|
||||
current_stage: str = "",
|
||||
frames_processed: int = 0,
|
||||
total_frames: int = 0,
|
||||
) -> None:
|
||||
"""Update job progress."""
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE queue SET
|
||||
progress_percent = ?,
|
||||
current_stage = ?,
|
||||
frames_processed = ?,
|
||||
total_frames = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(progress_percent, current_stage, frames_processed, total_frames, job_id),
|
||||
)
|
||||
|
||||
def complete_job(self, job_id: str, output_path: str) -> None:
|
||||
"""Mark job as completed."""
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE queue SET
|
||||
status = ?,
|
||||
output_path = ?,
|
||||
completed_at = ?,
|
||||
progress_percent = 100
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
JobStatus.COMPLETED.value,
|
||||
output_path,
|
||||
datetime.utcnow().isoformat(),
|
||||
job_id,
|
||||
),
|
||||
)
|
||||
logger.info(f"Job completed: {job_id}")
|
||||
|
||||
def fail_job(self, job_id: str, error_message: str) -> None:
|
||||
"""Mark job as failed."""
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE queue SET
|
||||
status = ?,
|
||||
error_message = ?,
|
||||
completed_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
JobStatus.FAILED.value,
|
||||
error_message,
|
||||
datetime.utcnow().isoformat(),
|
||||
job_id,
|
||||
),
|
||||
)
|
||||
logger.error(f"Job failed: {job_id} - {error_message}")
|
||||
|
||||
def cancel_job(self, job_id: str) -> None:
|
||||
"""Cancel a job."""
|
||||
with get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE queue SET status = ?
|
||||
WHERE id = ? AND status IN (?, ?)
|
||||
""",
|
||||
(
|
||||
JobStatus.CANCELLED.value,
|
||||
job_id,
|
||||
JobStatus.PENDING.value,
|
||||
JobStatus.QUEUED.value,
|
||||
),
|
||||
)
|
||||
logger.info(f"Job cancelled: {job_id}")
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID."""
|
||||
with get_db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM queue WHERE id = ?",
|
||||
(job_id,),
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return self._row_to_job(row)
|
||||
return None
|
||||
|
||||
def get_queue(self) -> list[Job]:
|
||||
"""Get all queued/processing jobs."""
|
||||
with get_db() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM queue
|
||||
WHERE status IN (?, ?)
|
||||
ORDER BY priority ASC, created_at ASC
|
||||
""",
|
||||
(JobStatus.QUEUED.value, JobStatus.PROCESSING.value),
|
||||
).fetchall()
|
||||
|
||||
return [self._row_to_job(row) for row in rows]
|
||||
|
||||
def get_completed_jobs(self, limit: int = 50) -> list[Job]:
|
||||
"""Get recently completed jobs."""
|
||||
with get_db() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM queue
|
||||
WHERE status IN (?, ?, ?)
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
JobStatus.COMPLETED.value,
|
||||
JobStatus.FAILED.value,
|
||||
JobStatus.CANCELLED.value,
|
||||
limit,
|
||||
),
|
||||
).fetchall()
|
||||
|
||||
return [self._row_to_job(row) for row in rows]
|
||||
|
||||
def get_queue_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
with get_db() as conn:
|
||||
stats = {}
|
||||
|
||||
# Count by status
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM queue
|
||||
GROUP BY status
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
for row in rows:
|
||||
stats[row["status"]] = row["count"]
|
||||
|
||||
# Currently processing
|
||||
processing = conn.execute(
|
||||
"SELECT * FROM queue WHERE status = ?",
|
||||
(JobStatus.PROCESSING.value,),
|
||||
).fetchone()
|
||||
|
||||
stats["current_job"] = self._row_to_job(processing) if processing else None
|
||||
|
||||
return stats
|
||||
|
||||
def clear_completed(self) -> int:
|
||||
"""Clear completed/failed/cancelled jobs."""
|
||||
with get_db() as conn:
|
||||
result = conn.execute(
|
||||
"""
|
||||
DELETE FROM queue
|
||||
WHERE status IN (?, ?, ?)
|
||||
""",
|
||||
(
|
||||
JobStatus.COMPLETED.value,
|
||||
JobStatus.FAILED.value,
|
||||
JobStatus.CANCELLED.value,
|
||||
),
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def _row_to_job(self, row) -> Job:
|
||||
"""Convert database row to Job object."""
|
||||
return Job(
|
||||
id=row["id"],
|
||||
type=JobType(row["type"]),
|
||||
status=JobStatus(row["status"]),
|
||||
input_path=row["input_path"],
|
||||
output_path=row["output_path"],
|
||||
config=json.loads(row["config_json"]),
|
||||
progress_percent=row["progress_percent"],
|
||||
current_stage=row["current_stage"],
|
||||
frames_processed=row["frames_processed"],
|
||||
total_frames=row["total_frames"],
|
||||
error_message=row["error_message"],
|
||||
priority=row["priority"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
started_at=(
|
||||
datetime.fromisoformat(row["started_at"])
|
||||
if row["started_at"]
|
||||
else None
|
||||
),
|
||||
completed_at=(
|
||||
datetime.fromisoformat(row["completed_at"])
|
||||
if row["completed_at"]
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Global queue manager instance
|
||||
queue_manager = QueueManager()
|
||||
1
src/ui/__init__.py
Normal file
1
src/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""UI components and handlers for the Gradio interface."""
|
||||
111
src/ui/app.py
Normal file
111
src/ui/app.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Main Gradio application assembly."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
from src.config import config
|
||||
from src.ui.theme import get_theme, get_css
|
||||
from src.ui.components.image_tab import create_image_tab
|
||||
from src.ui.components.video_tab import create_video_tab
|
||||
from src.ui.components.batch_tab import create_batch_tab
|
||||
from src.ui.components.history_tab import create_history_tab
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_gpu_status() -> str:
|
||||
"""Get GPU status string for display."""
|
||||
if not torch.cuda.is_available():
|
||||
return "GPU: Not available (CPU mode)"
|
||||
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
allocated = torch.cuda.memory_allocated() / (1024**3)
|
||||
total = props.total_memory / (1024**3)
|
||||
return f"GPU: {props.name} | VRAM: {allocated:.1f}/{total:.1f} GB"
|
||||
except Exception:
|
||||
return "GPU: Status unavailable"
|
||||
|
||||
|
||||
def create_app() -> gr.Blocks:
|
||||
"""Create and return the main Gradio application."""
|
||||
theme = get_theme()
|
||||
css = get_css()
|
||||
|
||||
with gr.Blocks(
|
||||
theme=theme,
|
||||
css=css,
|
||||
title="Real-ESRGAN Upscaler",
|
||||
fill_width=True,
|
||||
) as app:
|
||||
# Header
|
||||
gr.HTML(
|
||||
"""
|
||||
<div class="app-header">
|
||||
<h1>Real-ESRGAN Upscaler</h1>
|
||||
</div>
|
||||
""",
|
||||
elem_classes=["header"],
|
||||
)
|
||||
|
||||
# Main tabs
|
||||
with gr.Tabs(elem_id="main-tabs") as tabs:
|
||||
# Image upscaling tab
|
||||
image_components = create_image_tab()
|
||||
|
||||
# Video upscaling tab
|
||||
video_components = create_video_tab()
|
||||
|
||||
# Batch queue tab
|
||||
batch_components = create_batch_tab()
|
||||
|
||||
# History tab
|
||||
history_components = create_history_tab()
|
||||
|
||||
# Status bar
|
||||
gr.HTML(
|
||||
f"""
|
||||
<div class="status-bar">
|
||||
<div class="status-item">
|
||||
<span class="status-dot"></span>
|
||||
<span>{get_gpu_status()}</span>
|
||||
</div>
|
||||
<div class="status-item">
|
||||
Queue: 0 | Ready
|
||||
</div>
|
||||
</div>
|
||||
""",
|
||||
elem_classes=["status-bar-container"],
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def launch_app():
|
||||
"""Launch the Gradio application."""
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger.info("Starting Real-ESRGAN Web UI...")
|
||||
logger.info(get_gpu_status())
|
||||
|
||||
# Create and launch app
|
||||
app = create_app()
|
||||
|
||||
# Configure queue for GPU processing
|
||||
app.queue(
|
||||
default_concurrency_limit=config.concurrency_limit,
|
||||
max_size=config.max_queue_size,
|
||||
)
|
||||
|
||||
# Launch server
|
||||
app.launch(
|
||||
server_name=config.server_name,
|
||||
server_port=config.server_port,
|
||||
share=config.share,
|
||||
show_error=True,
|
||||
)
|
||||
1
src/ui/components/__init__.py
Normal file
1
src/ui/components/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Gradio UI components for each tab."""
|
||||
200
src/ui/components/batch_tab.py
Normal file
200
src/ui/components/batch_tab.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Batch processing queue tab component."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from src.config import get_model_choices, config
|
||||
from src.storage.queue import queue_manager, Job, JobType, JobStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_queue_data() -> list[list]:
|
||||
"""Get queue data for dataframe display."""
|
||||
jobs = queue_manager.get_queue()
|
||||
|
||||
data = []
|
||||
for job in jobs:
|
||||
status_emoji = {
|
||||
JobStatus.QUEUED: "⏳",
|
||||
JobStatus.PROCESSING: "🔄",
|
||||
JobStatus.COMPLETED: "✅",
|
||||
JobStatus.FAILED: "❌",
|
||||
JobStatus.CANCELLED: "🚫",
|
||||
}.get(job.status, "❓")
|
||||
|
||||
progress = f"{job.progress_percent:.0f}%" if job.status == JobStatus.PROCESSING else "-"
|
||||
|
||||
data.append([
|
||||
job.id[:8],
|
||||
job.type.value.title(),
|
||||
Path(job.input_path).name[:30],
|
||||
f"{status_emoji} {job.status.value.title()}",
|
||||
progress,
|
||||
job.current_stage or "-",
|
||||
])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_queue_stats() -> str:
|
||||
"""Get queue statistics for display."""
|
||||
stats = queue_manager.get_queue_stats()
|
||||
|
||||
queued = stats.get(JobStatus.QUEUED.value, 0)
|
||||
processing = stats.get(JobStatus.PROCESSING.value, 0)
|
||||
completed = stats.get(JobStatus.COMPLETED.value, 0)
|
||||
failed = stats.get(JobStatus.FAILED.value, 0)
|
||||
|
||||
current = stats.get("current_job")
|
||||
current_info = ""
|
||||
if current:
|
||||
current_info = f"\n\n**Currently Processing:**\n{Path(current.input_path).name}"
|
||||
|
||||
return (
|
||||
f"**Queue Status**\n\n"
|
||||
f"- Queued: {queued}\n"
|
||||
f"- Processing: {processing}\n"
|
||||
f"- Completed: {completed}\n"
|
||||
f"- Failed: {failed}"
|
||||
f"{current_info}"
|
||||
)
|
||||
|
||||
|
||||
def add_files_to_queue(
|
||||
files: Optional[list],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
) -> tuple[list[list], str]:
|
||||
"""Add uploaded files to queue."""
|
||||
if not files:
|
||||
return get_queue_data(), "No files selected"
|
||||
|
||||
added = 0
|
||||
for file in files:
|
||||
file_path = Path(file.name if hasattr(file, 'name') else file)
|
||||
|
||||
# Determine job type
|
||||
ext = file_path.suffix.lower()
|
||||
if ext in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]:
|
||||
job_type = JobType.IMAGE
|
||||
elif ext in [".mp4", ".avi", ".mov", ".mkv", ".webm"]:
|
||||
job_type = JobType.VIDEO
|
||||
else:
|
||||
continue
|
||||
|
||||
job = Job(
|
||||
type=job_type,
|
||||
input_path=str(file_path),
|
||||
config={
|
||||
"model_name": model_name,
|
||||
"scale": scale,
|
||||
"face_enhance": face_enhance,
|
||||
},
|
||||
)
|
||||
|
||||
queue_manager.add_job(job)
|
||||
added += 1
|
||||
|
||||
return get_queue_data(), f"Added {added} items to queue"
|
||||
|
||||
|
||||
def cancel_job(job_id: str) -> tuple[list[list], str]:
|
||||
"""Cancel a job."""
|
||||
if job_id:
|
||||
queue_manager.cancel_job(job_id)
|
||||
return get_queue_data(), f"Cancelled job: {job_id}"
|
||||
return get_queue_data(), "No job selected"
|
||||
|
||||
|
||||
def clear_completed() -> tuple[list[list], str]:
|
||||
"""Clear completed jobs from queue."""
|
||||
count = queue_manager.clear_completed()
|
||||
return get_queue_data(), f"Cleared {count} completed jobs"
|
||||
|
||||
|
||||
def create_batch_tab():
|
||||
"""Create the batch queue tab component."""
|
||||
with gr.Tab("Batch Queue", id="queue-tab"):
|
||||
gr.Markdown("## Batch Processing Queue")
|
||||
|
||||
with gr.Row():
|
||||
# Queue table
|
||||
with gr.Column(scale=2):
|
||||
queue_table = gr.Dataframe(
|
||||
headers=["ID", "Type", "Input", "Status", "Progress", "Stage"],
|
||||
datatype=["str", "str", "str", "str", "str", "str"],
|
||||
value=get_queue_data(),
|
||||
label="Queue",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Queue stats
|
||||
with gr.Column(scale=1):
|
||||
queue_stats = gr.Markdown(
|
||||
get_queue_stats(),
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
# Add to queue section
|
||||
with gr.Accordion("Add to Queue", open=True):
|
||||
with gr.Row():
|
||||
file_input = gr.File(
|
||||
label="Select Files",
|
||||
file_count="multiple",
|
||||
file_types=["image", "video"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=False,
|
||||
label="Face Enhancement",
|
||||
)
|
||||
|
||||
add_btn = gr.Button("Add to Queue", variant="primary")
|
||||
|
||||
# Queue actions
|
||||
with gr.Row():
|
||||
refresh_btn = gr.Button("Refresh", variant="secondary")
|
||||
clear_btn = gr.Button("Clear Completed", variant="secondary")
|
||||
|
||||
status_text = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
add_btn.click(
|
||||
fn=add_files_to_queue,
|
||||
inputs=[file_input, model_dropdown, scale_radio, face_enhance_cb],
|
||||
outputs=[queue_table, status_text],
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
fn=lambda: (get_queue_data(), get_queue_stats()),
|
||||
outputs=[queue_table, queue_stats],
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
fn=clear_completed,
|
||||
outputs=[queue_table, status_text],
|
||||
)
|
||||
|
||||
return {
|
||||
"queue_table": queue_table,
|
||||
"queue_stats": queue_stats,
|
||||
"status_text": status_text,
|
||||
}
|
||||
208
src/ui/components/history_tab.py
Normal file
208
src/ui/components/history_tab.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""History gallery tab component."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from src.storage.history import history_manager, HistoryItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_history_gallery() -> list[tuple[str, str]]:
|
||||
"""Get history items for gallery display."""
|
||||
items = history_manager.get_recent(limit=100)
|
||||
|
||||
gallery_items = []
|
||||
for item in items:
|
||||
# Use output path for display
|
||||
output_path = Path(item.output_path)
|
||||
if output_path.exists() and item.type == "image":
|
||||
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||
gallery_items.append((str(output_path), caption))
|
||||
|
||||
return gallery_items
|
||||
|
||||
|
||||
def get_history_stats() -> str:
|
||||
"""Get history statistics."""
|
||||
stats = history_manager.get_statistics()
|
||||
|
||||
return (
|
||||
f"**History Statistics**\n\n"
|
||||
f"- Total Items: {stats['total_items']}\n"
|
||||
f"- Images: {stats['images']}\n"
|
||||
f"- Videos: {stats['videos']}\n"
|
||||
f"- Total Processing Time: {stats['total_processing_time']/60:.1f} min\n"
|
||||
f"- Total Input Size: {stats['total_input_size_mb']:.1f} MB\n"
|
||||
f"- Total Output Size: {stats['total_output_size_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
|
||||
def search_history(query: str) -> list[tuple[str, str]]:
|
||||
"""Search history by filename."""
|
||||
if not query:
|
||||
return get_history_gallery()
|
||||
|
||||
items = history_manager.search(query)
|
||||
|
||||
gallery_items = []
|
||||
for item in items:
|
||||
output_path = Path(item.output_path)
|
||||
if output_path.exists() and item.type == "image":
|
||||
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||
gallery_items.append((str(output_path), caption))
|
||||
|
||||
return gallery_items
|
||||
|
||||
|
||||
def get_item_details(evt: gr.SelectData) -> tuple[Optional[tuple], str]:
|
||||
"""Get details for selected history item."""
|
||||
if evt.index is None:
|
||||
return None, "Select an item to view details"
|
||||
|
||||
items = history_manager.get_recent(limit=100)
|
||||
|
||||
# Filter to only images (same as gallery)
|
||||
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||
|
||||
if evt.index >= len(image_items):
|
||||
return None, "Item not found"
|
||||
|
||||
item = image_items[evt.index]
|
||||
|
||||
# Load images for slider
|
||||
input_path = Path(item.input_path)
|
||||
output_path = Path(item.output_path)
|
||||
|
||||
slider_images = None
|
||||
if input_path.exists() and output_path.exists():
|
||||
slider_images = (str(input_path), str(output_path))
|
||||
|
||||
# Format details
|
||||
details = (
|
||||
f"**{item.output_filename}**\n\n"
|
||||
f"- **Type:** {item.type.title()}\n"
|
||||
f"- **Model:** {item.model}\n"
|
||||
f"- **Scale:** {item.scale}x\n"
|
||||
f"- **Face Enhancement:** {'Yes' if item.face_enhance else 'No'}\n"
|
||||
f"- **Input:** {item.input_width}x{item.input_height}\n"
|
||||
f"- **Output:** {item.output_width}x{item.output_height}\n"
|
||||
f"- **Processing Time:** {item.processing_time_seconds:.1f}s\n"
|
||||
f"- **Input Size:** {item.input_size_bytes/1024:.1f} KB\n"
|
||||
f"- **Output Size:** {item.output_size_bytes/1024:.1f} KB\n"
|
||||
f"- **Created:** {item.created_at.strftime('%Y-%m-%d %H:%M')}"
|
||||
)
|
||||
|
||||
return slider_images, details
|
||||
|
||||
|
||||
def delete_selected(evt: gr.SelectData) -> tuple[list, str]:
|
||||
"""Delete selected history item."""
|
||||
if evt.index is None:
|
||||
return get_history_gallery(), "No item selected"
|
||||
|
||||
items = history_manager.get_recent(limit=100)
|
||||
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||
|
||||
if evt.index >= len(image_items):
|
||||
return get_history_gallery(), "Item not found"
|
||||
|
||||
item = image_items[evt.index]
|
||||
history_manager.delete_item(item.id)
|
||||
|
||||
return get_history_gallery(), f"Deleted: {item.output_filename}"
|
||||
|
||||
|
||||
def create_history_tab():
|
||||
"""Create the history gallery tab component."""
|
||||
with gr.Tab("History", id="history-tab"):
|
||||
gr.Markdown("## Processing History")
|
||||
|
||||
with gr.Row():
|
||||
# Search
|
||||
search_box = gr.Textbox(
|
||||
label="Search",
|
||||
placeholder="Search by filename...",
|
||||
scale=2,
|
||||
)
|
||||
|
||||
filter_dropdown = gr.Dropdown(
|
||||
choices=["All", "Images", "Videos"],
|
||||
value="All",
|
||||
label="Filter",
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Gallery
|
||||
with gr.Column(scale=2):
|
||||
gallery = gr.Gallery(
|
||||
value=get_history_gallery(),
|
||||
label="History",
|
||||
columns=4,
|
||||
object_fit="cover",
|
||||
height=400,
|
||||
allow_preview=True,
|
||||
)
|
||||
|
||||
# Details panel
|
||||
with gr.Column(scale=1):
|
||||
stats_display = gr.Markdown(
|
||||
get_history_stats(),
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
# Selected item details
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
try:
|
||||
from gradio_imageslider import ImageSlider
|
||||
|
||||
comparison_slider = ImageSlider(
|
||||
label="Before / After",
|
||||
type="filepath",
|
||||
)
|
||||
except ImportError:
|
||||
comparison_slider = gr.Image(
|
||||
label="Selected Image",
|
||||
type="filepath",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
item_details = gr.Markdown(
|
||||
"Select an item to view details",
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
download_btn = gr.Button("Download", variant="secondary")
|
||||
delete_btn = gr.Button("Delete", variant="stop")
|
||||
|
||||
status_text = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
search_box.change(
|
||||
fn=search_history,
|
||||
inputs=[search_box],
|
||||
outputs=[gallery],
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
fn=get_item_details,
|
||||
outputs=[comparison_slider, item_details],
|
||||
)
|
||||
|
||||
delete_btn.click(
|
||||
fn=lambda: (get_history_gallery(), get_history_stats()),
|
||||
outputs=[gallery, stats_display],
|
||||
)
|
||||
|
||||
return {
|
||||
"gallery": gallery,
|
||||
"comparison_slider": comparison_slider,
|
||||
"item_details": item_details,
|
||||
"stats_display": stats_display,
|
||||
}
|
||||
292
src/ui/components/image_tab.py
Normal file
292
src/ui/components/image_tab.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Image upscaling tab component."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.config import OUTPUT_DIR, config, get_model_choices
|
||||
from src.processing.upscaler import upscale_image, save_image, get_output_dimensions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_image(
|
||||
input_image: Optional[np.ndarray],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
tile_size: str,
|
||||
denoise: float,
|
||||
output_format: str,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
) -> tuple[Optional[tuple], Optional[str], str]:
|
||||
"""
|
||||
Process image for upscaling.
|
||||
|
||||
Args:
|
||||
input_image: Input image as numpy array
|
||||
model_name: Model to use
|
||||
scale: Scale factor (2 or 4)
|
||||
face_enhance: Enable face enhancement
|
||||
tile_size: Tile size setting ("Auto", "256", "512", "1024")
|
||||
denoise: Denoise strength (0-1)
|
||||
output_format: Output format (PNG, JPG, WebP)
|
||||
|
||||
Returns:
|
||||
Tuple of (slider_images, output_path, status_message)
|
||||
"""
|
||||
if input_image is None:
|
||||
return None, None, "Please upload an image first."
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Parse tile size
|
||||
tile = 0 if tile_size == "Auto" else int(tile_size)
|
||||
|
||||
# Progress callback for UI
|
||||
def progress_callback(pct: float, stage: str):
|
||||
progress(pct, desc=stage)
|
||||
|
||||
progress(0, desc="Starting upscale...")
|
||||
|
||||
# Upscale the image
|
||||
result = upscale_image(
|
||||
input_image,
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
tile_size=tile,
|
||||
denoise_strength=denoise,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Apply face enhancement if requested
|
||||
if face_enhance:
|
||||
progress(0.7, desc="Enhancing faces...")
|
||||
try:
|
||||
from src.processing.face_enhancer import enhance_faces
|
||||
|
||||
result.image = enhance_faces(result.image)
|
||||
except ImportError:
|
||||
logger.warning("Face enhancement not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Face enhancement failed: {e}")
|
||||
|
||||
progress(0.85, desc="Saving output...")
|
||||
|
||||
# Generate output filename
|
||||
output_name = f"upscaled_{uuid.uuid4().hex[:8]}"
|
||||
output_path = save_image(
|
||||
result.image,
|
||||
OUTPUT_DIR / output_name,
|
||||
format=output_format.lower(),
|
||||
)
|
||||
|
||||
# Convert for display (BGR to RGB)
|
||||
input_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) if input_image.shape[2] == 3 else input_image
|
||||
output_rgb = cv2.cvtColor(result.image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
progress(1.0, desc="Complete!")
|
||||
|
||||
# Calculate stats
|
||||
elapsed = time.time() - start_time
|
||||
input_size = f"{result.input_width}x{result.input_height}"
|
||||
output_size = f"{result.output_width}x{result.output_height}"
|
||||
|
||||
status = (
|
||||
f"Upscaled {input_size} -> {output_size} in {elapsed:.1f}s | "
|
||||
f"Model: {model_name} | Saved: {output_path.name}"
|
||||
)
|
||||
|
||||
# Return images for slider (input, output)
|
||||
return (input_rgb, output_rgb), str(output_path), status
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image processing failed")
|
||||
return None, None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def estimate_output(
|
||||
input_image: Optional[np.ndarray],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
) -> str:
|
||||
"""Estimate output dimensions for display."""
|
||||
if input_image is None:
|
||||
return "Upload an image to see estimated output size"
|
||||
|
||||
try:
|
||||
height, width = input_image.shape[:2]
|
||||
out_w, out_h = get_output_dimensions(width, height, model_name, scale)
|
||||
return f"Input: {width}x{height} -> Output: {out_w}x{out_h}"
|
||||
except Exception:
|
||||
return "Unable to estimate output size"
|
||||
|
||||
|
||||
def create_image_tab():
|
||||
"""Create the image upscaling tab component."""
|
||||
with gr.Tab("Image", id="image-tab"):
|
||||
# Status display at top
|
||||
status_text = gr.Markdown(
|
||||
"Upload an image to begin upscaling",
|
||||
elem_classes=["status-text"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Left column - Input
|
||||
with gr.Column(scale=1):
|
||||
input_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="numpy",
|
||||
sources=["upload", "clipboard"],
|
||||
elem_classes=["upload-area"],
|
||||
)
|
||||
|
||||
# Dimension estimate
|
||||
dimension_info = gr.Markdown(
|
||||
"Upload an image to see estimated output size",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Right column - Output (ImageSlider)
|
||||
with gr.Column(scale=1):
|
||||
try:
|
||||
from gradio_imageslider import ImageSlider
|
||||
|
||||
output_slider = ImageSlider(
|
||||
label="Before / After",
|
||||
type="numpy",
|
||||
show_download_button=True,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback if ImageSlider not available
|
||||
output_slider = gr.Image(
|
||||
label="Upscaled Result",
|
||||
type="numpy",
|
||||
show_download_button=True,
|
||||
)
|
||||
|
||||
output_file = gr.File(
|
||||
label="Download",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# Processing options
|
||||
with gr.Accordion("Processing Options", open=True):
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
info="Select upscaling model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale Factor",
|
||||
info="Output scale multiplier",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=config.default_face_enhance,
|
||||
label="Face Enhancement (GFPGAN)",
|
||||
info="Enhance faces in photos (not for anime)",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
tile_dropdown = gr.Dropdown(
|
||||
choices=["Auto", "256", "512", "1024"],
|
||||
value="Auto",
|
||||
label="Tile Size",
|
||||
info="Auto recommended for RTX 4090",
|
||||
)
|
||||
|
||||
denoise_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=0.5,
|
||||
step=0.1,
|
||||
label="Denoise Strength",
|
||||
info="Only for realesr-general-x4v3 model",
|
||||
)
|
||||
|
||||
format_radio = gr.Radio(
|
||||
choices=["PNG", "JPG", "WebP"],
|
||||
value="PNG",
|
||||
label="Output Format",
|
||||
)
|
||||
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
upscale_btn = gr.Button(
|
||||
"Upscale Image",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
elem_classes=["primary-action-btn"],
|
||||
)
|
||||
|
||||
clear_btn = gr.Button(
|
||||
"Clear",
|
||||
variant="secondary",
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
upscale_btn.click(
|
||||
fn=process_image,
|
||||
inputs=[
|
||||
input_image,
|
||||
model_dropdown,
|
||||
scale_radio,
|
||||
face_enhance_cb,
|
||||
tile_dropdown,
|
||||
denoise_slider,
|
||||
format_radio,
|
||||
],
|
||||
outputs=[output_slider, output_file, status_text],
|
||||
)
|
||||
|
||||
# Update dimension estimate when image changes
|
||||
input_image.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
# Also update when model/scale changes
|
||||
model_dropdown.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
scale_radio.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
# Clear button
|
||||
clear_btn.click(
|
||||
fn=lambda: (None, None, None, "Upload an image to begin upscaling"),
|
||||
outputs=[input_image, output_slider, output_file, status_text],
|
||||
)
|
||||
|
||||
return {
|
||||
"input_image": input_image,
|
||||
"output_slider": output_slider,
|
||||
"output_file": output_file,
|
||||
"status_text": status_text,
|
||||
"model_dropdown": model_dropdown,
|
||||
"scale_radio": scale_radio,
|
||||
"face_enhance_cb": face_enhance_cb,
|
||||
"upscale_btn": upscale_btn,
|
||||
}
|
||||
376
src/ui/components/video_tab.py
Normal file
376
src/ui/components/video_tab.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Video upscaling tab component."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
from src.config import OUTPUT_DIR, TEMP_DIR, VIDEO_CODECS, config, get_model_choices
|
||||
from src.processing.upscaler import upscale_image
|
||||
from src.video.extractor import VideoExtractor, get_video_info
|
||||
from src.video.encoder import encode_video
|
||||
from src.video.audio import extract_audio, has_audio_stream
|
||||
from src.video.checkpoint import checkpoint_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_video_metadata(video_path: Optional[str]) -> str:
|
||||
"""Get video metadata for display."""
|
||||
if not video_path:
|
||||
return "Upload a video to see details"
|
||||
|
||||
try:
|
||||
metadata = get_video_info(Path(video_path))
|
||||
return (
|
||||
f"**Resolution:** {metadata.width}x{metadata.height}\n"
|
||||
f"**Duration:** {metadata.duration_seconds:.1f}s\n"
|
||||
f"**FPS:** {metadata.fps:.2f}\n"
|
||||
f"**Frames:** {metadata.total_frames}\n"
|
||||
f"**Codec:** {metadata.codec}\n"
|
||||
f"**Audio:** {'Yes' if metadata.has_audio else 'No'}"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error reading video: {e}"
|
||||
|
||||
|
||||
def estimate_video_output(
|
||||
video_path: Optional[str],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
) -> str:
|
||||
"""Estimate output details for video."""
|
||||
if not video_path:
|
||||
return "Upload a video to see estimated output"
|
||||
|
||||
try:
|
||||
metadata = get_video_info(Path(video_path))
|
||||
out_w = metadata.width * scale
|
||||
out_h = metadata.height * scale
|
||||
|
||||
# Estimate processing time (rough: ~2 fps for 4x on RTX 4090)
|
||||
est_time = metadata.total_frames / 2 if scale == 4 else metadata.total_frames / 4
|
||||
est_minutes = est_time / 60
|
||||
|
||||
return (
|
||||
f"**Output Resolution:** {out_w}x{out_h}\n"
|
||||
f"**Estimated Time:** ~{est_minutes:.1f} minutes"
|
||||
)
|
||||
except Exception:
|
||||
return "Unable to estimate"
|
||||
|
||||
|
||||
def process_video(
|
||||
video_path: Optional[str],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
codec: str,
|
||||
crf: int,
|
||||
preset: str,
|
||||
preserve_audio: bool,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""
|
||||
Process video for upscaling.
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
model_name: Model to use
|
||||
scale: Scale factor
|
||||
face_enhance: Enable face enhancement
|
||||
codec: Output codec
|
||||
crf: Quality (lower = better)
|
||||
preset: Encoding preset
|
||||
preserve_audio: Keep original audio
|
||||
progress: Gradio progress tracker
|
||||
|
||||
Returns:
|
||||
Tuple of (output_path, status_message)
|
||||
"""
|
||||
if not video_path:
|
||||
return None, "Please upload a video first."
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
job_id = uuid.uuid4().hex[:8]
|
||||
|
||||
progress(0, desc="Analyzing video...")
|
||||
|
||||
# Get video info
|
||||
video_path = Path(video_path)
|
||||
metadata = get_video_info(video_path)
|
||||
|
||||
logger.info(
|
||||
f"Processing video: {video_path.name}, "
|
||||
f"{metadata.width}x{metadata.height}, "
|
||||
f"{metadata.total_frames} frames"
|
||||
)
|
||||
|
||||
# Create temp directories
|
||||
temp_dir = TEMP_DIR / f"video_{job_id}"
|
||||
frames_dir = temp_dir / "frames"
|
||||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract audio if needed
|
||||
audio_path = None
|
||||
if preserve_audio and has_audio_stream(video_path):
|
||||
progress(0.02, desc="Extracting audio...")
|
||||
audio_path = extract_audio(video_path, temp_dir / "audio")
|
||||
|
||||
# Create checkpoint
|
||||
checkpoint = checkpoint_manager.create_checkpoint(
|
||||
job_id=job_id,
|
||||
input_path=video_path,
|
||||
total_frames=metadata.total_frames,
|
||||
frames_dir=frames_dir,
|
||||
audio_path=audio_path,
|
||||
config={
|
||||
"model_name": model_name,
|
||||
"scale": scale,
|
||||
"face_enhance": face_enhance,
|
||||
"codec": codec,
|
||||
"crf": crf,
|
||||
"preset": preset,
|
||||
},
|
||||
)
|
||||
|
||||
# Process frames
|
||||
progress(0.05, desc="Processing frames...")
|
||||
|
||||
with VideoExtractor(video_path) as extractor:
|
||||
total = metadata.total_frames
|
||||
start_frame = checkpoint.processed_frames
|
||||
|
||||
for frame_idx, frame in extractor.extract_frames(start_frame=start_frame):
|
||||
# Upscale frame (RGB input)
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
result = upscale_image(
|
||||
frame_bgr,
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# Apply face enhancement if enabled
|
||||
if face_enhance:
|
||||
try:
|
||||
from src.processing.face_enhancer import enhance_faces
|
||||
result.image = enhance_faces(result.image)
|
||||
except Exception as e:
|
||||
logger.warning(f"Face enhancement failed on frame {frame_idx}: {e}")
|
||||
|
||||
# Save frame
|
||||
frame_path = frames_dir / f"frame_{frame_idx:08d}.png"
|
||||
cv2.imwrite(str(frame_path), result.image)
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint_manager.update_checkpoint(checkpoint, frame_idx)
|
||||
|
||||
# Update progress
|
||||
pct = 0.05 + (0.85 * (frame_idx + 1) / total)
|
||||
progress(
|
||||
pct,
|
||||
desc=f"Processing frame {frame_idx + 1}/{total}",
|
||||
)
|
||||
|
||||
# Encode output video
|
||||
progress(0.9, desc=f"Encoding video ({codec})...")
|
||||
|
||||
output_name = f"{video_path.stem}_upscaled_{job_id}"
|
||||
output_path = OUTPUT_DIR / f"{output_name}.mp4"
|
||||
|
||||
encode_video(
|
||||
frames_dir=frames_dir,
|
||||
output_path=output_path,
|
||||
fps=metadata.fps,
|
||||
codec=codec,
|
||||
crf=crf,
|
||||
preset=preset,
|
||||
audio_path=audio_path,
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
progress(0.98, desc="Cleaning up...")
|
||||
checkpoint_manager.delete_checkpoint(job_id)
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
progress(1.0, desc="Complete!")
|
||||
|
||||
# Calculate stats
|
||||
elapsed = time.time() - start_time
|
||||
output_size = output_path.stat().st_size / (1024 * 1024) # MB
|
||||
fps_actual = metadata.total_frames / elapsed
|
||||
|
||||
status = (
|
||||
f"Completed in {elapsed/60:.1f} minutes ({fps_actual:.2f} fps)\n"
|
||||
f"Output: {output_path.name} ({output_size:.1f} MB)"
|
||||
)
|
||||
|
||||
return str(output_path), status
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Video processing failed")
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def create_video_tab():
|
||||
"""Create the video upscaling tab component."""
|
||||
with gr.Tab("Video", id="video-tab"):
|
||||
# Status display
|
||||
status_text = gr.Markdown(
|
||||
"Upload a video to begin processing",
|
||||
elem_classes=["status-text"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Left column - Input
|
||||
with gr.Column(scale=1):
|
||||
input_video = gr.Video(
|
||||
label="Input Video",
|
||||
sources=["upload"],
|
||||
)
|
||||
|
||||
# Video info display
|
||||
video_info = gr.Markdown(
|
||||
"Upload a video to see details",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Right column - Output
|
||||
with gr.Column(scale=1):
|
||||
output_video = gr.Video(
|
||||
label="Upscaled Video",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Output estimate
|
||||
output_estimate = gr.Markdown(
|
||||
"Upload a video to see estimated output",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Processing options
|
||||
with gr.Accordion("Processing Options", open=True):
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale Factor",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=False,
|
||||
label="Face Enhancement",
|
||||
info="Not recommended for anime",
|
||||
)
|
||||
|
||||
# Codec options
|
||||
with gr.Accordion("Output Codec Settings", open=True):
|
||||
with gr.Row():
|
||||
codec_radio = gr.Radio(
|
||||
choices=list(VIDEO_CODECS.keys()),
|
||||
value=config.default_video_codec,
|
||||
label="Codec",
|
||||
info="H.265 recommended for quality/size balance",
|
||||
)
|
||||
|
||||
crf_slider = gr.Slider(
|
||||
minimum=15,
|
||||
maximum=35,
|
||||
value=config.default_video_crf,
|
||||
step=1,
|
||||
label="Quality (CRF)",
|
||||
info="Lower = better quality, larger file",
|
||||
)
|
||||
|
||||
preset_dropdown = gr.Dropdown(
|
||||
choices=["slow", "medium", "fast"],
|
||||
value=config.default_video_preset,
|
||||
label="Preset",
|
||||
info="Slow = best quality",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
preserve_audio_cb = gr.Checkbox(
|
||||
value=True,
|
||||
label="Preserve Audio",
|
||||
info="Keep original audio track",
|
||||
)
|
||||
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
process_btn = gr.Button(
|
||||
"Start Processing",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
elem_classes=["primary-action-btn"],
|
||||
)
|
||||
|
||||
cancel_btn = gr.Button(
|
||||
"Cancel",
|
||||
variant="stop",
|
||||
visible=False, # TODO: implement cancellation
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
process_btn.click(
|
||||
fn=process_video,
|
||||
inputs=[
|
||||
input_video,
|
||||
model_dropdown,
|
||||
scale_radio,
|
||||
face_enhance_cb,
|
||||
codec_radio,
|
||||
crf_slider,
|
||||
preset_dropdown,
|
||||
preserve_audio_cb,
|
||||
],
|
||||
outputs=[output_video, status_text],
|
||||
)
|
||||
|
||||
# Update video info when video changes
|
||||
input_video.change(
|
||||
fn=get_video_metadata,
|
||||
inputs=[input_video],
|
||||
outputs=[video_info],
|
||||
)
|
||||
|
||||
# Update output estimate
|
||||
input_video.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
model_dropdown.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
scale_radio.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
return {
|
||||
"input_video": input_video,
|
||||
"output_video": output_video,
|
||||
"status_text": status_text,
|
||||
"process_btn": process_btn,
|
||||
}
|
||||
1
src/ui/handlers/__init__.py
Normal file
1
src/ui/handlers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Event handlers for UI interactions."""
|
||||
284
src/ui/theme.py
Normal file
284
src/ui/theme.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Custom dark theme for Real-ESRGAN Web UI."""
|
||||
|
||||
import gradio as gr
|
||||
from gradio.themes import Base, colors, sizes
|
||||
|
||||
|
||||
class UpscalerTheme(Base):
|
||||
"""
|
||||
Professional dark theme optimized for image/video upscaling UI.
|
||||
|
||||
Features:
|
||||
- Dark background for better image comparison
|
||||
- Blue accent colors for actions
|
||||
- Comfortable spacing for power users
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
# Core colors
|
||||
primary_hue=colors.blue,
|
||||
secondary_hue=colors.slate,
|
||||
neutral_hue=colors.slate,
|
||||
# Sizing
|
||||
spacing_size=sizes.spacing_md,
|
||||
radius_size=sizes.radius_md,
|
||||
text_size=sizes.text_md,
|
||||
# Fonts
|
||||
font=(
|
||||
"Inter",
|
||||
"ui-sans-serif",
|
||||
"-apple-system",
|
||||
"BlinkMacSystemFont",
|
||||
"Segoe UI",
|
||||
"Roboto",
|
||||
"sans-serif",
|
||||
),
|
||||
font_mono=(
|
||||
"JetBrains Mono",
|
||||
"ui-monospace",
|
||||
"SFMono-Regular",
|
||||
"Menlo",
|
||||
"monospace",
|
||||
),
|
||||
)
|
||||
|
||||
# Dark mode overrides
|
||||
super().set(
|
||||
# Backgrounds - dark with subtle contrast
|
||||
body_background_fill="*neutral_950",
|
||||
body_background_fill_dark="*neutral_950",
|
||||
block_background_fill="*neutral_900",
|
||||
block_background_fill_dark="*neutral_900",
|
||||
# Borders
|
||||
block_border_width="1px",
|
||||
block_border_color="*neutral_800",
|
||||
block_border_color_dark="*neutral_700",
|
||||
border_color_primary="*primary_600",
|
||||
# Inputs
|
||||
input_background_fill="*neutral_800",
|
||||
input_background_fill_dark="*neutral_800",
|
||||
input_border_color="*neutral_700",
|
||||
input_border_color_focus="*primary_500",
|
||||
input_border_color_focus_dark="*primary_400",
|
||||
# Buttons - primary (blue)
|
||||
button_primary_background_fill="*primary_600",
|
||||
button_primary_background_fill_hover="*primary_500",
|
||||
button_primary_text_color="white",
|
||||
button_primary_border_color="*primary_600",
|
||||
# Buttons - secondary
|
||||
button_secondary_background_fill="*neutral_700",
|
||||
button_secondary_background_fill_hover="*neutral_600",
|
||||
button_secondary_text_color="*neutral_100",
|
||||
button_secondary_border_color="*neutral_600",
|
||||
# Buttons - stop (for cancel)
|
||||
button_cancel_background_fill="*error_600",
|
||||
button_cancel_background_fill_hover="*error_500",
|
||||
button_cancel_text_color="white",
|
||||
# Shadows
|
||||
block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3), 0 2px 4px -2px rgba(0, 0, 0, 0.2)",
|
||||
# Text
|
||||
body_text_color="*neutral_200",
|
||||
body_text_color_subdued="*neutral_400",
|
||||
body_text_color_dark="*neutral_200",
|
||||
# Labels
|
||||
block_title_text_color="*neutral_100",
|
||||
block_label_text_color="*neutral_300",
|
||||
# Sliders
|
||||
slider_color="*primary_500",
|
||||
slider_color_dark="*primary_400",
|
||||
# Checkboxes
|
||||
checkbox_background_color_selected="*primary_600",
|
||||
checkbox_background_color_selected_dark="*primary_500",
|
||||
checkbox_border_color_selected="*primary_600",
|
||||
# Tables/Dataframes
|
||||
table_border_color="*neutral_700",
|
||||
table_even_background_fill="*neutral_850",
|
||||
table_odd_background_fill="*neutral_900",
|
||||
# Tabs
|
||||
tab_selected_text_color="*primary_400",
|
||||
)
|
||||
|
||||
|
||||
# Custom CSS for additional styling
|
||||
CUSTOM_CSS = """
|
||||
/* Force dark color scheme */
|
||||
:root {
|
||||
color-scheme: dark;
|
||||
}
|
||||
|
||||
/* Header styling */
|
||||
.app-header {
|
||||
background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
|
||||
padding: 1rem 1.5rem;
|
||||
border-bottom: 1px solid #334155;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.app-header h1 {
|
||||
color: #f1f5f9;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* Tab navigation styling */
|
||||
.tabs .tab-nav {
|
||||
background: #1e293b;
|
||||
border-radius: 8px 8px 0 0;
|
||||
padding: 0.5rem 0.5rem 0;
|
||||
}
|
||||
|
||||
.tabs .tab-nav button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
font-weight: 500;
|
||||
border-radius: 8px 8px 0 0;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.tabs .tab-nav button.selected {
|
||||
background: #334155;
|
||||
border-bottom: 2px solid #3b82f6;
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* Image comparison slider */
|
||||
.image-slider-container {
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
/* Progress bar styling */
|
||||
.progress-bar {
|
||||
height: 8px;
|
||||
border-radius: 4px;
|
||||
background: #1e293b;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.progress-bar > div {
|
||||
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 50%, #60a5fa 100%);
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
|
||||
/* Gallery items */
|
||||
.gallery-item {
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.gallery-item:hover {
|
||||
transform: scale(1.02);
|
||||
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
/* Accordion styling */
|
||||
.accordion {
|
||||
border: 1px solid #334155;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.accordion > .label-wrap {
|
||||
background: #1e293b;
|
||||
padding: 0.75rem 1rem;
|
||||
}
|
||||
|
||||
/* Status bar */
|
||||
.status-bar {
|
||||
position: fixed;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
background: #0f172a;
|
||||
border-top: 1px solid #334155;
|
||||
padding: 0.5rem 1.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #94a3b8;
|
||||
z-index: 100;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.status-bar .status-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.status-bar .status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #22c55e;
|
||||
}
|
||||
|
||||
.status-bar .status-dot.busy {
|
||||
background: #eab308;
|
||||
}
|
||||
|
||||
/* Large primary buttons */
|
||||
.primary-action-btn {
|
||||
font-size: 1.125rem !important;
|
||||
padding: 0.875rem 2rem !important;
|
||||
font-weight: 600 !important;
|
||||
}
|
||||
|
||||
/* Upload area */
|
||||
.upload-area {
|
||||
border: 2px dashed #334155;
|
||||
border-radius: 12px;
|
||||
transition: border-color 0.2s ease;
|
||||
}
|
||||
|
||||
.upload-area:hover {
|
||||
border-color: #3b82f6;
|
||||
}
|
||||
|
||||
/* Processing info cards */
|
||||
.info-card {
|
||||
background: #1e293b;
|
||||
border: 1px solid #334155;
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.info-card h4 {
|
||||
color: #94a3b8;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.info-card .value {
|
||||
color: #f1f5f9;
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* Responsive adjustments */
|
||||
@media (max-width: 1200px) {
|
||||
.main-row {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
|
||||
/* Add padding at bottom for status bar */
|
||||
.gradio-container {
|
||||
padding-bottom: 60px !important;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def get_theme():
|
||||
"""Get the custom theme instance."""
|
||||
return UpscalerTheme()
|
||||
|
||||
|
||||
def get_css():
|
||||
"""Get custom CSS for additional styling."""
|
||||
return CUSTOM_CSS
|
||||
1
src/utils/__init__.py
Normal file
1
src/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility modules (memory management, validators)."""
|
||||
1
src/video/__init__.py
Normal file
1
src/video/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Video processing utilities (extraction, encoding, audio)."""
|
||||
221
src/video/audio.py
Normal file
221
src/video/audio.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Audio extraction and handling for video processing."""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioError(Exception):
|
||||
"""Error during audio processing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def check_ffmpeg():
|
||||
"""Check if FFmpeg is available."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
raise AudioError(
|
||||
"FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg"
|
||||
)
|
||||
|
||||
|
||||
def has_audio_stream(video_path: Path) -> bool:
|
||||
"""
|
||||
Check if video has an audio stream.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
True if video has audio
|
||||
"""
|
||||
check_ffmpeg()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "a",
|
||||
"-show_entries", "stream=codec_name",
|
||||
"-of", "csv=p=0",
|
||||
str(video_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return bool(result.stdout.strip())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_audio(
|
||||
video_path: Path,
|
||||
output_path: Path,
|
||||
copy_codec: bool = True,
|
||||
) -> Optional[Path]:
|
||||
"""
|
||||
Extract audio stream from video.
|
||||
|
||||
Args:
|
||||
video_path: Source video path
|
||||
output_path: Output audio path
|
||||
copy_codec: If True, copy without re-encoding
|
||||
|
||||
Returns:
|
||||
Path to extracted audio, or None if no audio
|
||||
"""
|
||||
check_ffmpeg()
|
||||
|
||||
if not has_audio_stream(video_path):
|
||||
logger.info(f"No audio stream in: {video_path}")
|
||||
return None
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if copy_codec:
|
||||
# Determine output extension based on source codec
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "a:0",
|
||||
"-show_entries", "stream=codec_name",
|
||||
"-of", "csv=p=0",
|
||||
str(video_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
codec = result.stdout.strip().lower()
|
||||
|
||||
# Map codec to extension
|
||||
ext_map = {
|
||||
"aac": ".aac",
|
||||
"mp3": ".mp3",
|
||||
"opus": ".opus",
|
||||
"vorbis": ".ogg",
|
||||
"flac": ".flac",
|
||||
"pcm_s16le": ".wav",
|
||||
"pcm_s24le": ".wav",
|
||||
}
|
||||
ext = ext_map.get(codec, ".aac")
|
||||
output_path = output_path.with_suffix(ext)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-vn", # No video
|
||||
"-acodec", "copy",
|
||||
str(output_path),
|
||||
]
|
||||
else:
|
||||
# Re-encode to AAC
|
||||
output_path = output_path.with_suffix(".aac")
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-vn",
|
||||
"-acodec", "aac",
|
||||
"-b:a", "320k",
|
||||
str(output_path),
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
logger.info(f"Extracted audio to: {output_path}")
|
||||
return output_path
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Audio extraction failed: {e.stderr}")
|
||||
raise AudioError(f"Audio extraction failed: {e.stderr}") from e
|
||||
|
||||
|
||||
def mux_audio_video(
|
||||
video_path: Path,
|
||||
audio_path: Path,
|
||||
output_path: Path,
|
||||
copy_streams: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Combine video and audio streams.
|
||||
|
||||
Args:
|
||||
video_path: Video-only file
|
||||
audio_path: Audio file
|
||||
output_path: Output path
|
||||
copy_streams: If True, copy without re-encoding
|
||||
|
||||
Returns:
|
||||
Path to muxed video
|
||||
"""
|
||||
check_ffmpeg()
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if copy_streams:
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-i", str(audio_path),
|
||||
"-c:v", "copy",
|
||||
"-c:a", "copy",
|
||||
"-map", "0:v:0",
|
||||
"-map", "1:a:0",
|
||||
str(output_path),
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-i", str(audio_path),
|
||||
"-c:v", "copy",
|
||||
"-c:a", "aac",
|
||||
"-b:a", "320k",
|
||||
"-map", "0:v:0",
|
||||
"-map", "1:a:0",
|
||||
str(output_path),
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
logger.info(f"Muxed audio/video to: {output_path}")
|
||||
return output_path
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Muxing failed: {e.stderr}")
|
||||
raise AudioError(f"Audio/video muxing failed: {e.stderr}") from e
|
||||
|
||||
|
||||
def get_audio_duration(audio_path: Path) -> float:
|
||||
"""
|
||||
Get duration of audio file in seconds.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
check_ffmpeg()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
str(audio_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return float(result.stdout.strip())
|
||||
except Exception:
|
||||
return 0.0
|
||||
250
src/video/checkpoint.py
Normal file
250
src/video/checkpoint.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Checkpoint system for resumable video processing."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config import CHECKPOINTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoCheckpoint:
|
||||
"""Checkpoint data for resumable video processing."""
|
||||
|
||||
job_id: str
|
||||
input_path: str
|
||||
input_hash: str # MD5 of first 1MB for verification
|
||||
total_frames: int
|
||||
processed_frames: int
|
||||
last_processed_frame: int
|
||||
frames_dir: str
|
||||
audio_path: Optional[str]
|
||||
config_json: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@property
|
||||
def progress_percent(self) -> float:
|
||||
"""Get progress as percentage."""
|
||||
if self.total_frames == 0:
|
||||
return 0.0
|
||||
return (self.processed_frames / self.total_frames) * 100
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if processing is complete."""
|
||||
return self.processed_frames >= self.total_frames
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manage video processing checkpoints for resume capability.
|
||||
|
||||
Features:
|
||||
- Save checkpoint every N frames
|
||||
- Verify input file hasn't changed
|
||||
- Clean up after completion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
save_interval: int = 100,
|
||||
):
|
||||
self.checkpoint_dir = checkpoint_dir or CHECKPOINTS_DIR
|
||||
self.save_interval = save_interval
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_checkpoint_path(self, job_id: str) -> Path:
|
||||
"""Get path to checkpoint file."""
|
||||
return self.checkpoint_dir / f"{job_id}.checkpoint.json"
|
||||
|
||||
def create_checkpoint(
|
||||
self,
|
||||
job_id: str,
|
||||
input_path: Path,
|
||||
total_frames: int,
|
||||
frames_dir: Path,
|
||||
audio_path: Optional[Path],
|
||||
config: dict,
|
||||
) -> VideoCheckpoint:
|
||||
"""
|
||||
Create a new checkpoint for a video job.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
input_path: Path to input video
|
||||
total_frames: Total frames to process
|
||||
frames_dir: Directory for processed frames
|
||||
audio_path: Path to extracted audio (if any)
|
||||
config: Processing configuration dict
|
||||
|
||||
Returns:
|
||||
New VideoCheckpoint instance
|
||||
"""
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
checkpoint = VideoCheckpoint(
|
||||
job_id=job_id,
|
||||
input_path=str(input_path),
|
||||
input_hash=self._compute_file_hash(input_path),
|
||||
total_frames=total_frames,
|
||||
processed_frames=0,
|
||||
last_processed_frame=-1,
|
||||
frames_dir=str(frames_dir),
|
||||
audio_path=str(audio_path) if audio_path else None,
|
||||
config_json=json.dumps(config),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
self._save_checkpoint(checkpoint)
|
||||
logger.info(f"Created checkpoint for job: {job_id}")
|
||||
|
||||
return checkpoint
|
||||
|
||||
def update_checkpoint(
|
||||
self,
|
||||
checkpoint: VideoCheckpoint,
|
||||
processed_frame: int,
|
||||
force_save: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update checkpoint with progress.
|
||||
|
||||
Args:
|
||||
checkpoint: Checkpoint to update
|
||||
processed_frame: Frame index that was just processed
|
||||
force_save: Force save even if not at interval
|
||||
"""
|
||||
checkpoint.processed_frames += 1
|
||||
checkpoint.last_processed_frame = processed_frame
|
||||
checkpoint.updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
# Save periodically or when forced
|
||||
if force_save or checkpoint.processed_frames % self.save_interval == 0:
|
||||
self._save_checkpoint(checkpoint)
|
||||
logger.debug(
|
||||
f"Checkpoint saved: {checkpoint.processed_frames}/{checkpoint.total_frames}"
|
||||
)
|
||||
|
||||
def load_checkpoint(self, job_id: str) -> Optional[VideoCheckpoint]:
|
||||
"""
|
||||
Load existing checkpoint if available.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
|
||||
Returns:
|
||||
VideoCheckpoint if valid, None otherwise
|
||||
"""
|
||||
path = self.get_checkpoint_path(job_id)
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
checkpoint = VideoCheckpoint(**data)
|
||||
|
||||
# Verify input file hasn't changed
|
||||
if Path(checkpoint.input_path).exists():
|
||||
current_hash = self._compute_file_hash(Path(checkpoint.input_path))
|
||||
if current_hash != checkpoint.input_hash:
|
||||
logger.warning("Input file changed since checkpoint, starting fresh")
|
||||
self.delete_checkpoint(job_id)
|
||||
return None
|
||||
|
||||
# Verify frames directory exists
|
||||
if not Path(checkpoint.frames_dir).exists():
|
||||
logger.warning("Frames directory missing, starting fresh")
|
||||
self.delete_checkpoint(job_id)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loaded checkpoint: {checkpoint.processed_frames}/{checkpoint.total_frames} frames"
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
self.delete_checkpoint(job_id)
|
||||
return None
|
||||
|
||||
def delete_checkpoint(self, job_id: str) -> None:
|
||||
"""Delete checkpoint file."""
|
||||
path = self.get_checkpoint_path(job_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.info(f"Deleted checkpoint: {job_id}")
|
||||
|
||||
def _save_checkpoint(self, checkpoint: VideoCheckpoint) -> None:
|
||||
"""Save checkpoint to disk."""
|
||||
path = self.get_checkpoint_path(checkpoint.job_id)
|
||||
with open(path, "w") as f:
|
||||
json.dump(asdict(checkpoint), f, indent=2)
|
||||
|
||||
def _compute_file_hash(
|
||||
self,
|
||||
path: Path,
|
||||
chunk_size: int = 1024 * 1024,
|
||||
) -> str:
|
||||
"""
|
||||
Compute MD5 hash of first chunk of file.
|
||||
|
||||
Using only the first 1MB is fast while still detecting
|
||||
if the file has been modified.
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
with open(path, "rb") as f:
|
||||
chunk = f.read(chunk_size)
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
def get_all_checkpoints(self) -> list[VideoCheckpoint]:
|
||||
"""Get all existing checkpoints."""
|
||||
checkpoints = []
|
||||
for path in self.checkpoint_dir.glob("*.checkpoint.json"):
|
||||
job_id = path.stem.replace(".checkpoint", "")
|
||||
checkpoint = self.load_checkpoint(job_id)
|
||||
if checkpoint:
|
||||
checkpoints.append(checkpoint)
|
||||
return checkpoints
|
||||
|
||||
def cleanup_old_checkpoints(self, max_age_days: int = 7) -> int:
|
||||
"""
|
||||
Clean up old checkpoint files.
|
||||
|
||||
Args:
|
||||
max_age_days: Delete checkpoints older than this
|
||||
|
||||
Returns:
|
||||
Number of checkpoints deleted
|
||||
"""
|
||||
import time
|
||||
|
||||
deleted = 0
|
||||
max_age_seconds = max_age_days * 24 * 60 * 60
|
||||
now = time.time()
|
||||
|
||||
for path in self.checkpoint_dir.glob("*.checkpoint.json"):
|
||||
if now - path.stat().st_mtime > max_age_seconds:
|
||||
path.unlink()
|
||||
deleted += 1
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Cleaned up {deleted} old checkpoints")
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
# Global checkpoint manager
|
||||
checkpoint_manager = CheckpointManager()
|
||||
274
src/video/encoder.py
Normal file
274
src/video/encoder.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Video encoding using FFmpeg."""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config import VIDEO_CODECS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncodingResult:
|
||||
"""Result of video encoding."""
|
||||
|
||||
output_path: Path
|
||||
codec: str
|
||||
duration_seconds: float
|
||||
file_size_bytes: int
|
||||
|
||||
|
||||
class VideoEncoderError(Exception):
|
||||
"""Error during video encoding."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def check_ffmpeg():
|
||||
"""Check if FFmpeg is available."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
raise VideoEncoderError(
|
||||
"FFmpeg not found. Please install FFmpeg: sudo apt install ffmpeg"
|
||||
)
|
||||
|
||||
|
||||
def check_nvenc():
|
||||
"""Check if NVENC is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-hide_banner", "-encoders"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return "h264_nvenc" in result.stdout
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class VideoEncoder:
|
||||
"""
|
||||
Encode video frames using FFmpeg.
|
||||
|
||||
Supports multiple codecs:
|
||||
- H.264 (libx264 / h264_nvenc)
|
||||
- H.265 (libx265 / hevc_nvenc)
|
||||
- AV1 (libsvtav1 / av1_nvenc)
|
||||
"""
|
||||
|
||||
def __init__(self, use_nvenc: bool = True):
|
||||
check_ffmpeg()
|
||||
self.use_nvenc = use_nvenc and check_nvenc()
|
||||
|
||||
if self.use_nvenc:
|
||||
logger.info("NVENC hardware encoding available")
|
||||
else:
|
||||
logger.info("Using software encoding")
|
||||
|
||||
def encode_frames(
|
||||
self,
|
||||
frames_dir: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
codec: str = "H.265",
|
||||
crf: int = 20,
|
||||
preset: str = "slow",
|
||||
audio_path: Optional[Path] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> EncodingResult:
|
||||
"""
|
||||
Encode frames to video.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing numbered frame images
|
||||
output_path: Output video path
|
||||
fps: Frame rate
|
||||
codec: Codec name ("H.264", "H.265", "AV1")
|
||||
crf: Quality (lower = better, 0-51)
|
||||
preset: Speed preset
|
||||
audio_path: Optional audio file to mux
|
||||
width: Optional output width (for scaling)
|
||||
height: Optional output height (for scaling)
|
||||
|
||||
Returns:
|
||||
EncodingResult with output info
|
||||
"""
|
||||
if codec not in VIDEO_CODECS:
|
||||
raise VideoEncoderError(f"Unknown codec: {codec}")
|
||||
|
||||
codec_config = VIDEO_CODECS[codec]
|
||||
|
||||
# Build FFmpeg command
|
||||
cmd = self._build_command(
|
||||
frames_dir=frames_dir,
|
||||
output_path=output_path,
|
||||
fps=fps,
|
||||
codec_config=codec_config,
|
||||
crf=crf,
|
||||
preset=preset,
|
||||
audio_path=audio_path,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
logger.info(f"Encoding video: {' '.join(cmd)}")
|
||||
|
||||
# Run FFmpeg
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg error: {e.stderr}")
|
||||
raise VideoEncoderError(f"Encoding failed: {e.stderr}") from e
|
||||
|
||||
# Get output info
|
||||
output_path = Path(output_path)
|
||||
file_size = output_path.stat().st_size if output_path.exists() else 0
|
||||
|
||||
return EncodingResult(
|
||||
output_path=output_path,
|
||||
codec=codec,
|
||||
duration_seconds=0, # TODO: probe output
|
||||
file_size_bytes=file_size,
|
||||
)
|
||||
|
||||
def _build_command(
|
||||
self,
|
||||
frames_dir: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
codec_config,
|
||||
crf: int,
|
||||
preset: str,
|
||||
audio_path: Optional[Path],
|
||||
width: Optional[int],
|
||||
height: Optional[int],
|
||||
) -> list[str]:
|
||||
"""Build FFmpeg command."""
|
||||
# Frame input pattern (expects frame_%08d.png format)
|
||||
frame_pattern = str(frames_dir / "frame_%08d.png")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y", # Overwrite output
|
||||
"-framerate", str(fps),
|
||||
"-i", frame_pattern,
|
||||
]
|
||||
|
||||
# Add audio input if provided
|
||||
if audio_path and audio_path.exists():
|
||||
cmd.extend(["-i", str(audio_path)])
|
||||
|
||||
# Video filter for scaling if needed
|
||||
vf_filters = []
|
||||
if width and height:
|
||||
# Ensure even dimensions
|
||||
w = width if width % 2 == 0 else width + 1
|
||||
h = height if height % 2 == 0 else height + 1
|
||||
vf_filters.append(f"scale={w}:{h}:flags=lanczos")
|
||||
|
||||
if vf_filters:
|
||||
cmd.extend(["-vf", ",".join(vf_filters)])
|
||||
|
||||
# Codec settings
|
||||
cmd.extend(self._get_codec_args(codec_config, crf, preset))
|
||||
|
||||
# Pixel format
|
||||
cmd.extend(["-pix_fmt", "yuv420p"])
|
||||
|
||||
# Audio settings
|
||||
if audio_path and audio_path.exists():
|
||||
cmd.extend(["-c:a", "aac", "-b:a", "320k"])
|
||||
|
||||
# Output
|
||||
cmd.append(str(output_path))
|
||||
|
||||
return cmd
|
||||
|
||||
def _get_codec_args(self, codec_config, crf: int, preset: str) -> list[str]:
|
||||
"""Get codec-specific FFmpeg arguments."""
|
||||
# Use NVENC if available and configured
|
||||
if self.use_nvenc and codec_config.nvenc_encoder:
|
||||
encoder = codec_config.nvenc_encoder
|
||||
|
||||
if "nvenc" in encoder:
|
||||
return [
|
||||
"-c:v", encoder,
|
||||
"-preset", "p7", # Highest quality NVENC preset
|
||||
"-tune", "hq",
|
||||
"-rc", "vbr",
|
||||
"-cq", str(crf),
|
||||
"-b:v", "0",
|
||||
]
|
||||
|
||||
# Software encoding
|
||||
encoder = codec_config.encoder
|
||||
|
||||
if encoder == "libx264":
|
||||
return [
|
||||
"-c:v", "libx264",
|
||||
"-preset", preset,
|
||||
"-crf", str(crf),
|
||||
"-tune", "film",
|
||||
]
|
||||
elif encoder == "libx265":
|
||||
return [
|
||||
"-c:v", "libx265",
|
||||
"-preset", preset,
|
||||
"-crf", str(crf),
|
||||
"-x265-params", "aq-mode=3",
|
||||
]
|
||||
elif encoder == "libsvtav1":
|
||||
return [
|
||||
"-c:v", "libsvtav1",
|
||||
"-preset", preset,
|
||||
"-crf", str(crf),
|
||||
"-svtav1-params", "tune=0:film-grain=8",
|
||||
]
|
||||
else:
|
||||
return ["-c:v", encoder, "-crf", str(crf)]
|
||||
|
||||
|
||||
def encode_video(
|
||||
frames_dir: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
codec: str = "H.265",
|
||||
crf: int = 20,
|
||||
preset: str = "slow",
|
||||
audio_path: Optional[Path] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Convenience function to encode video.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory with numbered frames
|
||||
output_path: Output video path
|
||||
fps: Frame rate
|
||||
codec: Video codec
|
||||
crf: Quality (lower = better)
|
||||
preset: Encoding preset
|
||||
audio_path: Optional audio to mux
|
||||
|
||||
Returns:
|
||||
Path to encoded video
|
||||
"""
|
||||
encoder = VideoEncoder()
|
||||
result = encoder.encode_frames(
|
||||
frames_dir=frames_dir,
|
||||
output_path=output_path,
|
||||
fps=fps,
|
||||
codec=codec,
|
||||
crf=crf,
|
||||
preset=preset,
|
||||
audio_path=audio_path,
|
||||
)
|
||||
return result.output_path
|
||||
201
src/video/extractor.py
Normal file
201
src/video/extractor.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Video frame extraction using PyAV."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
"""Video file metadata."""
|
||||
|
||||
width: int
|
||||
height: int
|
||||
fps: float
|
||||
total_frames: int
|
||||
duration_seconds: float
|
||||
codec: str
|
||||
has_audio: bool
|
||||
audio_codec: Optional[str] = None
|
||||
audio_sample_rate: Optional[int] = None
|
||||
|
||||
|
||||
class VideoExtractor:
|
||||
"""
|
||||
Extract frames from video files using PyAV.
|
||||
|
||||
Features:
|
||||
- High-performance frame extraction
|
||||
- Support for various video formats
|
||||
- Frame-accurate seeking for resume
|
||||
"""
|
||||
|
||||
def __init__(self, video_path: Path):
|
||||
self.video_path = Path(video_path)
|
||||
self.container = None
|
||||
self.video_stream = None
|
||||
self._metadata: Optional[VideoMetadata] = None
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def open(self):
|
||||
"""Open video file."""
|
||||
if not self.video_path.exists():
|
||||
raise FileNotFoundError(f"Video not found: {self.video_path}")
|
||||
|
||||
self.container = av.open(str(self.video_path))
|
||||
self.video_stream = self.container.streams.video[0]
|
||||
|
||||
# Enable threading for faster decoding
|
||||
self.video_stream.thread_type = "AUTO"
|
||||
|
||||
logger.info(f"Opened video: {self.video_path}")
|
||||
|
||||
def close(self):
|
||||
"""Close video file."""
|
||||
if self.container:
|
||||
self.container.close()
|
||||
self.container = None
|
||||
self.video_stream = None
|
||||
|
||||
@property
|
||||
def metadata(self) -> VideoMetadata:
|
||||
"""Get video metadata."""
|
||||
if self._metadata is not None:
|
||||
return self._metadata
|
||||
|
||||
if self.video_stream is None:
|
||||
raise RuntimeError("Video not opened")
|
||||
|
||||
video = self.video_stream
|
||||
|
||||
# Calculate total frames
|
||||
if video.frames > 0:
|
||||
total_frames = video.frames
|
||||
elif video.duration and video.time_base:
|
||||
# Estimate from duration
|
||||
duration_sec = float(video.duration * video.time_base)
|
||||
total_frames = int(duration_sec * float(video.average_rate))
|
||||
else:
|
||||
total_frames = 0
|
||||
|
||||
# Duration
|
||||
if video.duration and video.time_base:
|
||||
duration_seconds = float(video.duration * video.time_base)
|
||||
else:
|
||||
duration_seconds = 0.0
|
||||
|
||||
# Check for audio
|
||||
has_audio = len(self.container.streams.audio) > 0
|
||||
audio_codec = None
|
||||
audio_sample_rate = None
|
||||
|
||||
if has_audio:
|
||||
audio_stream = self.container.streams.audio[0]
|
||||
audio_codec = audio_stream.codec.name
|
||||
audio_sample_rate = audio_stream.sample_rate
|
||||
|
||||
self._metadata = VideoMetadata(
|
||||
width=video.width,
|
||||
height=video.height,
|
||||
fps=float(video.average_rate) if video.average_rate else 30.0,
|
||||
total_frames=total_frames,
|
||||
duration_seconds=duration_seconds,
|
||||
codec=video.codec.name,
|
||||
has_audio=has_audio,
|
||||
audio_codec=audio_codec,
|
||||
audio_sample_rate=audio_sample_rate,
|
||||
)
|
||||
|
||||
return self._metadata
|
||||
|
||||
def extract_frames(
|
||||
self,
|
||||
start_frame: int = 0,
|
||||
end_frame: Optional[int] = None,
|
||||
) -> Iterator[tuple[int, np.ndarray]]:
|
||||
"""
|
||||
Extract frames from video.
|
||||
|
||||
Args:
|
||||
start_frame: Starting frame index (for resume)
|
||||
end_frame: Ending frame index (exclusive, None for all)
|
||||
|
||||
Yields:
|
||||
Tuple of (frame_index, frame_array) where frame_array is RGB
|
||||
"""
|
||||
if self.container is None:
|
||||
raise RuntimeError("Video not opened")
|
||||
|
||||
frame_index = 0
|
||||
|
||||
# Seek to start position if needed
|
||||
if start_frame > 0:
|
||||
self._seek_to_frame(start_frame)
|
||||
frame_index = start_frame
|
||||
|
||||
for frame in self.container.decode(video=0):
|
||||
if end_frame is not None and frame_index >= end_frame:
|
||||
break
|
||||
|
||||
# Convert to RGB numpy array
|
||||
img = frame.to_ndarray(format="rgb24")
|
||||
|
||||
yield frame_index, img
|
||||
frame_index += 1
|
||||
|
||||
def _seek_to_frame(self, frame_index: int):
|
||||
"""Seek to approximate frame position."""
|
||||
if self.video_stream is None:
|
||||
return
|
||||
|
||||
# Calculate timestamp
|
||||
fps = float(self.video_stream.average_rate) if self.video_stream.average_rate else 30.0
|
||||
time_base = self.video_stream.time_base
|
||||
target_ts = int(frame_index / fps / time_base)
|
||||
|
||||
# Seek to keyframe before target
|
||||
self.container.seek(target_ts, stream=self.video_stream)
|
||||
|
||||
def get_frame(self, frame_index: int) -> Optional[np.ndarray]:
|
||||
"""Get a specific frame by index."""
|
||||
self._seek_to_frame(frame_index)
|
||||
|
||||
for frame in self.container.decode(video=0):
|
||||
return frame.to_ndarray(format="rgb24")
|
||||
|
||||
return None
|
||||
|
||||
def extract_keyframes(self) -> Iterator[tuple[int, np.ndarray]]:
|
||||
"""Extract only keyframes (for preview generation)."""
|
||||
if self.container is None:
|
||||
raise RuntimeError("Video not opened")
|
||||
|
||||
# Set codec to skip non-key frames
|
||||
self.video_stream.codec_context.skip_frame = "NONKEY"
|
||||
|
||||
frame_index = 0
|
||||
for frame in self.container.decode(video=0):
|
||||
img = frame.to_ndarray(format="rgb24")
|
||||
yield frame_index, img
|
||||
frame_index += 1
|
||||
|
||||
# Reset skip mode
|
||||
self.video_stream.codec_context.skip_frame = "DEFAULT"
|
||||
|
||||
|
||||
def get_video_info(video_path: Path) -> VideoMetadata:
|
||||
"""Quick function to get video metadata."""
|
||||
with VideoExtractor(video_path) as extractor:
|
||||
return extractor.metadata
|
||||
Reference in New Issue
Block a user