import asyncio import logging import uuid from datetime import datetime, timezone from typing import Optional from app.config import settings from app.schemas.common import TaskStatus from app.services import freepik_client from app.services.file_manager import download_result logger = logging.getLogger(__name__) _tasks: dict[str, dict] = {} _poll_tasks: dict[str, asyncio.Task] = {} def submit(freepik_task_id: str, status_path: str, metadata: Optional[dict] = None) -> str: """Register a Freepik task and start background polling. Args: freepik_task_id: The task_id returned by Freepik. status_path: The per-endpoint GET path for polling, e.g. '/v1/ai/text-to-image/flux-dev/{task-id}'. metadata: Optional metadata to attach to the task. """ internal_id = str(uuid.uuid4()) now = datetime.now(timezone.utc) _tasks[internal_id] = { 'task_id': internal_id, 'freepik_task_id': freepik_task_id, 'status_path': status_path, 'status': TaskStatus.pending, 'created_at': now, 'updated_at': now, 'progress': None, 'result_url': None, 'local_path': None, 'error': None, 'metadata': metadata or {}, } _poll_tasks[internal_id] = asyncio.create_task(_poll_loop(internal_id)) return internal_id def get_task(task_id: str) -> Optional[dict]: return _tasks.get(task_id) def list_tasks( status: Optional[TaskStatus] = None, limit: int = 20, offset: int = 0, ) -> tuple[list[dict], int]: tasks = list(_tasks.values()) if status: tasks = [t for t in tasks if t['status'] == status] tasks.sort(key=lambda t: t['created_at'], reverse=True) total = len(tasks) return tasks[offset:offset + limit], total def delete_task(task_id: str) -> bool: if task_id not in _tasks: return False poll = _poll_tasks.pop(task_id, None) if poll and not poll.done(): poll.cancel() _tasks.pop(task_id, None) return True def active_count() -> int: return sum( 1 for t in _tasks.values() if t['status'] in (TaskStatus.pending, TaskStatus.processing) ) async def _poll_loop(internal_id: str): """Poll Freepik API using the per-endpoint status path until done.""" task = _tasks.get(internal_id) if not task: return status_path = task['status_path'] elapsed = 0 try: while elapsed < settings.task_poll_timeout_seconds: await asyncio.sleep(settings.task_poll_interval_seconds) elapsed += settings.task_poll_interval_seconds try: result = await freepik_client.get_task_status(status_path) except Exception as exc: logger.warning(f'Poll error for {internal_id}: {exc}') continue data = result.get('data', result) fp_status = str(data.get('status', '')).upper() task['updated_at'] = datetime.now(timezone.utc) if fp_status in ('CREATED', 'IN_PROGRESS', 'PROCESSING'): task['status'] = TaskStatus.processing continue if fp_status == 'COMPLETED': task['status'] = TaskStatus.completed task['progress'] = 1.0 # Freepik returns results in data.generated[] (list of URLs) generated = data.get('generated', []) result_url = generated[0] if generated else None task['result_url'] = result_url if result_url: try: task['local_path'] = await download_result( internal_id, result_url ) except Exception as exc: logger.error(f'Download failed for {internal_id}: {exc}') logger.info(f'Task {internal_id} completed') return if fp_status == 'FAILED': task['status'] = TaskStatus.failed task['error'] = data.get('error', data.get('message', 'Unknown error')) logger.warning(f'Task {internal_id} failed: {task["error"]}') return # Timeout task['status'] = TaskStatus.failed task['error'] = f'Polling timed out after {settings.task_poll_timeout_seconds}s' logger.warning(f'Task {internal_id} timed out') except asyncio.CancelledError: logger.info(f'Polling cancelled for {internal_id}') except Exception as exc: task['status'] = TaskStatus.failed task['error'] = str(exc) logger.error(f'Unexpected error polling {internal_id}: {exc}') finally: _poll_tasks.pop(internal_id, None) def handle_webhook_completion(freepik_task_id: str, result_data: dict): """Called when a webhook notification arrives for a completed task.""" for task in _tasks.values(): if task['freepik_task_id'] == freepik_task_id: task['status'] = TaskStatus.completed task['progress'] = 1.0 task['updated_at'] = datetime.now(timezone.utc) generated = result_data.get('generated', []) task['result_url'] = generated[0] if generated else None poll = _poll_tasks.pop(task['task_id'], None) if poll and not poll.done(): poll.cancel() logger.info(f'Task {task["task_id"]} completed via webhook') break