feat: initial implementation of Real-ESRGAN Web UI
Full-featured Gradio 6.0+ web interface for Real-ESRGAN image/video upscaling, optimized for RTX 4090 (24GB VRAM). Features: - Image upscaling with before/after comparison (ImageSlider) - Video upscaling with progress tracking and checkpoint/resume - Face enhancement via GFPGAN integration - Multiple codecs: H.264, H.265, AV1 (with NVENC support) - Batch processing queue with SQLite persistence - Processing history gallery - Custom dark theme - Auto-download of model weights 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
1
src/ui/components/__init__.py
Normal file
1
src/ui/components/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Gradio UI components for each tab."""
|
||||
200
src/ui/components/batch_tab.py
Normal file
200
src/ui/components/batch_tab.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Batch processing queue tab component."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from src.config import get_model_choices, config
|
||||
from src.storage.queue import queue_manager, Job, JobType, JobStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_queue_data() -> list[list]:
|
||||
"""Get queue data for dataframe display."""
|
||||
jobs = queue_manager.get_queue()
|
||||
|
||||
data = []
|
||||
for job in jobs:
|
||||
status_emoji = {
|
||||
JobStatus.QUEUED: "⏳",
|
||||
JobStatus.PROCESSING: "🔄",
|
||||
JobStatus.COMPLETED: "✅",
|
||||
JobStatus.FAILED: "❌",
|
||||
JobStatus.CANCELLED: "🚫",
|
||||
}.get(job.status, "❓")
|
||||
|
||||
progress = f"{job.progress_percent:.0f}%" if job.status == JobStatus.PROCESSING else "-"
|
||||
|
||||
data.append([
|
||||
job.id[:8],
|
||||
job.type.value.title(),
|
||||
Path(job.input_path).name[:30],
|
||||
f"{status_emoji} {job.status.value.title()}",
|
||||
progress,
|
||||
job.current_stage or "-",
|
||||
])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_queue_stats() -> str:
|
||||
"""Get queue statistics for display."""
|
||||
stats = queue_manager.get_queue_stats()
|
||||
|
||||
queued = stats.get(JobStatus.QUEUED.value, 0)
|
||||
processing = stats.get(JobStatus.PROCESSING.value, 0)
|
||||
completed = stats.get(JobStatus.COMPLETED.value, 0)
|
||||
failed = stats.get(JobStatus.FAILED.value, 0)
|
||||
|
||||
current = stats.get("current_job")
|
||||
current_info = ""
|
||||
if current:
|
||||
current_info = f"\n\n**Currently Processing:**\n{Path(current.input_path).name}"
|
||||
|
||||
return (
|
||||
f"**Queue Status**\n\n"
|
||||
f"- Queued: {queued}\n"
|
||||
f"- Processing: {processing}\n"
|
||||
f"- Completed: {completed}\n"
|
||||
f"- Failed: {failed}"
|
||||
f"{current_info}"
|
||||
)
|
||||
|
||||
|
||||
def add_files_to_queue(
|
||||
files: Optional[list],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
) -> tuple[list[list], str]:
|
||||
"""Add uploaded files to queue."""
|
||||
if not files:
|
||||
return get_queue_data(), "No files selected"
|
||||
|
||||
added = 0
|
||||
for file in files:
|
||||
file_path = Path(file.name if hasattr(file, 'name') else file)
|
||||
|
||||
# Determine job type
|
||||
ext = file_path.suffix.lower()
|
||||
if ext in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff"]:
|
||||
job_type = JobType.IMAGE
|
||||
elif ext in [".mp4", ".avi", ".mov", ".mkv", ".webm"]:
|
||||
job_type = JobType.VIDEO
|
||||
else:
|
||||
continue
|
||||
|
||||
job = Job(
|
||||
type=job_type,
|
||||
input_path=str(file_path),
|
||||
config={
|
||||
"model_name": model_name,
|
||||
"scale": scale,
|
||||
"face_enhance": face_enhance,
|
||||
},
|
||||
)
|
||||
|
||||
queue_manager.add_job(job)
|
||||
added += 1
|
||||
|
||||
return get_queue_data(), f"Added {added} items to queue"
|
||||
|
||||
|
||||
def cancel_job(job_id: str) -> tuple[list[list], str]:
|
||||
"""Cancel a job."""
|
||||
if job_id:
|
||||
queue_manager.cancel_job(job_id)
|
||||
return get_queue_data(), f"Cancelled job: {job_id}"
|
||||
return get_queue_data(), "No job selected"
|
||||
|
||||
|
||||
def clear_completed() -> tuple[list[list], str]:
|
||||
"""Clear completed jobs from queue."""
|
||||
count = queue_manager.clear_completed()
|
||||
return get_queue_data(), f"Cleared {count} completed jobs"
|
||||
|
||||
|
||||
def create_batch_tab():
|
||||
"""Create the batch queue tab component."""
|
||||
with gr.Tab("Batch Queue", id="queue-tab"):
|
||||
gr.Markdown("## Batch Processing Queue")
|
||||
|
||||
with gr.Row():
|
||||
# Queue table
|
||||
with gr.Column(scale=2):
|
||||
queue_table = gr.Dataframe(
|
||||
headers=["ID", "Type", "Input", "Status", "Progress", "Stage"],
|
||||
datatype=["str", "str", "str", "str", "str", "str"],
|
||||
value=get_queue_data(),
|
||||
label="Queue",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Queue stats
|
||||
with gr.Column(scale=1):
|
||||
queue_stats = gr.Markdown(
|
||||
get_queue_stats(),
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
# Add to queue section
|
||||
with gr.Accordion("Add to Queue", open=True):
|
||||
with gr.Row():
|
||||
file_input = gr.File(
|
||||
label="Select Files",
|
||||
file_count="multiple",
|
||||
file_types=["image", "video"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=False,
|
||||
label="Face Enhancement",
|
||||
)
|
||||
|
||||
add_btn = gr.Button("Add to Queue", variant="primary")
|
||||
|
||||
# Queue actions
|
||||
with gr.Row():
|
||||
refresh_btn = gr.Button("Refresh", variant="secondary")
|
||||
clear_btn = gr.Button("Clear Completed", variant="secondary")
|
||||
|
||||
status_text = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
add_btn.click(
|
||||
fn=add_files_to_queue,
|
||||
inputs=[file_input, model_dropdown, scale_radio, face_enhance_cb],
|
||||
outputs=[queue_table, status_text],
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
fn=lambda: (get_queue_data(), get_queue_stats()),
|
||||
outputs=[queue_table, queue_stats],
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
fn=clear_completed,
|
||||
outputs=[queue_table, status_text],
|
||||
)
|
||||
|
||||
return {
|
||||
"queue_table": queue_table,
|
||||
"queue_stats": queue_stats,
|
||||
"status_text": status_text,
|
||||
}
|
||||
208
src/ui/components/history_tab.py
Normal file
208
src/ui/components/history_tab.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""History gallery tab component."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from src.storage.history import history_manager, HistoryItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_history_gallery() -> list[tuple[str, str]]:
|
||||
"""Get history items for gallery display."""
|
||||
items = history_manager.get_recent(limit=100)
|
||||
|
||||
gallery_items = []
|
||||
for item in items:
|
||||
# Use output path for display
|
||||
output_path = Path(item.output_path)
|
||||
if output_path.exists() and item.type == "image":
|
||||
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||
gallery_items.append((str(output_path), caption))
|
||||
|
||||
return gallery_items
|
||||
|
||||
|
||||
def get_history_stats() -> str:
|
||||
"""Get history statistics."""
|
||||
stats = history_manager.get_statistics()
|
||||
|
||||
return (
|
||||
f"**History Statistics**\n\n"
|
||||
f"- Total Items: {stats['total_items']}\n"
|
||||
f"- Images: {stats['images']}\n"
|
||||
f"- Videos: {stats['videos']}\n"
|
||||
f"- Total Processing Time: {stats['total_processing_time']/60:.1f} min\n"
|
||||
f"- Total Input Size: {stats['total_input_size_mb']:.1f} MB\n"
|
||||
f"- Total Output Size: {stats['total_output_size_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
|
||||
def search_history(query: str) -> list[tuple[str, str]]:
|
||||
"""Search history by filename."""
|
||||
if not query:
|
||||
return get_history_gallery()
|
||||
|
||||
items = history_manager.search(query)
|
||||
|
||||
gallery_items = []
|
||||
for item in items:
|
||||
output_path = Path(item.output_path)
|
||||
if output_path.exists() and item.type == "image":
|
||||
caption = f"{item.output_filename}\n{item.output_width}x{item.output_height}"
|
||||
gallery_items.append((str(output_path), caption))
|
||||
|
||||
return gallery_items
|
||||
|
||||
|
||||
def get_item_details(evt: gr.SelectData) -> tuple[Optional[tuple], str]:
|
||||
"""Get details for selected history item."""
|
||||
if evt.index is None:
|
||||
return None, "Select an item to view details"
|
||||
|
||||
items = history_manager.get_recent(limit=100)
|
||||
|
||||
# Filter to only images (same as gallery)
|
||||
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||
|
||||
if evt.index >= len(image_items):
|
||||
return None, "Item not found"
|
||||
|
||||
item = image_items[evt.index]
|
||||
|
||||
# Load images for slider
|
||||
input_path = Path(item.input_path)
|
||||
output_path = Path(item.output_path)
|
||||
|
||||
slider_images = None
|
||||
if input_path.exists() and output_path.exists():
|
||||
slider_images = (str(input_path), str(output_path))
|
||||
|
||||
# Format details
|
||||
details = (
|
||||
f"**{item.output_filename}**\n\n"
|
||||
f"- **Type:** {item.type.title()}\n"
|
||||
f"- **Model:** {item.model}\n"
|
||||
f"- **Scale:** {item.scale}x\n"
|
||||
f"- **Face Enhancement:** {'Yes' if item.face_enhance else 'No'}\n"
|
||||
f"- **Input:** {item.input_width}x{item.input_height}\n"
|
||||
f"- **Output:** {item.output_width}x{item.output_height}\n"
|
||||
f"- **Processing Time:** {item.processing_time_seconds:.1f}s\n"
|
||||
f"- **Input Size:** {item.input_size_bytes/1024:.1f} KB\n"
|
||||
f"- **Output Size:** {item.output_size_bytes/1024:.1f} KB\n"
|
||||
f"- **Created:** {item.created_at.strftime('%Y-%m-%d %H:%M')}"
|
||||
)
|
||||
|
||||
return slider_images, details
|
||||
|
||||
|
||||
def delete_selected(evt: gr.SelectData) -> tuple[list, str]:
|
||||
"""Delete selected history item."""
|
||||
if evt.index is None:
|
||||
return get_history_gallery(), "No item selected"
|
||||
|
||||
items = history_manager.get_recent(limit=100)
|
||||
image_items = [i for i in items if i.type == "image" and Path(i.output_path).exists()]
|
||||
|
||||
if evt.index >= len(image_items):
|
||||
return get_history_gallery(), "Item not found"
|
||||
|
||||
item = image_items[evt.index]
|
||||
history_manager.delete_item(item.id)
|
||||
|
||||
return get_history_gallery(), f"Deleted: {item.output_filename}"
|
||||
|
||||
|
||||
def create_history_tab():
|
||||
"""Create the history gallery tab component."""
|
||||
with gr.Tab("History", id="history-tab"):
|
||||
gr.Markdown("## Processing History")
|
||||
|
||||
with gr.Row():
|
||||
# Search
|
||||
search_box = gr.Textbox(
|
||||
label="Search",
|
||||
placeholder="Search by filename...",
|
||||
scale=2,
|
||||
)
|
||||
|
||||
filter_dropdown = gr.Dropdown(
|
||||
choices=["All", "Images", "Videos"],
|
||||
value="All",
|
||||
label="Filter",
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Gallery
|
||||
with gr.Column(scale=2):
|
||||
gallery = gr.Gallery(
|
||||
value=get_history_gallery(),
|
||||
label="History",
|
||||
columns=4,
|
||||
object_fit="cover",
|
||||
height=400,
|
||||
allow_preview=True,
|
||||
)
|
||||
|
||||
# Details panel
|
||||
with gr.Column(scale=1):
|
||||
stats_display = gr.Markdown(
|
||||
get_history_stats(),
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
# Selected item details
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
try:
|
||||
from gradio_imageslider import ImageSlider
|
||||
|
||||
comparison_slider = ImageSlider(
|
||||
label="Before / After",
|
||||
type="filepath",
|
||||
)
|
||||
except ImportError:
|
||||
comparison_slider = gr.Image(
|
||||
label="Selected Image",
|
||||
type="filepath",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
item_details = gr.Markdown(
|
||||
"Select an item to view details",
|
||||
elem_classes=["info-card"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
download_btn = gr.Button("Download", variant="secondary")
|
||||
delete_btn = gr.Button("Delete", variant="stop")
|
||||
|
||||
status_text = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
search_box.change(
|
||||
fn=search_history,
|
||||
inputs=[search_box],
|
||||
outputs=[gallery],
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
fn=get_item_details,
|
||||
outputs=[comparison_slider, item_details],
|
||||
)
|
||||
|
||||
delete_btn.click(
|
||||
fn=lambda: (get_history_gallery(), get_history_stats()),
|
||||
outputs=[gallery, stats_display],
|
||||
)
|
||||
|
||||
return {
|
||||
"gallery": gallery,
|
||||
"comparison_slider": comparison_slider,
|
||||
"item_details": item_details,
|
||||
"stats_display": stats_display,
|
||||
}
|
||||
292
src/ui/components/image_tab.py
Normal file
292
src/ui/components/image_tab.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Image upscaling tab component."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.config import OUTPUT_DIR, config, get_model_choices
|
||||
from src.processing.upscaler import upscale_image, save_image, get_output_dimensions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_image(
|
||||
input_image: Optional[np.ndarray],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
tile_size: str,
|
||||
denoise: float,
|
||||
output_format: str,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
) -> tuple[Optional[tuple], Optional[str], str]:
|
||||
"""
|
||||
Process image for upscaling.
|
||||
|
||||
Args:
|
||||
input_image: Input image as numpy array
|
||||
model_name: Model to use
|
||||
scale: Scale factor (2 or 4)
|
||||
face_enhance: Enable face enhancement
|
||||
tile_size: Tile size setting ("Auto", "256", "512", "1024")
|
||||
denoise: Denoise strength (0-1)
|
||||
output_format: Output format (PNG, JPG, WebP)
|
||||
|
||||
Returns:
|
||||
Tuple of (slider_images, output_path, status_message)
|
||||
"""
|
||||
if input_image is None:
|
||||
return None, None, "Please upload an image first."
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Parse tile size
|
||||
tile = 0 if tile_size == "Auto" else int(tile_size)
|
||||
|
||||
# Progress callback for UI
|
||||
def progress_callback(pct: float, stage: str):
|
||||
progress(pct, desc=stage)
|
||||
|
||||
progress(0, desc="Starting upscale...")
|
||||
|
||||
# Upscale the image
|
||||
result = upscale_image(
|
||||
input_image,
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
tile_size=tile,
|
||||
denoise_strength=denoise,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Apply face enhancement if requested
|
||||
if face_enhance:
|
||||
progress(0.7, desc="Enhancing faces...")
|
||||
try:
|
||||
from src.processing.face_enhancer import enhance_faces
|
||||
|
||||
result.image = enhance_faces(result.image)
|
||||
except ImportError:
|
||||
logger.warning("Face enhancement not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Face enhancement failed: {e}")
|
||||
|
||||
progress(0.85, desc="Saving output...")
|
||||
|
||||
# Generate output filename
|
||||
output_name = f"upscaled_{uuid.uuid4().hex[:8]}"
|
||||
output_path = save_image(
|
||||
result.image,
|
||||
OUTPUT_DIR / output_name,
|
||||
format=output_format.lower(),
|
||||
)
|
||||
|
||||
# Convert for display (BGR to RGB)
|
||||
input_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) if input_image.shape[2] == 3 else input_image
|
||||
output_rgb = cv2.cvtColor(result.image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
progress(1.0, desc="Complete!")
|
||||
|
||||
# Calculate stats
|
||||
elapsed = time.time() - start_time
|
||||
input_size = f"{result.input_width}x{result.input_height}"
|
||||
output_size = f"{result.output_width}x{result.output_height}"
|
||||
|
||||
status = (
|
||||
f"Upscaled {input_size} -> {output_size} in {elapsed:.1f}s | "
|
||||
f"Model: {model_name} | Saved: {output_path.name}"
|
||||
)
|
||||
|
||||
# Return images for slider (input, output)
|
||||
return (input_rgb, output_rgb), str(output_path), status
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image processing failed")
|
||||
return None, None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def estimate_output(
|
||||
input_image: Optional[np.ndarray],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
) -> str:
|
||||
"""Estimate output dimensions for display."""
|
||||
if input_image is None:
|
||||
return "Upload an image to see estimated output size"
|
||||
|
||||
try:
|
||||
height, width = input_image.shape[:2]
|
||||
out_w, out_h = get_output_dimensions(width, height, model_name, scale)
|
||||
return f"Input: {width}x{height} -> Output: {out_w}x{out_h}"
|
||||
except Exception:
|
||||
return "Unable to estimate output size"
|
||||
|
||||
|
||||
def create_image_tab():
|
||||
"""Create the image upscaling tab component."""
|
||||
with gr.Tab("Image", id="image-tab"):
|
||||
# Status display at top
|
||||
status_text = gr.Markdown(
|
||||
"Upload an image to begin upscaling",
|
||||
elem_classes=["status-text"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Left column - Input
|
||||
with gr.Column(scale=1):
|
||||
input_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="numpy",
|
||||
sources=["upload", "clipboard"],
|
||||
elem_classes=["upload-area"],
|
||||
)
|
||||
|
||||
# Dimension estimate
|
||||
dimension_info = gr.Markdown(
|
||||
"Upload an image to see estimated output size",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Right column - Output (ImageSlider)
|
||||
with gr.Column(scale=1):
|
||||
try:
|
||||
from gradio_imageslider import ImageSlider
|
||||
|
||||
output_slider = ImageSlider(
|
||||
label="Before / After",
|
||||
type="numpy",
|
||||
show_download_button=True,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback if ImageSlider not available
|
||||
output_slider = gr.Image(
|
||||
label="Upscaled Result",
|
||||
type="numpy",
|
||||
show_download_button=True,
|
||||
)
|
||||
|
||||
output_file = gr.File(
|
||||
label="Download",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# Processing options
|
||||
with gr.Accordion("Processing Options", open=True):
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
info="Select upscaling model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale Factor",
|
||||
info="Output scale multiplier",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=config.default_face_enhance,
|
||||
label="Face Enhancement (GFPGAN)",
|
||||
info="Enhance faces in photos (not for anime)",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
tile_dropdown = gr.Dropdown(
|
||||
choices=["Auto", "256", "512", "1024"],
|
||||
value="Auto",
|
||||
label="Tile Size",
|
||||
info="Auto recommended for RTX 4090",
|
||||
)
|
||||
|
||||
denoise_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=0.5,
|
||||
step=0.1,
|
||||
label="Denoise Strength",
|
||||
info="Only for realesr-general-x4v3 model",
|
||||
)
|
||||
|
||||
format_radio = gr.Radio(
|
||||
choices=["PNG", "JPG", "WebP"],
|
||||
value="PNG",
|
||||
label="Output Format",
|
||||
)
|
||||
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
upscale_btn = gr.Button(
|
||||
"Upscale Image",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
elem_classes=["primary-action-btn"],
|
||||
)
|
||||
|
||||
clear_btn = gr.Button(
|
||||
"Clear",
|
||||
variant="secondary",
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
upscale_btn.click(
|
||||
fn=process_image,
|
||||
inputs=[
|
||||
input_image,
|
||||
model_dropdown,
|
||||
scale_radio,
|
||||
face_enhance_cb,
|
||||
tile_dropdown,
|
||||
denoise_slider,
|
||||
format_radio,
|
||||
],
|
||||
outputs=[output_slider, output_file, status_text],
|
||||
)
|
||||
|
||||
# Update dimension estimate when image changes
|
||||
input_image.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
# Also update when model/scale changes
|
||||
model_dropdown.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
scale_radio.change(
|
||||
fn=estimate_output,
|
||||
inputs=[input_image, model_dropdown, scale_radio],
|
||||
outputs=[dimension_info],
|
||||
)
|
||||
|
||||
# Clear button
|
||||
clear_btn.click(
|
||||
fn=lambda: (None, None, None, "Upload an image to begin upscaling"),
|
||||
outputs=[input_image, output_slider, output_file, status_text],
|
||||
)
|
||||
|
||||
return {
|
||||
"input_image": input_image,
|
||||
"output_slider": output_slider,
|
||||
"output_file": output_file,
|
||||
"status_text": status_text,
|
||||
"model_dropdown": model_dropdown,
|
||||
"scale_radio": scale_radio,
|
||||
"face_enhance_cb": face_enhance_cb,
|
||||
"upscale_btn": upscale_btn,
|
||||
}
|
||||
376
src/ui/components/video_tab.py
Normal file
376
src/ui/components/video_tab.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Video upscaling tab component."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
from src.config import OUTPUT_DIR, TEMP_DIR, VIDEO_CODECS, config, get_model_choices
|
||||
from src.processing.upscaler import upscale_image
|
||||
from src.video.extractor import VideoExtractor, get_video_info
|
||||
from src.video.encoder import encode_video
|
||||
from src.video.audio import extract_audio, has_audio_stream
|
||||
from src.video.checkpoint import checkpoint_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_video_metadata(video_path: Optional[str]) -> str:
|
||||
"""Get video metadata for display."""
|
||||
if not video_path:
|
||||
return "Upload a video to see details"
|
||||
|
||||
try:
|
||||
metadata = get_video_info(Path(video_path))
|
||||
return (
|
||||
f"**Resolution:** {metadata.width}x{metadata.height}\n"
|
||||
f"**Duration:** {metadata.duration_seconds:.1f}s\n"
|
||||
f"**FPS:** {metadata.fps:.2f}\n"
|
||||
f"**Frames:** {metadata.total_frames}\n"
|
||||
f"**Codec:** {metadata.codec}\n"
|
||||
f"**Audio:** {'Yes' if metadata.has_audio else 'No'}"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error reading video: {e}"
|
||||
|
||||
|
||||
def estimate_video_output(
|
||||
video_path: Optional[str],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
) -> str:
|
||||
"""Estimate output details for video."""
|
||||
if not video_path:
|
||||
return "Upload a video to see estimated output"
|
||||
|
||||
try:
|
||||
metadata = get_video_info(Path(video_path))
|
||||
out_w = metadata.width * scale
|
||||
out_h = metadata.height * scale
|
||||
|
||||
# Estimate processing time (rough: ~2 fps for 4x on RTX 4090)
|
||||
est_time = metadata.total_frames / 2 if scale == 4 else metadata.total_frames / 4
|
||||
est_minutes = est_time / 60
|
||||
|
||||
return (
|
||||
f"**Output Resolution:** {out_w}x{out_h}\n"
|
||||
f"**Estimated Time:** ~{est_minutes:.1f} minutes"
|
||||
)
|
||||
except Exception:
|
||||
return "Unable to estimate"
|
||||
|
||||
|
||||
def process_video(
|
||||
video_path: Optional[str],
|
||||
model_name: str,
|
||||
scale: int,
|
||||
face_enhance: bool,
|
||||
codec: str,
|
||||
crf: int,
|
||||
preset: str,
|
||||
preserve_audio: bool,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""
|
||||
Process video for upscaling.
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
model_name: Model to use
|
||||
scale: Scale factor
|
||||
face_enhance: Enable face enhancement
|
||||
codec: Output codec
|
||||
crf: Quality (lower = better)
|
||||
preset: Encoding preset
|
||||
preserve_audio: Keep original audio
|
||||
progress: Gradio progress tracker
|
||||
|
||||
Returns:
|
||||
Tuple of (output_path, status_message)
|
||||
"""
|
||||
if not video_path:
|
||||
return None, "Please upload a video first."
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
job_id = uuid.uuid4().hex[:8]
|
||||
|
||||
progress(0, desc="Analyzing video...")
|
||||
|
||||
# Get video info
|
||||
video_path = Path(video_path)
|
||||
metadata = get_video_info(video_path)
|
||||
|
||||
logger.info(
|
||||
f"Processing video: {video_path.name}, "
|
||||
f"{metadata.width}x{metadata.height}, "
|
||||
f"{metadata.total_frames} frames"
|
||||
)
|
||||
|
||||
# Create temp directories
|
||||
temp_dir = TEMP_DIR / f"video_{job_id}"
|
||||
frames_dir = temp_dir / "frames"
|
||||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract audio if needed
|
||||
audio_path = None
|
||||
if preserve_audio and has_audio_stream(video_path):
|
||||
progress(0.02, desc="Extracting audio...")
|
||||
audio_path = extract_audio(video_path, temp_dir / "audio")
|
||||
|
||||
# Create checkpoint
|
||||
checkpoint = checkpoint_manager.create_checkpoint(
|
||||
job_id=job_id,
|
||||
input_path=video_path,
|
||||
total_frames=metadata.total_frames,
|
||||
frames_dir=frames_dir,
|
||||
audio_path=audio_path,
|
||||
config={
|
||||
"model_name": model_name,
|
||||
"scale": scale,
|
||||
"face_enhance": face_enhance,
|
||||
"codec": codec,
|
||||
"crf": crf,
|
||||
"preset": preset,
|
||||
},
|
||||
)
|
||||
|
||||
# Process frames
|
||||
progress(0.05, desc="Processing frames...")
|
||||
|
||||
with VideoExtractor(video_path) as extractor:
|
||||
total = metadata.total_frames
|
||||
start_frame = checkpoint.processed_frames
|
||||
|
||||
for frame_idx, frame in extractor.extract_frames(start_frame=start_frame):
|
||||
# Upscale frame (RGB input)
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
result = upscale_image(
|
||||
frame_bgr,
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# Apply face enhancement if enabled
|
||||
if face_enhance:
|
||||
try:
|
||||
from src.processing.face_enhancer import enhance_faces
|
||||
result.image = enhance_faces(result.image)
|
||||
except Exception as e:
|
||||
logger.warning(f"Face enhancement failed on frame {frame_idx}: {e}")
|
||||
|
||||
# Save frame
|
||||
frame_path = frames_dir / f"frame_{frame_idx:08d}.png"
|
||||
cv2.imwrite(str(frame_path), result.image)
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint_manager.update_checkpoint(checkpoint, frame_idx)
|
||||
|
||||
# Update progress
|
||||
pct = 0.05 + (0.85 * (frame_idx + 1) / total)
|
||||
progress(
|
||||
pct,
|
||||
desc=f"Processing frame {frame_idx + 1}/{total}",
|
||||
)
|
||||
|
||||
# Encode output video
|
||||
progress(0.9, desc=f"Encoding video ({codec})...")
|
||||
|
||||
output_name = f"{video_path.stem}_upscaled_{job_id}"
|
||||
output_path = OUTPUT_DIR / f"{output_name}.mp4"
|
||||
|
||||
encode_video(
|
||||
frames_dir=frames_dir,
|
||||
output_path=output_path,
|
||||
fps=metadata.fps,
|
||||
codec=codec,
|
||||
crf=crf,
|
||||
preset=preset,
|
||||
audio_path=audio_path,
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
progress(0.98, desc="Cleaning up...")
|
||||
checkpoint_manager.delete_checkpoint(job_id)
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
progress(1.0, desc="Complete!")
|
||||
|
||||
# Calculate stats
|
||||
elapsed = time.time() - start_time
|
||||
output_size = output_path.stat().st_size / (1024 * 1024) # MB
|
||||
fps_actual = metadata.total_frames / elapsed
|
||||
|
||||
status = (
|
||||
f"Completed in {elapsed/60:.1f} minutes ({fps_actual:.2f} fps)\n"
|
||||
f"Output: {output_path.name} ({output_size:.1f} MB)"
|
||||
)
|
||||
|
||||
return str(output_path), status
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Video processing failed")
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def create_video_tab():
|
||||
"""Create the video upscaling tab component."""
|
||||
with gr.Tab("Video", id="video-tab"):
|
||||
# Status display
|
||||
status_text = gr.Markdown(
|
||||
"Upload a video to begin processing",
|
||||
elem_classes=["status-text"],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Left column - Input
|
||||
with gr.Column(scale=1):
|
||||
input_video = gr.Video(
|
||||
label="Input Video",
|
||||
sources=["upload"],
|
||||
)
|
||||
|
||||
# Video info display
|
||||
video_info = gr.Markdown(
|
||||
"Upload a video to see details",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Right column - Output
|
||||
with gr.Column(scale=1):
|
||||
output_video = gr.Video(
|
||||
label="Upscaled Video",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Output estimate
|
||||
output_estimate = gr.Markdown(
|
||||
"Upload a video to see estimated output",
|
||||
elem_classes=["info-text"],
|
||||
)
|
||||
|
||||
# Processing options
|
||||
with gr.Accordion("Processing Options", open=True):
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(
|
||||
choices=get_model_choices(),
|
||||
value=config.default_model,
|
||||
label="Model",
|
||||
)
|
||||
|
||||
scale_radio = gr.Radio(
|
||||
choices=[2, 4],
|
||||
value=config.default_scale,
|
||||
label="Scale Factor",
|
||||
)
|
||||
|
||||
face_enhance_cb = gr.Checkbox(
|
||||
value=False,
|
||||
label="Face Enhancement",
|
||||
info="Not recommended for anime",
|
||||
)
|
||||
|
||||
# Codec options
|
||||
with gr.Accordion("Output Codec Settings", open=True):
|
||||
with gr.Row():
|
||||
codec_radio = gr.Radio(
|
||||
choices=list(VIDEO_CODECS.keys()),
|
||||
value=config.default_video_codec,
|
||||
label="Codec",
|
||||
info="H.265 recommended for quality/size balance",
|
||||
)
|
||||
|
||||
crf_slider = gr.Slider(
|
||||
minimum=15,
|
||||
maximum=35,
|
||||
value=config.default_video_crf,
|
||||
step=1,
|
||||
label="Quality (CRF)",
|
||||
info="Lower = better quality, larger file",
|
||||
)
|
||||
|
||||
preset_dropdown = gr.Dropdown(
|
||||
choices=["slow", "medium", "fast"],
|
||||
value=config.default_video_preset,
|
||||
label="Preset",
|
||||
info="Slow = best quality",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
preserve_audio_cb = gr.Checkbox(
|
||||
value=True,
|
||||
label="Preserve Audio",
|
||||
info="Keep original audio track",
|
||||
)
|
||||
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
process_btn = gr.Button(
|
||||
"Start Processing",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
elem_classes=["primary-action-btn"],
|
||||
)
|
||||
|
||||
cancel_btn = gr.Button(
|
||||
"Cancel",
|
||||
variant="stop",
|
||||
visible=False, # TODO: implement cancellation
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
process_btn.click(
|
||||
fn=process_video,
|
||||
inputs=[
|
||||
input_video,
|
||||
model_dropdown,
|
||||
scale_radio,
|
||||
face_enhance_cb,
|
||||
codec_radio,
|
||||
crf_slider,
|
||||
preset_dropdown,
|
||||
preserve_audio_cb,
|
||||
],
|
||||
outputs=[output_video, status_text],
|
||||
)
|
||||
|
||||
# Update video info when video changes
|
||||
input_video.change(
|
||||
fn=get_video_metadata,
|
||||
inputs=[input_video],
|
||||
outputs=[video_info],
|
||||
)
|
||||
|
||||
# Update output estimate
|
||||
input_video.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
model_dropdown.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
scale_radio.change(
|
||||
fn=estimate_video_output,
|
||||
inputs=[input_video, model_dropdown, scale_radio],
|
||||
outputs=[output_estimate],
|
||||
)
|
||||
|
||||
return {
|
||||
"input_video": input_video,
|
||||
"output_video": output_video,
|
||||
"status_text": status_text,
|
||||
"process_btn": process_btn,
|
||||
}
|
||||
Reference in New Issue
Block a user