"""SQLite database for projects, generations, and presets.""" import json import logging import uuid from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Optional import aiosqlite logger = logging.getLogger(__name__) @dataclass class Project: """Project entity for organizing generations.""" id: str name: str created_at: datetime updated_at: datetime description: str = "" @classmethod def create(cls, name: str, description: str = "") -> "Project": """Create a new project with generated ID.""" now = datetime.utcnow() return cls( id=f"proj_{uuid.uuid4().hex[:12]}", name=name, created_at=now, updated_at=now, description=description, ) @dataclass class Generation: """Audio generation record.""" id: str project_id: Optional[str] model: str variant: str prompt: str parameters: dict[str, Any] created_at: datetime audio_path: Optional[str] = None duration_seconds: Optional[float] = None sample_rate: Optional[int] = None preset_used: Optional[str] = None conditioning: dict[str, Any] = field(default_factory=dict) name: Optional[str] = None tags: list[str] = field(default_factory=list) notes: Optional[str] = None seed: Optional[int] = None @classmethod def create( cls, model: str, variant: str, prompt: str, parameters: dict[str, Any], project_id: Optional[str] = None, **kwargs, ) -> "Generation": """Create a new generation record.""" return cls( id=f"gen_{uuid.uuid4().hex[:12]}", project_id=project_id, model=model, variant=variant, prompt=prompt, parameters=parameters, created_at=datetime.utcnow(), **kwargs, ) @dataclass class Preset: """Generation parameter preset.""" id: str model: str name: str parameters: dict[str, Any] created_at: datetime description: str = "" is_builtin: bool = False @classmethod def create( cls, model: str, name: str, parameters: dict[str, Any], description: str = "", ) -> "Preset": """Create a new custom preset.""" return cls( id=f"preset_{uuid.uuid4().hex[:12]}", model=model, name=name, parameters=parameters, created_at=datetime.utcnow(), description=description, is_builtin=False, ) class Database: """Async SQLite database for AudioCraft Studio. Handles storage of projects, generations, and presets. """ SCHEMA = """ CREATE TABLE IF NOT EXISTS projects ( id TEXT PRIMARY KEY, name TEXT NOT NULL, description TEXT DEFAULT '', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS generations ( id TEXT PRIMARY KEY, project_id TEXT REFERENCES projects(id) ON DELETE SET NULL, model TEXT NOT NULL, variant TEXT NOT NULL, prompt TEXT NOT NULL, parameters JSON NOT NULL, preset_used TEXT, conditioning JSON, audio_path TEXT, duration_seconds REAL, sample_rate INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, name TEXT, tags JSON, notes TEXT, seed INTEGER ); CREATE TABLE IF NOT EXISTS presets ( id TEXT PRIMARY KEY, model TEXT NOT NULL, name TEXT NOT NULL, description TEXT DEFAULT '', parameters JSON NOT NULL, is_builtin BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id); CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC); CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model); CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model); """ def __init__(self, db_path: Path): """Initialize database. Args: db_path: Path to SQLite database file """ self.db_path = db_path self._connection: Optional[aiosqlite.Connection] = None async def connect(self) -> None: """Open database connection and initialize schema.""" self.db_path.parent.mkdir(parents=True, exist_ok=True) self._connection = await aiosqlite.connect(self.db_path) self._connection.row_factory = aiosqlite.Row # Initialize schema await self._connection.executescript(self.SCHEMA) await self._connection.commit() logger.info(f"Database connected: {self.db_path}") async def initialize(self) -> None: """Alias for connect() for compatibility.""" await self.connect() async def close(self) -> None: """Close database connection.""" if self._connection: await self._connection.close() self._connection = None @property def conn(self) -> aiosqlite.Connection: """Get active connection.""" if not self._connection: raise RuntimeError("Database not connected") return self._connection # Project Methods async def create_project(self, project: Project) -> Project: """Create a new project.""" await self.conn.execute( """ INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?) """, ( project.id, project.name, project.description, project.created_at.isoformat(), project.updated_at.isoformat(), ), ) await self.conn.commit() return project async def get_project(self, project_id: str) -> Optional[Project]: """Get a project by ID.""" async with self.conn.execute( "SELECT * FROM projects WHERE id = ?", (project_id,) ) as cursor: row = await cursor.fetchone() if row: return Project( id=row["id"], name=row["name"], description=row["description"] or "", created_at=datetime.fromisoformat(row["created_at"]), updated_at=datetime.fromisoformat(row["updated_at"]), ) return None async def list_projects( self, limit: int = 100, offset: int = 0 ) -> list[Project]: """List all projects, ordered by last update.""" async with self.conn.execute( """ SELECT * FROM projects ORDER BY updated_at DESC LIMIT ? OFFSET ? """, (limit, offset), ) as cursor: rows = await cursor.fetchall() return [ Project( id=row["id"], name=row["name"], description=row["description"] or "", created_at=datetime.fromisoformat(row["created_at"]), updated_at=datetime.fromisoformat(row["updated_at"]), ) for row in rows ] async def update_project(self, project: Project) -> None: """Update a project.""" project.updated_at = datetime.utcnow() await self.conn.execute( """ UPDATE projects SET name = ?, description = ?, updated_at = ? WHERE id = ? """, (project.name, project.description, project.updated_at.isoformat(), project.id), ) await self.conn.commit() async def delete_project(self, project_id: str) -> bool: """Delete a project (generations are kept but unlinked).""" result = await self.conn.execute( "DELETE FROM projects WHERE id = ?", (project_id,) ) await self.conn.commit() return result.rowcount > 0 # Generation Methods async def create_generation(self, generation: Generation) -> Generation: """Create a new generation record.""" await self.conn.execute( """ INSERT INTO generations ( id, project_id, model, variant, prompt, parameters, preset_used, conditioning, audio_path, duration_seconds, sample_rate, created_at, name, tags, notes, seed ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( generation.id, generation.project_id, generation.model, generation.variant, generation.prompt, json.dumps(generation.parameters), generation.preset_used, json.dumps(generation.conditioning), generation.audio_path, generation.duration_seconds, generation.sample_rate, generation.created_at.isoformat(), generation.name, json.dumps(generation.tags), generation.notes, generation.seed, ), ) await self.conn.commit() # Update project's updated_at if linked if generation.project_id: await self.conn.execute( "UPDATE projects SET updated_at = ? WHERE id = ?", (datetime.utcnow().isoformat(), generation.project_id), ) await self.conn.commit() return generation async def get_generation(self, generation_id: str) -> Optional[Generation]: """Get a generation by ID.""" async with self.conn.execute( "SELECT * FROM generations WHERE id = ?", (generation_id,) ) as cursor: row = await cursor.fetchone() if row: return self._row_to_generation(row) return None async def list_generations( self, project_id: Optional[str] = None, model: Optional[str] = None, limit: int = 100, offset: int = 0, search: Optional[str] = None, ) -> list[Generation]: """List generations with optional filters.""" conditions = [] params = [] if project_id: conditions.append("project_id = ?") params.append(project_id) if model: conditions.append("model = ?") params.append(model) if search: conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)") search_pattern = f"%{search}%" params.extend([search_pattern, search_pattern, search_pattern]) where_clause = " AND ".join(conditions) if conditions else "1=1" async with self.conn.execute( f""" SELECT * FROM generations WHERE {where_clause} ORDER BY created_at DESC LIMIT ? OFFSET ? """, (*params, limit, offset), ) as cursor: rows = await cursor.fetchall() return [self._row_to_generation(row) for row in rows] async def update_generation(self, generation: Generation) -> None: """Update a generation record.""" await self.conn.execute( """ UPDATE generations SET project_id = ?, name = ?, tags = ?, notes = ?, audio_path = ?, duration_seconds = ?, sample_rate = ? WHERE id = ? """, ( generation.project_id, generation.name, json.dumps(generation.tags), generation.notes, generation.audio_path, generation.duration_seconds, generation.sample_rate, generation.id, ), ) await self.conn.commit() async def delete_generation(self, generation_id: str) -> bool: """Delete a generation record.""" result = await self.conn.execute( "DELETE FROM generations WHERE id = ?", (generation_id,) ) await self.conn.commit() return result.rowcount > 0 async def count_generations( self, project_id: Optional[str] = None, model: Optional[str] = None ) -> int: """Count generations with optional filters.""" conditions = [] params = [] if project_id: conditions.append("project_id = ?") params.append(project_id) if model: conditions.append("model = ?") params.append(model) where_clause = " AND ".join(conditions) if conditions else "1=1" async with self.conn.execute( f"SELECT COUNT(*) FROM generations WHERE {where_clause}", params, ) as cursor: row = await cursor.fetchone() return row[0] if row else 0 def _row_to_generation(self, row: aiosqlite.Row) -> Generation: """Convert database row to Generation object.""" return Generation( id=row["id"], project_id=row["project_id"], model=row["model"], variant=row["variant"], prompt=row["prompt"], parameters=json.loads(row["parameters"]), preset_used=row["preset_used"], conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {}, audio_path=row["audio_path"], duration_seconds=row["duration_seconds"], sample_rate=row["sample_rate"], created_at=datetime.fromisoformat(row["created_at"]), name=row["name"], tags=json.loads(row["tags"]) if row["tags"] else [], notes=row["notes"], seed=row["seed"], ) # Preset Methods async def create_preset(self, preset: Preset) -> Preset: """Create a new preset.""" await self.conn.execute( """ INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( preset.id, preset.model, preset.name, preset.description, json.dumps(preset.parameters), preset.is_builtin, preset.created_at.isoformat(), ), ) await self.conn.commit() return preset async def get_preset(self, preset_id: str) -> Optional[Preset]: """Get a preset by ID.""" async with self.conn.execute( "SELECT * FROM presets WHERE id = ?", (preset_id,) ) as cursor: row = await cursor.fetchone() if row: return self._row_to_preset(row) return None async def list_presets( self, model: Optional[str] = None, include_builtin: bool = True ) -> list[Preset]: """List presets with optional model filter.""" conditions = [] params = [] if model: conditions.append("model = ?") params.append(model) if not include_builtin: conditions.append("is_builtin = FALSE") where_clause = " AND ".join(conditions) if conditions else "1=1" async with self.conn.execute( f""" SELECT * FROM presets WHERE {where_clause} ORDER BY is_builtin DESC, name ASC """, params, ) as cursor: rows = await cursor.fetchall() return [self._row_to_preset(row) for row in rows] async def delete_preset(self, preset_id: str) -> bool: """Delete a preset (only custom presets can be deleted).""" result = await self.conn.execute( "DELETE FROM presets WHERE id = ? AND is_builtin = FALSE", (preset_id,), ) await self.conn.commit() return result.rowcount > 0 def _row_to_preset(self, row: aiosqlite.Row) -> Preset: """Convert database row to Preset object.""" return Preset( id=row["id"], model=row["model"], name=row["name"], description=row["description"] or "", parameters=json.loads(row["parameters"]), is_builtin=bool(row["is_builtin"]), created_at=datetime.fromisoformat(row["created_at"]), ) # Utility Methods async def get_stats(self) -> dict[str, Any]: """Get database statistics.""" stats = {} async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor: row = await cursor.fetchone() stats["projects"] = row[0] if row else 0 async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor: row = await cursor.fetchone() stats["generations"] = row[0] if row else 0 async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor: row = await cursor.fetchone() stats["presets"] = row[0] if row else 0 async with self.conn.execute( "SELECT model, COUNT(*) as count FROM generations GROUP BY model" ) as cursor: rows = await cursor.fetchall() stats["generations_by_model"] = {row["model"]: row["count"] for row in rows} return stats