- Add initialize() method as alias for connect() - Remove invalid tab_nav_background_fill theme parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
555 lines
17 KiB
Python
555 lines
17 KiB
Python
"""SQLite database for projects, generations, and presets."""
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import aiosqlite
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Project:
|
|
"""Project entity for organizing generations."""
|
|
|
|
id: str
|
|
name: str
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
description: str = ""
|
|
|
|
@classmethod
|
|
def create(cls, name: str, description: str = "") -> "Project":
|
|
"""Create a new project with generated ID."""
|
|
now = datetime.utcnow()
|
|
return cls(
|
|
id=f"proj_{uuid.uuid4().hex[:12]}",
|
|
name=name,
|
|
created_at=now,
|
|
updated_at=now,
|
|
description=description,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Generation:
|
|
"""Audio generation record."""
|
|
|
|
id: str
|
|
project_id: Optional[str]
|
|
model: str
|
|
variant: str
|
|
prompt: str
|
|
parameters: dict[str, Any]
|
|
created_at: datetime
|
|
audio_path: Optional[str] = None
|
|
duration_seconds: Optional[float] = None
|
|
sample_rate: Optional[int] = None
|
|
preset_used: Optional[str] = None
|
|
conditioning: dict[str, Any] = field(default_factory=dict)
|
|
name: Optional[str] = None
|
|
tags: list[str] = field(default_factory=list)
|
|
notes: Optional[str] = None
|
|
seed: Optional[int] = None
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
model: str,
|
|
variant: str,
|
|
prompt: str,
|
|
parameters: dict[str, Any],
|
|
project_id: Optional[str] = None,
|
|
**kwargs,
|
|
) -> "Generation":
|
|
"""Create a new generation record."""
|
|
return cls(
|
|
id=f"gen_{uuid.uuid4().hex[:12]}",
|
|
project_id=project_id,
|
|
model=model,
|
|
variant=variant,
|
|
prompt=prompt,
|
|
parameters=parameters,
|
|
created_at=datetime.utcnow(),
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Preset:
|
|
"""Generation parameter preset."""
|
|
|
|
id: str
|
|
model: str
|
|
name: str
|
|
parameters: dict[str, Any]
|
|
created_at: datetime
|
|
description: str = ""
|
|
is_builtin: bool = False
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
model: str,
|
|
name: str,
|
|
parameters: dict[str, Any],
|
|
description: str = "",
|
|
) -> "Preset":
|
|
"""Create a new custom preset."""
|
|
return cls(
|
|
id=f"preset_{uuid.uuid4().hex[:12]}",
|
|
model=model,
|
|
name=name,
|
|
parameters=parameters,
|
|
created_at=datetime.utcnow(),
|
|
description=description,
|
|
is_builtin=False,
|
|
)
|
|
|
|
|
|
class Database:
|
|
"""Async SQLite database for AudioCraft Studio.
|
|
|
|
Handles storage of projects, generations, and presets.
|
|
"""
|
|
|
|
SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS projects (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
description TEXT DEFAULT '',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS generations (
|
|
id TEXT PRIMARY KEY,
|
|
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
|
|
model TEXT NOT NULL,
|
|
variant TEXT NOT NULL,
|
|
prompt TEXT NOT NULL,
|
|
parameters JSON NOT NULL,
|
|
preset_used TEXT,
|
|
conditioning JSON,
|
|
audio_path TEXT,
|
|
duration_seconds REAL,
|
|
sample_rate INTEGER,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
name TEXT,
|
|
tags JSON,
|
|
notes TEXT,
|
|
seed INTEGER
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS presets (
|
|
id TEXT PRIMARY KEY,
|
|
model TEXT NOT NULL,
|
|
name TEXT NOT NULL,
|
|
description TEXT DEFAULT '',
|
|
parameters JSON NOT NULL,
|
|
is_builtin BOOLEAN DEFAULT FALSE,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
|
|
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
|
|
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
|
|
"""
|
|
|
|
def __init__(self, db_path: Path):
|
|
"""Initialize database.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database file
|
|
"""
|
|
self.db_path = db_path
|
|
self._connection: Optional[aiosqlite.Connection] = None
|
|
|
|
async def connect(self) -> None:
|
|
"""Open database connection and initialize schema."""
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._connection = await aiosqlite.connect(self.db_path)
|
|
self._connection.row_factory = aiosqlite.Row
|
|
|
|
# Initialize schema
|
|
await self._connection.executescript(self.SCHEMA)
|
|
await self._connection.commit()
|
|
|
|
logger.info(f"Database connected: {self.db_path}")
|
|
|
|
async def initialize(self) -> None:
|
|
"""Alias for connect() for compatibility."""
|
|
await self.connect()
|
|
|
|
async def close(self) -> None:
|
|
"""Close database connection."""
|
|
if self._connection:
|
|
await self._connection.close()
|
|
self._connection = None
|
|
|
|
@property
|
|
def conn(self) -> aiosqlite.Connection:
|
|
"""Get active connection."""
|
|
if not self._connection:
|
|
raise RuntimeError("Database not connected")
|
|
return self._connection
|
|
|
|
# Project Methods
|
|
|
|
async def create_project(self, project: Project) -> Project:
|
|
"""Create a new project."""
|
|
await self.conn.execute(
|
|
"""
|
|
INSERT INTO projects (id, name, description, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
project.id,
|
|
project.name,
|
|
project.description,
|
|
project.created_at.isoformat(),
|
|
project.updated_at.isoformat(),
|
|
),
|
|
)
|
|
await self.conn.commit()
|
|
return project
|
|
|
|
async def get_project(self, project_id: str) -> Optional[Project]:
|
|
"""Get a project by ID."""
|
|
async with self.conn.execute(
|
|
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return Project(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
updated_at=datetime.fromisoformat(row["updated_at"]),
|
|
)
|
|
return None
|
|
|
|
async def list_projects(
|
|
self, limit: int = 100, offset: int = 0
|
|
) -> list[Project]:
|
|
"""List all projects, ordered by last update."""
|
|
async with self.conn.execute(
|
|
"""
|
|
SELECT * FROM projects
|
|
ORDER BY updated_at DESC
|
|
LIMIT ? OFFSET ?
|
|
""",
|
|
(limit, offset),
|
|
) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return [
|
|
Project(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
updated_at=datetime.fromisoformat(row["updated_at"]),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
async def update_project(self, project: Project) -> None:
|
|
"""Update a project."""
|
|
project.updated_at = datetime.utcnow()
|
|
await self.conn.execute(
|
|
"""
|
|
UPDATE projects SET name = ?, description = ?, updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(project.name, project.description, project.updated_at.isoformat(), project.id),
|
|
)
|
|
await self.conn.commit()
|
|
|
|
async def delete_project(self, project_id: str) -> bool:
|
|
"""Delete a project (generations are kept but unlinked)."""
|
|
result = await self.conn.execute(
|
|
"DELETE FROM projects WHERE id = ?", (project_id,)
|
|
)
|
|
await self.conn.commit()
|
|
return result.rowcount > 0
|
|
|
|
# Generation Methods
|
|
|
|
async def create_generation(self, generation: Generation) -> Generation:
|
|
"""Create a new generation record."""
|
|
await self.conn.execute(
|
|
"""
|
|
INSERT INTO generations (
|
|
id, project_id, model, variant, prompt, parameters,
|
|
preset_used, conditioning, audio_path, duration_seconds,
|
|
sample_rate, created_at, name, tags, notes, seed
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
generation.id,
|
|
generation.project_id,
|
|
generation.model,
|
|
generation.variant,
|
|
generation.prompt,
|
|
json.dumps(generation.parameters),
|
|
generation.preset_used,
|
|
json.dumps(generation.conditioning),
|
|
generation.audio_path,
|
|
generation.duration_seconds,
|
|
generation.sample_rate,
|
|
generation.created_at.isoformat(),
|
|
generation.name,
|
|
json.dumps(generation.tags),
|
|
generation.notes,
|
|
generation.seed,
|
|
),
|
|
)
|
|
await self.conn.commit()
|
|
|
|
# Update project's updated_at if linked
|
|
if generation.project_id:
|
|
await self.conn.execute(
|
|
"UPDATE projects SET updated_at = ? WHERE id = ?",
|
|
(datetime.utcnow().isoformat(), generation.project_id),
|
|
)
|
|
await self.conn.commit()
|
|
|
|
return generation
|
|
|
|
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
|
"""Get a generation by ID."""
|
|
async with self.conn.execute(
|
|
"SELECT * FROM generations WHERE id = ?", (generation_id,)
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return self._row_to_generation(row)
|
|
return None
|
|
|
|
async def list_generations(
|
|
self,
|
|
project_id: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
search: Optional[str] = None,
|
|
) -> list[Generation]:
|
|
"""List generations with optional filters."""
|
|
conditions = []
|
|
params = []
|
|
|
|
if project_id:
|
|
conditions.append("project_id = ?")
|
|
params.append(project_id)
|
|
|
|
if model:
|
|
conditions.append("model = ?")
|
|
params.append(model)
|
|
|
|
if search:
|
|
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
|
|
search_pattern = f"%{search}%"
|
|
params.extend([search_pattern, search_pattern, search_pattern])
|
|
|
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
|
|
async with self.conn.execute(
|
|
f"""
|
|
SELECT * FROM generations
|
|
WHERE {where_clause}
|
|
ORDER BY created_at DESC
|
|
LIMIT ? OFFSET ?
|
|
""",
|
|
(*params, limit, offset),
|
|
) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_generation(row) for row in rows]
|
|
|
|
async def update_generation(self, generation: Generation) -> None:
|
|
"""Update a generation record."""
|
|
await self.conn.execute(
|
|
"""
|
|
UPDATE generations SET
|
|
project_id = ?, name = ?, tags = ?, notes = ?,
|
|
audio_path = ?, duration_seconds = ?, sample_rate = ?
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
generation.project_id,
|
|
generation.name,
|
|
json.dumps(generation.tags),
|
|
generation.notes,
|
|
generation.audio_path,
|
|
generation.duration_seconds,
|
|
generation.sample_rate,
|
|
generation.id,
|
|
),
|
|
)
|
|
await self.conn.commit()
|
|
|
|
async def delete_generation(self, generation_id: str) -> bool:
|
|
"""Delete a generation record."""
|
|
result = await self.conn.execute(
|
|
"DELETE FROM generations WHERE id = ?", (generation_id,)
|
|
)
|
|
await self.conn.commit()
|
|
return result.rowcount > 0
|
|
|
|
async def count_generations(
|
|
self, project_id: Optional[str] = None, model: Optional[str] = None
|
|
) -> int:
|
|
"""Count generations with optional filters."""
|
|
conditions = []
|
|
params = []
|
|
|
|
if project_id:
|
|
conditions.append("project_id = ?")
|
|
params.append(project_id)
|
|
|
|
if model:
|
|
conditions.append("model = ?")
|
|
params.append(model)
|
|
|
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
|
|
async with self.conn.execute(
|
|
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
|
|
params,
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else 0
|
|
|
|
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
|
|
"""Convert database row to Generation object."""
|
|
return Generation(
|
|
id=row["id"],
|
|
project_id=row["project_id"],
|
|
model=row["model"],
|
|
variant=row["variant"],
|
|
prompt=row["prompt"],
|
|
parameters=json.loads(row["parameters"]),
|
|
preset_used=row["preset_used"],
|
|
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
|
|
audio_path=row["audio_path"],
|
|
duration_seconds=row["duration_seconds"],
|
|
sample_rate=row["sample_rate"],
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
name=row["name"],
|
|
tags=json.loads(row["tags"]) if row["tags"] else [],
|
|
notes=row["notes"],
|
|
seed=row["seed"],
|
|
)
|
|
|
|
# Preset Methods
|
|
|
|
async def create_preset(self, preset: Preset) -> Preset:
|
|
"""Create a new preset."""
|
|
await self.conn.execute(
|
|
"""
|
|
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
preset.id,
|
|
preset.model,
|
|
preset.name,
|
|
preset.description,
|
|
json.dumps(preset.parameters),
|
|
preset.is_builtin,
|
|
preset.created_at.isoformat(),
|
|
),
|
|
)
|
|
await self.conn.commit()
|
|
return preset
|
|
|
|
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
|
"""Get a preset by ID."""
|
|
async with self.conn.execute(
|
|
"SELECT * FROM presets WHERE id = ?", (preset_id,)
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return self._row_to_preset(row)
|
|
return None
|
|
|
|
async def list_presets(
|
|
self, model: Optional[str] = None, include_builtin: bool = True
|
|
) -> list[Preset]:
|
|
"""List presets with optional model filter."""
|
|
conditions = []
|
|
params = []
|
|
|
|
if model:
|
|
conditions.append("model = ?")
|
|
params.append(model)
|
|
|
|
if not include_builtin:
|
|
conditions.append("is_builtin = FALSE")
|
|
|
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
|
|
async with self.conn.execute(
|
|
f"""
|
|
SELECT * FROM presets
|
|
WHERE {where_clause}
|
|
ORDER BY is_builtin DESC, name ASC
|
|
""",
|
|
params,
|
|
) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_preset(row) for row in rows]
|
|
|
|
async def delete_preset(self, preset_id: str) -> bool:
|
|
"""Delete a preset (only custom presets can be deleted)."""
|
|
result = await self.conn.execute(
|
|
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
|
|
(preset_id,),
|
|
)
|
|
await self.conn.commit()
|
|
return result.rowcount > 0
|
|
|
|
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
|
|
"""Convert database row to Preset object."""
|
|
return Preset(
|
|
id=row["id"],
|
|
model=row["model"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
parameters=json.loads(row["parameters"]),
|
|
is_builtin=bool(row["is_builtin"]),
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
)
|
|
|
|
# Utility Methods
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""Get database statistics."""
|
|
stats = {}
|
|
|
|
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
|
|
row = await cursor.fetchone()
|
|
stats["projects"] = row[0] if row else 0
|
|
|
|
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
|
|
row = await cursor.fetchone()
|
|
stats["generations"] = row[0] if row else 0
|
|
|
|
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
|
|
row = await cursor.fetchone()
|
|
stats["presets"] = row[0] if row else 0
|
|
|
|
async with self.conn.execute(
|
|
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
|
|
) as cursor:
|
|
rows = await cursor.fetchall()
|
|
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
|
|
|
|
return stats
|