Files
audiocraft-ui/src/storage/database.py
Sebastian Krüger 64a94e7ab7 Fix Database.initialize and remove invalid theme param
- Add initialize() method as alias for connect()
- Remove invalid tab_nav_background_fill theme parameter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 23:24:08 +01:00

555 lines
17 KiB
Python

"""SQLite database for projects, generations, and presets."""
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import aiosqlite
logger = logging.getLogger(__name__)
@dataclass
class Project:
"""Project entity for organizing generations."""
id: str
name: str
created_at: datetime
updated_at: datetime
description: str = ""
@classmethod
def create(cls, name: str, description: str = "") -> "Project":
"""Create a new project with generated ID."""
now = datetime.utcnow()
return cls(
id=f"proj_{uuid.uuid4().hex[:12]}",
name=name,
created_at=now,
updated_at=now,
description=description,
)
@dataclass
class Generation:
"""Audio generation record."""
id: str
project_id: Optional[str]
model: str
variant: str
prompt: str
parameters: dict[str, Any]
created_at: datetime
audio_path: Optional[str] = None
duration_seconds: Optional[float] = None
sample_rate: Optional[int] = None
preset_used: Optional[str] = None
conditioning: dict[str, Any] = field(default_factory=dict)
name: Optional[str] = None
tags: list[str] = field(default_factory=list)
notes: Optional[str] = None
seed: Optional[int] = None
@classmethod
def create(
cls,
model: str,
variant: str,
prompt: str,
parameters: dict[str, Any],
project_id: Optional[str] = None,
**kwargs,
) -> "Generation":
"""Create a new generation record."""
return cls(
id=f"gen_{uuid.uuid4().hex[:12]}",
project_id=project_id,
model=model,
variant=variant,
prompt=prompt,
parameters=parameters,
created_at=datetime.utcnow(),
**kwargs,
)
@dataclass
class Preset:
"""Generation parameter preset."""
id: str
model: str
name: str
parameters: dict[str, Any]
created_at: datetime
description: str = ""
is_builtin: bool = False
@classmethod
def create(
cls,
model: str,
name: str,
parameters: dict[str, Any],
description: str = "",
) -> "Preset":
"""Create a new custom preset."""
return cls(
id=f"preset_{uuid.uuid4().hex[:12]}",
model=model,
name=name,
parameters=parameters,
created_at=datetime.utcnow(),
description=description,
is_builtin=False,
)
class Database:
"""Async SQLite database for AudioCraft Studio.
Handles storage of projects, generations, and presets.
"""
SCHEMA = """
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT DEFAULT '',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS generations (
id TEXT PRIMARY KEY,
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
model TEXT NOT NULL,
variant TEXT NOT NULL,
prompt TEXT NOT NULL,
parameters JSON NOT NULL,
preset_used TEXT,
conditioning JSON,
audio_path TEXT,
duration_seconds REAL,
sample_rate INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
name TEXT,
tags JSON,
notes TEXT,
seed INTEGER
);
CREATE TABLE IF NOT EXISTS presets (
id TEXT PRIMARY KEY,
model TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT DEFAULT '',
parameters JSON NOT NULL,
is_builtin BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
"""
def __init__(self, db_path: Path):
"""Initialize database.
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._connection: Optional[aiosqlite.Connection] = None
async def connect(self) -> None:
"""Open database connection and initialize schema."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._connection = await aiosqlite.connect(self.db_path)
self._connection.row_factory = aiosqlite.Row
# Initialize schema
await self._connection.executescript(self.SCHEMA)
await self._connection.commit()
logger.info(f"Database connected: {self.db_path}")
async def initialize(self) -> None:
"""Alias for connect() for compatibility."""
await self.connect()
async def close(self) -> None:
"""Close database connection."""
if self._connection:
await self._connection.close()
self._connection = None
@property
def conn(self) -> aiosqlite.Connection:
"""Get active connection."""
if not self._connection:
raise RuntimeError("Database not connected")
return self._connection
# Project Methods
async def create_project(self, project: Project) -> Project:
"""Create a new project."""
await self.conn.execute(
"""
INSERT INTO projects (id, name, description, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(
project.id,
project.name,
project.description,
project.created_at.isoformat(),
project.updated_at.isoformat(),
),
)
await self.conn.commit()
return project
async def get_project(self, project_id: str) -> Optional[Project]:
"""Get a project by ID."""
async with self.conn.execute(
"SELECT * FROM projects WHERE id = ?", (project_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return Project(
id=row["id"],
name=row["name"],
description=row["description"] or "",
created_at=datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]),
)
return None
async def list_projects(
self, limit: int = 100, offset: int = 0
) -> list[Project]:
"""List all projects, ordered by last update."""
async with self.conn.execute(
"""
SELECT * FROM projects
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
) as cursor:
rows = await cursor.fetchall()
return [
Project(
id=row["id"],
name=row["name"],
description=row["description"] or "",
created_at=datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]),
)
for row in rows
]
async def update_project(self, project: Project) -> None:
"""Update a project."""
project.updated_at = datetime.utcnow()
await self.conn.execute(
"""
UPDATE projects SET name = ?, description = ?, updated_at = ?
WHERE id = ?
""",
(project.name, project.description, project.updated_at.isoformat(), project.id),
)
await self.conn.commit()
async def delete_project(self, project_id: str) -> bool:
"""Delete a project (generations are kept but unlinked)."""
result = await self.conn.execute(
"DELETE FROM projects WHERE id = ?", (project_id,)
)
await self.conn.commit()
return result.rowcount > 0
# Generation Methods
async def create_generation(self, generation: Generation) -> Generation:
"""Create a new generation record."""
await self.conn.execute(
"""
INSERT INTO generations (
id, project_id, model, variant, prompt, parameters,
preset_used, conditioning, audio_path, duration_seconds,
sample_rate, created_at, name, tags, notes, seed
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
generation.id,
generation.project_id,
generation.model,
generation.variant,
generation.prompt,
json.dumps(generation.parameters),
generation.preset_used,
json.dumps(generation.conditioning),
generation.audio_path,
generation.duration_seconds,
generation.sample_rate,
generation.created_at.isoformat(),
generation.name,
json.dumps(generation.tags),
generation.notes,
generation.seed,
),
)
await self.conn.commit()
# Update project's updated_at if linked
if generation.project_id:
await self.conn.execute(
"UPDATE projects SET updated_at = ? WHERE id = ?",
(datetime.utcnow().isoformat(), generation.project_id),
)
await self.conn.commit()
return generation
async def get_generation(self, generation_id: str) -> Optional[Generation]:
"""Get a generation by ID."""
async with self.conn.execute(
"SELECT * FROM generations WHERE id = ?", (generation_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_generation(row)
return None
async def list_generations(
self,
project_id: Optional[str] = None,
model: Optional[str] = None,
limit: int = 100,
offset: int = 0,
search: Optional[str] = None,
) -> list[Generation]:
"""List generations with optional filters."""
conditions = []
params = []
if project_id:
conditions.append("project_id = ?")
params.append(project_id)
if model:
conditions.append("model = ?")
params.append(model)
if search:
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
search_pattern = f"%{search}%"
params.extend([search_pattern, search_pattern, search_pattern])
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"""
SELECT * FROM generations
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
(*params, limit, offset),
) as cursor:
rows = await cursor.fetchall()
return [self._row_to_generation(row) for row in rows]
async def update_generation(self, generation: Generation) -> None:
"""Update a generation record."""
await self.conn.execute(
"""
UPDATE generations SET
project_id = ?, name = ?, tags = ?, notes = ?,
audio_path = ?, duration_seconds = ?, sample_rate = ?
WHERE id = ?
""",
(
generation.project_id,
generation.name,
json.dumps(generation.tags),
generation.notes,
generation.audio_path,
generation.duration_seconds,
generation.sample_rate,
generation.id,
),
)
await self.conn.commit()
async def delete_generation(self, generation_id: str) -> bool:
"""Delete a generation record."""
result = await self.conn.execute(
"DELETE FROM generations WHERE id = ?", (generation_id,)
)
await self.conn.commit()
return result.rowcount > 0
async def count_generations(
self, project_id: Optional[str] = None, model: Optional[str] = None
) -> int:
"""Count generations with optional filters."""
conditions = []
params = []
if project_id:
conditions.append("project_id = ?")
params.append(project_id)
if model:
conditions.append("model = ?")
params.append(model)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
params,
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
"""Convert database row to Generation object."""
return Generation(
id=row["id"],
project_id=row["project_id"],
model=row["model"],
variant=row["variant"],
prompt=row["prompt"],
parameters=json.loads(row["parameters"]),
preset_used=row["preset_used"],
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
audio_path=row["audio_path"],
duration_seconds=row["duration_seconds"],
sample_rate=row["sample_rate"],
created_at=datetime.fromisoformat(row["created_at"]),
name=row["name"],
tags=json.loads(row["tags"]) if row["tags"] else [],
notes=row["notes"],
seed=row["seed"],
)
# Preset Methods
async def create_preset(self, preset: Preset) -> Preset:
"""Create a new preset."""
await self.conn.execute(
"""
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
preset.id,
preset.model,
preset.name,
preset.description,
json.dumps(preset.parameters),
preset.is_builtin,
preset.created_at.isoformat(),
),
)
await self.conn.commit()
return preset
async def get_preset(self, preset_id: str) -> Optional[Preset]:
"""Get a preset by ID."""
async with self.conn.execute(
"SELECT * FROM presets WHERE id = ?", (preset_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_preset(row)
return None
async def list_presets(
self, model: Optional[str] = None, include_builtin: bool = True
) -> list[Preset]:
"""List presets with optional model filter."""
conditions = []
params = []
if model:
conditions.append("model = ?")
params.append(model)
if not include_builtin:
conditions.append("is_builtin = FALSE")
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"""
SELECT * FROM presets
WHERE {where_clause}
ORDER BY is_builtin DESC, name ASC
""",
params,
) as cursor:
rows = await cursor.fetchall()
return [self._row_to_preset(row) for row in rows]
async def delete_preset(self, preset_id: str) -> bool:
"""Delete a preset (only custom presets can be deleted)."""
result = await self.conn.execute(
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
(preset_id,),
)
await self.conn.commit()
return result.rowcount > 0
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
"""Convert database row to Preset object."""
return Preset(
id=row["id"],
model=row["model"],
name=row["name"],
description=row["description"] or "",
parameters=json.loads(row["parameters"]),
is_builtin=bool(row["is_builtin"]),
created_at=datetime.fromisoformat(row["created_at"]),
)
# Utility Methods
async def get_stats(self) -> dict[str, Any]:
"""Get database statistics."""
stats = {}
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
row = await cursor.fetchone()
stats["projects"] = row[0] if row else 0
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
row = await cursor.fetchone()
stats["generations"] = row[0] if row else 0
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
row = await cursor.fetchone()
stats["presets"] = row[0] if row else 0
async with self.conn.execute(
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
) as cursor:
rows = await cursor.fetchall()
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
return stats