Add root_path setting to configure Gradio for HTTPS reverse proxy. Set AUDIOCRAFT_ROOT_PATH env var to external URL (e.g., https://example.com). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
180 lines
5.4 KiB
Python
180 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Main entry point for AudioCraft Studio."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Get project root directory
|
|
PROJECT_ROOT = Path(__file__).parent.absolute()
|
|
|
|
# Change to project root to ensure relative paths work
|
|
os.chdir(PROJECT_ROOT)
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from config.settings import get_settings
|
|
from src.core.gpu_manager import GPUMemoryManager
|
|
from src.core.model_registry import ModelRegistry
|
|
from src.services.generation_service import GenerationService
|
|
from src.services.batch_processor import BatchProcessor
|
|
from src.services.project_service import ProjectService
|
|
from src.storage.database import Database
|
|
from src.ui.app import create_app
|
|
from src.models import register_all_adapters
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler("audiocraft.log"),
|
|
],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def initialize_services():
|
|
"""Initialize all application services."""
|
|
settings = get_settings()
|
|
|
|
# Initialize database
|
|
logger.info("Initializing database...")
|
|
db = Database(settings.database_path)
|
|
await db.initialize()
|
|
|
|
# Initialize GPU manager
|
|
logger.info("Initializing GPU manager...")
|
|
gpu_manager = GPUMemoryManager(
|
|
device_id=0,
|
|
comfyui_reserve_gb=settings.comfyui_reserve_gb,
|
|
)
|
|
|
|
# Initialize model registry
|
|
logger.info("Initializing model registry...")
|
|
config_path = Path(settings.models_config)
|
|
if not config_path.is_absolute():
|
|
config_path = PROJECT_ROOT / config_path
|
|
logger.info(f"Models config path: {config_path} (exists: {config_path.exists()})")
|
|
|
|
model_registry = ModelRegistry(
|
|
config_path=config_path,
|
|
gpu_manager=gpu_manager,
|
|
max_cached_models=settings.max_loaded_models,
|
|
idle_timeout_minutes=settings.idle_unload_minutes,
|
|
)
|
|
|
|
# Register all model adapters
|
|
logger.info("Registering model adapters...")
|
|
register_all_adapters(model_registry)
|
|
|
|
# Log available models for debugging
|
|
available_models = model_registry.list_models()
|
|
logger.info(f"Available models: {len(available_models)}")
|
|
for model in available_models[:5]: # Log first 5
|
|
logger.info(f" - {model['model_id']}/{model['variant']} (has_adapter: {model['has_adapter']})")
|
|
|
|
# Initialize services
|
|
logger.info("Initializing services...")
|
|
generation_service = GenerationService(
|
|
registry=model_registry,
|
|
gpu_manager=gpu_manager,
|
|
database=db,
|
|
output_dir=settings.output_dir,
|
|
)
|
|
|
|
batch_processor = BatchProcessor(
|
|
generation_service=generation_service,
|
|
max_queue_size=settings.max_queue_size,
|
|
)
|
|
|
|
project_service = ProjectService(
|
|
database=db,
|
|
output_dir=settings.output_dir,
|
|
)
|
|
|
|
return {
|
|
"db": db,
|
|
"gpu_manager": gpu_manager,
|
|
"model_registry": model_registry,
|
|
"generation_service": generation_service,
|
|
"batch_processor": batch_processor,
|
|
"project_service": project_service,
|
|
}
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
settings = get_settings()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("AudioCraft Studio")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Host: {settings.host}")
|
|
logger.info(f"Gradio Port: {settings.gradio_port}")
|
|
logger.info(f"API Port: {settings.api_port}")
|
|
logger.info(f"Output Dir: {settings.output_dir}")
|
|
logger.info("=" * 60)
|
|
|
|
# Initialize services
|
|
logger.info("Initializing services...")
|
|
|
|
try:
|
|
services = asyncio.run(initialize_services())
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize services: {e}")
|
|
logger.warning("Starting in demo mode without backend services")
|
|
services = {}
|
|
|
|
# Create and launch app
|
|
logger.info("Creating Gradio application...")
|
|
app = create_app(
|
|
generation_service=services.get("generation_service"),
|
|
batch_processor=services.get("batch_processor"),
|
|
project_service=services.get("project_service"),
|
|
gpu_manager=services.get("gpu_manager"),
|
|
model_registry=services.get("model_registry"),
|
|
)
|
|
|
|
# Start batch processor if available
|
|
batch_processor = services.get("batch_processor")
|
|
if batch_processor:
|
|
logger.info("Starting batch processor...")
|
|
asyncio.run(batch_processor.start())
|
|
|
|
# Launch the app
|
|
logger.info("Launching application...")
|
|
try:
|
|
# Gradio 6.x: show_api replaced with footer_links
|
|
footer = ["api", "gradio"] if settings.api_enabled else ["gradio"]
|
|
launch_kwargs = {
|
|
"server_name": settings.host,
|
|
"server_port": settings.gradio_port,
|
|
"share": False,
|
|
"footer_links": footer,
|
|
}
|
|
# Add root_path for HTTPS reverse proxy support
|
|
if settings.root_path:
|
|
launch_kwargs["root_path"] = settings.root_path
|
|
logger.info(f"Using root_path: {settings.root_path}")
|
|
app.launch(**launch_kwargs)
|
|
except KeyboardInterrupt:
|
|
logger.info("Shutting down...")
|
|
finally:
|
|
# Cleanup
|
|
if batch_processor:
|
|
asyncio.run(batch_processor.stop())
|
|
if "db" in services:
|
|
asyncio.run(services["db"].close())
|
|
|
|
logger.info("Shutdown complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|