Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
550
src/storage/database.py
Normal file
550
src/storage/database.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""SQLite database for projects, generations, and presets."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
"""Project entity for organizing generations."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
description: str = ""
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, description: str = "") -> "Project":
|
||||
"""Create a new project with generated ID."""
|
||||
now = datetime.utcnow()
|
||||
return cls(
|
||||
id=f"proj_{uuid.uuid4().hex[:12]}",
|
||||
name=name,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
"""Audio generation record."""
|
||||
|
||||
id: str
|
||||
project_id: Optional[str]
|
||||
model: str
|
||||
variant: str
|
||||
prompt: str
|
||||
parameters: dict[str, Any]
|
||||
created_at: datetime
|
||||
audio_path: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
sample_rate: Optional[int] = None
|
||||
preset_used: Optional[str] = None
|
||||
conditioning: dict[str, Any] = field(default_factory=dict)
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
variant: str,
|
||||
prompt: str,
|
||||
parameters: dict[str, Any],
|
||||
project_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> "Generation":
|
||||
"""Create a new generation record."""
|
||||
return cls(
|
||||
id=f"gen_{uuid.uuid4().hex[:12]}",
|
||||
project_id=project_id,
|
||||
model=model,
|
||||
variant=variant,
|
||||
prompt=prompt,
|
||||
parameters=parameters,
|
||||
created_at=datetime.utcnow(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Preset:
|
||||
"""Generation parameter preset."""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
name: str
|
||||
parameters: dict[str, Any]
|
||||
created_at: datetime
|
||||
description: str = ""
|
||||
is_builtin: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
name: str,
|
||||
parameters: dict[str, Any],
|
||||
description: str = "",
|
||||
) -> "Preset":
|
||||
"""Create a new custom preset."""
|
||||
return cls(
|
||||
id=f"preset_{uuid.uuid4().hex[:12]}",
|
||||
model=model,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
created_at=datetime.utcnow(),
|
||||
description=description,
|
||||
is_builtin=False,
|
||||
)
|
||||
|
||||
|
||||
class Database:
|
||||
"""Async SQLite database for AudioCraft Studio.
|
||||
|
||||
Handles storage of projects, generations, and presets.
|
||||
"""
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS generations (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
|
||||
model TEXT NOT NULL,
|
||||
variant TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
parameters JSON NOT NULL,
|
||||
preset_used TEXT,
|
||||
conditioning JSON,
|
||||
audio_path TEXT,
|
||||
duration_seconds REAL,
|
||||
sample_rate INTEGER,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
name TEXT,
|
||||
tags JSON,
|
||||
notes TEXT,
|
||||
seed INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS presets (
|
||||
id TEXT PRIMARY KEY,
|
||||
model TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
parameters JSON NOT NULL,
|
||||
is_builtin BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
|
||||
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
"""Initialize database.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._connection: Optional[aiosqlite.Connection] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Open database connection and initialize schema."""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._connection = await aiosqlite.connect(self.db_path)
|
||||
self._connection.row_factory = aiosqlite.Row
|
||||
|
||||
# Initialize schema
|
||||
await self._connection.executescript(self.SCHEMA)
|
||||
await self._connection.commit()
|
||||
|
||||
logger.info(f"Database connected: {self.db_path}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def conn(self) -> aiosqlite.Connection:
|
||||
"""Get active connection."""
|
||||
if not self._connection:
|
||||
raise RuntimeError("Database not connected")
|
||||
return self._connection
|
||||
|
||||
# Project Methods
|
||||
|
||||
async def create_project(self, project: Project) -> Project:
|
||||
"""Create a new project."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO projects (id, name, description, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
project.id,
|
||||
project.name,
|
||||
project.description,
|
||||
project.created_at.isoformat(),
|
||||
project.updated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return project
|
||||
|
||||
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||
"""Get a project by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Project(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||
)
|
||||
return None
|
||||
|
||||
async def list_projects(
|
||||
self, limit: int = 100, offset: int = 0
|
||||
) -> list[Project]:
|
||||
"""List all projects, ordered by last update."""
|
||||
async with self.conn.execute(
|
||||
"""
|
||||
SELECT * FROM projects
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Project(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def update_project(self, project: Project) -> None:
|
||||
"""Update a project."""
|
||||
project.updated_at = datetime.utcnow()
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE projects SET name = ?, description = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(project.name, project.description, project.updated_at.isoformat(), project.id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
async def delete_project(self, project_id: str) -> bool:
|
||||
"""Delete a project (generations are kept but unlinked)."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM projects WHERE id = ?", (project_id,)
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
# Generation Methods
|
||||
|
||||
async def create_generation(self, generation: Generation) -> Generation:
|
||||
"""Create a new generation record."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO generations (
|
||||
id, project_id, model, variant, prompt, parameters,
|
||||
preset_used, conditioning, audio_path, duration_seconds,
|
||||
sample_rate, created_at, name, tags, notes, seed
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
generation.id,
|
||||
generation.project_id,
|
||||
generation.model,
|
||||
generation.variant,
|
||||
generation.prompt,
|
||||
json.dumps(generation.parameters),
|
||||
generation.preset_used,
|
||||
json.dumps(generation.conditioning),
|
||||
generation.audio_path,
|
||||
generation.duration_seconds,
|
||||
generation.sample_rate,
|
||||
generation.created_at.isoformat(),
|
||||
generation.name,
|
||||
json.dumps(generation.tags),
|
||||
generation.notes,
|
||||
generation.seed,
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
# Update project's updated_at if linked
|
||||
if generation.project_id:
|
||||
await self.conn.execute(
|
||||
"UPDATE projects SET updated_at = ? WHERE id = ?",
|
||||
(datetime.utcnow().isoformat(), generation.project_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
return generation
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
"""Get a generation by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM generations WHERE id = ?", (generation_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_generation(row)
|
||||
return None
|
||||
|
||||
async def list_generations(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
search: Optional[str] = None,
|
||||
) -> list[Generation]:
|
||||
"""List generations with optional filters."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if project_id:
|
||||
conditions.append("project_id = ?")
|
||||
params.append(project_id)
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
if search:
|
||||
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
|
||||
search_pattern = f"%{search}%"
|
||||
params.extend([search_pattern, search_pattern, search_pattern])
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"""
|
||||
SELECT * FROM generations
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(*params, limit, offset),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_generation(row) for row in rows]
|
||||
|
||||
async def update_generation(self, generation: Generation) -> None:
|
||||
"""Update a generation record."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE generations SET
|
||||
project_id = ?, name = ?, tags = ?, notes = ?,
|
||||
audio_path = ?, duration_seconds = ?, sample_rate = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
generation.project_id,
|
||||
generation.name,
|
||||
json.dumps(generation.tags),
|
||||
generation.notes,
|
||||
generation.audio_path,
|
||||
generation.duration_seconds,
|
||||
generation.sample_rate,
|
||||
generation.id,
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""Delete a generation record."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM generations WHERE id = ?", (generation_id,)
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def count_generations(
|
||||
self, project_id: Optional[str] = None, model: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count generations with optional filters."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if project_id:
|
||||
conditions.append("project_id = ?")
|
||||
params.append(project_id)
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
|
||||
params,
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
|
||||
"""Convert database row to Generation object."""
|
||||
return Generation(
|
||||
id=row["id"],
|
||||
project_id=row["project_id"],
|
||||
model=row["model"],
|
||||
variant=row["variant"],
|
||||
prompt=row["prompt"],
|
||||
parameters=json.loads(row["parameters"]),
|
||||
preset_used=row["preset_used"],
|
||||
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
|
||||
audio_path=row["audio_path"],
|
||||
duration_seconds=row["duration_seconds"],
|
||||
sample_rate=row["sample_rate"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
name=row["name"],
|
||||
tags=json.loads(row["tags"]) if row["tags"] else [],
|
||||
notes=row["notes"],
|
||||
seed=row["seed"],
|
||||
)
|
||||
|
||||
# Preset Methods
|
||||
|
||||
async def create_preset(self, preset: Preset) -> Preset:
|
||||
"""Create a new preset."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
preset.id,
|
||||
preset.model,
|
||||
preset.name,
|
||||
preset.description,
|
||||
json.dumps(preset.parameters),
|
||||
preset.is_builtin,
|
||||
preset.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return preset
|
||||
|
||||
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
||||
"""Get a preset by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM presets WHERE id = ?", (preset_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_preset(row)
|
||||
return None
|
||||
|
||||
async def list_presets(
|
||||
self, model: Optional[str] = None, include_builtin: bool = True
|
||||
) -> list[Preset]:
|
||||
"""List presets with optional model filter."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
if not include_builtin:
|
||||
conditions.append("is_builtin = FALSE")
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"""
|
||||
SELECT * FROM presets
|
||||
WHERE {where_clause}
|
||||
ORDER BY is_builtin DESC, name ASC
|
||||
""",
|
||||
params,
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_preset(row) for row in rows]
|
||||
|
||||
async def delete_preset(self, preset_id: str) -> bool:
|
||||
"""Delete a preset (only custom presets can be deleted)."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
|
||||
(preset_id,),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
|
||||
"""Convert database row to Preset object."""
|
||||
return Preset(
|
||||
id=row["id"],
|
||||
model=row["model"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
parameters=json.loads(row["parameters"]),
|
||||
is_builtin=bool(row["is_builtin"]),
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
)
|
||||
|
||||
# Utility Methods
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get database statistics."""
|
||||
stats = {}
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["projects"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["generations"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["presets"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute(
|
||||
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user