Files

185 lines
5.1 KiB
Python
Raw Permalink Normal View History

2026-04-08 10:56:45 +02:00
"""Async task polling with beautiful Rich Live display."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Tuple
from rich.console import Console
from rich.live import Live
from rich.panel import Panel
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
TaskStatus = Literal["PENDING", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", "CREATED"]
STATUS_COLORS: dict[str, str] = {
"CREATED": "yellow",
"PENDING": "yellow",
"IN_PROGRESS": "cyan",
"COMPLETED": "green",
"FAILED": "red",
"CANCELLED": "dim red",
}
STATUS_ICONS: dict[str, str] = {
"CREATED": "",
"PENDING": "",
"IN_PROGRESS": "",
"COMPLETED": "",
"FAILED": "",
"CANCELLED": "",
}
TASK_TYPE_LABELS: dict[str, str] = {
"image": "Image Generation",
"video": "Video Generation",
"upscale-image": "Image Upscaling",
"upscale-video": "Video Upscaling",
"icon": "Icon Generation",
"expand": "Image Expansion",
"describe": "Image Analysis",
"relight": "Image Relighting",
"style-transfer": "Style Transfer",
}
class FreepikTaskError(Exception):
pass
class FreepikTimeoutError(Exception):
pass
@dataclass
class PollConfig:
initial_delay: float = 2.0
min_interval: float = 2.0
max_interval: float = 15.0
backoff_factor: float = 1.5
max_wait: float = 600.0
task_type: str = "image"
def _render_panel(
task_id: str,
status: str,
elapsed: float,
task_type: str,
extra_info: dict[str, str] | None = None,
) -> Panel:
color = STATUS_COLORS.get(status.upper(), "white")
icon = STATUS_ICONS.get(status.upper(), "?")
label = TASK_TYPE_LABELS.get(task_type, task_type.replace("-", " ").title())
grid = Table.grid(padding=(0, 2))
grid.add_column(style="dim", width=14, no_wrap=True)
grid.add_column(overflow="fold")
spinner = Spinner("dots", style="bold magenta")
# Header row with spinner
header = Text()
header.append(f"{label}", style="bold white")
grid.add_row(spinner, header)
grid.add_row("", "") # spacer
short_id = task_id[:12] + "" if len(task_id) > 12 else task_id
grid.add_row("Task ID", f"[bold blue]{short_id}[/bold blue]")
grid.add_row(
"Status",
f"[{color}]{icon} {status.replace('_', ' ')}[/{color}]",
)
mins, secs = divmod(int(elapsed), 60)
time_str = f"{mins}m {secs:02d}s" if mins else f"{secs}s"
grid.add_row("Elapsed", f"[dim]{time_str}[/dim]")
if extra_info:
for k, v in extra_info.items():
grid.add_row(k, v)
return Panel(
grid,
title="[bold magenta]~ Freepik AI ~[/bold magenta]",
border_style="magenta",
padding=(1, 2),
width=52,
)
def poll_task(
check_fn: Callable[[str], Tuple[str, dict[str, Any]]],
task_id: str,
config: PollConfig,
console: Console,
extra_info: dict[str, str] | None = None,
) -> dict[str, Any]:
"""
Poll until COMPLETED or FAILED, displaying a live status panel.
Args:
check_fn: Callable(task_id) → (status_str, raw_response_dict)
task_id: The task ID to poll
config: Polling configuration
console: Rich console instance
extra_info: Extra rows to display in the status panel
Returns:
The raw response dict when status is COMPLETED
"""
start = time.monotonic()
interval = config.initial_delay
current_status = "PENDING"
result: dict[str, Any] = {}
with Live(
_render_panel(task_id, current_status, 0, config.task_type, extra_info),
console=console,
refresh_per_second=8,
transient=False,
) as live:
time.sleep(config.initial_delay)
while True:
elapsed = time.monotonic() - start
if elapsed > config.max_wait:
live.stop()
raise FreepikTimeoutError(
f"Task {task_id} timed out after {config.max_wait:.0f}s"
)
try:
current_status, result = check_fn(task_id)
except Exception as exc:
live.stop()
raise exc
live.update(
_render_panel(task_id, current_status, elapsed, config.task_type, extra_info)
)
upper = current_status.upper()
if upper == "COMPLETED":
live.update(
_render_panel(task_id, "COMPLETED", elapsed, config.task_type, extra_info)
)
return result
if upper in ("FAILED", "CANCELLED"):
live.stop()
data = result.get("data", result)
error_msg = (
data.get("error", {}).get("message")
or data.get("message")
or f"Task ended with status {upper}"
)
raise FreepikTaskError(f"{upper}: {error_msg}")
time.sleep(interval)
interval = min(interval * config.backoff_factor, config.max_interval)