106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
|
|
"""Image generation and analysis API methods."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any, Optional, Tuple
|
||
|
|
|
||
|
|
from freepik_cli.api.client import FreepikClient
|
||
|
|
from freepik_cli.api.models import (
|
||
|
|
IMAGE_POST_ENDPOINTS,
|
||
|
|
IMAGE_STATUS_ENDPOINTS,
|
||
|
|
ImageModel,
|
||
|
|
get_output_urls,
|
||
|
|
get_status,
|
||
|
|
get_task_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class ImageAPI:
|
||
|
|
def __init__(self, client: FreepikClient) -> None:
|
||
|
|
self._client = client
|
||
|
|
|
||
|
|
def generate(self, model: ImageModel, payload: dict[str, Any]) -> str:
|
||
|
|
"""Submit a generation task. Returns task_id."""
|
||
|
|
endpoint = IMAGE_POST_ENDPOINTS[model]
|
||
|
|
raw = self._client.post(endpoint, json=payload)
|
||
|
|
return get_task_id(raw)
|
||
|
|
|
||
|
|
def get_status(self, model: ImageModel, task_id: str) -> Tuple[str, dict[str, Any]]:
|
||
|
|
"""Poll status. Returns (status_str, raw_response)."""
|
||
|
|
endpoint = IMAGE_STATUS_ENDPOINTS[model].format(task_id=task_id)
|
||
|
|
raw = self._client.get(endpoint)
|
||
|
|
return get_status(raw), raw
|
||
|
|
|
||
|
|
def get_output_urls(self, raw: dict[str, Any]) -> list[str]:
|
||
|
|
return get_output_urls(raw)
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Image-to-prompt (describe)
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def describe_submit(self, image_b64: str) -> str:
|
||
|
|
"""Submit image-to-prompt task. Returns task_id."""
|
||
|
|
raw = self._client.post("/v1/ai/image-to-prompt", json={"image": image_b64})
|
||
|
|
return get_task_id(raw)
|
||
|
|
|
||
|
|
def describe_status(self, task_id: str) -> Tuple[str, dict[str, Any]]:
|
||
|
|
raw = self._client.get(f"/v1/ai/image-to-prompt/{task_id}")
|
||
|
|
return get_status(raw), raw
|
||
|
|
|
||
|
|
def get_prompt_text(self, raw: dict[str, Any]) -> str:
|
||
|
|
"""Extract generated prompt text from a completed describe response."""
|
||
|
|
data = raw.get("data", raw)
|
||
|
|
return (
|
||
|
|
data.get("prompt")
|
||
|
|
or data.get("description")
|
||
|
|
or data.get("text")
|
||
|
|
or data.get("result", {}).get("prompt", "")
|
||
|
|
or ""
|
||
|
|
)
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Image expansion (outpainting)
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def expand_submit(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
image_b64: str,
|
||
|
|
left: int = 0,
|
||
|
|
right: int = 0,
|
||
|
|
top: int = 0,
|
||
|
|
bottom: int = 0,
|
||
|
|
prompt: Optional[str] = None,
|
||
|
|
seed: Optional[int] = None,
|
||
|
|
) -> str:
|
||
|
|
payload: dict[str, Any] = {
|
||
|
|
"image": image_b64,
|
||
|
|
"left": left,
|
||
|
|
"right": right,
|
||
|
|
"top": top,
|
||
|
|
"bottom": bottom,
|
||
|
|
}
|
||
|
|
if prompt:
|
||
|
|
payload["prompt"] = prompt
|
||
|
|
if seed is not None:
|
||
|
|
payload["seed"] = seed
|
||
|
|
|
||
|
|
endpoint_map = {
|
||
|
|
"flux-pro": "/v1/ai/image-expand/flux-pro",
|
||
|
|
"ideogram": "/v1/ai/image-expand/ideogram",
|
||
|
|
"seedream-v4-5": "/v1/ai/image-expand/seedream-v4-5",
|
||
|
|
}
|
||
|
|
endpoint = endpoint_map.get(model, "/v1/ai/image-expand/flux-pro")
|
||
|
|
raw = self._client.post(endpoint, json=payload)
|
||
|
|
return get_task_id(raw)
|
||
|
|
|
||
|
|
def expand_status(self, model: str, task_id: str) -> Tuple[str, dict[str, Any]]:
|
||
|
|
endpoint_map = {
|
||
|
|
"flux-pro": "/v1/ai/image-expand/flux-pro",
|
||
|
|
"ideogram": "/v1/ai/image-expand/ideogram",
|
||
|
|
"seedream-v4-5": "/v1/ai/image-expand/seedream-v4-5",
|
||
|
|
}
|
||
|
|
base = endpoint_map.get(model, "/v1/ai/image-expand/flux-pro")
|
||
|
|
raw = self._client.get(f"{base}/{task_id}")
|
||
|
|
return get_status(raw), raw
|