feat: add LLMX configuration with Crawl4AI RAG MCP server

- Add config.toml with MCP servers configuration
- Add compose.yaml for PostgreSQL+pgvector, PostgREST, and Crawl4AI RAG
- Include forked mcp-crawl4ai-rag with BGE 1024-dim embedding support
- Custom schema (crawled_pages_1024.sql) for BGE embeddings

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-25 08:29:43 +01:00
commit 10bcbb2120
23 changed files with 10224 additions and 0 deletions

View File

@@ -0,0 +1,335 @@
"""
AI Hallucination Detector
Main orchestrator for detecting AI coding assistant hallucinations in Python scripts.
Combines AST analysis, knowledge graph validation, and comprehensive reporting.
"""
import asyncio
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Optional, List
from dotenv import load_dotenv
from ai_script_analyzer import AIScriptAnalyzer, analyze_ai_script
from knowledge_graph_validator import KnowledgeGraphValidator
from hallucination_reporter import HallucinationReporter
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class AIHallucinationDetector:
"""Main detector class that orchestrates the entire process"""
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.validator = KnowledgeGraphValidator(neo4j_uri, neo4j_user, neo4j_password)
self.reporter = HallucinationReporter()
self.analyzer = AIScriptAnalyzer()
async def initialize(self):
"""Initialize connections and components"""
await self.validator.initialize()
logger.info("AI Hallucination Detector initialized successfully")
async def close(self):
"""Close connections"""
await self.validator.close()
async def detect_hallucinations(self, script_path: str,
output_dir: Optional[str] = None,
save_json: bool = True,
save_markdown: bool = True,
print_summary: bool = True) -> dict:
"""
Main detection function that analyzes a script and generates reports
Args:
script_path: Path to the AI-generated Python script
output_dir: Directory to save reports (defaults to script directory)
save_json: Whether to save JSON report
save_markdown: Whether to save Markdown report
print_summary: Whether to print summary to console
Returns:
Complete validation report as dictionary
"""
logger.info(f"Starting hallucination detection for: {script_path}")
# Validate input
if not os.path.exists(script_path):
raise FileNotFoundError(f"Script not found: {script_path}")
if not script_path.endswith('.py'):
raise ValueError("Only Python (.py) files are supported")
# Set output directory
if output_dir is None:
output_dir = str(Path(script_path).parent)
os.makedirs(output_dir, exist_ok=True)
try:
# Step 1: Analyze the script using AST
logger.info("Step 1: Analyzing script structure...")
analysis_result = self.analyzer.analyze_script(script_path)
if analysis_result.errors:
logger.warning(f"Analysis warnings: {analysis_result.errors}")
logger.info(f"Found: {len(analysis_result.imports)} imports, "
f"{len(analysis_result.class_instantiations)} class instantiations, "
f"{len(analysis_result.method_calls)} method calls, "
f"{len(analysis_result.function_calls)} function calls, "
f"{len(analysis_result.attribute_accesses)} attribute accesses")
# Step 2: Validate against knowledge graph
logger.info("Step 2: Validating against knowledge graph...")
validation_result = await self.validator.validate_script(analysis_result)
logger.info(f"Validation complete. Overall confidence: {validation_result.overall_confidence:.1%}")
# Step 3: Generate comprehensive report
logger.info("Step 3: Generating reports...")
report = self.reporter.generate_comprehensive_report(validation_result)
# Step 4: Save reports
script_name = Path(script_path).stem
if save_json:
json_path = os.path.join(output_dir, f"{script_name}_hallucination_report.json")
self.reporter.save_json_report(report, json_path)
if save_markdown:
md_path = os.path.join(output_dir, f"{script_name}_hallucination_report.md")
self.reporter.save_markdown_report(report, md_path)
# Step 5: Print summary
if print_summary:
self.reporter.print_summary(report)
logger.info("Hallucination detection completed successfully")
return report
except Exception as e:
logger.error(f"Error during hallucination detection: {str(e)}")
raise
async def batch_detect(self, script_paths: List[str],
output_dir: Optional[str] = None) -> List[dict]:
"""
Detect hallucinations in multiple scripts
Args:
script_paths: List of paths to Python scripts
output_dir: Directory to save all reports
Returns:
List of validation reports
"""
logger.info(f"Starting batch detection for {len(script_paths)} scripts")
results = []
for i, script_path in enumerate(script_paths, 1):
logger.info(f"Processing script {i}/{len(script_paths)}: {script_path}")
try:
result = await self.detect_hallucinations(
script_path=script_path,
output_dir=output_dir,
print_summary=False # Don't print individual summaries in batch mode
)
results.append(result)
except Exception as e:
logger.error(f"Failed to process {script_path}: {str(e)}")
# Continue with other scripts
continue
# Print batch summary
self._print_batch_summary(results)
return results
def _print_batch_summary(self, results: List[dict]):
"""Print summary of batch processing results"""
if not results:
print("No scripts were successfully processed.")
return
print("\n" + "="*80)
print("🚀 BATCH HALLUCINATION DETECTION SUMMARY")
print("="*80)
total_scripts = len(results)
total_validations = sum(r['validation_summary']['total_validations'] for r in results)
total_valid = sum(r['validation_summary']['valid_count'] for r in results)
total_invalid = sum(r['validation_summary']['invalid_count'] for r in results)
total_not_found = sum(r['validation_summary']['not_found_count'] for r in results)
total_hallucinations = sum(len(r['hallucinations_detected']) for r in results)
avg_confidence = sum(r['validation_summary']['overall_confidence'] for r in results) / total_scripts
print(f"Scripts Processed: {total_scripts}")
print(f"Total Validations: {total_validations}")
print(f"Average Confidence: {avg_confidence:.1%}")
print(f"Total Hallucinations: {total_hallucinations}")
print(f"\nAggregated Results:")
print(f" ✅ Valid: {total_valid} ({total_valid/total_validations:.1%})")
print(f" ❌ Invalid: {total_invalid} ({total_invalid/total_validations:.1%})")
print(f" 🔍 Not Found: {total_not_found} ({total_not_found/total_validations:.1%})")
# Show worst performing scripts
print(f"\n🚨 Scripts with Most Hallucinations:")
sorted_results = sorted(results, key=lambda x: len(x['hallucinations_detected']), reverse=True)
for result in sorted_results[:5]:
script_name = Path(result['analysis_metadata']['script_path']).name
hall_count = len(result['hallucinations_detected'])
confidence = result['validation_summary']['overall_confidence']
print(f" - {script_name}: {hall_count} hallucinations ({confidence:.1%} confidence)")
print("="*80)
async def main():
"""Command-line interface for the AI Hallucination Detector"""
parser = argparse.ArgumentParser(
description="Detect AI coding assistant hallucinations in Python scripts",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Analyze single script
python ai_hallucination_detector.py script.py
# Analyze multiple scripts
python ai_hallucination_detector.py script1.py script2.py script3.py
# Specify output directory
python ai_hallucination_detector.py script.py --output-dir reports/
# Skip markdown report
python ai_hallucination_detector.py script.py --no-markdown
"""
)
parser.add_argument(
'scripts',
nargs='+',
help='Python script(s) to analyze for hallucinations'
)
parser.add_argument(
'--output-dir',
help='Directory to save reports (defaults to script directory)'
)
parser.add_argument(
'--no-json',
action='store_true',
help='Skip JSON report generation'
)
parser.add_argument(
'--no-markdown',
action='store_true',
help='Skip Markdown report generation'
)
parser.add_argument(
'--no-summary',
action='store_true',
help='Skip printing summary to console'
)
parser.add_argument(
'--neo4j-uri',
default=None,
help='Neo4j URI (default: from environment NEO4J_URI)'
)
parser.add_argument(
'--neo4j-user',
default=None,
help='Neo4j username (default: from environment NEO4J_USER)'
)
parser.add_argument(
'--neo4j-password',
default=None,
help='Neo4j password (default: from environment NEO4J_PASSWORD)'
)
parser.add_argument(
'--verbose',
action='store_true',
help='Enable verbose logging'
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.INFO)
# Only enable debug for our modules, not neo4j
logging.getLogger('neo4j').setLevel(logging.WARNING)
logging.getLogger('neo4j.pool').setLevel(logging.WARNING)
logging.getLogger('neo4j.io').setLevel(logging.WARNING)
# Load environment variables
load_dotenv()
# Get Neo4j credentials
neo4j_uri = args.neo4j_uri or os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = args.neo4j_user or os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = args.neo4j_password or os.environ.get('NEO4J_PASSWORD', 'password')
if not neo4j_password or neo4j_password == 'password':
logger.error("Please set NEO4J_PASSWORD environment variable or use --neo4j-password")
sys.exit(1)
# Initialize detector
detector = AIHallucinationDetector(neo4j_uri, neo4j_user, neo4j_password)
try:
await detector.initialize()
# Process scripts
if len(args.scripts) == 1:
# Single script mode
await detector.detect_hallucinations(
script_path=args.scripts[0],
output_dir=args.output_dir,
save_json=not args.no_json,
save_markdown=not args.no_markdown,
print_summary=not args.no_summary
)
else:
# Batch mode
await detector.batch_detect(
script_paths=args.scripts,
output_dir=args.output_dir
)
except KeyboardInterrupt:
logger.info("Detection interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"Detection failed: {str(e)}")
sys.exit(1)
finally:
await detector.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,532 @@
"""
AI Script Analyzer
Parses Python scripts generated by AI coding assistants using AST to extract:
- Import statements and their usage
- Class instantiations and method calls
- Function calls with parameters
- Attribute access patterns
- Variable type tracking
"""
import ast
import logging
from pathlib import Path
from typing import Dict, List, Set, Any, Optional, Tuple
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ImportInfo:
"""Information about an import statement"""
module: str
name: str
alias: Optional[str] = None
is_from_import: bool = False
line_number: int = 0
@dataclass
class MethodCall:
"""Information about a method call"""
object_name: str
method_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
object_type: Optional[str] = None # Inferred class type
@dataclass
class AttributeAccess:
"""Information about attribute access"""
object_name: str
attribute_name: str
line_number: int
object_type: Optional[str] = None # Inferred class type
@dataclass
class FunctionCall:
"""Information about a function call"""
function_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
full_name: Optional[str] = None # Module.function_name
@dataclass
class ClassInstantiation:
"""Information about class instantiation"""
variable_name: str
class_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
full_class_name: Optional[str] = None # Module.ClassName
@dataclass
class AnalysisResult:
"""Complete analysis results for a Python script"""
file_path: str
imports: List[ImportInfo] = field(default_factory=list)
class_instantiations: List[ClassInstantiation] = field(default_factory=list)
method_calls: List[MethodCall] = field(default_factory=list)
attribute_accesses: List[AttributeAccess] = field(default_factory=list)
function_calls: List[FunctionCall] = field(default_factory=list)
variable_types: Dict[str, str] = field(default_factory=dict) # variable_name -> class_type
errors: List[str] = field(default_factory=list)
class AIScriptAnalyzer:
"""Analyzes AI-generated Python scripts for validation against knowledge graph"""
def __init__(self):
self.import_map: Dict[str, str] = {} # alias -> actual_module_name
self.variable_types: Dict[str, str] = {} # variable_name -> class_type
self.context_manager_vars: Dict[str, Tuple[int, int, str]] = {} # var_name -> (start_line, end_line, type)
def analyze_script(self, script_path: str) -> AnalysisResult:
"""Analyze a Python script and extract all relevant information"""
try:
with open(script_path, 'r', encoding='utf-8') as f:
content = f.read()
tree = ast.parse(content)
result = AnalysisResult(file_path=script_path)
# Reset state for new analysis
self.import_map.clear()
self.variable_types.clear()
self.context_manager_vars.clear()
# Track processed nodes to avoid duplicates
self.processed_calls = set()
self.method_call_attributes = set()
# First pass: collect imports and build import map
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._extract_imports(node, result)
# Second pass: analyze usage patterns
for node in ast.walk(tree):
self._analyze_node(node, result)
# Set inferred types on method calls and attribute accesses
self._infer_object_types(result)
result.variable_types = self.variable_types.copy()
return result
except Exception as e:
error_msg = f"Failed to analyze script {script_path}: {str(e)}"
logger.error(error_msg)
result = AnalysisResult(file_path=script_path)
result.errors.append(error_msg)
return result
def _extract_imports(self, node: ast.AST, result: AnalysisResult):
"""Extract import information and build import mapping"""
line_num = getattr(node, 'lineno', 0)
if isinstance(node, ast.Import):
for alias in node.names:
import_name = alias.name
alias_name = alias.asname or import_name
result.imports.append(ImportInfo(
module=import_name,
name=import_name,
alias=alias.asname,
is_from_import=False,
line_number=line_num
))
self.import_map[alias_name] = import_name
elif isinstance(node, ast.ImportFrom):
module = node.module or ""
for alias in node.names:
import_name = alias.name
alias_name = alias.asname or import_name
result.imports.append(ImportInfo(
module=module,
name=import_name,
alias=alias.asname,
is_from_import=True,
line_number=line_num
))
# Map alias to full module.name
if module:
full_name = f"{module}.{import_name}"
self.import_map[alias_name] = full_name
else:
self.import_map[alias_name] = import_name
def _analyze_node(self, node: ast.AST, result: AnalysisResult):
"""Analyze individual AST nodes for usage patterns"""
line_num = getattr(node, 'lineno', 0)
# Assignments (class instantiations and method call results)
if isinstance(node, ast.Assign):
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
if isinstance(node.value, ast.Call):
# Check if it's a class instantiation or method call
if isinstance(node.value.func, ast.Name):
# Direct function/class call
self._extract_class_instantiation(node, result)
# Mark this call as processed to avoid duplicate processing
self.processed_calls.add(id(node.value))
elif isinstance(node.value.func, ast.Attribute):
# Method call - track the variable assignment for type inference
var_name = node.targets[0].id
self._track_method_result_assignment(node.value, var_name)
# Still process the method call
self._extract_method_call(node.value, result)
self.processed_calls.add(id(node.value))
# AsyncWith statements (context managers)
elif isinstance(node, ast.AsyncWith):
self._handle_async_with(node, result)
elif isinstance(node, ast.With):
self._handle_with(node, result)
# Method calls and function calls
elif isinstance(node, ast.Call):
# Skip if this call was already processed as part of an assignment
if id(node) in self.processed_calls:
return
if isinstance(node.func, ast.Attribute):
self._extract_method_call(node, result)
# Mark this attribute as used in method call to avoid duplicate processing
self.method_call_attributes.add(id(node.func))
elif isinstance(node.func, ast.Name):
# Check if this is likely a class instantiation (based on imported classes)
func_name = node.func.id
full_name = self._resolve_full_name(func_name)
# If this is a known imported class, treat as class instantiation
if self._is_likely_class_instantiation(func_name, full_name):
self._extract_nested_class_instantiation(node, result)
else:
self._extract_function_call(node, result)
# Attribute access (not in call context)
elif isinstance(node, ast.Attribute):
# Skip if this attribute was already processed as part of a method call
if id(node) in self.method_call_attributes:
return
self._extract_attribute_access(node, result)
def _extract_class_instantiation(self, node: ast.Assign, result: AnalysisResult):
"""Extract class instantiation from assignment"""
target = node.targets[0]
call = node.value
line_num = getattr(node, 'lineno', 0)
if isinstance(target, ast.Name) and isinstance(call, ast.Call):
var_name = target.id
class_name = self._get_name_from_call(call.func)
if class_name:
args = [self._get_arg_representation(arg) for arg in call.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in call.keywords if kw.arg
}
# Resolve full class name using import map
full_class_name = self._resolve_full_name(class_name)
instantiation = ClassInstantiation(
variable_name=var_name,
class_name=class_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_class_name=full_class_name
)
result.class_instantiations.append(instantiation)
# Track variable type for later method call analysis
self.variable_types[var_name] = full_class_name or class_name
def _extract_method_call(self, node: ast.Call, result: AnalysisResult):
"""Extract method call information"""
if isinstance(node.func, ast.Attribute):
line_num = getattr(node, 'lineno', 0)
# Get object and method names
obj_name = self._get_name_from_node(node.func.value)
method_name = node.func.attr
if obj_name and method_name:
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
method_call = MethodCall(
object_name=obj_name,
method_name=method_name,
args=args,
kwargs=kwargs,
line_number=line_num,
object_type=self.variable_types.get(obj_name)
)
result.method_calls.append(method_call)
def _extract_function_call(self, node: ast.Call, result: AnalysisResult):
"""Extract function call information"""
if isinstance(node.func, ast.Name):
line_num = getattr(node, 'lineno', 0)
func_name = node.func.id
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
# Resolve full function name using import map
full_func_name = self._resolve_full_name(func_name)
function_call = FunctionCall(
function_name=func_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_name=full_func_name
)
result.function_calls.append(function_call)
def _extract_attribute_access(self, node: ast.Attribute, result: AnalysisResult):
"""Extract attribute access information"""
line_num = getattr(node, 'lineno', 0)
obj_name = self._get_name_from_node(node.value)
attr_name = node.attr
if obj_name and attr_name:
attribute_access = AttributeAccess(
object_name=obj_name,
attribute_name=attr_name,
line_number=line_num,
object_type=self.variable_types.get(obj_name)
)
result.attribute_accesses.append(attribute_access)
def _infer_object_types(self, result: AnalysisResult):
"""Update object types for method calls and attribute accesses"""
for method_call in result.method_calls:
if not method_call.object_type:
# First check context manager variables
obj_type = self._get_context_aware_type(method_call.object_name, method_call.line_number)
if obj_type:
method_call.object_type = obj_type
else:
method_call.object_type = self.variable_types.get(method_call.object_name)
for attr_access in result.attribute_accesses:
if not attr_access.object_type:
# First check context manager variables
obj_type = self._get_context_aware_type(attr_access.object_name, attr_access.line_number)
if obj_type:
attr_access.object_type = obj_type
else:
attr_access.object_type = self.variable_types.get(attr_access.object_name)
def _get_context_aware_type(self, var_name: str, line_number: int) -> Optional[str]:
"""Get the type of a variable considering its context (e.g., async with scope)"""
if var_name in self.context_manager_vars:
start_line, end_line, var_type = self.context_manager_vars[var_name]
if start_line <= line_number <= end_line:
return var_type
return None
def _get_name_from_call(self, node: ast.AST) -> Optional[str]:
"""Get the name from a call node (for class instantiation)"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
value_name = self._get_name_from_node(node.value)
if value_name:
return f"{value_name}.{node.attr}"
return None
def _get_name_from_node(self, node: ast.AST) -> Optional[str]:
"""Get string representation of a node (for object names)"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
value_name = self._get_name_from_node(node.value)
if value_name:
return f"{value_name}.{node.attr}"
return None
def _get_arg_representation(self, node: ast.AST) -> str:
"""Get string representation of an argument"""
if isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return self._get_name_from_node(node) or "<?>"
elif isinstance(node, ast.Call):
func_name = self._get_name_from_call(node.func)
return f"{func_name}(...)" if func_name else "call(...)"
else:
return f"<{type(node).__name__}>"
def _is_likely_class_instantiation(self, func_name: str, full_name: Optional[str]) -> bool:
"""Determine if a function call is likely a class instantiation"""
# Check if it's a known imported class (classes typically start with uppercase)
if func_name and func_name[0].isupper():
return True
# Check if the full name suggests a class (contains known class patterns)
if full_name:
# Common class patterns in module names
class_patterns = [
'Model', 'Provider', 'Client', 'Agent', 'Manager', 'Handler',
'Builder', 'Factory', 'Service', 'Controller', 'Processor'
]
return any(pattern in full_name for pattern in class_patterns)
return False
def _extract_nested_class_instantiation(self, node: ast.Call, result: AnalysisResult):
"""Extract class instantiation that's not in direct assignment (e.g., as parameter)"""
line_num = getattr(node, 'lineno', 0)
if isinstance(node.func, ast.Name):
class_name = node.func.id
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
# Resolve full class name using import map
full_class_name = self._resolve_full_name(class_name)
# Use a synthetic variable name since this isn't assigned to a variable
var_name = f"<{class_name.lower()}_instance>"
instantiation = ClassInstantiation(
variable_name=var_name,
class_name=class_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_class_name=full_class_name
)
result.class_instantiations.append(instantiation)
def _track_method_result_assignment(self, call_node: ast.Call, var_name: str):
"""Track when a variable is assigned the result of a method call"""
if isinstance(call_node.func, ast.Attribute):
# For now, we'll use a generic type hint for method results
# In a more sophisticated system, we could look up the return type
self.variable_types[var_name] = "method_result"
def _handle_async_with(self, node: ast.AsyncWith, result: AnalysisResult):
"""Handle async with statements and track context manager variables"""
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
var_name = item.optional_vars.id
# If the context manager is a method call, track the result type
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
# Extract and process the method call
self._extract_method_call(item.context_expr, result)
self.processed_calls.add(id(item.context_expr))
# Track context manager scope for pydantic_ai run_stream calls
obj_name = self._get_name_from_node(item.context_expr.func.value)
method_name = item.context_expr.func.attr
if (obj_name and obj_name in self.variable_types and
'pydantic_ai' in str(self.variable_types[obj_name]) and
method_name == 'run_stream'):
# Calculate the scope of this async with block
start_line = getattr(node, 'lineno', 0)
end_line = getattr(node, 'end_lineno', start_line + 50) # fallback estimate
# For run_stream, the return type is specifically StreamedRunResult
# This is the actual return type, not a generic placeholder
self.context_manager_vars[var_name] = (start_line, end_line, "pydantic_ai.StreamedRunResult")
def _handle_with(self, node: ast.With, result: AnalysisResult):
"""Handle regular with statements and track context manager variables"""
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
var_name = item.optional_vars.id
# If the context manager is a method call, track the result type
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
# Extract and process the method call
self._extract_method_call(item.context_expr, result)
self.processed_calls.add(id(item.context_expr))
# Track basic type information
self.variable_types[var_name] = "context_manager_result"
def _resolve_full_name(self, name: str) -> Optional[str]:
"""Resolve a name to its full module.name using import map"""
# Check if it's a direct import mapping
if name in self.import_map:
return self.import_map[name]
# Check if it's a dotted name with first part in import map
parts = name.split('.')
if len(parts) > 1 and parts[0] in self.import_map:
base_module = self.import_map[parts[0]]
return f"{base_module}.{'.'.join(parts[1:])}"
return None
def analyze_ai_script(script_path: str) -> AnalysisResult:
"""Convenience function to analyze a single AI-generated script"""
analyzer = AIScriptAnalyzer()
return analyzer.analyze_script(script_path)
if __name__ == "__main__":
# Example usage
import sys
if len(sys.argv) != 2:
print("Usage: python ai_script_analyzer.py <script_path>")
sys.exit(1)
script_path = sys.argv[1]
result = analyze_ai_script(script_path)
print(f"Analysis Results for: {result.file_path}")
print(f"Imports: {len(result.imports)}")
print(f"Class Instantiations: {len(result.class_instantiations)}")
print(f"Method Calls: {len(result.method_calls)}")
print(f"Function Calls: {len(result.function_calls)}")
print(f"Attribute Accesses: {len(result.attribute_accesses)}")
if result.errors:
print(f"Errors: {result.errors}")

View File

@@ -0,0 +1,523 @@
"""
Hallucination Reporter
Generates comprehensive reports about AI coding assistant hallucinations
detected in Python scripts. Supports multiple output formats.
"""
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Any, Optional
from knowledge_graph_validator import (
ScriptValidationResult, ValidationStatus, ValidationResult
)
logger = logging.getLogger(__name__)
class HallucinationReporter:
"""Generates reports about detected hallucinations"""
def __init__(self):
self.report_timestamp = datetime.now(timezone.utc)
def generate_comprehensive_report(self, validation_result: ScriptValidationResult) -> Dict[str, Any]:
"""Generate a comprehensive report in JSON format"""
# Categorize validations by status (knowledge graph items only)
valid_items = []
invalid_items = []
uncertain_items = []
not_found_items = []
# Process imports (only knowledge graph ones)
for val in validation_result.import_validations:
if not val.validation.details.get('in_knowledge_graph', False):
continue # Skip external libraries
item = {
'type': 'IMPORT',
'name': val.import_info.module,
'line': val.import_info.line_number,
'status': val.validation.status.value,
'confidence': val.validation.confidence,
'message': val.validation.message,
'details': {
'is_from_import': val.import_info.is_from_import,
'alias': val.import_info.alias,
'available_classes': val.available_classes,
'available_functions': val.available_functions
}
}
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
# Process classes (only knowledge graph ones)
for val in validation_result.class_validations:
class_name = val.class_instantiation.full_class_name or val.class_instantiation.class_name
if not self._is_from_knowledge_graph(class_name, validation_result):
continue # Skip external classes
item = {
'type': 'CLASS_INSTANTIATION',
'name': val.class_instantiation.class_name,
'full_name': val.class_instantiation.full_class_name,
'variable': val.class_instantiation.variable_name,
'line': val.class_instantiation.line_number,
'status': val.validation.status.value,
'confidence': val.validation.confidence,
'message': val.validation.message,
'details': {
'args_provided': val.class_instantiation.args,
'kwargs_provided': list(val.class_instantiation.kwargs.keys()),
'constructor_params': val.constructor_params,
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None
}
}
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
# Track reported items to avoid duplicates
reported_items = set()
# Process methods (only knowledge graph ones)
for val in validation_result.method_validations:
if not (val.method_call.object_type and self._is_from_knowledge_graph(val.method_call.object_type, validation_result)):
continue # Skip external methods
# Create unique key to avoid duplicates
key = (val.method_call.line_number, val.method_call.method_name, val.method_call.object_type)
if key not in reported_items:
reported_items.add(key)
item = {
'type': 'METHOD_CALL',
'name': val.method_call.method_name,
'object': val.method_call.object_name,
'object_type': val.method_call.object_type,
'line': val.method_call.line_number,
'status': val.validation.status.value,
'confidence': val.validation.confidence,
'message': val.validation.message,
'details': {
'args_provided': val.method_call.args,
'kwargs_provided': list(val.method_call.kwargs.keys()),
'expected_params': val.expected_params,
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None,
'suggestions': val.validation.suggestions
}
}
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
# Process attributes (only knowledge graph ones) - but skip if already reported as method
for val in validation_result.attribute_validations:
if not (val.attribute_access.object_type and self._is_from_knowledge_graph(val.attribute_access.object_type, validation_result)):
continue # Skip external attributes
# Create unique key - if this was already reported as a method, skip it
key = (val.attribute_access.line_number, val.attribute_access.attribute_name, val.attribute_access.object_type)
if key not in reported_items:
reported_items.add(key)
item = {
'type': 'ATTRIBUTE_ACCESS',
'name': val.attribute_access.attribute_name,
'object': val.attribute_access.object_name,
'object_type': val.attribute_access.object_type,
'line': val.attribute_access.line_number,
'status': val.validation.status.value,
'confidence': val.validation.confidence,
'message': val.validation.message,
'details': {
'expected_type': val.expected_type
}
}
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
# Process functions (only knowledge graph ones)
for val in validation_result.function_validations:
if not (val.function_call.full_name and self._is_from_knowledge_graph(val.function_call.full_name, validation_result)):
continue # Skip external functions
item = {
'type': 'FUNCTION_CALL',
'name': val.function_call.function_name,
'full_name': val.function_call.full_name,
'line': val.function_call.line_number,
'status': val.validation.status.value,
'confidence': val.validation.confidence,
'message': val.validation.message,
'details': {
'args_provided': val.function_call.args,
'kwargs_provided': list(val.function_call.kwargs.keys()),
'expected_params': val.expected_params,
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None
}
}
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
# Create library summary
library_summary = self._create_library_summary(validation_result)
# Generate report
report = {
'analysis_metadata': {
'script_path': validation_result.script_path,
'analysis_timestamp': self.report_timestamp.isoformat(),
'total_imports': len(validation_result.import_validations),
'total_classes': len(validation_result.class_validations),
'total_methods': len(validation_result.method_validations),
'total_attributes': len(validation_result.attribute_validations),
'total_functions': len(validation_result.function_validations)
},
'validation_summary': {
'overall_confidence': validation_result.overall_confidence,
'total_validations': len(valid_items) + len(invalid_items) + len(uncertain_items) + len(not_found_items),
'valid_count': len(valid_items),
'invalid_count': len(invalid_items),
'uncertain_count': len(uncertain_items),
'not_found_count': len(not_found_items),
'hallucination_rate': len(invalid_items + not_found_items) / max(1, len(valid_items) + len(invalid_items) + len(not_found_items))
},
'libraries_analyzed': library_summary,
'validation_details': {
'valid_items': valid_items,
'invalid_items': invalid_items,
'uncertain_items': uncertain_items,
'not_found_items': not_found_items
},
'hallucinations_detected': validation_result.hallucinations_detected,
'recommendations': self._generate_recommendations(validation_result)
}
return report
def _is_from_knowledge_graph(self, item_name: str, validation_result) -> bool:
"""Check if an item is from a knowledge graph module"""
if not item_name:
return False
# Get knowledge graph modules from import validations
kg_modules = set()
for val in validation_result.import_validations:
if val.validation.details.get('in_knowledge_graph', False):
kg_modules.add(val.import_info.module)
if '.' in val.import_info.module:
kg_modules.add(val.import_info.module.split('.')[0])
# Check if the item belongs to any knowledge graph module
if '.' in item_name:
base_module = item_name.split('.')[0]
return base_module in kg_modules
return any(item_name in module or module.endswith(item_name) for module in kg_modules)
def _serialize_validation_result(self, validation_result) -> Dict[str, Any]:
"""Convert ValidationResult to JSON-serializable dictionary"""
if validation_result is None:
return None
return {
'status': validation_result.status.value,
'confidence': validation_result.confidence,
'message': validation_result.message,
'details': validation_result.details,
'suggestions': validation_result.suggestions
}
def _categorize_item(self, item: Dict[str, Any], status: ValidationStatus,
valid_items: List, invalid_items: List, uncertain_items: List, not_found_items: List):
"""Categorize validation item by status"""
if status == ValidationStatus.VALID:
valid_items.append(item)
elif status == ValidationStatus.INVALID:
invalid_items.append(item)
elif status == ValidationStatus.UNCERTAIN:
uncertain_items.append(item)
elif status == ValidationStatus.NOT_FOUND:
not_found_items.append(item)
def _create_library_summary(self, validation_result: ScriptValidationResult) -> List[Dict[str, Any]]:
"""Create summary of libraries analyzed"""
library_stats = {}
# Aggregate stats by library/module
for val in validation_result.import_validations:
module = val.import_info.module
if module not in library_stats:
library_stats[module] = {
'module_name': module,
'import_status': val.validation.status.value,
'import_confidence': val.validation.confidence,
'classes_used': [],
'methods_called': [],
'attributes_accessed': [],
'functions_called': []
}
# Add class usage
for val in validation_result.class_validations:
class_name = val.class_instantiation.class_name
full_name = val.class_instantiation.full_class_name
# Try to match to library
if full_name:
parts = full_name.split('.')
if len(parts) > 1:
module = '.'.join(parts[:-1])
if module in library_stats:
library_stats[module]['classes_used'].append({
'class_name': class_name,
'status': val.validation.status.value,
'confidence': val.validation.confidence
})
# Add method usage
for val in validation_result.method_validations:
method_name = val.method_call.method_name
object_type = val.method_call.object_type
if object_type:
parts = object_type.split('.')
if len(parts) > 1:
module = '.'.join(parts[:-1])
if module in library_stats:
library_stats[module]['methods_called'].append({
'method_name': method_name,
'class_name': parts[-1],
'status': val.validation.status.value,
'confidence': val.validation.confidence
})
# Add attribute usage
for val in validation_result.attribute_validations:
attr_name = val.attribute_access.attribute_name
object_type = val.attribute_access.object_type
if object_type:
parts = object_type.split('.')
if len(parts) > 1:
module = '.'.join(parts[:-1])
if module in library_stats:
library_stats[module]['attributes_accessed'].append({
'attribute_name': attr_name,
'class_name': parts[-1],
'status': val.validation.status.value,
'confidence': val.validation.confidence
})
# Add function usage
for val in validation_result.function_validations:
func_name = val.function_call.function_name
full_name = val.function_call.full_name
if full_name:
parts = full_name.split('.')
if len(parts) > 1:
module = '.'.join(parts[:-1])
if module in library_stats:
library_stats[module]['functions_called'].append({
'function_name': func_name,
'status': val.validation.status.value,
'confidence': val.validation.confidence
})
return list(library_stats.values())
def _generate_recommendations(self, validation_result: ScriptValidationResult) -> List[str]:
"""Generate recommendations based on validation results"""
recommendations = []
# Only count actual hallucinations (from knowledge graph libraries)
kg_hallucinations = [h for h in validation_result.hallucinations_detected]
if kg_hallucinations:
method_issues = [h for h in kg_hallucinations if h['type'] == 'METHOD_NOT_FOUND']
attr_issues = [h for h in kg_hallucinations if h['type'] == 'ATTRIBUTE_NOT_FOUND']
param_issues = [h for h in kg_hallucinations if h['type'] == 'INVALID_PARAMETERS']
if method_issues:
recommendations.append(
f"Found {len(method_issues)} non-existent methods in knowledge graph libraries. "
"Consider checking the official documentation for correct method names."
)
if attr_issues:
recommendations.append(
f"Found {len(attr_issues)} non-existent attributes in knowledge graph libraries. "
"Verify attribute names against the class documentation."
)
if param_issues:
recommendations.append(
f"Found {len(param_issues)} parameter mismatches in knowledge graph libraries. "
"Check function signatures for correct parameter names and types."
)
else:
recommendations.append(
"No hallucinations detected in knowledge graph libraries. "
"External library usage appears to be working as expected."
)
if validation_result.overall_confidence < 0.7:
recommendations.append(
"Overall confidence is moderate. Most validations were for external libraries not in the knowledge graph."
)
return recommendations
def save_json_report(self, report: Dict[str, Any], output_path: str):
"""Save report as JSON file"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
logger.info(f"JSON report saved to: {output_path}")
def save_markdown_report(self, report: Dict[str, Any], output_path: str):
"""Save report as Markdown file"""
md_content = self._generate_markdown_content(report)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(md_content)
logger.info(f"Markdown report saved to: {output_path}")
def _generate_markdown_content(self, report: Dict[str, Any]) -> str:
"""Generate Markdown content from report"""
md = []
# Header
md.append("# AI Hallucination Detection Report")
md.append("")
md.append(f"**Script:** `{report['analysis_metadata']['script_path']}`")
md.append(f"**Analysis Date:** {report['analysis_metadata']['analysis_timestamp']}")
md.append(f"**Overall Confidence:** {report['validation_summary']['overall_confidence']:.2%}")
md.append("")
# Summary
summary = report['validation_summary']
md.append("## Summary")
md.append("")
md.append(f"- **Total Validations:** {summary['total_validations']}")
md.append(f"- **Valid:** {summary['valid_count']} ({summary['valid_count']/summary['total_validations']:.1%})")
md.append(f"- **Invalid:** {summary['invalid_count']} ({summary['invalid_count']/summary['total_validations']:.1%})")
md.append(f"- **Not Found:** {summary['not_found_count']} ({summary['not_found_count']/summary['total_validations']:.1%})")
md.append(f"- **Uncertain:** {summary['uncertain_count']} ({summary['uncertain_count']/summary['total_validations']:.1%})")
md.append(f"- **Hallucination Rate:** {summary['hallucination_rate']:.1%}")
md.append("")
# Hallucinations
if report['hallucinations_detected']:
md.append("## 🚨 Hallucinations Detected")
md.append("")
for i, hallucination in enumerate(report['hallucinations_detected'], 1):
md.append(f"### {i}. {hallucination['type'].replace('_', ' ').title()}")
md.append(f"**Location:** {hallucination['location']}")
md.append(f"**Description:** {hallucination['description']}")
if hallucination.get('suggestion'):
md.append(f"**Suggestion:** {hallucination['suggestion']}")
md.append("")
# Libraries
if report['libraries_analyzed']:
md.append("## 📚 Libraries Analyzed")
md.append("")
for lib in report['libraries_analyzed']:
md.append(f"### {lib['module_name']}")
md.append(f"**Import Status:** {lib['import_status']}")
md.append(f"**Import Confidence:** {lib['import_confidence']:.2%}")
if lib['classes_used']:
md.append("**Classes Used:**")
for cls in lib['classes_used']:
status_emoji = "" if cls['status'] == 'VALID' else ""
md.append(f" - {status_emoji} `{cls['class_name']}` ({cls['confidence']:.1%})")
if lib['methods_called']:
md.append("**Methods Called:**")
for method in lib['methods_called']:
status_emoji = "" if method['status'] == 'VALID' else ""
md.append(f" - {status_emoji} `{method['class_name']}.{method['method_name']}()` ({method['confidence']:.1%})")
if lib['attributes_accessed']:
md.append("**Attributes Accessed:**")
for attr in lib['attributes_accessed']:
status_emoji = "" if attr['status'] == 'VALID' else ""
md.append(f" - {status_emoji} `{attr['class_name']}.{attr['attribute_name']}` ({attr['confidence']:.1%})")
if lib['functions_called']:
md.append("**Functions Called:**")
for func in lib['functions_called']:
status_emoji = "" if func['status'] == 'VALID' else ""
md.append(f" - {status_emoji} `{func['function_name']}()` ({func['confidence']:.1%})")
md.append("")
# Recommendations
if report['recommendations']:
md.append("## 💡 Recommendations")
md.append("")
for rec in report['recommendations']:
md.append(f"- {rec}")
md.append("")
# Detailed Results
md.append("## 📋 Detailed Validation Results")
md.append("")
# Invalid items
invalid_items = report['validation_details']['invalid_items']
if invalid_items:
md.append("### ❌ Invalid Items")
md.append("")
for item in invalid_items:
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
md.append("")
# Not found items
not_found_items = report['validation_details']['not_found_items']
if not_found_items:
md.append("### 🔍 Not Found Items")
md.append("")
for item in not_found_items:
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
md.append("")
# Valid items (sample)
valid_items = report['validation_details']['valid_items']
if valid_items:
md.append("### ✅ Valid Items (Sample)")
md.append("")
for item in valid_items[:10]: # Show first 10
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
if len(valid_items) > 10:
md.append(f"- ... and {len(valid_items) - 10} more valid items")
md.append("")
return "\n".join(md)
def print_summary(self, report: Dict[str, Any]):
"""Print a concise summary to console"""
print("\n" + "="*80)
print("🤖 AI HALLUCINATION DETECTION REPORT")
print("="*80)
print(f"Script: {report['analysis_metadata']['script_path']}")
print(f"Overall Confidence: {report['validation_summary']['overall_confidence']:.1%}")
summary = report['validation_summary']
print(f"\nValidation Results:")
print(f" ✅ Valid: {summary['valid_count']}")
print(f" ❌ Invalid: {summary['invalid_count']}")
print(f" 🔍 Not Found: {summary['not_found_count']}")
print(f" ❓ Uncertain: {summary['uncertain_count']}")
print(f" 📊 Hallucination Rate: {summary['hallucination_rate']:.1%}")
if report['hallucinations_detected']:
print(f"\n🚨 {len(report['hallucinations_detected'])} Hallucinations Detected:")
for hall in report['hallucinations_detected'][:5]: # Show first 5
print(f" - {hall['type'].replace('_', ' ').title()} at {hall['location']}")
print(f" {hall['description']}")
if report['recommendations']:
print(f"\n💡 Recommendations:")
for rec in report['recommendations'][:3]: # Show first 3
print(f" - {rec}")
print("="*80)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,858 @@
"""
Direct Neo4j GitHub Code Repository Extractor
Creates nodes and relationships directly in Neo4j without Graphiti:
- File nodes
- Class nodes
- Method nodes
- Function nodes
- Import relationships
Bypasses all LLM processing for maximum speed.
"""
import asyncio
import logging
import os
import subprocess
import shutil
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional, Dict, Any, Set
import ast
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
class Neo4jCodeAnalyzer:
"""Analyzes code for direct Neo4j insertion"""
def __init__(self):
# External modules to ignore
self.external_modules = {
# Python standard library
'os', 'sys', 'json', 'logging', 'datetime', 'pathlib', 'typing', 'collections',
'asyncio', 'subprocess', 'ast', 're', 'string', 'urllib', 'http', 'email',
'time', 'uuid', 'hashlib', 'base64', 'itertools', 'functools', 'operator',
'contextlib', 'copy', 'pickle', 'tempfile', 'shutil', 'glob', 'fnmatch',
'io', 'codecs', 'locale', 'platform', 'socket', 'ssl', 'threading', 'queue',
'multiprocessing', 'concurrent', 'warnings', 'traceback', 'inspect',
'importlib', 'pkgutil', 'types', 'weakref', 'gc', 'dataclasses', 'enum',
'abc', 'numbers', 'decimal', 'fractions', 'math', 'cmath', 'random', 'statistics',
# Common third-party libraries
'requests', 'urllib3', 'httpx', 'aiohttp', 'flask', 'django', 'fastapi',
'pydantic', 'sqlalchemy', 'alembic', 'psycopg2', 'pymongo', 'redis',
'celery', 'pytest', 'unittest', 'mock', 'faker', 'factory', 'hypothesis',
'numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'torch',
'tensorflow', 'keras', 'opencv', 'pillow', 'boto3', 'botocore', 'azure',
'google', 'openai', 'anthropic', 'langchain', 'transformers', 'huggingface_hub',
'click', 'typer', 'rich', 'colorama', 'tqdm', 'python-dotenv', 'pyyaml',
'toml', 'configargparse', 'marshmallow', 'attrs', 'dataclasses-json',
'jsonschema', 'cerberus', 'voluptuous', 'schema', 'jinja2', 'mako',
'cryptography', 'bcrypt', 'passlib', 'jwt', 'authlib', 'oauthlib'
}
def analyze_python_file(self, file_path: Path, repo_root: Path, project_modules: Set[str]) -> Dict[str, Any]:
"""Extract structure for direct Neo4j insertion"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tree = ast.parse(content)
relative_path = str(file_path.relative_to(repo_root))
module_name = self._get_importable_module_name(file_path, repo_root, relative_path)
# Extract structure
classes = []
functions = []
imports = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Extract class with its methods and attributes
methods = []
attributes = []
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if not item.name.startswith('_'): # Public methods only
# Extract comprehensive parameter info
params = self._extract_function_parameters(item)
# Get return type annotation
return_type = self._get_name(item.returns) if item.returns else 'Any'
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p['optional'] and p['default'] is not None:
param_str += f"={p['default']}"
elif p['optional']:
param_str += "=None"
if p['kind'] != 'positional':
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
methods.append({
'name': item.name,
'params': params, # Full parameter objects
'params_detailed': params_detailed, # Detailed string format
'return_type': return_type,
'args': [arg.arg for arg in item.args.args if arg.arg != 'self'] # Keep for backwards compatibility
})
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
# Type annotated attributes
if not item.target.id.startswith('_'):
attributes.append({
'name': item.target.id,
'type': self._get_name(item.annotation) if item.annotation else 'Any'
})
classes.append({
'name': node.name,
'full_name': f"{module_name}.{node.name}",
'methods': methods,
'attributes': attributes
})
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Only top-level functions
if not any(node in cls_node.body for cls_node in ast.walk(tree) if isinstance(cls_node, ast.ClassDef)):
if not node.name.startswith('_'):
# Extract comprehensive parameter info
params = self._extract_function_parameters(node)
# Get return type annotation
return_type = self._get_name(node.returns) if node.returns else 'Any'
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p['optional'] and p['default'] is not None:
param_str += f"={p['default']}"
elif p['optional']:
param_str += "=None"
if p['kind'] != 'positional':
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
# Simple format for backwards compatibility
params_list = [f"{p['name']}:{p['type']}" for p in params]
functions.append({
'name': node.name,
'full_name': f"{module_name}.{node.name}",
'params': params, # Full parameter objects
'params_detailed': params_detailed, # Detailed string format
'params_list': params_list, # Simple string format for backwards compatibility
'return_type': return_type,
'args': [arg.arg for arg in node.args.args] # Keep for backwards compatibility
})
elif isinstance(node, (ast.Import, ast.ImportFrom)):
# Track internal imports only
if isinstance(node, ast.Import):
for alias in node.names:
if self._is_likely_internal(alias.name, project_modules):
imports.append(alias.name)
elif isinstance(node, ast.ImportFrom) and node.module:
if (node.module.startswith('.') or self._is_likely_internal(node.module, project_modules)):
imports.append(node.module)
return {
'module_name': module_name,
'file_path': relative_path,
'classes': classes,
'functions': functions,
'imports': list(set(imports)), # Remove duplicates
'line_count': len(content.splitlines())
}
except Exception as e:
logger.warning(f"Could not analyze {file_path}: {e}")
return None
def _is_likely_internal(self, import_name: str, project_modules: Set[str]) -> bool:
"""Check if an import is likely internal to the project"""
if not import_name:
return False
# Relative imports are definitely internal
if import_name.startswith('.'):
return True
# Check if it's a known external module
base_module = import_name.split('.')[0]
if base_module in self.external_modules:
return False
# Check if it matches any project module
for project_module in project_modules:
if import_name.startswith(project_module):
return True
# If it's not obviously external, consider it internal
if (not any(ext in base_module.lower() for ext in ['test', 'mock', 'fake']) and
not base_module.startswith('_') and
len(base_module) > 2):
return True
return False
def _get_importable_module_name(self, file_path: Path, repo_root: Path, relative_path: str) -> str:
"""Determine the actual importable module name for a Python file"""
# Start with the default: convert file path to module path
default_module = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
# Common patterns to detect the actual package root
path_parts = Path(relative_path).parts
# Look for common package indicators
package_roots = []
# Check each directory level for __init__.py to find package boundaries
current_path = repo_root
for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself
current_path = current_path / part
if (current_path / '__init__.py').exists():
# This is a package directory, mark it as a potential root
package_roots.append(i)
if package_roots:
# Use the first (outermost) package as the root
package_start = package_roots[0]
module_parts = path_parts[package_start:]
module_name = '.'.join(module_parts).replace('.py', '')
return module_name
# Fallback: look for common Python project structures
# Skip common non-package directories
skip_dirs = {'src', 'lib', 'source', 'python', 'pkg', 'packages'}
# Find the first directory that's not in skip_dirs
filtered_parts = []
for part in path_parts:
if part.lower() not in skip_dirs or filtered_parts: # Once we start including, include everything
filtered_parts.append(part)
if filtered_parts:
module_name = '.'.join(filtered_parts).replace('.py', '')
return module_name
# Final fallback: use the default
return default_module
def _extract_function_parameters(self, func_node):
"""Comprehensive parameter extraction from function definition"""
params = []
# Regular positional arguments
for i, arg in enumerate(func_node.args.args):
if arg.arg == 'self':
continue
param_info = {
'name': arg.arg,
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
'kind': 'positional',
'optional': False,
'default': None
}
# Check if this argument has a default value
defaults_start = len(func_node.args.args) - len(func_node.args.defaults)
if i >= defaults_start:
default_idx = i - defaults_start
if default_idx < len(func_node.args.defaults):
param_info['optional'] = True
param_info['default'] = self._get_default_value(func_node.args.defaults[default_idx])
params.append(param_info)
# *args parameter
if func_node.args.vararg:
params.append({
'name': f"*{func_node.args.vararg.arg}",
'type': self._get_name(func_node.args.vararg.annotation) if func_node.args.vararg.annotation else 'Any',
'kind': 'var_positional',
'optional': True,
'default': None
})
# Keyword-only arguments (after *)
for i, arg in enumerate(func_node.args.kwonlyargs):
param_info = {
'name': arg.arg,
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
'kind': 'keyword_only',
'optional': True, # All kwonly args are optional unless explicitly required
'default': None
}
# Check for default value
if i < len(func_node.args.kw_defaults) and func_node.args.kw_defaults[i] is not None:
param_info['default'] = self._get_default_value(func_node.args.kw_defaults[i])
else:
param_info['optional'] = False # No default = required kwonly arg
params.append(param_info)
# **kwargs parameter
if func_node.args.kwarg:
params.append({
'name': f"**{func_node.args.kwarg.arg}",
'type': self._get_name(func_node.args.kwarg.annotation) if func_node.args.kwarg.annotation else 'Dict[str, Any]',
'kind': 'var_keyword',
'optional': True,
'default': None
})
return params
def _get_default_value(self, default_node):
"""Extract default value from AST node"""
try:
if isinstance(default_node, ast.Constant):
return repr(default_node.value)
elif isinstance(default_node, ast.Name):
return default_node.id
elif isinstance(default_node, ast.Attribute):
return self._get_name(default_node)
elif isinstance(default_node, ast.List):
return "[]"
elif isinstance(default_node, ast.Dict):
return "{}"
else:
return "..."
except Exception:
return "..."
def _get_name(self, node):
"""Extract name from AST node, handling complex types safely"""
if node is None:
return "Any"
try:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
if hasattr(node, 'value'):
return f"{self._get_name(node.value)}.{node.attr}"
else:
return node.attr
elif isinstance(node, ast.Subscript):
# Handle List[Type], Dict[K,V], etc.
base = self._get_name(node.value)
if hasattr(node, 'slice'):
if isinstance(node.slice, ast.Name):
return f"{base}[{node.slice.id}]"
elif isinstance(node.slice, ast.Tuple):
elts = [self._get_name(elt) for elt in node.slice.elts]
return f"{base}[{', '.join(elts)}]"
elif isinstance(node.slice, ast.Constant):
return f"{base}[{repr(node.slice.value)}]"
elif isinstance(node.slice, ast.Attribute):
return f"{base}[{self._get_name(node.slice)}]"
elif isinstance(node.slice, ast.Subscript):
return f"{base}[{self._get_name(node.slice)}]"
else:
# Try to get the name of the slice, fallback to Any if it fails
try:
slice_name = self._get_name(node.slice)
return f"{base}[{slice_name}]"
except:
return f"{base}[Any]"
return base
elif isinstance(node, ast.Constant):
return str(node.value)
elif isinstance(node, ast.Str): # Python < 3.8
return f'"{node.s}"'
elif isinstance(node, ast.Tuple):
elts = [self._get_name(elt) for elt in node.elts]
return f"({', '.join(elts)})"
elif isinstance(node, ast.List):
elts = [self._get_name(elt) for elt in node.elts]
return f"[{', '.join(elts)}]"
else:
# Fallback for complex types - return a simple string representation
return "Any"
except Exception:
# If anything goes wrong, return a safe default
return "Any"
class DirectNeo4jExtractor:
"""Creates nodes and relationships directly in Neo4j"""
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.neo4j_uri = neo4j_uri
self.neo4j_user = neo4j_user
self.neo4j_password = neo4j_password
self.driver = None
self.analyzer = Neo4jCodeAnalyzer()
async def initialize(self):
"""Initialize Neo4j connection"""
logger.info("Initializing Neo4j connection...")
self.driver = AsyncGraphDatabase.driver(
self.neo4j_uri,
auth=(self.neo4j_user, self.neo4j_password)
)
# Clear existing data
# logger.info("Clearing existing data...")
# async with self.driver.session() as session:
# await session.run("MATCH (n) DETACH DELETE n")
# Create constraints and indexes
logger.info("Creating constraints and indexes...")
async with self.driver.session() as session:
# Create constraints - using MERGE-friendly approach
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE")
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.full_name IS UNIQUE")
# Remove unique constraints for methods/attributes since they can be duplicated across classes
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.full_name IS UNIQUE")
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:Function) REQUIRE f.full_name IS UNIQUE")
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.full_name IS UNIQUE")
# Create indexes for performance
await session.run("CREATE INDEX IF NOT EXISTS FOR (f:File) ON (f.name)")
await session.run("CREATE INDEX IF NOT EXISTS FOR (c:Class) ON (c.name)")
await session.run("CREATE INDEX IF NOT EXISTS FOR (m:Method) ON (m.name)")
logger.info("Neo4j initialized successfully")
async def clear_repository_data(self, repo_name: str):
"""Clear all data for a specific repository"""
logger.info(f"Clearing existing data for repository: {repo_name}")
async with self.driver.session() as session:
# Delete in specific order to avoid constraint issues
# 1. Delete methods and attributes (they depend on classes)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
DETACH DELETE m
""", repo_name=repo_name)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
DETACH DELETE a
""", repo_name=repo_name)
# 2. Delete functions (they depend on files)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
DETACH DELETE func
""", repo_name=repo_name)
# 3. Delete classes (they depend on files)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
DETACH DELETE c
""", repo_name=repo_name)
# 4. Delete files (they depend on repository)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
DETACH DELETE f
""", repo_name=repo_name)
# 5. Finally delete the repository
await session.run("""
MATCH (r:Repository {name: $repo_name})
DETACH DELETE r
""", repo_name=repo_name)
logger.info(f"Cleared data for repository: {repo_name}")
async def close(self):
"""Close Neo4j connection"""
if self.driver:
await self.driver.close()
def clone_repo(self, repo_url: str, target_dir: str) -> str:
"""Clone repository with shallow clone"""
logger.info(f"Cloning repository to: {target_dir}")
if os.path.exists(target_dir):
logger.info(f"Removing existing directory: {target_dir}")
try:
def handle_remove_readonly(func, path, exc):
try:
if os.path.exists(path):
os.chmod(path, 0o777)
func(path)
except PermissionError:
logger.warning(f"Could not remove {path} - file in use, skipping")
pass
shutil.rmtree(target_dir, onerror=handle_remove_readonly)
except Exception as e:
logger.warning(f"Could not fully remove {target_dir}: {e}. Proceeding anyway...")
logger.info(f"Running git clone from {repo_url}")
subprocess.run(['git', 'clone', '--depth', '1', repo_url, target_dir], check=True)
logger.info("Repository cloned successfully")
return target_dir
def get_python_files(self, repo_path: str) -> List[Path]:
"""Get Python files, focusing on main source directories"""
python_files = []
exclude_dirs = {
'tests', 'test', '__pycache__', '.git', 'venv', 'env',
'node_modules', 'build', 'dist', '.pytest_cache', 'docs',
'examples', 'example', 'demo', 'benchmark'
}
for root, dirs, files in os.walk(repo_path):
dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')]
for file in files:
if file.endswith('.py') and not file.startswith('test_'):
file_path = Path(root) / file
if (file_path.stat().st_size < 500_000 and
file not in ['setup.py', 'conftest.py']):
python_files.append(file_path)
return python_files
async def analyze_repository(self, repo_url: str, temp_dir: str = None):
"""Analyze repository and create nodes/relationships in Neo4j"""
repo_name = repo_url.split('/')[-1].replace('.git', '')
logger.info(f"Analyzing repository: {repo_name}")
# Clear existing data for this repository before re-processing
await self.clear_repository_data(repo_name)
# Set default temp_dir to repos folder at script level
if temp_dir is None:
script_dir = Path(__file__).parent
temp_dir = str(script_dir / "repos" / repo_name)
# Clone and analyze
repo_path = Path(self.clone_repo(repo_url, temp_dir))
try:
logger.info("Getting Python files...")
python_files = self.get_python_files(str(repo_path))
logger.info(f"Found {len(python_files)} Python files to analyze")
# First pass: identify project modules
logger.info("Identifying project modules...")
project_modules = set()
for file_path in python_files:
relative_path = str(file_path.relative_to(repo_path))
module_parts = relative_path.replace('/', '.').replace('.py', '').split('.')
if len(module_parts) > 0 and not module_parts[0].startswith('.'):
project_modules.add(module_parts[0])
logger.info(f"Identified project modules: {sorted(project_modules)}")
# Second pass: analyze files and collect data
logger.info("Analyzing Python files...")
modules_data = []
for i, file_path in enumerate(python_files):
if i % 20 == 0:
logger.info(f"Analyzing file {i+1}/{len(python_files)}: {file_path.name}")
analysis = self.analyzer.analyze_python_file(file_path, repo_path, project_modules)
if analysis:
modules_data.append(analysis)
logger.info(f"Found {len(modules_data)} files with content")
# Create nodes and relationships in Neo4j
logger.info("Creating nodes and relationships in Neo4j...")
await self._create_graph(repo_name, modules_data)
# Print summary
total_classes = sum(len(mod['classes']) for mod in modules_data)
total_methods = sum(len(cls['methods']) for mod in modules_data for cls in mod['classes'])
total_functions = sum(len(mod['functions']) for mod in modules_data)
total_imports = sum(len(mod['imports']) for mod in modules_data)
print(f"\\n=== Direct Neo4j Repository Analysis for {repo_name} ===")
print(f"Files processed: {len(modules_data)}")
print(f"Classes created: {total_classes}")
print(f"Methods created: {total_methods}")
print(f"Functions created: {total_functions}")
print(f"Import relationships: {total_imports}")
logger.info(f"Successfully created Neo4j graph for {repo_name}")
finally:
if os.path.exists(temp_dir):
logger.info(f"Cleaning up temporary directory: {temp_dir}")
try:
def handle_remove_readonly(func, path, exc):
try:
if os.path.exists(path):
os.chmod(path, 0o777)
func(path)
except PermissionError:
logger.warning(f"Could not remove {path} - file in use, skipping")
pass
shutil.rmtree(temp_dir, onerror=handle_remove_readonly)
logger.info("Cleanup completed")
except Exception as e:
logger.warning(f"Cleanup failed: {e}. Directory may remain at {temp_dir}")
# Don't fail the whole process due to cleanup issues
async def _create_graph(self, repo_name: str, modules_data: List[Dict]):
"""Create all nodes and relationships in Neo4j"""
async with self.driver.session() as session:
# Create Repository node
await session.run(
"CREATE (r:Repository {name: $repo_name, created_at: datetime()})",
repo_name=repo_name
)
nodes_created = 0
relationships_created = 0
for i, mod in enumerate(modules_data):
# 1. Create File node
await session.run("""
CREATE (f:File {
name: $name,
path: $path,
module_name: $module_name,
line_count: $line_count,
created_at: datetime()
})
""",
name=mod['file_path'].split('/')[-1],
path=mod['file_path'],
module_name=mod['module_name'],
line_count=mod['line_count']
)
nodes_created += 1
# 2. Connect File to Repository
await session.run("""
MATCH (r:Repository {name: $repo_name})
MATCH (f:File {path: $file_path})
CREATE (r)-[:CONTAINS]->(f)
""", repo_name=repo_name, file_path=mod['file_path'])
relationships_created += 1
# 3. Create Class nodes and relationships
for cls in mod['classes']:
# Create Class node using MERGE to avoid duplicates
await session.run("""
MERGE (c:Class {full_name: $full_name})
ON CREATE SET c.name = $name, c.created_at = datetime()
""", name=cls['name'], full_name=cls['full_name'])
nodes_created += 1
# Connect File to Class
await session.run("""
MATCH (f:File {path: $file_path})
MATCH (c:Class {full_name: $class_full_name})
MERGE (f)-[:DEFINES]->(c)
""", file_path=mod['file_path'], class_full_name=cls['full_name'])
relationships_created += 1
# 4. Create Method nodes - use MERGE to avoid duplicates
for method in cls['methods']:
method_full_name = f"{cls['full_name']}.{method['name']}"
# Create method with unique ID to avoid conflicts
method_id = f"{cls['full_name']}::{method['name']}"
await session.run("""
MERGE (m:Method {method_id: $method_id})
ON CREATE SET m.name = $name,
m.full_name = $full_name,
m.args = $args,
m.params_list = $params_list,
m.params_detailed = $params_detailed,
m.return_type = $return_type,
m.created_at = datetime()
""",
name=method['name'],
full_name=method_full_name,
method_id=method_id,
args=method['args'],
params_list=[f"{p['name']}:{p['type']}" for p in method['params']], # Simple format
params_detailed=method.get('params_detailed', []), # Detailed format
return_type=method['return_type']
)
nodes_created += 1
# Connect Class to Method
await session.run("""
MATCH (c:Class {full_name: $class_full_name})
MATCH (m:Method {method_id: $method_id})
MERGE (c)-[:HAS_METHOD]->(m)
""",
class_full_name=cls['full_name'],
method_id=method_id
)
relationships_created += 1
# 5. Create Attribute nodes - use MERGE to avoid duplicates
for attr in cls['attributes']:
attr_full_name = f"{cls['full_name']}.{attr['name']}"
# Create attribute with unique ID to avoid conflicts
attr_id = f"{cls['full_name']}::{attr['name']}"
await session.run("""
MERGE (a:Attribute {attr_id: $attr_id})
ON CREATE SET a.name = $name,
a.full_name = $full_name,
a.type = $type,
a.created_at = datetime()
""",
name=attr['name'],
full_name=attr_full_name,
attr_id=attr_id,
type=attr['type']
)
nodes_created += 1
# Connect Class to Attribute
await session.run("""
MATCH (c:Class {full_name: $class_full_name})
MATCH (a:Attribute {attr_id: $attr_id})
MERGE (c)-[:HAS_ATTRIBUTE]->(a)
""",
class_full_name=cls['full_name'],
attr_id=attr_id
)
relationships_created += 1
# 6. Create Function nodes (top-level) - use MERGE to avoid duplicates
for func in mod['functions']:
func_id = f"{mod['file_path']}::{func['name']}"
await session.run("""
MERGE (f:Function {func_id: $func_id})
ON CREATE SET f.name = $name,
f.full_name = $full_name,
f.args = $args,
f.params_list = $params_list,
f.params_detailed = $params_detailed,
f.return_type = $return_type,
f.created_at = datetime()
""",
name=func['name'],
full_name=func['full_name'],
func_id=func_id,
args=func['args'],
params_list=func.get('params_list', []), # Simple format for backwards compatibility
params_detailed=func.get('params_detailed', []), # Detailed format
return_type=func['return_type']
)
nodes_created += 1
# Connect File to Function
await session.run("""
MATCH (file:File {path: $file_path})
MATCH (func:Function {func_id: $func_id})
MERGE (file)-[:DEFINES]->(func)
""", file_path=mod['file_path'], func_id=func_id)
relationships_created += 1
# 7. Create Import relationships
for import_name in mod['imports']:
# Try to find the target file
await session.run("""
MATCH (source:File {path: $source_path})
OPTIONAL MATCH (target:File)
WHERE target.module_name = $import_name OR target.module_name STARTS WITH $import_name
WITH source, target
WHERE target IS NOT NULL
MERGE (source)-[:IMPORTS]->(target)
""", source_path=mod['file_path'], import_name=import_name)
relationships_created += 1
if (i + 1) % 10 == 0:
logger.info(f"Processed {i + 1}/{len(modules_data)} files...")
logger.info(f"Created {nodes_created} nodes and {relationships_created} relationships")
async def search_graph(self, query_type: str, **kwargs):
"""Search the Neo4j graph directly"""
async with self.driver.session() as session:
if query_type == "files_importing":
target = kwargs.get('target')
result = await session.run("""
MATCH (source:File)-[:IMPORTS]->(target:File)
WHERE target.module_name CONTAINS $target
RETURN source.path as file, target.module_name as imports
""", target=target)
return [{"file": record["file"], "imports": record["imports"]} async for record in result]
elif query_type == "classes_in_file":
file_path = kwargs.get('file_path')
result = await session.run("""
MATCH (f:File {path: $file_path})-[:DEFINES]->(c:Class)
RETURN c.name as class_name, c.full_name as full_name
""", file_path=file_path)
return [{"class_name": record["class_name"], "full_name": record["full_name"]} async for record in result]
elif query_type == "methods_of_class":
class_name = kwargs.get('class_name')
result = await session.run("""
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE c.name CONTAINS $class_name OR c.full_name CONTAINS $class_name
RETURN m.name as method_name, m.args as args
""", class_name=class_name)
return [{"method_name": record["method_name"], "args": record["args"]} async for record in result]
async def main():
"""Example usage"""
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
try:
await extractor.initialize()
# Analyze repository - direct Neo4j, no LLM processing!
# repo_url = "https://github.com/pydantic/pydantic-ai.git"
repo_url = "https://github.com/getzep/graphiti.git"
await extractor.analyze_repository(repo_url)
# Direct graph queries
print("\\n=== Direct Neo4j Queries ===")
# Which files import from models?
results = await extractor.search_graph("files_importing", target="models")
print(f"\\nFiles importing from 'models': {len(results)}")
for result in results[:3]:
print(f"- {result['file']} imports {result['imports']}")
# What classes are in a specific file?
results = await extractor.search_graph("classes_in_file", file_path="pydantic_ai/models/openai.py")
print(f"\\nClasses in openai.py: {len(results)}")
for result in results:
print(f"- {result['class_name']}")
# What methods does OpenAIModel have?
results = await extractor.search_graph("methods_of_class", class_name="OpenAIModel")
print(f"\\nMethods of OpenAIModel: {len(results)}")
for result in results[:5]:
print(f"- {result['method_name']}({', '.join(result['args'])})")
finally:
await extractor.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,400 @@
#!/usr/bin/env python3
"""
Knowledge Graph Query Tool
Interactive script to explore what's actually stored in your Neo4j knowledge graph.
Useful for debugging hallucination detection and understanding graph contents.
"""
import asyncio
import os
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from typing import List, Dict, Any
import argparse
class KnowledgeGraphQuerier:
"""Interactive tool to query the knowledge graph"""
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.neo4j_uri = neo4j_uri
self.neo4j_user = neo4j_user
self.neo4j_password = neo4j_password
self.driver = None
async def initialize(self):
"""Initialize Neo4j connection"""
self.driver = AsyncGraphDatabase.driver(
self.neo4j_uri,
auth=(self.neo4j_user, self.neo4j_password)
)
print("🔗 Connected to Neo4j knowledge graph")
async def close(self):
"""Close Neo4j connection"""
if self.driver:
await self.driver.close()
async def list_repositories(self):
"""List all repositories in the knowledge graph"""
print("\n📚 Repositories in Knowledge Graph:")
print("=" * 50)
async with self.driver.session() as session:
query = "MATCH (r:Repository) RETURN r.name as name ORDER BY r.name"
result = await session.run(query)
repos = []
async for record in result:
repos.append(record['name'])
if repos:
for i, repo in enumerate(repos, 1):
print(f"{i}. {repo}")
else:
print("No repositories found in knowledge graph.")
return repos
async def explore_repository(self, repo_name: str):
"""Get overview of a specific repository"""
print(f"\n🔍 Exploring Repository: {repo_name}")
print("=" * 60)
async with self.driver.session() as session:
# Get file count
files_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
RETURN count(f) as file_count
"""
result = await session.run(files_query, repo_name=repo_name)
file_count = (await result.single())['file_count']
# Get class count
classes_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
RETURN count(DISTINCT c) as class_count
"""
result = await session.run(classes_query, repo_name=repo_name)
class_count = (await result.single())['class_count']
# Get function count
functions_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
RETURN count(DISTINCT func) as function_count
"""
result = await session.run(functions_query, repo_name=repo_name)
function_count = (await result.single())['function_count']
print(f"📄 Files: {file_count}")
print(f"🏗️ Classes: {class_count}")
print(f"⚙️ Functions: {function_count}")
async def list_classes(self, repo_name: str = None, limit: int = 20):
"""List classes in the knowledge graph"""
title = f"Classes in {repo_name}" if repo_name else "All Classes"
print(f"\n🏗️ {title} (limit {limit}):")
print("=" * 50)
async with self.driver.session() as session:
if repo_name:
query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
RETURN c.name as name, c.full_name as full_name
ORDER BY c.name
LIMIT $limit
"""
result = await session.run(query, repo_name=repo_name, limit=limit)
else:
query = """
MATCH (c:Class)
RETURN c.name as name, c.full_name as full_name
ORDER BY c.name
LIMIT $limit
"""
result = await session.run(query, limit=limit)
classes = []
async for record in result:
classes.append({
'name': record['name'],
'full_name': record['full_name']
})
if classes:
for i, cls in enumerate(classes, 1):
print(f"{i:2d}. {cls['name']} ({cls['full_name']})")
else:
print("No classes found.")
return classes
async def explore_class(self, class_name: str):
"""Get detailed information about a specific class"""
print(f"\n🔍 Exploring Class: {class_name}")
print("=" * 60)
async with self.driver.session() as session:
# Find the class
class_query = """
MATCH (c:Class)
WHERE c.name = $class_name OR c.full_name = $class_name
RETURN c.name as name, c.full_name as full_name
LIMIT 1
"""
result = await session.run(class_query, class_name=class_name)
class_record = await result.single()
if not class_record:
print(f"❌ Class '{class_name}' not found in knowledge graph.")
return None
actual_name = class_record['name']
full_name = class_record['full_name']
print(f"📋 Name: {actual_name}")
print(f"📋 Full Name: {full_name}")
# Get methods
methods_query = """
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE c.name = $class_name OR c.full_name = $class_name
RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed, m.return_type as return_type
ORDER BY m.name
"""
result = await session.run(methods_query, class_name=class_name)
methods = []
async for record in result:
methods.append({
'name': record['name'],
'params_list': record['params_list'] or [],
'params_detailed': record['params_detailed'] or [],
'return_type': record['return_type'] or 'Any'
})
if methods:
print(f"\n⚙️ Methods ({len(methods)}):")
for i, method in enumerate(methods, 1):
# Use detailed params if available, fall back to simple params
params_to_show = method['params_detailed'] or method['params_list']
params = ', '.join(params_to_show) if params_to_show else ''
print(f"{i:2d}. {method['name']}({params}) -> {method['return_type']}")
else:
print("\n⚙️ No methods found.")
# Get attributes
attributes_query = """
MATCH (c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
WHERE c.name = $class_name OR c.full_name = $class_name
RETURN a.name as name, a.type as type
ORDER BY a.name
"""
result = await session.run(attributes_query, class_name=class_name)
attributes = []
async for record in result:
attributes.append({
'name': record['name'],
'type': record['type'] or 'Any'
})
if attributes:
print(f"\n📋 Attributes ({len(attributes)}):")
for i, attr in enumerate(attributes, 1):
print(f"{i:2d}. {attr['name']}: {attr['type']}")
else:
print("\n📋 No attributes found.")
return {'methods': methods, 'attributes': attributes}
async def search_method(self, method_name: str, class_name: str = None):
"""Search for methods by name"""
title = f"Method '{method_name}'"
if class_name:
title += f" in class '{class_name}'"
print(f"\n🔍 Searching for {title}:")
print("=" * 60)
async with self.driver.session() as session:
if class_name:
query = """
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE (c.name = $class_name OR c.full_name = $class_name)
AND m.name = $method_name
RETURN c.name as class_name, c.full_name as class_full_name,
m.name as method_name, m.params_list as params_list,
m.return_type as return_type, m.args as args
"""
result = await session.run(query, class_name=class_name, method_name=method_name)
else:
query = """
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE m.name = $method_name
RETURN c.name as class_name, c.full_name as class_full_name,
m.name as method_name, m.params_list as params_list,
m.return_type as return_type, m.args as args
ORDER BY c.name
"""
result = await session.run(query, method_name=method_name)
methods = []
async for record in result:
methods.append({
'class_name': record['class_name'],
'class_full_name': record['class_full_name'],
'method_name': record['method_name'],
'params_list': record['params_list'] or [],
'return_type': record['return_type'] or 'Any',
'args': record['args'] or []
})
if methods:
for i, method in enumerate(methods, 1):
params = ', '.join(method['params_list']) if method['params_list'] else ''
print(f"{i}. {method['class_full_name']}.{method['method_name']}({params}) -> {method['return_type']}")
if method['args']:
print(f" Legacy args: {method['args']}")
else:
print(f"❌ Method '{method_name}' not found.")
return methods
async def run_custom_query(self, query: str):
"""Run a custom Cypher query"""
print(f"\n🔍 Running Custom Query:")
print("=" * 60)
print(f"Query: {query}")
print("-" * 60)
async with self.driver.session() as session:
try:
result = await session.run(query)
records = []
async for record in result:
records.append(dict(record))
if records:
for i, record in enumerate(records, 1):
print(f"{i:2d}. {record}")
if i >= 20: # Limit output
print(f"... and {len(records) - 20} more records")
break
else:
print("No results found.")
return records
except Exception as e:
print(f"❌ Query error: {str(e)}")
return None
async def interactive_mode(querier: KnowledgeGraphQuerier):
"""Interactive exploration mode"""
print("\n🚀 Welcome to Knowledge Graph Explorer!")
print("Available commands:")
print(" repos - List all repositories")
print(" explore <repo> - Explore a specific repository")
print(" classes [repo] - List classes (optionally in specific repo)")
print(" class <name> - Explore a specific class")
print(" method <name> [class] - Search for method")
print(" query <cypher> - Run custom Cypher query")
print(" quit - Exit")
print()
while True:
try:
command = input("🔍 > ").strip()
if not command:
continue
elif command == "quit":
break
elif command == "repos":
await querier.list_repositories()
elif command.startswith("explore "):
repo_name = command[8:].strip()
await querier.explore_repository(repo_name)
elif command == "classes":
await querier.list_classes()
elif command.startswith("classes "):
repo_name = command[8:].strip()
await querier.list_classes(repo_name)
elif command.startswith("class "):
class_name = command[6:].strip()
await querier.explore_class(class_name)
elif command.startswith("method "):
parts = command[7:].strip().split()
if len(parts) >= 2:
await querier.search_method(parts[0], parts[1])
else:
await querier.search_method(parts[0])
elif command.startswith("query "):
query = command[6:].strip()
await querier.run_custom_query(query)
else:
print("❌ Unknown command. Type 'quit' to exit.")
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
except Exception as e:
print(f"❌ Error: {str(e)}")
async def main():
"""Main function with CLI argument support"""
parser = argparse.ArgumentParser(description="Query the knowledge graph")
parser.add_argument('--repos', action='store_true', help='List repositories')
parser.add_argument('--classes', metavar='REPO', nargs='?', const='', help='List classes')
parser.add_argument('--explore', metavar='REPO', help='Explore repository')
parser.add_argument('--class', dest='class_name', metavar='NAME', help='Explore class')
parser.add_argument('--method', nargs='+', metavar=('NAME', 'CLASS'), help='Search method')
parser.add_argument('--query', metavar='CYPHER', help='Run custom query')
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
args = parser.parse_args()
# Load environment
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
querier = KnowledgeGraphQuerier(neo4j_uri, neo4j_user, neo4j_password)
try:
await querier.initialize()
# Execute commands based on arguments
if args.repos:
await querier.list_repositories()
elif args.classes is not None:
await querier.list_classes(args.classes if args.classes else None)
elif args.explore:
await querier.explore_repository(args.explore)
elif args.class_name:
await querier.explore_class(args.class_name)
elif args.method:
if len(args.method) >= 2:
await querier.search_method(args.method[0], args.method[1])
else:
await querier.search_method(args.method[0])
elif args.query:
await querier.run_custom_query(args.query)
elif args.interactive or len(sys.argv) == 1:
await interactive_mode(querier)
else:
parser.print_help()
finally:
await querier.close()
if __name__ == "__main__":
import sys
asyncio.run(main())

View File

@@ -0,0 +1,160 @@
from __future__ import annotations
from typing import Dict, List, Optional
from dataclasses import dataclass
from pydantic import BaseModel, Field
from dotenv import load_dotenv
from rich.markdown import Markdown
from rich.console import Console
from rich.live import Live
import asyncio
import os
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai import Agent, RunContext
from graphiti_core import Graphiti
load_dotenv()
# ========== Define dependencies ==========
@dataclass
class GraphitiDependencies:
"""Dependencies for the Graphiti agent."""
graphiti_client: Graphiti
# ========== Helper function to get model configuration ==========
def get_model():
"""Configure and return the LLM model to use."""
model_choice = os.getenv('MODEL_CHOICE', 'gpt-4.1-mini')
api_key = os.getenv('OPENAI_API_KEY', 'no-api-key-provided')
return OpenAIModel(model_choice, provider=OpenAIProvider(api_key=api_key))
# ========== Create the Graphiti agent ==========
graphiti_agent = Agent(
get_model(),
system_prompt="""You are a helpful assistant with access to a knowledge graph filled with temporal data about LLMs.
When the user asks you a question, use your search tool to query the knowledge graph and then answer honestly.
Be willing to admit when you didn't find the information necessary to answer the question.""",
deps_type=GraphitiDependencies
)
# ========== Define a result model for Graphiti search ==========
class GraphitiSearchResult(BaseModel):
"""Model representing a search result from Graphiti."""
uuid: str = Field(description="The unique identifier for this fact")
fact: str = Field(description="The factual statement retrieved from the knowledge graph")
valid_at: Optional[str] = Field(None, description="When this fact became valid (if known)")
invalid_at: Optional[str] = Field(None, description="When this fact became invalid (if known)")
source_node_uuid: Optional[str] = Field(None, description="UUID of the source node")
# ========== Graphiti search tool ==========
@graphiti_agent.tool
async def search_graphiti(ctx: RunContext[GraphitiDependencies], query: str) -> List[GraphitiSearchResult]:
"""Search the Graphiti knowledge graph with the given query.
Args:
ctx: The run context containing dependencies
query: The search query to find information in the knowledge graph
Returns:
A list of search results containing facts that match the query
"""
# Access the Graphiti client from dependencies
graphiti = ctx.deps.graphiti_client
try:
# Perform the search
results = await graphiti.search(query)
# Format the results
formatted_results = []
for result in results:
formatted_result = GraphitiSearchResult(
uuid=result.uuid,
fact=result.fact,
source_node_uuid=result.source_node_uuid if hasattr(result, 'source_node_uuid') else None
)
# Add temporal information if available
if hasattr(result, 'valid_at') and result.valid_at:
formatted_result.valid_at = str(result.valid_at)
if hasattr(result, 'invalid_at') and result.invalid_at:
formatted_result.invalid_at = str(result.invalid_at)
formatted_results.append(formatted_result)
return formatted_results
except Exception as e:
# Log the error but don't close the connection since it's managed by the dependency
print(f"Error searching Graphiti: {str(e)}")
raise
# ========== Main execution function ==========
async def main():
"""Run the Graphiti agent with user queries."""
print("Graphiti Agent - Powered by Pydantic AI, Graphiti, and Neo4j")
print("Enter 'exit' to quit the program.")
# Neo4j connection parameters
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
# Initialize Graphiti with Neo4j connection
graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
# Initialize the graph database with graphiti's indices if needed
try:
await graphiti_client.build_indices_and_constraints()
print("Graphiti indices built successfully.")
except Exception as e:
print(f"Note: {str(e)}")
print("Continuing with existing indices...")
console = Console()
messages = []
try:
while True:
# Get user input
user_input = input("\n[You] ")
# Check if user wants to exit
if user_input.lower() in ['exit', 'quit', 'bye', 'goodbye']:
print("Goodbye!")
break
try:
# Process the user input and output the response
print("\n[Assistant]")
with Live('', console=console, vertical_overflow='visible') as live:
# Pass the Graphiti client as a dependency
deps = GraphitiDependencies(graphiti_client=graphiti_client)
async with graphiti_agent.run_a_stream(
user_input, message_history=messages, deps=deps
) as result:
curr_message = ""
async for message in result.stream_text(delta=True):
curr_message += message
live.update(Markdown(curr_message))
# Add the new messages to the chat history
messages.extend(result.all_messages())
except Exception as e:
print(f"\n[Error] An error occurred: {str(e)}")
finally:
# Close the Graphiti connection when done
await graphiti_client.close()
print("\nGraphiti connection closed.")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nProgram terminated by user.")
except Exception as e:
print(f"\nUnexpected error: {str(e)}")
raise