180 lines
6.1 KiB
Python
180 lines
6.1 KiB
Python
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.schemas.common import TaskStatus
|
||
|
|
from app.services import freepik_client
|
||
|
|
from app.services.file_manager import download_result
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
_tasks: dict[str, dict] = {}
|
||
|
|
_poll_tasks: dict[str, asyncio.Task] = {}
|
||
|
|
|
||
|
|
|
||
|
|
def submit(freepik_task_id: str, metadata: Optional[dict] = None) -> str:
|
||
|
|
"""Register a Freepik task and start background polling."""
|
||
|
|
internal_id = str(uuid.uuid4())
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
_tasks[internal_id] = {
|
||
|
|
'task_id': internal_id,
|
||
|
|
'freepik_task_id': freepik_task_id,
|
||
|
|
'status': TaskStatus.pending,
|
||
|
|
'created_at': now,
|
||
|
|
'updated_at': now,
|
||
|
|
'progress': None,
|
||
|
|
'result_url': None,
|
||
|
|
'local_path': None,
|
||
|
|
'error': None,
|
||
|
|
'metadata': metadata or {},
|
||
|
|
}
|
||
|
|
_poll_tasks[internal_id] = asyncio.create_task(_poll_loop(internal_id))
|
||
|
|
return internal_id
|
||
|
|
|
||
|
|
|
||
|
|
def get_task(task_id: str) -> Optional[dict]:
|
||
|
|
return _tasks.get(task_id)
|
||
|
|
|
||
|
|
|
||
|
|
def list_tasks(
|
||
|
|
status: Optional[TaskStatus] = None,
|
||
|
|
limit: int = 20,
|
||
|
|
offset: int = 0,
|
||
|
|
) -> tuple[list[dict], int]:
|
||
|
|
tasks = list(_tasks.values())
|
||
|
|
if status:
|
||
|
|
tasks = [t for t in tasks if t['status'] == status]
|
||
|
|
tasks.sort(key=lambda t: t['created_at'], reverse=True)
|
||
|
|
total = len(tasks)
|
||
|
|
return tasks[offset:offset + limit], total
|
||
|
|
|
||
|
|
|
||
|
|
def delete_task(task_id: str) -> bool:
|
||
|
|
if task_id not in _tasks:
|
||
|
|
return False
|
||
|
|
# Cancel polling if active
|
||
|
|
poll = _poll_tasks.pop(task_id, None)
|
||
|
|
if poll and not poll.done():
|
||
|
|
poll.cancel()
|
||
|
|
_tasks.pop(task_id, None)
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def active_count() -> int:
|
||
|
|
return sum(
|
||
|
|
1 for t in _tasks.values()
|
||
|
|
if t['status'] in (TaskStatus.pending, TaskStatus.processing)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def _poll_loop(internal_id: str):
|
||
|
|
"""Poll Freepik API until the task completes or times out."""
|
||
|
|
task = _tasks.get(internal_id)
|
||
|
|
if not task:
|
||
|
|
return
|
||
|
|
|
||
|
|
freepik_id = task['freepik_task_id']
|
||
|
|
elapsed = 0
|
||
|
|
|
||
|
|
try:
|
||
|
|
while elapsed < settings.task_poll_timeout_seconds:
|
||
|
|
await asyncio.sleep(settings.task_poll_interval_seconds)
|
||
|
|
elapsed += settings.task_poll_interval_seconds
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = await freepik_client.get_task_status(freepik_id)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning(f'Poll error for {internal_id}: {exc}')
|
||
|
|
continue
|
||
|
|
|
||
|
|
data = result.get('data', result)
|
||
|
|
fp_status = data.get('status', '')
|
||
|
|
|
||
|
|
task['updated_at'] = datetime.now(timezone.utc)
|
||
|
|
|
||
|
|
if fp_status in ('IN_PROGRESS', 'PROCESSING', 'processing'):
|
||
|
|
task['status'] = TaskStatus.processing
|
||
|
|
task['progress'] = data.get('progress')
|
||
|
|
continue
|
||
|
|
|
||
|
|
if fp_status in ('COMPLETED', 'completed', 'done'):
|
||
|
|
task['status'] = TaskStatus.completed
|
||
|
|
task['progress'] = 1.0
|
||
|
|
# Extract result URL from various response shapes
|
||
|
|
result_url = (
|
||
|
|
data.get('result_url')
|
||
|
|
or data.get('output', {}).get('url')
|
||
|
|
or _extract_first_url(data)
|
||
|
|
)
|
||
|
|
task['result_url'] = result_url
|
||
|
|
if result_url:
|
||
|
|
try:
|
||
|
|
task['local_path'] = await download_result(
|
||
|
|
internal_id, result_url
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f'Download failed for {internal_id}: {exc}')
|
||
|
|
logger.info(f'Task {internal_id} completed')
|
||
|
|
return
|
||
|
|
|
||
|
|
if fp_status in ('FAILED', 'failed', 'error'):
|
||
|
|
task['status'] = TaskStatus.failed
|
||
|
|
task['error'] = data.get('error', data.get('message', 'Unknown error'))
|
||
|
|
logger.warning(f'Task {internal_id} failed: {task["error"]}')
|
||
|
|
return
|
||
|
|
|
||
|
|
# Timeout
|
||
|
|
task['status'] = TaskStatus.failed
|
||
|
|
task['error'] = f'Polling timed out after {settings.task_poll_timeout_seconds}s'
|
||
|
|
logger.warning(f'Task {internal_id} timed out')
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
logger.info(f'Polling cancelled for {internal_id}')
|
||
|
|
except Exception as exc:
|
||
|
|
task['status'] = TaskStatus.failed
|
||
|
|
task['error'] = str(exc)
|
||
|
|
logger.error(f'Unexpected error polling {internal_id}: {exc}')
|
||
|
|
finally:
|
||
|
|
_poll_tasks.pop(internal_id, None)
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_first_url(data: dict) -> Optional[str]:
|
||
|
|
"""Try to extract the first URL from common Freepik response shapes."""
|
||
|
|
# Some endpoints return {"data": {"images": [{"url": "..."}]}}
|
||
|
|
for key in ('images', 'videos', 'results', 'outputs'):
|
||
|
|
items = data.get(key, [])
|
||
|
|
if isinstance(items, list) and items:
|
||
|
|
first = items[0]
|
||
|
|
if isinstance(first, dict) and 'url' in first:
|
||
|
|
return first['url']
|
||
|
|
if isinstance(first, str) and first.startswith('http'):
|
||
|
|
return first
|
||
|
|
# Direct URL field
|
||
|
|
if 'url' in data and isinstance(data['url'], str):
|
||
|
|
return data['url']
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def handle_webhook_completion(freepik_task_id: str, result_data: dict):
|
||
|
|
"""Called when a webhook notification arrives for a completed task."""
|
||
|
|
for task in _tasks.values():
|
||
|
|
if task['freepik_task_id'] == freepik_task_id:
|
||
|
|
task['status'] = TaskStatus.completed
|
||
|
|
task['progress'] = 1.0
|
||
|
|
task['updated_at'] = datetime.now(timezone.utc)
|
||
|
|
result_url = (
|
||
|
|
result_data.get('result_url')
|
||
|
|
or result_data.get('output', {}).get('url')
|
||
|
|
or _extract_first_url(result_data)
|
||
|
|
)
|
||
|
|
task['result_url'] = result_url
|
||
|
|
# Cancel polling since webhook already notified us
|
||
|
|
poll = _poll_tasks.pop(task['task_id'], None)
|
||
|
|
if poll and not poll.done():
|
||
|
|
poll.cancel()
|
||
|
|
logger.info(f'Task {task["task_id"]} completed via webhook')
|
||
|
|
break
|