89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
|
|
"""Image and video upscaling API methods."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any, Optional, Tuple
|
||
|
|
|
||
|
|
from freepik_cli.api.client import FreepikClient
|
||
|
|
from freepik_cli.api.models import (
|
||
|
|
UPSCALE_POST_ENDPOINTS,
|
||
|
|
UPSCALE_STATUS_ENDPOINTS,
|
||
|
|
VIDEO_UPSCALE_POST_ENDPOINTS,
|
||
|
|
VIDEO_UPSCALE_STATUS_ENDPOINT,
|
||
|
|
UpscaleMode,
|
||
|
|
VideoUpscaleMode,
|
||
|
|
get_output_urls,
|
||
|
|
get_status,
|
||
|
|
get_task_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class UpscaleAPI:
|
||
|
|
def __init__(self, client: FreepikClient) -> None:
|
||
|
|
self._client = client
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Image upscaling
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def upscale_image(
|
||
|
|
self,
|
||
|
|
mode: UpscaleMode,
|
||
|
|
image_b64: str,
|
||
|
|
scale_factor: Optional[str] = None,
|
||
|
|
creativity: Optional[int] = None,
|
||
|
|
prompt: Optional[str] = None,
|
||
|
|
seed: Optional[int] = None,
|
||
|
|
) -> str:
|
||
|
|
"""Submit an image upscale task. Returns task_id."""
|
||
|
|
payload: dict[str, Any] = {"image": image_b64}
|
||
|
|
|
||
|
|
if scale_factor:
|
||
|
|
# Convert "2x" → 2, "4x" → 4
|
||
|
|
factor = scale_factor.rstrip("xX")
|
||
|
|
try:
|
||
|
|
payload["scale_factor"] = int(factor)
|
||
|
|
except ValueError:
|
||
|
|
payload["scale_factor"] = 2
|
||
|
|
|
||
|
|
if mode == UpscaleMode.CREATIVE:
|
||
|
|
if creativity is not None:
|
||
|
|
payload["creativity"] = creativity
|
||
|
|
if prompt:
|
||
|
|
payload["prompt"] = prompt
|
||
|
|
|
||
|
|
if seed is not None:
|
||
|
|
payload["seed"] = seed
|
||
|
|
|
||
|
|
endpoint = UPSCALE_POST_ENDPOINTS[mode]
|
||
|
|
raw = self._client.post(endpoint, json=payload)
|
||
|
|
return get_task_id(raw)
|
||
|
|
|
||
|
|
def upscale_image_status(self, mode: UpscaleMode, task_id: str) -> Tuple[str, dict[str, Any]]:
|
||
|
|
endpoint = UPSCALE_STATUS_ENDPOINTS[mode].format(task_id=task_id)
|
||
|
|
raw = self._client.get(endpoint)
|
||
|
|
return get_status(raw), raw
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Video upscaling
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def upscale_video(
|
||
|
|
self,
|
||
|
|
mode: VideoUpscaleMode,
|
||
|
|
video_b64: str,
|
||
|
|
) -> str:
|
||
|
|
"""Submit a video upscale task. Returns task_id."""
|
||
|
|
payload: dict[str, Any] = {"video": video_b64}
|
||
|
|
endpoint = VIDEO_UPSCALE_POST_ENDPOINTS[mode]
|
||
|
|
raw = self._client.post(endpoint, json=payload)
|
||
|
|
return get_task_id(raw)
|
||
|
|
|
||
|
|
def upscale_video_status(self, task_id: str) -> Tuple[str, dict[str, Any]]:
|
||
|
|
endpoint = VIDEO_UPSCALE_STATUS_ENDPOINT.format(task_id=task_id)
|
||
|
|
raw = self._client.get(endpoint)
|
||
|
|
return get_status(raw), raw
|
||
|
|
|
||
|
|
def get_output_urls(self, raw: dict[str, Any]) -> list[str]:
|
||
|
|
return get_output_urls(raw)
|