Files
freepik-api/app/services/task_tracker.py

161 lines
5.4 KiB
Python
Raw Permalink Normal View History

import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Optional
from app.config import settings
from app.schemas.common import TaskStatus
from app.services import freepik_client
from app.services.file_manager import download_result
logger = logging.getLogger(__name__)
_tasks: dict[str, dict] = {}
_poll_tasks: dict[str, asyncio.Task] = {}
def submit(freepik_task_id: str, status_path: str, metadata: Optional[dict] = None) -> str:
"""Register a Freepik task and start background polling.
Args:
freepik_task_id: The task_id returned by Freepik.
status_path: The per-endpoint GET path for polling, e.g.
'/v1/ai/text-to-image/flux-dev/{task-id}'.
metadata: Optional metadata to attach to the task.
"""
internal_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
_tasks[internal_id] = {
'task_id': internal_id,
'freepik_task_id': freepik_task_id,
'status_path': status_path,
'status': TaskStatus.pending,
'created_at': now,
'updated_at': now,
'progress': None,
'result_url': None,
'local_path': None,
'error': None,
'metadata': metadata or {},
}
_poll_tasks[internal_id] = asyncio.create_task(_poll_loop(internal_id))
return internal_id
def get_task(task_id: str) -> Optional[dict]:
return _tasks.get(task_id)
def list_tasks(
status: Optional[TaskStatus] = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[dict], int]:
tasks = list(_tasks.values())
if status:
tasks = [t for t in tasks if t['status'] == status]
tasks.sort(key=lambda t: t['created_at'], reverse=True)
total = len(tasks)
return tasks[offset:offset + limit], total
def delete_task(task_id: str) -> bool:
if task_id not in _tasks:
return False
poll = _poll_tasks.pop(task_id, None)
if poll and not poll.done():
poll.cancel()
_tasks.pop(task_id, None)
return True
def active_count() -> int:
return sum(
1 for t in _tasks.values()
if t['status'] in (TaskStatus.pending, TaskStatus.processing)
)
async def _poll_loop(internal_id: str):
"""Poll Freepik API using the per-endpoint status path until done."""
task = _tasks.get(internal_id)
if not task:
return
status_path = task['status_path']
elapsed = 0
try:
while elapsed < settings.task_poll_timeout_seconds:
await asyncio.sleep(settings.task_poll_interval_seconds)
elapsed += settings.task_poll_interval_seconds
try:
result = await freepik_client.get_task_status(status_path)
except Exception as exc:
logger.warning(f'Poll error for {internal_id}: {exc}')
continue
data = result.get('data', result)
fp_status = str(data.get('status', '')).upper()
task['updated_at'] = datetime.now(timezone.utc)
if fp_status in ('CREATED', 'IN_PROGRESS', 'PROCESSING'):
task['status'] = TaskStatus.processing
continue
if fp_status == 'COMPLETED':
task['status'] = TaskStatus.completed
task['progress'] = 1.0
# Freepik returns results in data.generated[] (list of URLs)
generated = data.get('generated', [])
result_url = generated[0] if generated else None
task['result_url'] = result_url
if result_url:
try:
task['local_path'] = await download_result(
internal_id, result_url
)
except Exception as exc:
logger.error(f'Download failed for {internal_id}: {exc}')
logger.info(f'Task {internal_id} completed')
return
if fp_status == 'FAILED':
task['status'] = TaskStatus.failed
task['error'] = data.get('error', data.get('message', 'Unknown error'))
logger.warning(f'Task {internal_id} failed: {task["error"]}')
return
# Timeout
task['status'] = TaskStatus.failed
task['error'] = f'Polling timed out after {settings.task_poll_timeout_seconds}s'
logger.warning(f'Task {internal_id} timed out')
except asyncio.CancelledError:
logger.info(f'Polling cancelled for {internal_id}')
except Exception as exc:
task['status'] = TaskStatus.failed
task['error'] = str(exc)
logger.error(f'Unexpected error polling {internal_id}: {exc}')
finally:
_poll_tasks.pop(internal_id, None)
def handle_webhook_completion(freepik_task_id: str, result_data: dict):
"""Called when a webhook notification arrives for a completed task."""
for task in _tasks.values():
if task['freepik_task_id'] == freepik_task_id:
task['status'] = TaskStatus.completed
task['progress'] = 1.0
task['updated_at'] = datetime.now(timezone.utc)
generated = result_data.get('generated', [])
task['result_url'] = generated[0] if generated else None
poll = _poll_tasks.pop(task['task_id'], None)
if poll and not poll.done():
poll.cancel()
logger.info(f'Task {task["task_id"]} completed via webhook')
break