#!/usr/bin/env python3 """ Flux.1 Image Generation Service OpenAI-compatible image generation using Flux.1 Schnell model. Provides /v1/images/generations endpoint. """ import base64 import io import os from typing import Optional import torch from diffusers import FluxPipeline from fastapi import HTTPException from PIL import Image from pydantic import BaseModel, Field # Import base service class import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..')) from core.base_service import GPUService class ImageGenerationRequest(BaseModel): """Image generation request (OpenAI-compatible)""" model: str = Field(default="flux-schnell", description="Model name") prompt: str = Field(..., description="Text description of the image to generate") n: int = Field(default=1, ge=1, le=4, description="Number of images to generate") size: str = Field(default="1024x1024", description="Image size (e.g., 512x512, 1024x1024)") response_format: str = Field(default="b64_json", description="Response format: url or b64_json") quality: str = Field(default="standard", description="Image quality: standard or hd") style: str = Field(default="natural", description="Image style: natural or vivid") class ImageGenerationResponse(BaseModel): """Image generation response (OpenAI-compatible)""" created: int = Field(..., description="Unix timestamp") data: list = Field(..., description="List of generated images") class FluxService(GPUService): """Flux.1 Schnell image generation service""" def __init__(self): # Get port from environment or use default port = int(os.getenv("PORT", "8002")) super().__init__(name="flux-schnell", port=port) # Service-specific attributes self.pipeline: Optional[FluxPipeline] = None self.model_name = os.getenv("MODEL_NAME", "black-forest-labs/FLUX.1-schnell") async def initialize(self): """Initialize Flux.1 pipeline""" await super().initialize() self.logger.info(f"Loading Flux.1 pipeline: {self.model_name}") # Load pipeline self.pipeline = FluxPipeline.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, cache_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache") ) # Move to GPU if torch.cuda.is_available(): self.pipeline = self.pipeline.to("cuda") self.logger.info("Flux.1 pipeline loaded on GPU") else: self.logger.warning("GPU not available, running on CPU (very slow)") # Enable memory optimizations if hasattr(self.pipeline, 'enable_model_cpu_offload'): # This moves models to GPU only when needed, saving VRAM self.pipeline.enable_model_cpu_offload() self.logger.info("Flux.1 pipeline initialized successfully") async def cleanup(self): """Cleanup resources""" await super().cleanup() if self.pipeline: self.logger.info("Flux.1 pipeline cleanup") self.pipeline = None def parse_size(self, size_str: str) -> tuple[int, int]: """Parse size string like '1024x1024' into (width, height)""" try: parts = size_str.lower().split('x') if len(parts) != 2: return (1024, 1024) width = int(parts[0]) height = int(parts[1]) return (width, height) except: return (1024, 1024) def image_to_base64(self, image: Image.Image) -> str: """Convert PIL Image to base64 string""" buffered = io.BytesIO() image.save(buffered, format="PNG") img_bytes = buffered.getvalue() return base64.b64encode(img_bytes).decode('utf-8') def create_app(self): """Create FastAPI routes""" @self.app.get("/") async def root(): """Root endpoint""" return { "service": "Flux.1 Schnell Image Generation", "model": self.model_name, "max_images": 4 } @self.app.get("/v1/models") async def list_models(): """List available models (OpenAI-compatible)""" return { "object": "list", "data": [ { "id": "flux-schnell", "object": "model", "created": 1234567890, "owned_by": "black-forest-labs", "permission": [], "root": self.model_name, "parent": None, } ] } @self.app.post("/v1/images/generations") async def generate_image(request: ImageGenerationRequest) -> ImageGenerationResponse: """Generate images from text prompt (OpenAI-compatible)""" if not self.pipeline: raise HTTPException(status_code=503, detail="Model not initialized") self.logger.info(f"Generating {request.n} image(s): {request.prompt[:100]}...") try: # Parse image size width, height = self.parse_size(request.size) self.logger.info(f"Size: {width}x{height}") # Generate images images = [] for i in range(request.n): self.logger.info(f"Generating image {i+1}/{request.n}") # Flux.1 Schnell uses 4 inference steps for speed image = self.pipeline( prompt=request.prompt, width=width, height=height, num_inference_steps=4, # Schnell is optimized for 4 steps guidance_scale=0.0, # Schnell doesn't use guidance ).images[0] # Convert to base64 if request.response_format == "b64_json": image_data = { "b64_json": self.image_to_base64(image) } else: # For URL format, we'd need to save and serve the file # For now, we'll return base64 anyway image_data = { "b64_json": self.image_to_base64(image) } images.append(image_data) self.logger.info(f"Generated {request.n} image(s) successfully") return ImageGenerationResponse( created=1234567890, data=images ) except Exception as e: self.logger.error(f"Error generating image: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": service = FluxService() service.run()