194 lines
6.8 KiB
Python
194 lines
6.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Flux.1 Image Generation Service
|
||
|
|
|
||
|
|
OpenAI-compatible image generation using Flux.1 Schnell model.
|
||
|
|
Provides /v1/images/generations endpoint.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import io
|
||
|
|
import os
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from diffusers import FluxPipeline
|
||
|
|
from fastapi import HTTPException
|
||
|
|
from PIL import Image
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
# Import base service class
|
||
|
|
import sys
|
||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||
|
|
from core.base_service import GPUService
|
||
|
|
|
||
|
|
|
||
|
|
class ImageGenerationRequest(BaseModel):
|
||
|
|
"""Image generation request (OpenAI-compatible)"""
|
||
|
|
model: str = Field(default="flux-schnell", description="Model name")
|
||
|
|
prompt: str = Field(..., description="Text description of the image to generate")
|
||
|
|
n: int = Field(default=1, ge=1, le=4, description="Number of images to generate")
|
||
|
|
size: str = Field(default="1024x1024", description="Image size (e.g., 512x512, 1024x1024)")
|
||
|
|
response_format: str = Field(default="b64_json", description="Response format: url or b64_json")
|
||
|
|
quality: str = Field(default="standard", description="Image quality: standard or hd")
|
||
|
|
style: str = Field(default="natural", description="Image style: natural or vivid")
|
||
|
|
|
||
|
|
|
||
|
|
class ImageGenerationResponse(BaseModel):
|
||
|
|
"""Image generation response (OpenAI-compatible)"""
|
||
|
|
created: int = Field(..., description="Unix timestamp")
|
||
|
|
data: list = Field(..., description="List of generated images")
|
||
|
|
|
||
|
|
|
||
|
|
class FluxService(GPUService):
|
||
|
|
"""Flux.1 Schnell image generation service"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
# Get port from environment or use default
|
||
|
|
port = int(os.getenv("PORT", "8002"))
|
||
|
|
super().__init__(name="flux-schnell", port=port)
|
||
|
|
|
||
|
|
# Service-specific attributes
|
||
|
|
self.pipeline: Optional[FluxPipeline] = None
|
||
|
|
self.model_name = os.getenv("MODEL_NAME", "black-forest-labs/FLUX.1-schnell")
|
||
|
|
|
||
|
|
async def initialize(self):
|
||
|
|
"""Initialize Flux.1 pipeline"""
|
||
|
|
await super().initialize()
|
||
|
|
|
||
|
|
self.logger.info(f"Loading Flux.1 pipeline: {self.model_name}")
|
||
|
|
|
||
|
|
# Load pipeline
|
||
|
|
self.pipeline = FluxPipeline.from_pretrained(
|
||
|
|
self.model_name,
|
||
|
|
torch_dtype=torch.bfloat16,
|
||
|
|
cache_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache")
|
||
|
|
)
|
||
|
|
|
||
|
|
# Move to GPU
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
self.pipeline = self.pipeline.to("cuda")
|
||
|
|
self.logger.info("Flux.1 pipeline loaded on GPU")
|
||
|
|
else:
|
||
|
|
self.logger.warning("GPU not available, running on CPU (very slow)")
|
||
|
|
|
||
|
|
# Enable memory optimizations
|
||
|
|
if hasattr(self.pipeline, 'enable_model_cpu_offload'):
|
||
|
|
# This moves models to GPU only when needed, saving VRAM
|
||
|
|
self.pipeline.enable_model_cpu_offload()
|
||
|
|
|
||
|
|
self.logger.info("Flux.1 pipeline initialized successfully")
|
||
|
|
|
||
|
|
async def cleanup(self):
|
||
|
|
"""Cleanup resources"""
|
||
|
|
await super().cleanup()
|
||
|
|
if self.pipeline:
|
||
|
|
self.logger.info("Flux.1 pipeline cleanup")
|
||
|
|
self.pipeline = None
|
||
|
|
|
||
|
|
def parse_size(self, size_str: str) -> tuple[int, int]:
|
||
|
|
"""Parse size string like '1024x1024' into (width, height)"""
|
||
|
|
try:
|
||
|
|
parts = size_str.lower().split('x')
|
||
|
|
if len(parts) != 2:
|
||
|
|
return (1024, 1024)
|
||
|
|
width = int(parts[0])
|
||
|
|
height = int(parts[1])
|
||
|
|
return (width, height)
|
||
|
|
except:
|
||
|
|
return (1024, 1024)
|
||
|
|
|
||
|
|
def image_to_base64(self, image: Image.Image) -> str:
|
||
|
|
"""Convert PIL Image to base64 string"""
|
||
|
|
buffered = io.BytesIO()
|
||
|
|
image.save(buffered, format="PNG")
|
||
|
|
img_bytes = buffered.getvalue()
|
||
|
|
return base64.b64encode(img_bytes).decode('utf-8')
|
||
|
|
|
||
|
|
def create_app(self):
|
||
|
|
"""Create FastAPI routes"""
|
||
|
|
|
||
|
|
@self.app.get("/")
|
||
|
|
async def root():
|
||
|
|
"""Root endpoint"""
|
||
|
|
return {
|
||
|
|
"service": "Flux.1 Schnell Image Generation",
|
||
|
|
"model": self.model_name,
|
||
|
|
"max_images": 4
|
||
|
|
}
|
||
|
|
|
||
|
|
@self.app.get("/v1/models")
|
||
|
|
async def list_models():
|
||
|
|
"""List available models (OpenAI-compatible)"""
|
||
|
|
return {
|
||
|
|
"object": "list",
|
||
|
|
"data": [
|
||
|
|
{
|
||
|
|
"id": "flux-schnell",
|
||
|
|
"object": "model",
|
||
|
|
"created": 1234567890,
|
||
|
|
"owned_by": "black-forest-labs",
|
||
|
|
"permission": [],
|
||
|
|
"root": self.model_name,
|
||
|
|
"parent": None,
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
@self.app.post("/v1/images/generations")
|
||
|
|
async def generate_image(request: ImageGenerationRequest) -> ImageGenerationResponse:
|
||
|
|
"""Generate images from text prompt (OpenAI-compatible)"""
|
||
|
|
if not self.pipeline:
|
||
|
|
raise HTTPException(status_code=503, detail="Model not initialized")
|
||
|
|
|
||
|
|
self.logger.info(f"Generating {request.n} image(s): {request.prompt[:100]}...")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Parse image size
|
||
|
|
width, height = self.parse_size(request.size)
|
||
|
|
self.logger.info(f"Size: {width}x{height}")
|
||
|
|
|
||
|
|
# Generate images
|
||
|
|
images = []
|
||
|
|
for i in range(request.n):
|
||
|
|
self.logger.info(f"Generating image {i+1}/{request.n}")
|
||
|
|
|
||
|
|
# Flux.1 Schnell uses 4 inference steps for speed
|
||
|
|
image = self.pipeline(
|
||
|
|
prompt=request.prompt,
|
||
|
|
width=width,
|
||
|
|
height=height,
|
||
|
|
num_inference_steps=4, # Schnell is optimized for 4 steps
|
||
|
|
guidance_scale=0.0, # Schnell doesn't use guidance
|
||
|
|
).images[0]
|
||
|
|
|
||
|
|
# Convert to base64
|
||
|
|
if request.response_format == "b64_json":
|
||
|
|
image_data = {
|
||
|
|
"b64_json": self.image_to_base64(image)
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
# For URL format, we'd need to save and serve the file
|
||
|
|
# For now, we'll return base64 anyway
|
||
|
|
image_data = {
|
||
|
|
"b64_json": self.image_to_base64(image)
|
||
|
|
}
|
||
|
|
|
||
|
|
images.append(image_data)
|
||
|
|
|
||
|
|
self.logger.info(f"Generated {request.n} image(s) successfully")
|
||
|
|
|
||
|
|
return ImageGenerationResponse(
|
||
|
|
created=1234567890,
|
||
|
|
data=images
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error generating image: {e}", exc_info=True)
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
service = FluxService()
|
||
|
|
service.run()
|