Files

89 lines
2.8 KiB
Python
Raw Permalink Normal View History

2026-04-08 10:56:45 +02:00
"""Image and video upscaling API methods."""
from __future__ import annotations
from typing import Any, Optional, Tuple
from freepik_cli.api.client import FreepikClient
from freepik_cli.api.models import (
UPSCALE_POST_ENDPOINTS,
UPSCALE_STATUS_ENDPOINTS,
VIDEO_UPSCALE_POST_ENDPOINTS,
VIDEO_UPSCALE_STATUS_ENDPOINT,
UpscaleMode,
VideoUpscaleMode,
get_output_urls,
get_status,
get_task_id,
)
class UpscaleAPI:
def __init__(self, client: FreepikClient) -> None:
self._client = client
# ------------------------------------------------------------------
# Image upscaling
# ------------------------------------------------------------------
def upscale_image(
self,
mode: UpscaleMode,
image_b64: str,
scale_factor: Optional[str] = None,
creativity: Optional[int] = None,
prompt: Optional[str] = None,
seed: Optional[int] = None,
) -> str:
"""Submit an image upscale task. Returns task_id."""
payload: dict[str, Any] = {"image": image_b64}
if scale_factor:
# Convert "2x" → 2, "4x" → 4
factor = scale_factor.rstrip("xX")
try:
payload["scale_factor"] = int(factor)
except ValueError:
payload["scale_factor"] = 2
if mode == UpscaleMode.CREATIVE:
if creativity is not None:
payload["creativity"] = creativity
if prompt:
payload["prompt"] = prompt
if seed is not None:
payload["seed"] = seed
endpoint = UPSCALE_POST_ENDPOINTS[mode]
raw = self._client.post(endpoint, json=payload)
return get_task_id(raw)
def upscale_image_status(self, mode: UpscaleMode, task_id: str) -> Tuple[str, dict[str, Any]]:
endpoint = UPSCALE_STATUS_ENDPOINTS[mode].format(task_id=task_id)
raw = self._client.get(endpoint)
return get_status(raw), raw
# ------------------------------------------------------------------
# Video upscaling
# ------------------------------------------------------------------
def upscale_video(
self,
mode: VideoUpscaleMode,
video_b64: str,
) -> str:
"""Submit a video upscale task. Returns task_id."""
payload: dict[str, Any] = {"video": video_b64}
endpoint = VIDEO_UPSCALE_POST_ENDPOINTS[mode]
raw = self._client.post(endpoint, json=payload)
return get_task_id(raw)
def upscale_video_status(self, task_id: str) -> Tuple[str, dict[str, Any]]:
endpoint = VIDEO_UPSCALE_STATUS_ENDPOINT.format(task_id=task_id)
raw = self._client.get(endpoint)
return get_status(raw), raw
def get_output_urls(self, raw: dict[str, Any]) -> list[str]:
return get_output_urls(raw)