"""Async task polling with beautiful Rich Live display.""" from __future__ import annotations import time from dataclasses import dataclass, field from typing import Any, Callable, Literal, Tuple from rich.console import Console from rich.live import Live from rich.panel import Panel from rich.spinner import Spinner from rich.table import Table from rich.text import Text TaskStatus = Literal["PENDING", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", "CREATED"] STATUS_COLORS: dict[str, str] = { "CREATED": "yellow", "PENDING": "yellow", "IN_PROGRESS": "cyan", "COMPLETED": "green", "FAILED": "red", "CANCELLED": "dim red", } STATUS_ICONS: dict[str, str] = { "CREATED": "○", "PENDING": "○", "IN_PROGRESS": "◐", "COMPLETED": "●", "FAILED": "✗", "CANCELLED": "✗", } TASK_TYPE_LABELS: dict[str, str] = { "image": "Image Generation", "video": "Video Generation", "upscale-image": "Image Upscaling", "upscale-video": "Video Upscaling", "icon": "Icon Generation", "expand": "Image Expansion", "describe": "Image Analysis", "relight": "Image Relighting", "style-transfer": "Style Transfer", } class FreepikTaskError(Exception): pass class FreepikTimeoutError(Exception): pass @dataclass class PollConfig: initial_delay: float = 2.0 min_interval: float = 2.0 max_interval: float = 15.0 backoff_factor: float = 1.5 max_wait: float = 600.0 task_type: str = "image" def _render_panel( task_id: str, status: str, elapsed: float, task_type: str, extra_info: dict[str, str] | None = None, ) -> Panel: color = STATUS_COLORS.get(status.upper(), "white") icon = STATUS_ICONS.get(status.upper(), "?") label = TASK_TYPE_LABELS.get(task_type, task_type.replace("-", " ").title()) grid = Table.grid(padding=(0, 2)) grid.add_column(style="dim", width=14, no_wrap=True) grid.add_column(overflow="fold") spinner = Spinner("dots", style="bold magenta") # Header row with spinner header = Text() header.append(f"{label}", style="bold white") grid.add_row(spinner, header) grid.add_row("", "") # spacer short_id = task_id[:12] + "…" if len(task_id) > 12 else task_id grid.add_row("Task ID", f"[bold blue]{short_id}[/bold blue]") grid.add_row( "Status", f"[{color}]{icon} {status.replace('_', ' ')}[/{color}]", ) mins, secs = divmod(int(elapsed), 60) time_str = f"{mins}m {secs:02d}s" if mins else f"{secs}s" grid.add_row("Elapsed", f"[dim]{time_str}[/dim]") if extra_info: for k, v in extra_info.items(): grid.add_row(k, v) return Panel( grid, title="[bold magenta]~ Freepik AI ~[/bold magenta]", border_style="magenta", padding=(1, 2), width=52, ) def poll_task( check_fn: Callable[[str], Tuple[str, dict[str, Any]]], task_id: str, config: PollConfig, console: Console, extra_info: dict[str, str] | None = None, ) -> dict[str, Any]: """ Poll until COMPLETED or FAILED, displaying a live status panel. Args: check_fn: Callable(task_id) → (status_str, raw_response_dict) task_id: The task ID to poll config: Polling configuration console: Rich console instance extra_info: Extra rows to display in the status panel Returns: The raw response dict when status is COMPLETED """ start = time.monotonic() interval = config.initial_delay current_status = "PENDING" result: dict[str, Any] = {} with Live( _render_panel(task_id, current_status, 0, config.task_type, extra_info), console=console, refresh_per_second=8, transient=False, ) as live: time.sleep(config.initial_delay) while True: elapsed = time.monotonic() - start if elapsed > config.max_wait: live.stop() raise FreepikTimeoutError( f"Task {task_id} timed out after {config.max_wait:.0f}s" ) try: current_status, result = check_fn(task_id) except Exception as exc: live.stop() raise exc live.update( _render_panel(task_id, current_status, elapsed, config.task_type, extra_info) ) upper = current_status.upper() if upper == "COMPLETED": live.update( _render_panel(task_id, "COMPLETED", elapsed, config.task_type, extra_info) ) return result if upper in ("FAILED", "CANCELLED"): live.stop() data = result.get("data", result) error_msg = ( data.get("error", {}).get("message") or data.get("message") or f"Task ended with status {upper}" ) raise FreepikTaskError(f"{upper}: {error_msg}") time.sleep(interval) interval = min(interval * config.backoff_factor, config.max_interval)