feat: add LLMX configuration with Crawl4AI RAG MCP server
- Add config.toml with MCP servers configuration - Add compose.yaml for PostgreSQL+pgvector, PostgREST, and Crawl4AI RAG - Include forked mcp-crawl4ai-rag with BGE 1024-dim embedding support - Custom schema (crawled_pages_1024.sql) for BGE embeddings 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
AI Hallucination Detector
|
||||
|
||||
Main orchestrator for detecting AI coding assistant hallucinations in Python scripts.
|
||||
Combines AST analysis, knowledge graph validation, and comprehensive reporting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ai_script_analyzer import AIScriptAnalyzer, analyze_ai_script
|
||||
from knowledge_graph_validator import KnowledgeGraphValidator
|
||||
from hallucination_reporter import HallucinationReporter
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIHallucinationDetector:
|
||||
"""Main detector class that orchestrates the entire process"""
|
||||
|
||||
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
|
||||
self.validator = KnowledgeGraphValidator(neo4j_uri, neo4j_user, neo4j_password)
|
||||
self.reporter = HallucinationReporter()
|
||||
self.analyzer = AIScriptAnalyzer()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connections and components"""
|
||||
await self.validator.initialize()
|
||||
logger.info("AI Hallucination Detector initialized successfully")
|
||||
|
||||
async def close(self):
|
||||
"""Close connections"""
|
||||
await self.validator.close()
|
||||
|
||||
async def detect_hallucinations(self, script_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
save_json: bool = True,
|
||||
save_markdown: bool = True,
|
||||
print_summary: bool = True) -> dict:
|
||||
"""
|
||||
Main detection function that analyzes a script and generates reports
|
||||
|
||||
Args:
|
||||
script_path: Path to the AI-generated Python script
|
||||
output_dir: Directory to save reports (defaults to script directory)
|
||||
save_json: Whether to save JSON report
|
||||
save_markdown: Whether to save Markdown report
|
||||
print_summary: Whether to print summary to console
|
||||
|
||||
Returns:
|
||||
Complete validation report as dictionary
|
||||
"""
|
||||
logger.info(f"Starting hallucination detection for: {script_path}")
|
||||
|
||||
# Validate input
|
||||
if not os.path.exists(script_path):
|
||||
raise FileNotFoundError(f"Script not found: {script_path}")
|
||||
|
||||
if not script_path.endswith('.py'):
|
||||
raise ValueError("Only Python (.py) files are supported")
|
||||
|
||||
# Set output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(script_path).parent)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Step 1: Analyze the script using AST
|
||||
logger.info("Step 1: Analyzing script structure...")
|
||||
analysis_result = self.analyzer.analyze_script(script_path)
|
||||
|
||||
if analysis_result.errors:
|
||||
logger.warning(f"Analysis warnings: {analysis_result.errors}")
|
||||
|
||||
logger.info(f"Found: {len(analysis_result.imports)} imports, "
|
||||
f"{len(analysis_result.class_instantiations)} class instantiations, "
|
||||
f"{len(analysis_result.method_calls)} method calls, "
|
||||
f"{len(analysis_result.function_calls)} function calls, "
|
||||
f"{len(analysis_result.attribute_accesses)} attribute accesses")
|
||||
|
||||
# Step 2: Validate against knowledge graph
|
||||
logger.info("Step 2: Validating against knowledge graph...")
|
||||
validation_result = await self.validator.validate_script(analysis_result)
|
||||
|
||||
logger.info(f"Validation complete. Overall confidence: {validation_result.overall_confidence:.1%}")
|
||||
|
||||
# Step 3: Generate comprehensive report
|
||||
logger.info("Step 3: Generating reports...")
|
||||
report = self.reporter.generate_comprehensive_report(validation_result)
|
||||
|
||||
# Step 4: Save reports
|
||||
script_name = Path(script_path).stem
|
||||
|
||||
if save_json:
|
||||
json_path = os.path.join(output_dir, f"{script_name}_hallucination_report.json")
|
||||
self.reporter.save_json_report(report, json_path)
|
||||
|
||||
if save_markdown:
|
||||
md_path = os.path.join(output_dir, f"{script_name}_hallucination_report.md")
|
||||
self.reporter.save_markdown_report(report, md_path)
|
||||
|
||||
# Step 5: Print summary
|
||||
if print_summary:
|
||||
self.reporter.print_summary(report)
|
||||
|
||||
logger.info("Hallucination detection completed successfully")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during hallucination detection: {str(e)}")
|
||||
raise
|
||||
|
||||
async def batch_detect(self, script_paths: List[str],
|
||||
output_dir: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Detect hallucinations in multiple scripts
|
||||
|
||||
Args:
|
||||
script_paths: List of paths to Python scripts
|
||||
output_dir: Directory to save all reports
|
||||
|
||||
Returns:
|
||||
List of validation reports
|
||||
"""
|
||||
logger.info(f"Starting batch detection for {len(script_paths)} scripts")
|
||||
|
||||
results = []
|
||||
for i, script_path in enumerate(script_paths, 1):
|
||||
logger.info(f"Processing script {i}/{len(script_paths)}: {script_path}")
|
||||
|
||||
try:
|
||||
result = await self.detect_hallucinations(
|
||||
script_path=script_path,
|
||||
output_dir=output_dir,
|
||||
print_summary=False # Don't print individual summaries in batch mode
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {script_path}: {str(e)}")
|
||||
# Continue with other scripts
|
||||
continue
|
||||
|
||||
# Print batch summary
|
||||
self._print_batch_summary(results)
|
||||
|
||||
return results
|
||||
|
||||
def _print_batch_summary(self, results: List[dict]):
|
||||
"""Print summary of batch processing results"""
|
||||
if not results:
|
||||
print("No scripts were successfully processed.")
|
||||
return
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🚀 BATCH HALLUCINATION DETECTION SUMMARY")
|
||||
print("="*80)
|
||||
|
||||
total_scripts = len(results)
|
||||
total_validations = sum(r['validation_summary']['total_validations'] for r in results)
|
||||
total_valid = sum(r['validation_summary']['valid_count'] for r in results)
|
||||
total_invalid = sum(r['validation_summary']['invalid_count'] for r in results)
|
||||
total_not_found = sum(r['validation_summary']['not_found_count'] for r in results)
|
||||
total_hallucinations = sum(len(r['hallucinations_detected']) for r in results)
|
||||
|
||||
avg_confidence = sum(r['validation_summary']['overall_confidence'] for r in results) / total_scripts
|
||||
|
||||
print(f"Scripts Processed: {total_scripts}")
|
||||
print(f"Total Validations: {total_validations}")
|
||||
print(f"Average Confidence: {avg_confidence:.1%}")
|
||||
print(f"Total Hallucinations: {total_hallucinations}")
|
||||
|
||||
print(f"\nAggregated Results:")
|
||||
print(f" ✅ Valid: {total_valid} ({total_valid/total_validations:.1%})")
|
||||
print(f" ❌ Invalid: {total_invalid} ({total_invalid/total_validations:.1%})")
|
||||
print(f" 🔍 Not Found: {total_not_found} ({total_not_found/total_validations:.1%})")
|
||||
|
||||
# Show worst performing scripts
|
||||
print(f"\n🚨 Scripts with Most Hallucinations:")
|
||||
sorted_results = sorted(results, key=lambda x: len(x['hallucinations_detected']), reverse=True)
|
||||
for result in sorted_results[:5]:
|
||||
script_name = Path(result['analysis_metadata']['script_path']).name
|
||||
hall_count = len(result['hallucinations_detected'])
|
||||
confidence = result['validation_summary']['overall_confidence']
|
||||
print(f" - {script_name}: {hall_count} hallucinations ({confidence:.1%} confidence)")
|
||||
|
||||
print("="*80)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Command-line interface for the AI Hallucination Detector"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Detect AI coding assistant hallucinations in Python scripts",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Analyze single script
|
||||
python ai_hallucination_detector.py script.py
|
||||
|
||||
# Analyze multiple scripts
|
||||
python ai_hallucination_detector.py script1.py script2.py script3.py
|
||||
|
||||
# Specify output directory
|
||||
python ai_hallucination_detector.py script.py --output-dir reports/
|
||||
|
||||
# Skip markdown report
|
||||
python ai_hallucination_detector.py script.py --no-markdown
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'scripts',
|
||||
nargs='+',
|
||||
help='Python script(s) to analyze for hallucinations'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
help='Directory to save reports (defaults to script directory)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-json',
|
||||
action='store_true',
|
||||
help='Skip JSON report generation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-markdown',
|
||||
action='store_true',
|
||||
help='Skip Markdown report generation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-summary',
|
||||
action='store_true',
|
||||
help='Skip printing summary to console'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--neo4j-uri',
|
||||
default=None,
|
||||
help='Neo4j URI (default: from environment NEO4J_URI)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--neo4j-user',
|
||||
default=None,
|
||||
help='Neo4j username (default: from environment NEO4J_USER)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--neo4j-password',
|
||||
default=None,
|
||||
help='Neo4j password (default: from environment NEO4J_PASSWORD)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
# Only enable debug for our modules, not neo4j
|
||||
logging.getLogger('neo4j').setLevel(logging.WARNING)
|
||||
logging.getLogger('neo4j.pool').setLevel(logging.WARNING)
|
||||
logging.getLogger('neo4j.io').setLevel(logging.WARNING)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get Neo4j credentials
|
||||
neo4j_uri = args.neo4j_uri or os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = args.neo4j_user or os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = args.neo4j_password or os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
if not neo4j_password or neo4j_password == 'password':
|
||||
logger.error("Please set NEO4J_PASSWORD environment variable or use --neo4j-password")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize detector
|
||||
detector = AIHallucinationDetector(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
await detector.initialize()
|
||||
|
||||
# Process scripts
|
||||
if len(args.scripts) == 1:
|
||||
# Single script mode
|
||||
await detector.detect_hallucinations(
|
||||
script_path=args.scripts[0],
|
||||
output_dir=args.output_dir,
|
||||
save_json=not args.no_json,
|
||||
save_markdown=not args.no_markdown,
|
||||
print_summary=not args.no_summary
|
||||
)
|
||||
else:
|
||||
# Batch mode
|
||||
await detector.batch_detect(
|
||||
script_paths=args.scripts,
|
||||
output_dir=args.output_dir
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Detection interrupted by user")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Detection failed: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
finally:
|
||||
await detector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
532
servers/mcp-crawl4ai-rag/knowledge_graphs/ai_script_analyzer.py
Normal file
532
servers/mcp-crawl4ai-rag/knowledge_graphs/ai_script_analyzer.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""
|
||||
AI Script Analyzer
|
||||
|
||||
Parses Python scripts generated by AI coding assistants using AST to extract:
|
||||
- Import statements and their usage
|
||||
- Class instantiations and method calls
|
||||
- Function calls with parameters
|
||||
- Attribute access patterns
|
||||
- Variable type tracking
|
||||
"""
|
||||
|
||||
import ast
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportInfo:
|
||||
"""Information about an import statement"""
|
||||
module: str
|
||||
name: str
|
||||
alias: Optional[str] = None
|
||||
is_from_import: bool = False
|
||||
line_number: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodCall:
|
||||
"""Information about a method call"""
|
||||
object_name: str
|
||||
method_name: str
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
line_number: int
|
||||
object_type: Optional[str] = None # Inferred class type
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttributeAccess:
|
||||
"""Information about attribute access"""
|
||||
object_name: str
|
||||
attribute_name: str
|
||||
line_number: int
|
||||
object_type: Optional[str] = None # Inferred class type
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
"""Information about a function call"""
|
||||
function_name: str
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
line_number: int
|
||||
full_name: Optional[str] = None # Module.function_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassInstantiation:
|
||||
"""Information about class instantiation"""
|
||||
variable_name: str
|
||||
class_name: str
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
line_number: int
|
||||
full_class_name: Optional[str] = None # Module.ClassName
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""Complete analysis results for a Python script"""
|
||||
file_path: str
|
||||
imports: List[ImportInfo] = field(default_factory=list)
|
||||
class_instantiations: List[ClassInstantiation] = field(default_factory=list)
|
||||
method_calls: List[MethodCall] = field(default_factory=list)
|
||||
attribute_accesses: List[AttributeAccess] = field(default_factory=list)
|
||||
function_calls: List[FunctionCall] = field(default_factory=list)
|
||||
variable_types: Dict[str, str] = field(default_factory=dict) # variable_name -> class_type
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class AIScriptAnalyzer:
|
||||
"""Analyzes AI-generated Python scripts for validation against knowledge graph"""
|
||||
|
||||
def __init__(self):
|
||||
self.import_map: Dict[str, str] = {} # alias -> actual_module_name
|
||||
self.variable_types: Dict[str, str] = {} # variable_name -> class_type
|
||||
self.context_manager_vars: Dict[str, Tuple[int, int, str]] = {} # var_name -> (start_line, end_line, type)
|
||||
|
||||
def analyze_script(self, script_path: str) -> AnalysisResult:
|
||||
"""Analyze a Python script and extract all relevant information"""
|
||||
try:
|
||||
with open(script_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
result = AnalysisResult(file_path=script_path)
|
||||
|
||||
# Reset state for new analysis
|
||||
self.import_map.clear()
|
||||
self.variable_types.clear()
|
||||
self.context_manager_vars.clear()
|
||||
|
||||
# Track processed nodes to avoid duplicates
|
||||
self.processed_calls = set()
|
||||
self.method_call_attributes = set()
|
||||
|
||||
# First pass: collect imports and build import map
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._extract_imports(node, result)
|
||||
|
||||
# Second pass: analyze usage patterns
|
||||
for node in ast.walk(tree):
|
||||
self._analyze_node(node, result)
|
||||
|
||||
# Set inferred types on method calls and attribute accesses
|
||||
self._infer_object_types(result)
|
||||
|
||||
result.variable_types = self.variable_types.copy()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to analyze script {script_path}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
result = AnalysisResult(file_path=script_path)
|
||||
result.errors.append(error_msg)
|
||||
return result
|
||||
|
||||
def _extract_imports(self, node: ast.AST, result: AnalysisResult):
|
||||
"""Extract import information and build import mapping"""
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
import_name = alias.name
|
||||
alias_name = alias.asname or import_name
|
||||
|
||||
result.imports.append(ImportInfo(
|
||||
module=import_name,
|
||||
name=import_name,
|
||||
alias=alias.asname,
|
||||
is_from_import=False,
|
||||
line_number=line_num
|
||||
))
|
||||
|
||||
self.import_map[alias_name] = import_name
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module = node.module or ""
|
||||
for alias in node.names:
|
||||
import_name = alias.name
|
||||
alias_name = alias.asname or import_name
|
||||
|
||||
result.imports.append(ImportInfo(
|
||||
module=module,
|
||||
name=import_name,
|
||||
alias=alias.asname,
|
||||
is_from_import=True,
|
||||
line_number=line_num
|
||||
))
|
||||
|
||||
# Map alias to full module.name
|
||||
if module:
|
||||
full_name = f"{module}.{import_name}"
|
||||
self.import_map[alias_name] = full_name
|
||||
else:
|
||||
self.import_map[alias_name] = import_name
|
||||
|
||||
def _analyze_node(self, node: ast.AST, result: AnalysisResult):
|
||||
"""Analyze individual AST nodes for usage patterns"""
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
# Assignments (class instantiations and method call results)
|
||||
if isinstance(node, ast.Assign):
|
||||
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
|
||||
if isinstance(node.value, ast.Call):
|
||||
# Check if it's a class instantiation or method call
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
# Direct function/class call
|
||||
self._extract_class_instantiation(node, result)
|
||||
# Mark this call as processed to avoid duplicate processing
|
||||
self.processed_calls.add(id(node.value))
|
||||
elif isinstance(node.value.func, ast.Attribute):
|
||||
# Method call - track the variable assignment for type inference
|
||||
var_name = node.targets[0].id
|
||||
self._track_method_result_assignment(node.value, var_name)
|
||||
# Still process the method call
|
||||
self._extract_method_call(node.value, result)
|
||||
self.processed_calls.add(id(node.value))
|
||||
|
||||
# AsyncWith statements (context managers)
|
||||
elif isinstance(node, ast.AsyncWith):
|
||||
self._handle_async_with(node, result)
|
||||
elif isinstance(node, ast.With):
|
||||
self._handle_with(node, result)
|
||||
|
||||
# Method calls and function calls
|
||||
elif isinstance(node, ast.Call):
|
||||
# Skip if this call was already processed as part of an assignment
|
||||
if id(node) in self.processed_calls:
|
||||
return
|
||||
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
self._extract_method_call(node, result)
|
||||
# Mark this attribute as used in method call to avoid duplicate processing
|
||||
self.method_call_attributes.add(id(node.func))
|
||||
elif isinstance(node.func, ast.Name):
|
||||
# Check if this is likely a class instantiation (based on imported classes)
|
||||
func_name = node.func.id
|
||||
full_name = self._resolve_full_name(func_name)
|
||||
|
||||
# If this is a known imported class, treat as class instantiation
|
||||
if self._is_likely_class_instantiation(func_name, full_name):
|
||||
self._extract_nested_class_instantiation(node, result)
|
||||
else:
|
||||
self._extract_function_call(node, result)
|
||||
|
||||
# Attribute access (not in call context)
|
||||
elif isinstance(node, ast.Attribute):
|
||||
# Skip if this attribute was already processed as part of a method call
|
||||
if id(node) in self.method_call_attributes:
|
||||
return
|
||||
self._extract_attribute_access(node, result)
|
||||
|
||||
def _extract_class_instantiation(self, node: ast.Assign, result: AnalysisResult):
|
||||
"""Extract class instantiation from assignment"""
|
||||
target = node.targets[0]
|
||||
call = node.value
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
if isinstance(target, ast.Name) and isinstance(call, ast.Call):
|
||||
var_name = target.id
|
||||
class_name = self._get_name_from_call(call.func)
|
||||
|
||||
if class_name:
|
||||
args = [self._get_arg_representation(arg) for arg in call.args]
|
||||
kwargs = {
|
||||
kw.arg: self._get_arg_representation(kw.value)
|
||||
for kw in call.keywords if kw.arg
|
||||
}
|
||||
|
||||
# Resolve full class name using import map
|
||||
full_class_name = self._resolve_full_name(class_name)
|
||||
|
||||
instantiation = ClassInstantiation(
|
||||
variable_name=var_name,
|
||||
class_name=class_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
line_number=line_num,
|
||||
full_class_name=full_class_name
|
||||
)
|
||||
|
||||
result.class_instantiations.append(instantiation)
|
||||
|
||||
# Track variable type for later method call analysis
|
||||
self.variable_types[var_name] = full_class_name or class_name
|
||||
|
||||
def _extract_method_call(self, node: ast.Call, result: AnalysisResult):
|
||||
"""Extract method call information"""
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
# Get object and method names
|
||||
obj_name = self._get_name_from_node(node.func.value)
|
||||
method_name = node.func.attr
|
||||
|
||||
if obj_name and method_name:
|
||||
args = [self._get_arg_representation(arg) for arg in node.args]
|
||||
kwargs = {
|
||||
kw.arg: self._get_arg_representation(kw.value)
|
||||
for kw in node.keywords if kw.arg
|
||||
}
|
||||
|
||||
method_call = MethodCall(
|
||||
object_name=obj_name,
|
||||
method_name=method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
line_number=line_num,
|
||||
object_type=self.variable_types.get(obj_name)
|
||||
)
|
||||
|
||||
result.method_calls.append(method_call)
|
||||
|
||||
def _extract_function_call(self, node: ast.Call, result: AnalysisResult):
|
||||
"""Extract function call information"""
|
||||
if isinstance(node.func, ast.Name):
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
func_name = node.func.id
|
||||
|
||||
args = [self._get_arg_representation(arg) for arg in node.args]
|
||||
kwargs = {
|
||||
kw.arg: self._get_arg_representation(kw.value)
|
||||
for kw in node.keywords if kw.arg
|
||||
}
|
||||
|
||||
# Resolve full function name using import map
|
||||
full_func_name = self._resolve_full_name(func_name)
|
||||
|
||||
function_call = FunctionCall(
|
||||
function_name=func_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
line_number=line_num,
|
||||
full_name=full_func_name
|
||||
)
|
||||
|
||||
result.function_calls.append(function_call)
|
||||
|
||||
def _extract_attribute_access(self, node: ast.Attribute, result: AnalysisResult):
|
||||
"""Extract attribute access information"""
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
obj_name = self._get_name_from_node(node.value)
|
||||
attr_name = node.attr
|
||||
|
||||
if obj_name and attr_name:
|
||||
attribute_access = AttributeAccess(
|
||||
object_name=obj_name,
|
||||
attribute_name=attr_name,
|
||||
line_number=line_num,
|
||||
object_type=self.variable_types.get(obj_name)
|
||||
)
|
||||
|
||||
result.attribute_accesses.append(attribute_access)
|
||||
|
||||
def _infer_object_types(self, result: AnalysisResult):
|
||||
"""Update object types for method calls and attribute accesses"""
|
||||
for method_call in result.method_calls:
|
||||
if not method_call.object_type:
|
||||
# First check context manager variables
|
||||
obj_type = self._get_context_aware_type(method_call.object_name, method_call.line_number)
|
||||
if obj_type:
|
||||
method_call.object_type = obj_type
|
||||
else:
|
||||
method_call.object_type = self.variable_types.get(method_call.object_name)
|
||||
|
||||
for attr_access in result.attribute_accesses:
|
||||
if not attr_access.object_type:
|
||||
# First check context manager variables
|
||||
obj_type = self._get_context_aware_type(attr_access.object_name, attr_access.line_number)
|
||||
if obj_type:
|
||||
attr_access.object_type = obj_type
|
||||
else:
|
||||
attr_access.object_type = self.variable_types.get(attr_access.object_name)
|
||||
|
||||
def _get_context_aware_type(self, var_name: str, line_number: int) -> Optional[str]:
|
||||
"""Get the type of a variable considering its context (e.g., async with scope)"""
|
||||
if var_name in self.context_manager_vars:
|
||||
start_line, end_line, var_type = self.context_manager_vars[var_name]
|
||||
if start_line <= line_number <= end_line:
|
||||
return var_type
|
||||
return None
|
||||
|
||||
def _get_name_from_call(self, node: ast.AST) -> Optional[str]:
|
||||
"""Get the name from a call node (for class instantiation)"""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
value_name = self._get_name_from_node(node.value)
|
||||
if value_name:
|
||||
return f"{value_name}.{node.attr}"
|
||||
return None
|
||||
|
||||
def _get_name_from_node(self, node: ast.AST) -> Optional[str]:
|
||||
"""Get string representation of a node (for object names)"""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
value_name = self._get_name_from_node(node.value)
|
||||
if value_name:
|
||||
return f"{value_name}.{node.attr}"
|
||||
return None
|
||||
|
||||
def _get_arg_representation(self, node: ast.AST) -> str:
|
||||
"""Get string representation of an argument"""
|
||||
if isinstance(node, ast.Constant):
|
||||
return repr(node.value)
|
||||
elif isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
return self._get_name_from_node(node) or "<?>"
|
||||
elif isinstance(node, ast.Call):
|
||||
func_name = self._get_name_from_call(node.func)
|
||||
return f"{func_name}(...)" if func_name else "call(...)"
|
||||
else:
|
||||
return f"<{type(node).__name__}>"
|
||||
|
||||
def _is_likely_class_instantiation(self, func_name: str, full_name: Optional[str]) -> bool:
|
||||
"""Determine if a function call is likely a class instantiation"""
|
||||
# Check if it's a known imported class (classes typically start with uppercase)
|
||||
if func_name and func_name[0].isupper():
|
||||
return True
|
||||
|
||||
# Check if the full name suggests a class (contains known class patterns)
|
||||
if full_name:
|
||||
# Common class patterns in module names
|
||||
class_patterns = [
|
||||
'Model', 'Provider', 'Client', 'Agent', 'Manager', 'Handler',
|
||||
'Builder', 'Factory', 'Service', 'Controller', 'Processor'
|
||||
]
|
||||
return any(pattern in full_name for pattern in class_patterns)
|
||||
|
||||
return False
|
||||
|
||||
def _extract_nested_class_instantiation(self, node: ast.Call, result: AnalysisResult):
|
||||
"""Extract class instantiation that's not in direct assignment (e.g., as parameter)"""
|
||||
line_num = getattr(node, 'lineno', 0)
|
||||
|
||||
if isinstance(node.func, ast.Name):
|
||||
class_name = node.func.id
|
||||
|
||||
args = [self._get_arg_representation(arg) for arg in node.args]
|
||||
kwargs = {
|
||||
kw.arg: self._get_arg_representation(kw.value)
|
||||
for kw in node.keywords if kw.arg
|
||||
}
|
||||
|
||||
# Resolve full class name using import map
|
||||
full_class_name = self._resolve_full_name(class_name)
|
||||
|
||||
# Use a synthetic variable name since this isn't assigned to a variable
|
||||
var_name = f"<{class_name.lower()}_instance>"
|
||||
|
||||
instantiation = ClassInstantiation(
|
||||
variable_name=var_name,
|
||||
class_name=class_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
line_number=line_num,
|
||||
full_class_name=full_class_name
|
||||
)
|
||||
|
||||
result.class_instantiations.append(instantiation)
|
||||
|
||||
def _track_method_result_assignment(self, call_node: ast.Call, var_name: str):
|
||||
"""Track when a variable is assigned the result of a method call"""
|
||||
if isinstance(call_node.func, ast.Attribute):
|
||||
# For now, we'll use a generic type hint for method results
|
||||
# In a more sophisticated system, we could look up the return type
|
||||
self.variable_types[var_name] = "method_result"
|
||||
|
||||
def _handle_async_with(self, node: ast.AsyncWith, result: AnalysisResult):
|
||||
"""Handle async with statements and track context manager variables"""
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
var_name = item.optional_vars.id
|
||||
|
||||
# If the context manager is a method call, track the result type
|
||||
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
|
||||
# Extract and process the method call
|
||||
self._extract_method_call(item.context_expr, result)
|
||||
self.processed_calls.add(id(item.context_expr))
|
||||
|
||||
# Track context manager scope for pydantic_ai run_stream calls
|
||||
obj_name = self._get_name_from_node(item.context_expr.func.value)
|
||||
method_name = item.context_expr.func.attr
|
||||
|
||||
if (obj_name and obj_name in self.variable_types and
|
||||
'pydantic_ai' in str(self.variable_types[obj_name]) and
|
||||
method_name == 'run_stream'):
|
||||
|
||||
# Calculate the scope of this async with block
|
||||
start_line = getattr(node, 'lineno', 0)
|
||||
end_line = getattr(node, 'end_lineno', start_line + 50) # fallback estimate
|
||||
|
||||
# For run_stream, the return type is specifically StreamedRunResult
|
||||
# This is the actual return type, not a generic placeholder
|
||||
self.context_manager_vars[var_name] = (start_line, end_line, "pydantic_ai.StreamedRunResult")
|
||||
|
||||
def _handle_with(self, node: ast.With, result: AnalysisResult):
|
||||
"""Handle regular with statements and track context manager variables"""
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
var_name = item.optional_vars.id
|
||||
|
||||
# If the context manager is a method call, track the result type
|
||||
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
|
||||
# Extract and process the method call
|
||||
self._extract_method_call(item.context_expr, result)
|
||||
self.processed_calls.add(id(item.context_expr))
|
||||
|
||||
# Track basic type information
|
||||
self.variable_types[var_name] = "context_manager_result"
|
||||
|
||||
def _resolve_full_name(self, name: str) -> Optional[str]:
|
||||
"""Resolve a name to its full module.name using import map"""
|
||||
# Check if it's a direct import mapping
|
||||
if name in self.import_map:
|
||||
return self.import_map[name]
|
||||
|
||||
# Check if it's a dotted name with first part in import map
|
||||
parts = name.split('.')
|
||||
if len(parts) > 1 and parts[0] in self.import_map:
|
||||
base_module = self.import_map[parts[0]]
|
||||
return f"{base_module}.{'.'.join(parts[1:])}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def analyze_ai_script(script_path: str) -> AnalysisResult:
|
||||
"""Convenience function to analyze a single AI-generated script"""
|
||||
analyzer = AIScriptAnalyzer()
|
||||
return analyzer.analyze_script(script_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
import sys
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python ai_script_analyzer.py <script_path>")
|
||||
sys.exit(1)
|
||||
|
||||
script_path = sys.argv[1]
|
||||
result = analyze_ai_script(script_path)
|
||||
|
||||
print(f"Analysis Results for: {result.file_path}")
|
||||
print(f"Imports: {len(result.imports)}")
|
||||
print(f"Class Instantiations: {len(result.class_instantiations)}")
|
||||
print(f"Method Calls: {len(result.method_calls)}")
|
||||
print(f"Function Calls: {len(result.function_calls)}")
|
||||
print(f"Attribute Accesses: {len(result.attribute_accesses)}")
|
||||
|
||||
if result.errors:
|
||||
print(f"Errors: {result.errors}")
|
||||
@@ -0,0 +1,523 @@
|
||||
"""
|
||||
Hallucination Reporter
|
||||
|
||||
Generates comprehensive reports about AI coding assistant hallucinations
|
||||
detected in Python scripts. Supports multiple output formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from knowledge_graph_validator import (
|
||||
ScriptValidationResult, ValidationStatus, ValidationResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HallucinationReporter:
|
||||
"""Generates reports about detected hallucinations"""
|
||||
|
||||
def __init__(self):
|
||||
self.report_timestamp = datetime.now(timezone.utc)
|
||||
|
||||
def generate_comprehensive_report(self, validation_result: ScriptValidationResult) -> Dict[str, Any]:
|
||||
"""Generate a comprehensive report in JSON format"""
|
||||
|
||||
# Categorize validations by status (knowledge graph items only)
|
||||
valid_items = []
|
||||
invalid_items = []
|
||||
uncertain_items = []
|
||||
not_found_items = []
|
||||
|
||||
# Process imports (only knowledge graph ones)
|
||||
for val in validation_result.import_validations:
|
||||
if not val.validation.details.get('in_knowledge_graph', False):
|
||||
continue # Skip external libraries
|
||||
item = {
|
||||
'type': 'IMPORT',
|
||||
'name': val.import_info.module,
|
||||
'line': val.import_info.line_number,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence,
|
||||
'message': val.validation.message,
|
||||
'details': {
|
||||
'is_from_import': val.import_info.is_from_import,
|
||||
'alias': val.import_info.alias,
|
||||
'available_classes': val.available_classes,
|
||||
'available_functions': val.available_functions
|
||||
}
|
||||
}
|
||||
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
|
||||
|
||||
# Process classes (only knowledge graph ones)
|
||||
for val in validation_result.class_validations:
|
||||
class_name = val.class_instantiation.full_class_name or val.class_instantiation.class_name
|
||||
if not self._is_from_knowledge_graph(class_name, validation_result):
|
||||
continue # Skip external classes
|
||||
item = {
|
||||
'type': 'CLASS_INSTANTIATION',
|
||||
'name': val.class_instantiation.class_name,
|
||||
'full_name': val.class_instantiation.full_class_name,
|
||||
'variable': val.class_instantiation.variable_name,
|
||||
'line': val.class_instantiation.line_number,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence,
|
||||
'message': val.validation.message,
|
||||
'details': {
|
||||
'args_provided': val.class_instantiation.args,
|
||||
'kwargs_provided': list(val.class_instantiation.kwargs.keys()),
|
||||
'constructor_params': val.constructor_params,
|
||||
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None
|
||||
}
|
||||
}
|
||||
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
|
||||
|
||||
# Track reported items to avoid duplicates
|
||||
reported_items = set()
|
||||
|
||||
# Process methods (only knowledge graph ones)
|
||||
for val in validation_result.method_validations:
|
||||
if not (val.method_call.object_type and self._is_from_knowledge_graph(val.method_call.object_type, validation_result)):
|
||||
continue # Skip external methods
|
||||
|
||||
# Create unique key to avoid duplicates
|
||||
key = (val.method_call.line_number, val.method_call.method_name, val.method_call.object_type)
|
||||
if key not in reported_items:
|
||||
reported_items.add(key)
|
||||
item = {
|
||||
'type': 'METHOD_CALL',
|
||||
'name': val.method_call.method_name,
|
||||
'object': val.method_call.object_name,
|
||||
'object_type': val.method_call.object_type,
|
||||
'line': val.method_call.line_number,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence,
|
||||
'message': val.validation.message,
|
||||
'details': {
|
||||
'args_provided': val.method_call.args,
|
||||
'kwargs_provided': list(val.method_call.kwargs.keys()),
|
||||
'expected_params': val.expected_params,
|
||||
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None,
|
||||
'suggestions': val.validation.suggestions
|
||||
}
|
||||
}
|
||||
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
|
||||
|
||||
# Process attributes (only knowledge graph ones) - but skip if already reported as method
|
||||
for val in validation_result.attribute_validations:
|
||||
if not (val.attribute_access.object_type and self._is_from_knowledge_graph(val.attribute_access.object_type, validation_result)):
|
||||
continue # Skip external attributes
|
||||
|
||||
# Create unique key - if this was already reported as a method, skip it
|
||||
key = (val.attribute_access.line_number, val.attribute_access.attribute_name, val.attribute_access.object_type)
|
||||
if key not in reported_items:
|
||||
reported_items.add(key)
|
||||
item = {
|
||||
'type': 'ATTRIBUTE_ACCESS',
|
||||
'name': val.attribute_access.attribute_name,
|
||||
'object': val.attribute_access.object_name,
|
||||
'object_type': val.attribute_access.object_type,
|
||||
'line': val.attribute_access.line_number,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence,
|
||||
'message': val.validation.message,
|
||||
'details': {
|
||||
'expected_type': val.expected_type
|
||||
}
|
||||
}
|
||||
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
|
||||
|
||||
# Process functions (only knowledge graph ones)
|
||||
for val in validation_result.function_validations:
|
||||
if not (val.function_call.full_name and self._is_from_knowledge_graph(val.function_call.full_name, validation_result)):
|
||||
continue # Skip external functions
|
||||
item = {
|
||||
'type': 'FUNCTION_CALL',
|
||||
'name': val.function_call.function_name,
|
||||
'full_name': val.function_call.full_name,
|
||||
'line': val.function_call.line_number,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence,
|
||||
'message': val.validation.message,
|
||||
'details': {
|
||||
'args_provided': val.function_call.args,
|
||||
'kwargs_provided': list(val.function_call.kwargs.keys()),
|
||||
'expected_params': val.expected_params,
|
||||
'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None
|
||||
}
|
||||
}
|
||||
self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items)
|
||||
|
||||
# Create library summary
|
||||
library_summary = self._create_library_summary(validation_result)
|
||||
|
||||
# Generate report
|
||||
report = {
|
||||
'analysis_metadata': {
|
||||
'script_path': validation_result.script_path,
|
||||
'analysis_timestamp': self.report_timestamp.isoformat(),
|
||||
'total_imports': len(validation_result.import_validations),
|
||||
'total_classes': len(validation_result.class_validations),
|
||||
'total_methods': len(validation_result.method_validations),
|
||||
'total_attributes': len(validation_result.attribute_validations),
|
||||
'total_functions': len(validation_result.function_validations)
|
||||
},
|
||||
'validation_summary': {
|
||||
'overall_confidence': validation_result.overall_confidence,
|
||||
'total_validations': len(valid_items) + len(invalid_items) + len(uncertain_items) + len(not_found_items),
|
||||
'valid_count': len(valid_items),
|
||||
'invalid_count': len(invalid_items),
|
||||
'uncertain_count': len(uncertain_items),
|
||||
'not_found_count': len(not_found_items),
|
||||
'hallucination_rate': len(invalid_items + not_found_items) / max(1, len(valid_items) + len(invalid_items) + len(not_found_items))
|
||||
},
|
||||
'libraries_analyzed': library_summary,
|
||||
'validation_details': {
|
||||
'valid_items': valid_items,
|
||||
'invalid_items': invalid_items,
|
||||
'uncertain_items': uncertain_items,
|
||||
'not_found_items': not_found_items
|
||||
},
|
||||
'hallucinations_detected': validation_result.hallucinations_detected,
|
||||
'recommendations': self._generate_recommendations(validation_result)
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
def _is_from_knowledge_graph(self, item_name: str, validation_result) -> bool:
|
||||
"""Check if an item is from a knowledge graph module"""
|
||||
if not item_name:
|
||||
return False
|
||||
|
||||
# Get knowledge graph modules from import validations
|
||||
kg_modules = set()
|
||||
for val in validation_result.import_validations:
|
||||
if val.validation.details.get('in_knowledge_graph', False):
|
||||
kg_modules.add(val.import_info.module)
|
||||
if '.' in val.import_info.module:
|
||||
kg_modules.add(val.import_info.module.split('.')[0])
|
||||
|
||||
# Check if the item belongs to any knowledge graph module
|
||||
if '.' in item_name:
|
||||
base_module = item_name.split('.')[0]
|
||||
return base_module in kg_modules
|
||||
|
||||
return any(item_name in module or module.endswith(item_name) for module in kg_modules)
|
||||
|
||||
def _serialize_validation_result(self, validation_result) -> Dict[str, Any]:
|
||||
"""Convert ValidationResult to JSON-serializable dictionary"""
|
||||
if validation_result is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
'status': validation_result.status.value,
|
||||
'confidence': validation_result.confidence,
|
||||
'message': validation_result.message,
|
||||
'details': validation_result.details,
|
||||
'suggestions': validation_result.suggestions
|
||||
}
|
||||
|
||||
def _categorize_item(self, item: Dict[str, Any], status: ValidationStatus,
|
||||
valid_items: List, invalid_items: List, uncertain_items: List, not_found_items: List):
|
||||
"""Categorize validation item by status"""
|
||||
if status == ValidationStatus.VALID:
|
||||
valid_items.append(item)
|
||||
elif status == ValidationStatus.INVALID:
|
||||
invalid_items.append(item)
|
||||
elif status == ValidationStatus.UNCERTAIN:
|
||||
uncertain_items.append(item)
|
||||
elif status == ValidationStatus.NOT_FOUND:
|
||||
not_found_items.append(item)
|
||||
|
||||
def _create_library_summary(self, validation_result: ScriptValidationResult) -> List[Dict[str, Any]]:
|
||||
"""Create summary of libraries analyzed"""
|
||||
library_stats = {}
|
||||
|
||||
# Aggregate stats by library/module
|
||||
for val in validation_result.import_validations:
|
||||
module = val.import_info.module
|
||||
if module not in library_stats:
|
||||
library_stats[module] = {
|
||||
'module_name': module,
|
||||
'import_status': val.validation.status.value,
|
||||
'import_confidence': val.validation.confidence,
|
||||
'classes_used': [],
|
||||
'methods_called': [],
|
||||
'attributes_accessed': [],
|
||||
'functions_called': []
|
||||
}
|
||||
|
||||
# Add class usage
|
||||
for val in validation_result.class_validations:
|
||||
class_name = val.class_instantiation.class_name
|
||||
full_name = val.class_instantiation.full_class_name
|
||||
|
||||
# Try to match to library
|
||||
if full_name:
|
||||
parts = full_name.split('.')
|
||||
if len(parts) > 1:
|
||||
module = '.'.join(parts[:-1])
|
||||
if module in library_stats:
|
||||
library_stats[module]['classes_used'].append({
|
||||
'class_name': class_name,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence
|
||||
})
|
||||
|
||||
# Add method usage
|
||||
for val in validation_result.method_validations:
|
||||
method_name = val.method_call.method_name
|
||||
object_type = val.method_call.object_type
|
||||
|
||||
if object_type:
|
||||
parts = object_type.split('.')
|
||||
if len(parts) > 1:
|
||||
module = '.'.join(parts[:-1])
|
||||
if module in library_stats:
|
||||
library_stats[module]['methods_called'].append({
|
||||
'method_name': method_name,
|
||||
'class_name': parts[-1],
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence
|
||||
})
|
||||
|
||||
# Add attribute usage
|
||||
for val in validation_result.attribute_validations:
|
||||
attr_name = val.attribute_access.attribute_name
|
||||
object_type = val.attribute_access.object_type
|
||||
|
||||
if object_type:
|
||||
parts = object_type.split('.')
|
||||
if len(parts) > 1:
|
||||
module = '.'.join(parts[:-1])
|
||||
if module in library_stats:
|
||||
library_stats[module]['attributes_accessed'].append({
|
||||
'attribute_name': attr_name,
|
||||
'class_name': parts[-1],
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence
|
||||
})
|
||||
|
||||
# Add function usage
|
||||
for val in validation_result.function_validations:
|
||||
func_name = val.function_call.function_name
|
||||
full_name = val.function_call.full_name
|
||||
|
||||
if full_name:
|
||||
parts = full_name.split('.')
|
||||
if len(parts) > 1:
|
||||
module = '.'.join(parts[:-1])
|
||||
if module in library_stats:
|
||||
library_stats[module]['functions_called'].append({
|
||||
'function_name': func_name,
|
||||
'status': val.validation.status.value,
|
||||
'confidence': val.validation.confidence
|
||||
})
|
||||
|
||||
return list(library_stats.values())
|
||||
|
||||
def _generate_recommendations(self, validation_result: ScriptValidationResult) -> List[str]:
|
||||
"""Generate recommendations based on validation results"""
|
||||
recommendations = []
|
||||
|
||||
# Only count actual hallucinations (from knowledge graph libraries)
|
||||
kg_hallucinations = [h for h in validation_result.hallucinations_detected]
|
||||
|
||||
if kg_hallucinations:
|
||||
method_issues = [h for h in kg_hallucinations if h['type'] == 'METHOD_NOT_FOUND']
|
||||
attr_issues = [h for h in kg_hallucinations if h['type'] == 'ATTRIBUTE_NOT_FOUND']
|
||||
param_issues = [h for h in kg_hallucinations if h['type'] == 'INVALID_PARAMETERS']
|
||||
|
||||
if method_issues:
|
||||
recommendations.append(
|
||||
f"Found {len(method_issues)} non-existent methods in knowledge graph libraries. "
|
||||
"Consider checking the official documentation for correct method names."
|
||||
)
|
||||
|
||||
if attr_issues:
|
||||
recommendations.append(
|
||||
f"Found {len(attr_issues)} non-existent attributes in knowledge graph libraries. "
|
||||
"Verify attribute names against the class documentation."
|
||||
)
|
||||
|
||||
if param_issues:
|
||||
recommendations.append(
|
||||
f"Found {len(param_issues)} parameter mismatches in knowledge graph libraries. "
|
||||
"Check function signatures for correct parameter names and types."
|
||||
)
|
||||
else:
|
||||
recommendations.append(
|
||||
"No hallucinations detected in knowledge graph libraries. "
|
||||
"External library usage appears to be working as expected."
|
||||
)
|
||||
|
||||
if validation_result.overall_confidence < 0.7:
|
||||
recommendations.append(
|
||||
"Overall confidence is moderate. Most validations were for external libraries not in the knowledge graph."
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
def save_json_report(self, report: Dict[str, Any], output_path: str):
|
||||
"""Save report as JSON file"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"JSON report saved to: {output_path}")
|
||||
|
||||
def save_markdown_report(self, report: Dict[str, Any], output_path: str):
|
||||
"""Save report as Markdown file"""
|
||||
md_content = self._generate_markdown_content(report)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(md_content)
|
||||
|
||||
logger.info(f"Markdown report saved to: {output_path}")
|
||||
|
||||
def _generate_markdown_content(self, report: Dict[str, Any]) -> str:
|
||||
"""Generate Markdown content from report"""
|
||||
md = []
|
||||
|
||||
# Header
|
||||
md.append("# AI Hallucination Detection Report")
|
||||
md.append("")
|
||||
md.append(f"**Script:** `{report['analysis_metadata']['script_path']}`")
|
||||
md.append(f"**Analysis Date:** {report['analysis_metadata']['analysis_timestamp']}")
|
||||
md.append(f"**Overall Confidence:** {report['validation_summary']['overall_confidence']:.2%}")
|
||||
md.append("")
|
||||
|
||||
# Summary
|
||||
summary = report['validation_summary']
|
||||
md.append("## Summary")
|
||||
md.append("")
|
||||
md.append(f"- **Total Validations:** {summary['total_validations']}")
|
||||
md.append(f"- **Valid:** {summary['valid_count']} ({summary['valid_count']/summary['total_validations']:.1%})")
|
||||
md.append(f"- **Invalid:** {summary['invalid_count']} ({summary['invalid_count']/summary['total_validations']:.1%})")
|
||||
md.append(f"- **Not Found:** {summary['not_found_count']} ({summary['not_found_count']/summary['total_validations']:.1%})")
|
||||
md.append(f"- **Uncertain:** {summary['uncertain_count']} ({summary['uncertain_count']/summary['total_validations']:.1%})")
|
||||
md.append(f"- **Hallucination Rate:** {summary['hallucination_rate']:.1%}")
|
||||
md.append("")
|
||||
|
||||
# Hallucinations
|
||||
if report['hallucinations_detected']:
|
||||
md.append("## 🚨 Hallucinations Detected")
|
||||
md.append("")
|
||||
for i, hallucination in enumerate(report['hallucinations_detected'], 1):
|
||||
md.append(f"### {i}. {hallucination['type'].replace('_', ' ').title()}")
|
||||
md.append(f"**Location:** {hallucination['location']}")
|
||||
md.append(f"**Description:** {hallucination['description']}")
|
||||
if hallucination.get('suggestion'):
|
||||
md.append(f"**Suggestion:** {hallucination['suggestion']}")
|
||||
md.append("")
|
||||
|
||||
# Libraries
|
||||
if report['libraries_analyzed']:
|
||||
md.append("## 📚 Libraries Analyzed")
|
||||
md.append("")
|
||||
for lib in report['libraries_analyzed']:
|
||||
md.append(f"### {lib['module_name']}")
|
||||
md.append(f"**Import Status:** {lib['import_status']}")
|
||||
md.append(f"**Import Confidence:** {lib['import_confidence']:.2%}")
|
||||
|
||||
if lib['classes_used']:
|
||||
md.append("**Classes Used:**")
|
||||
for cls in lib['classes_used']:
|
||||
status_emoji = "✅" if cls['status'] == 'VALID' else "❌"
|
||||
md.append(f" - {status_emoji} `{cls['class_name']}` ({cls['confidence']:.1%})")
|
||||
|
||||
if lib['methods_called']:
|
||||
md.append("**Methods Called:**")
|
||||
for method in lib['methods_called']:
|
||||
status_emoji = "✅" if method['status'] == 'VALID' else "❌"
|
||||
md.append(f" - {status_emoji} `{method['class_name']}.{method['method_name']}()` ({method['confidence']:.1%})")
|
||||
|
||||
if lib['attributes_accessed']:
|
||||
md.append("**Attributes Accessed:**")
|
||||
for attr in lib['attributes_accessed']:
|
||||
status_emoji = "✅" if attr['status'] == 'VALID' else "❌"
|
||||
md.append(f" - {status_emoji} `{attr['class_name']}.{attr['attribute_name']}` ({attr['confidence']:.1%})")
|
||||
|
||||
if lib['functions_called']:
|
||||
md.append("**Functions Called:**")
|
||||
for func in lib['functions_called']:
|
||||
status_emoji = "✅" if func['status'] == 'VALID' else "❌"
|
||||
md.append(f" - {status_emoji} `{func['function_name']}()` ({func['confidence']:.1%})")
|
||||
|
||||
md.append("")
|
||||
|
||||
# Recommendations
|
||||
if report['recommendations']:
|
||||
md.append("## 💡 Recommendations")
|
||||
md.append("")
|
||||
for rec in report['recommendations']:
|
||||
md.append(f"- {rec}")
|
||||
md.append("")
|
||||
|
||||
# Detailed Results
|
||||
md.append("## 📋 Detailed Validation Results")
|
||||
md.append("")
|
||||
|
||||
# Invalid items
|
||||
invalid_items = report['validation_details']['invalid_items']
|
||||
if invalid_items:
|
||||
md.append("### ❌ Invalid Items")
|
||||
md.append("")
|
||||
for item in invalid_items:
|
||||
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
|
||||
md.append("")
|
||||
|
||||
# Not found items
|
||||
not_found_items = report['validation_details']['not_found_items']
|
||||
if not_found_items:
|
||||
md.append("### 🔍 Not Found Items")
|
||||
md.append("")
|
||||
for item in not_found_items:
|
||||
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
|
||||
md.append("")
|
||||
|
||||
# Valid items (sample)
|
||||
valid_items = report['validation_details']['valid_items']
|
||||
if valid_items:
|
||||
md.append("### ✅ Valid Items (Sample)")
|
||||
md.append("")
|
||||
for item in valid_items[:10]: # Show first 10
|
||||
md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}")
|
||||
if len(valid_items) > 10:
|
||||
md.append(f"- ... and {len(valid_items) - 10} more valid items")
|
||||
md.append("")
|
||||
|
||||
return "\n".join(md)
|
||||
|
||||
def print_summary(self, report: Dict[str, Any]):
|
||||
"""Print a concise summary to console"""
|
||||
print("\n" + "="*80)
|
||||
print("🤖 AI HALLUCINATION DETECTION REPORT")
|
||||
print("="*80)
|
||||
|
||||
print(f"Script: {report['analysis_metadata']['script_path']}")
|
||||
print(f"Overall Confidence: {report['validation_summary']['overall_confidence']:.1%}")
|
||||
|
||||
summary = report['validation_summary']
|
||||
print(f"\nValidation Results:")
|
||||
print(f" ✅ Valid: {summary['valid_count']}")
|
||||
print(f" ❌ Invalid: {summary['invalid_count']}")
|
||||
print(f" 🔍 Not Found: {summary['not_found_count']}")
|
||||
print(f" ❓ Uncertain: {summary['uncertain_count']}")
|
||||
print(f" 📊 Hallucination Rate: {summary['hallucination_rate']:.1%}")
|
||||
|
||||
if report['hallucinations_detected']:
|
||||
print(f"\n🚨 {len(report['hallucinations_detected'])} Hallucinations Detected:")
|
||||
for hall in report['hallucinations_detected'][:5]: # Show first 5
|
||||
print(f" - {hall['type'].replace('_', ' ').title()} at {hall['location']}")
|
||||
print(f" {hall['description']}")
|
||||
|
||||
if report['recommendations']:
|
||||
print(f"\n💡 Recommendations:")
|
||||
for rec in report['recommendations'][:3]: # Show first 3
|
||||
print(f" - {rec}")
|
||||
|
||||
print("="*80)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,858 @@
|
||||
"""
|
||||
Direct Neo4j GitHub Code Repository Extractor
|
||||
|
||||
Creates nodes and relationships directly in Neo4j without Graphiti:
|
||||
- File nodes
|
||||
- Class nodes
|
||||
- Method nodes
|
||||
- Function nodes
|
||||
- Import relationships
|
||||
|
||||
Bypasses all LLM processing for maximum speed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
import ast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jCodeAnalyzer:
|
||||
"""Analyzes code for direct Neo4j insertion"""
|
||||
|
||||
def __init__(self):
|
||||
# External modules to ignore
|
||||
self.external_modules = {
|
||||
# Python standard library
|
||||
'os', 'sys', 'json', 'logging', 'datetime', 'pathlib', 'typing', 'collections',
|
||||
'asyncio', 'subprocess', 'ast', 're', 'string', 'urllib', 'http', 'email',
|
||||
'time', 'uuid', 'hashlib', 'base64', 'itertools', 'functools', 'operator',
|
||||
'contextlib', 'copy', 'pickle', 'tempfile', 'shutil', 'glob', 'fnmatch',
|
||||
'io', 'codecs', 'locale', 'platform', 'socket', 'ssl', 'threading', 'queue',
|
||||
'multiprocessing', 'concurrent', 'warnings', 'traceback', 'inspect',
|
||||
'importlib', 'pkgutil', 'types', 'weakref', 'gc', 'dataclasses', 'enum',
|
||||
'abc', 'numbers', 'decimal', 'fractions', 'math', 'cmath', 'random', 'statistics',
|
||||
|
||||
# Common third-party libraries
|
||||
'requests', 'urllib3', 'httpx', 'aiohttp', 'flask', 'django', 'fastapi',
|
||||
'pydantic', 'sqlalchemy', 'alembic', 'psycopg2', 'pymongo', 'redis',
|
||||
'celery', 'pytest', 'unittest', 'mock', 'faker', 'factory', 'hypothesis',
|
||||
'numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'torch',
|
||||
'tensorflow', 'keras', 'opencv', 'pillow', 'boto3', 'botocore', 'azure',
|
||||
'google', 'openai', 'anthropic', 'langchain', 'transformers', 'huggingface_hub',
|
||||
'click', 'typer', 'rich', 'colorama', 'tqdm', 'python-dotenv', 'pyyaml',
|
||||
'toml', 'configargparse', 'marshmallow', 'attrs', 'dataclasses-json',
|
||||
'jsonschema', 'cerberus', 'voluptuous', 'schema', 'jinja2', 'mako',
|
||||
'cryptography', 'bcrypt', 'passlib', 'jwt', 'authlib', 'oauthlib'
|
||||
}
|
||||
|
||||
def analyze_python_file(self, file_path: Path, repo_root: Path, project_modules: Set[str]) -> Dict[str, Any]:
|
||||
"""Extract structure for direct Neo4j insertion"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
relative_path = str(file_path.relative_to(repo_root))
|
||||
module_name = self._get_importable_module_name(file_path, repo_root, relative_path)
|
||||
|
||||
# Extract structure
|
||||
classes = []
|
||||
functions = []
|
||||
imports = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Extract class with its methods and attributes
|
||||
methods = []
|
||||
attributes = []
|
||||
|
||||
for item in node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if not item.name.startswith('_'): # Public methods only
|
||||
# Extract comprehensive parameter info
|
||||
params = self._extract_function_parameters(item)
|
||||
|
||||
# Get return type annotation
|
||||
return_type = self._get_name(item.returns) if item.returns else 'Any'
|
||||
|
||||
# Create detailed parameter list for Neo4j storage
|
||||
params_detailed = []
|
||||
for p in params:
|
||||
param_str = f"{p['name']}:{p['type']}"
|
||||
if p['optional'] and p['default'] is not None:
|
||||
param_str += f"={p['default']}"
|
||||
elif p['optional']:
|
||||
param_str += "=None"
|
||||
if p['kind'] != 'positional':
|
||||
param_str = f"[{p['kind']}] {param_str}"
|
||||
params_detailed.append(param_str)
|
||||
|
||||
methods.append({
|
||||
'name': item.name,
|
||||
'params': params, # Full parameter objects
|
||||
'params_detailed': params_detailed, # Detailed string format
|
||||
'return_type': return_type,
|
||||
'args': [arg.arg for arg in item.args.args if arg.arg != 'self'] # Keep for backwards compatibility
|
||||
})
|
||||
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
|
||||
# Type annotated attributes
|
||||
if not item.target.id.startswith('_'):
|
||||
attributes.append({
|
||||
'name': item.target.id,
|
||||
'type': self._get_name(item.annotation) if item.annotation else 'Any'
|
||||
})
|
||||
|
||||
classes.append({
|
||||
'name': node.name,
|
||||
'full_name': f"{module_name}.{node.name}",
|
||||
'methods': methods,
|
||||
'attributes': attributes
|
||||
})
|
||||
|
||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Only top-level functions
|
||||
if not any(node in cls_node.body for cls_node in ast.walk(tree) if isinstance(cls_node, ast.ClassDef)):
|
||||
if not node.name.startswith('_'):
|
||||
# Extract comprehensive parameter info
|
||||
params = self._extract_function_parameters(node)
|
||||
|
||||
# Get return type annotation
|
||||
return_type = self._get_name(node.returns) if node.returns else 'Any'
|
||||
|
||||
# Create detailed parameter list for Neo4j storage
|
||||
params_detailed = []
|
||||
for p in params:
|
||||
param_str = f"{p['name']}:{p['type']}"
|
||||
if p['optional'] and p['default'] is not None:
|
||||
param_str += f"={p['default']}"
|
||||
elif p['optional']:
|
||||
param_str += "=None"
|
||||
if p['kind'] != 'positional':
|
||||
param_str = f"[{p['kind']}] {param_str}"
|
||||
params_detailed.append(param_str)
|
||||
|
||||
# Simple format for backwards compatibility
|
||||
params_list = [f"{p['name']}:{p['type']}" for p in params]
|
||||
|
||||
functions.append({
|
||||
'name': node.name,
|
||||
'full_name': f"{module_name}.{node.name}",
|
||||
'params': params, # Full parameter objects
|
||||
'params_detailed': params_detailed, # Detailed string format
|
||||
'params_list': params_list, # Simple string format for backwards compatibility
|
||||
'return_type': return_type,
|
||||
'args': [arg.arg for arg in node.args.args] # Keep for backwards compatibility
|
||||
})
|
||||
|
||||
elif isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
# Track internal imports only
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if self._is_likely_internal(alias.name, project_modules):
|
||||
imports.append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
if (node.module.startswith('.') or self._is_likely_internal(node.module, project_modules)):
|
||||
imports.append(node.module)
|
||||
|
||||
return {
|
||||
'module_name': module_name,
|
||||
'file_path': relative_path,
|
||||
'classes': classes,
|
||||
'functions': functions,
|
||||
'imports': list(set(imports)), # Remove duplicates
|
||||
'line_count': len(content.splitlines())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not analyze {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def _is_likely_internal(self, import_name: str, project_modules: Set[str]) -> bool:
|
||||
"""Check if an import is likely internal to the project"""
|
||||
if not import_name:
|
||||
return False
|
||||
|
||||
# Relative imports are definitely internal
|
||||
if import_name.startswith('.'):
|
||||
return True
|
||||
|
||||
# Check if it's a known external module
|
||||
base_module = import_name.split('.')[0]
|
||||
if base_module in self.external_modules:
|
||||
return False
|
||||
|
||||
# Check if it matches any project module
|
||||
for project_module in project_modules:
|
||||
if import_name.startswith(project_module):
|
||||
return True
|
||||
|
||||
# If it's not obviously external, consider it internal
|
||||
if (not any(ext in base_module.lower() for ext in ['test', 'mock', 'fake']) and
|
||||
not base_module.startswith('_') and
|
||||
len(base_module) > 2):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_importable_module_name(self, file_path: Path, repo_root: Path, relative_path: str) -> str:
|
||||
"""Determine the actual importable module name for a Python file"""
|
||||
# Start with the default: convert file path to module path
|
||||
default_module = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
|
||||
|
||||
# Common patterns to detect the actual package root
|
||||
path_parts = Path(relative_path).parts
|
||||
|
||||
# Look for common package indicators
|
||||
package_roots = []
|
||||
|
||||
# Check each directory level for __init__.py to find package boundaries
|
||||
current_path = repo_root
|
||||
for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself
|
||||
current_path = current_path / part
|
||||
if (current_path / '__init__.py').exists():
|
||||
# This is a package directory, mark it as a potential root
|
||||
package_roots.append(i)
|
||||
|
||||
if package_roots:
|
||||
# Use the first (outermost) package as the root
|
||||
package_start = package_roots[0]
|
||||
module_parts = path_parts[package_start:]
|
||||
module_name = '.'.join(module_parts).replace('.py', '')
|
||||
return module_name
|
||||
|
||||
# Fallback: look for common Python project structures
|
||||
# Skip common non-package directories
|
||||
skip_dirs = {'src', 'lib', 'source', 'python', 'pkg', 'packages'}
|
||||
|
||||
# Find the first directory that's not in skip_dirs
|
||||
filtered_parts = []
|
||||
for part in path_parts:
|
||||
if part.lower() not in skip_dirs or filtered_parts: # Once we start including, include everything
|
||||
filtered_parts.append(part)
|
||||
|
||||
if filtered_parts:
|
||||
module_name = '.'.join(filtered_parts).replace('.py', '')
|
||||
return module_name
|
||||
|
||||
# Final fallback: use the default
|
||||
return default_module
|
||||
|
||||
def _extract_function_parameters(self, func_node):
|
||||
"""Comprehensive parameter extraction from function definition"""
|
||||
params = []
|
||||
|
||||
# Regular positional arguments
|
||||
for i, arg in enumerate(func_node.args.args):
|
||||
if arg.arg == 'self':
|
||||
continue
|
||||
|
||||
param_info = {
|
||||
'name': arg.arg,
|
||||
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
|
||||
'kind': 'positional',
|
||||
'optional': False,
|
||||
'default': None
|
||||
}
|
||||
|
||||
# Check if this argument has a default value
|
||||
defaults_start = len(func_node.args.args) - len(func_node.args.defaults)
|
||||
if i >= defaults_start:
|
||||
default_idx = i - defaults_start
|
||||
if default_idx < len(func_node.args.defaults):
|
||||
param_info['optional'] = True
|
||||
param_info['default'] = self._get_default_value(func_node.args.defaults[default_idx])
|
||||
|
||||
params.append(param_info)
|
||||
|
||||
# *args parameter
|
||||
if func_node.args.vararg:
|
||||
params.append({
|
||||
'name': f"*{func_node.args.vararg.arg}",
|
||||
'type': self._get_name(func_node.args.vararg.annotation) if func_node.args.vararg.annotation else 'Any',
|
||||
'kind': 'var_positional',
|
||||
'optional': True,
|
||||
'default': None
|
||||
})
|
||||
|
||||
# Keyword-only arguments (after *)
|
||||
for i, arg in enumerate(func_node.args.kwonlyargs):
|
||||
param_info = {
|
||||
'name': arg.arg,
|
||||
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
|
||||
'kind': 'keyword_only',
|
||||
'optional': True, # All kwonly args are optional unless explicitly required
|
||||
'default': None
|
||||
}
|
||||
|
||||
# Check for default value
|
||||
if i < len(func_node.args.kw_defaults) and func_node.args.kw_defaults[i] is not None:
|
||||
param_info['default'] = self._get_default_value(func_node.args.kw_defaults[i])
|
||||
else:
|
||||
param_info['optional'] = False # No default = required kwonly arg
|
||||
|
||||
params.append(param_info)
|
||||
|
||||
# **kwargs parameter
|
||||
if func_node.args.kwarg:
|
||||
params.append({
|
||||
'name': f"**{func_node.args.kwarg.arg}",
|
||||
'type': self._get_name(func_node.args.kwarg.annotation) if func_node.args.kwarg.annotation else 'Dict[str, Any]',
|
||||
'kind': 'var_keyword',
|
||||
'optional': True,
|
||||
'default': None
|
||||
})
|
||||
|
||||
return params
|
||||
|
||||
def _get_default_value(self, default_node):
|
||||
"""Extract default value from AST node"""
|
||||
try:
|
||||
if isinstance(default_node, ast.Constant):
|
||||
return repr(default_node.value)
|
||||
elif isinstance(default_node, ast.Name):
|
||||
return default_node.id
|
||||
elif isinstance(default_node, ast.Attribute):
|
||||
return self._get_name(default_node)
|
||||
elif isinstance(default_node, ast.List):
|
||||
return "[]"
|
||||
elif isinstance(default_node, ast.Dict):
|
||||
return "{}"
|
||||
else:
|
||||
return "..."
|
||||
except Exception:
|
||||
return "..."
|
||||
|
||||
def _get_name(self, node):
|
||||
"""Extract name from AST node, handling complex types safely"""
|
||||
if node is None:
|
||||
return "Any"
|
||||
|
||||
try:
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
if hasattr(node, 'value'):
|
||||
return f"{self._get_name(node.value)}.{node.attr}"
|
||||
else:
|
||||
return node.attr
|
||||
elif isinstance(node, ast.Subscript):
|
||||
# Handle List[Type], Dict[K,V], etc.
|
||||
base = self._get_name(node.value)
|
||||
if hasattr(node, 'slice'):
|
||||
if isinstance(node.slice, ast.Name):
|
||||
return f"{base}[{node.slice.id}]"
|
||||
elif isinstance(node.slice, ast.Tuple):
|
||||
elts = [self._get_name(elt) for elt in node.slice.elts]
|
||||
return f"{base}[{', '.join(elts)}]"
|
||||
elif isinstance(node.slice, ast.Constant):
|
||||
return f"{base}[{repr(node.slice.value)}]"
|
||||
elif isinstance(node.slice, ast.Attribute):
|
||||
return f"{base}[{self._get_name(node.slice)}]"
|
||||
elif isinstance(node.slice, ast.Subscript):
|
||||
return f"{base}[{self._get_name(node.slice)}]"
|
||||
else:
|
||||
# Try to get the name of the slice, fallback to Any if it fails
|
||||
try:
|
||||
slice_name = self._get_name(node.slice)
|
||||
return f"{base}[{slice_name}]"
|
||||
except:
|
||||
return f"{base}[Any]"
|
||||
return base
|
||||
elif isinstance(node, ast.Constant):
|
||||
return str(node.value)
|
||||
elif isinstance(node, ast.Str): # Python < 3.8
|
||||
return f'"{node.s}"'
|
||||
elif isinstance(node, ast.Tuple):
|
||||
elts = [self._get_name(elt) for elt in node.elts]
|
||||
return f"({', '.join(elts)})"
|
||||
elif isinstance(node, ast.List):
|
||||
elts = [self._get_name(elt) for elt in node.elts]
|
||||
return f"[{', '.join(elts)}]"
|
||||
else:
|
||||
# Fallback for complex types - return a simple string representation
|
||||
return "Any"
|
||||
except Exception:
|
||||
# If anything goes wrong, return a safe default
|
||||
return "Any"
|
||||
|
||||
|
||||
class DirectNeo4jExtractor:
|
||||
"""Creates nodes and relationships directly in Neo4j"""
|
||||
|
||||
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
|
||||
self.neo4j_uri = neo4j_uri
|
||||
self.neo4j_user = neo4j_user
|
||||
self.neo4j_password = neo4j_password
|
||||
self.driver = None
|
||||
self.analyzer = Neo4jCodeAnalyzer()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Neo4j connection"""
|
||||
logger.info("Initializing Neo4j connection...")
|
||||
self.driver = AsyncGraphDatabase.driver(
|
||||
self.neo4j_uri,
|
||||
auth=(self.neo4j_user, self.neo4j_password)
|
||||
)
|
||||
|
||||
# Clear existing data
|
||||
# logger.info("Clearing existing data...")
|
||||
# async with self.driver.session() as session:
|
||||
# await session.run("MATCH (n) DETACH DELETE n")
|
||||
|
||||
# Create constraints and indexes
|
||||
logger.info("Creating constraints and indexes...")
|
||||
async with self.driver.session() as session:
|
||||
# Create constraints - using MERGE-friendly approach
|
||||
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE")
|
||||
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.full_name IS UNIQUE")
|
||||
# Remove unique constraints for methods/attributes since they can be duplicated across classes
|
||||
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.full_name IS UNIQUE")
|
||||
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:Function) REQUIRE f.full_name IS UNIQUE")
|
||||
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.full_name IS UNIQUE")
|
||||
|
||||
# Create indexes for performance
|
||||
await session.run("CREATE INDEX IF NOT EXISTS FOR (f:File) ON (f.name)")
|
||||
await session.run("CREATE INDEX IF NOT EXISTS FOR (c:Class) ON (c.name)")
|
||||
await session.run("CREATE INDEX IF NOT EXISTS FOR (m:Method) ON (m.name)")
|
||||
|
||||
logger.info("Neo4j initialized successfully")
|
||||
|
||||
async def clear_repository_data(self, repo_name: str):
|
||||
"""Clear all data for a specific repository"""
|
||||
logger.info(f"Clearing existing data for repository: {repo_name}")
|
||||
async with self.driver.session() as session:
|
||||
# Delete in specific order to avoid constraint issues
|
||||
|
||||
# 1. Delete methods and attributes (they depend on classes)
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
|
||||
DETACH DELETE m
|
||||
""", repo_name=repo_name)
|
||||
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
|
||||
DETACH DELETE a
|
||||
""", repo_name=repo_name)
|
||||
|
||||
# 2. Delete functions (they depend on files)
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
|
||||
DETACH DELETE func
|
||||
""", repo_name=repo_name)
|
||||
|
||||
# 3. Delete classes (they depend on files)
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
|
||||
DETACH DELETE c
|
||||
""", repo_name=repo_name)
|
||||
|
||||
# 4. Delete files (they depend on repository)
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
|
||||
DETACH DELETE f
|
||||
""", repo_name=repo_name)
|
||||
|
||||
# 5. Finally delete the repository
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})
|
||||
DETACH DELETE r
|
||||
""", repo_name=repo_name)
|
||||
|
||||
logger.info(f"Cleared data for repository: {repo_name}")
|
||||
|
||||
async def close(self):
|
||||
"""Close Neo4j connection"""
|
||||
if self.driver:
|
||||
await self.driver.close()
|
||||
|
||||
def clone_repo(self, repo_url: str, target_dir: str) -> str:
|
||||
"""Clone repository with shallow clone"""
|
||||
logger.info(f"Cloning repository to: {target_dir}")
|
||||
if os.path.exists(target_dir):
|
||||
logger.info(f"Removing existing directory: {target_dir}")
|
||||
try:
|
||||
def handle_remove_readonly(func, path, exc):
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.chmod(path, 0o777)
|
||||
func(path)
|
||||
except PermissionError:
|
||||
logger.warning(f"Could not remove {path} - file in use, skipping")
|
||||
pass
|
||||
shutil.rmtree(target_dir, onerror=handle_remove_readonly)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fully remove {target_dir}: {e}. Proceeding anyway...")
|
||||
|
||||
logger.info(f"Running git clone from {repo_url}")
|
||||
subprocess.run(['git', 'clone', '--depth', '1', repo_url, target_dir], check=True)
|
||||
logger.info("Repository cloned successfully")
|
||||
return target_dir
|
||||
|
||||
def get_python_files(self, repo_path: str) -> List[Path]:
|
||||
"""Get Python files, focusing on main source directories"""
|
||||
python_files = []
|
||||
exclude_dirs = {
|
||||
'tests', 'test', '__pycache__', '.git', 'venv', 'env',
|
||||
'node_modules', 'build', 'dist', '.pytest_cache', 'docs',
|
||||
'examples', 'example', 'demo', 'benchmark'
|
||||
}
|
||||
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')]
|
||||
|
||||
for file in files:
|
||||
if file.endswith('.py') and not file.startswith('test_'):
|
||||
file_path = Path(root) / file
|
||||
if (file_path.stat().st_size < 500_000 and
|
||||
file not in ['setup.py', 'conftest.py']):
|
||||
python_files.append(file_path)
|
||||
|
||||
return python_files
|
||||
|
||||
async def analyze_repository(self, repo_url: str, temp_dir: str = None):
|
||||
"""Analyze repository and create nodes/relationships in Neo4j"""
|
||||
repo_name = repo_url.split('/')[-1].replace('.git', '')
|
||||
logger.info(f"Analyzing repository: {repo_name}")
|
||||
|
||||
# Clear existing data for this repository before re-processing
|
||||
await self.clear_repository_data(repo_name)
|
||||
|
||||
# Set default temp_dir to repos folder at script level
|
||||
if temp_dir is None:
|
||||
script_dir = Path(__file__).parent
|
||||
temp_dir = str(script_dir / "repos" / repo_name)
|
||||
|
||||
# Clone and analyze
|
||||
repo_path = Path(self.clone_repo(repo_url, temp_dir))
|
||||
|
||||
try:
|
||||
logger.info("Getting Python files...")
|
||||
python_files = self.get_python_files(str(repo_path))
|
||||
logger.info(f"Found {len(python_files)} Python files to analyze")
|
||||
|
||||
# First pass: identify project modules
|
||||
logger.info("Identifying project modules...")
|
||||
project_modules = set()
|
||||
for file_path in python_files:
|
||||
relative_path = str(file_path.relative_to(repo_path))
|
||||
module_parts = relative_path.replace('/', '.').replace('.py', '').split('.')
|
||||
if len(module_parts) > 0 and not module_parts[0].startswith('.'):
|
||||
project_modules.add(module_parts[0])
|
||||
|
||||
logger.info(f"Identified project modules: {sorted(project_modules)}")
|
||||
|
||||
# Second pass: analyze files and collect data
|
||||
logger.info("Analyzing Python files...")
|
||||
modules_data = []
|
||||
for i, file_path in enumerate(python_files):
|
||||
if i % 20 == 0:
|
||||
logger.info(f"Analyzing file {i+1}/{len(python_files)}: {file_path.name}")
|
||||
|
||||
analysis = self.analyzer.analyze_python_file(file_path, repo_path, project_modules)
|
||||
if analysis:
|
||||
modules_data.append(analysis)
|
||||
|
||||
logger.info(f"Found {len(modules_data)} files with content")
|
||||
|
||||
# Create nodes and relationships in Neo4j
|
||||
logger.info("Creating nodes and relationships in Neo4j...")
|
||||
await self._create_graph(repo_name, modules_data)
|
||||
|
||||
# Print summary
|
||||
total_classes = sum(len(mod['classes']) for mod in modules_data)
|
||||
total_methods = sum(len(cls['methods']) for mod in modules_data for cls in mod['classes'])
|
||||
total_functions = sum(len(mod['functions']) for mod in modules_data)
|
||||
total_imports = sum(len(mod['imports']) for mod in modules_data)
|
||||
|
||||
print(f"\\n=== Direct Neo4j Repository Analysis for {repo_name} ===")
|
||||
print(f"Files processed: {len(modules_data)}")
|
||||
print(f"Classes created: {total_classes}")
|
||||
print(f"Methods created: {total_methods}")
|
||||
print(f"Functions created: {total_functions}")
|
||||
print(f"Import relationships: {total_imports}")
|
||||
|
||||
logger.info(f"Successfully created Neo4j graph for {repo_name}")
|
||||
|
||||
finally:
|
||||
if os.path.exists(temp_dir):
|
||||
logger.info(f"Cleaning up temporary directory: {temp_dir}")
|
||||
try:
|
||||
def handle_remove_readonly(func, path, exc):
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.chmod(path, 0o777)
|
||||
func(path)
|
||||
except PermissionError:
|
||||
logger.warning(f"Could not remove {path} - file in use, skipping")
|
||||
pass
|
||||
|
||||
shutil.rmtree(temp_dir, onerror=handle_remove_readonly)
|
||||
logger.info("Cleanup completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cleanup failed: {e}. Directory may remain at {temp_dir}")
|
||||
# Don't fail the whole process due to cleanup issues
|
||||
|
||||
async def _create_graph(self, repo_name: str, modules_data: List[Dict]):
|
||||
"""Create all nodes and relationships in Neo4j"""
|
||||
|
||||
async with self.driver.session() as session:
|
||||
# Create Repository node
|
||||
await session.run(
|
||||
"CREATE (r:Repository {name: $repo_name, created_at: datetime()})",
|
||||
repo_name=repo_name
|
||||
)
|
||||
|
||||
nodes_created = 0
|
||||
relationships_created = 0
|
||||
|
||||
for i, mod in enumerate(modules_data):
|
||||
# 1. Create File node
|
||||
await session.run("""
|
||||
CREATE (f:File {
|
||||
name: $name,
|
||||
path: $path,
|
||||
module_name: $module_name,
|
||||
line_count: $line_count,
|
||||
created_at: datetime()
|
||||
})
|
||||
""",
|
||||
name=mod['file_path'].split('/')[-1],
|
||||
path=mod['file_path'],
|
||||
module_name=mod['module_name'],
|
||||
line_count=mod['line_count']
|
||||
)
|
||||
nodes_created += 1
|
||||
|
||||
# 2. Connect File to Repository
|
||||
await session.run("""
|
||||
MATCH (r:Repository {name: $repo_name})
|
||||
MATCH (f:File {path: $file_path})
|
||||
CREATE (r)-[:CONTAINS]->(f)
|
||||
""", repo_name=repo_name, file_path=mod['file_path'])
|
||||
relationships_created += 1
|
||||
|
||||
# 3. Create Class nodes and relationships
|
||||
for cls in mod['classes']:
|
||||
# Create Class node using MERGE to avoid duplicates
|
||||
await session.run("""
|
||||
MERGE (c:Class {full_name: $full_name})
|
||||
ON CREATE SET c.name = $name, c.created_at = datetime()
|
||||
""", name=cls['name'], full_name=cls['full_name'])
|
||||
nodes_created += 1
|
||||
|
||||
# Connect File to Class
|
||||
await session.run("""
|
||||
MATCH (f:File {path: $file_path})
|
||||
MATCH (c:Class {full_name: $class_full_name})
|
||||
MERGE (f)-[:DEFINES]->(c)
|
||||
""", file_path=mod['file_path'], class_full_name=cls['full_name'])
|
||||
relationships_created += 1
|
||||
|
||||
# 4. Create Method nodes - use MERGE to avoid duplicates
|
||||
for method in cls['methods']:
|
||||
method_full_name = f"{cls['full_name']}.{method['name']}"
|
||||
# Create method with unique ID to avoid conflicts
|
||||
method_id = f"{cls['full_name']}::{method['name']}"
|
||||
|
||||
await session.run("""
|
||||
MERGE (m:Method {method_id: $method_id})
|
||||
ON CREATE SET m.name = $name,
|
||||
m.full_name = $full_name,
|
||||
m.args = $args,
|
||||
m.params_list = $params_list,
|
||||
m.params_detailed = $params_detailed,
|
||||
m.return_type = $return_type,
|
||||
m.created_at = datetime()
|
||||
""",
|
||||
name=method['name'],
|
||||
full_name=method_full_name,
|
||||
method_id=method_id,
|
||||
args=method['args'],
|
||||
params_list=[f"{p['name']}:{p['type']}" for p in method['params']], # Simple format
|
||||
params_detailed=method.get('params_detailed', []), # Detailed format
|
||||
return_type=method['return_type']
|
||||
)
|
||||
nodes_created += 1
|
||||
|
||||
# Connect Class to Method
|
||||
await session.run("""
|
||||
MATCH (c:Class {full_name: $class_full_name})
|
||||
MATCH (m:Method {method_id: $method_id})
|
||||
MERGE (c)-[:HAS_METHOD]->(m)
|
||||
""",
|
||||
class_full_name=cls['full_name'],
|
||||
method_id=method_id
|
||||
)
|
||||
relationships_created += 1
|
||||
|
||||
# 5. Create Attribute nodes - use MERGE to avoid duplicates
|
||||
for attr in cls['attributes']:
|
||||
attr_full_name = f"{cls['full_name']}.{attr['name']}"
|
||||
# Create attribute with unique ID to avoid conflicts
|
||||
attr_id = f"{cls['full_name']}::{attr['name']}"
|
||||
await session.run("""
|
||||
MERGE (a:Attribute {attr_id: $attr_id})
|
||||
ON CREATE SET a.name = $name,
|
||||
a.full_name = $full_name,
|
||||
a.type = $type,
|
||||
a.created_at = datetime()
|
||||
""",
|
||||
name=attr['name'],
|
||||
full_name=attr_full_name,
|
||||
attr_id=attr_id,
|
||||
type=attr['type']
|
||||
)
|
||||
nodes_created += 1
|
||||
|
||||
# Connect Class to Attribute
|
||||
await session.run("""
|
||||
MATCH (c:Class {full_name: $class_full_name})
|
||||
MATCH (a:Attribute {attr_id: $attr_id})
|
||||
MERGE (c)-[:HAS_ATTRIBUTE]->(a)
|
||||
""",
|
||||
class_full_name=cls['full_name'],
|
||||
attr_id=attr_id
|
||||
)
|
||||
relationships_created += 1
|
||||
|
||||
# 6. Create Function nodes (top-level) - use MERGE to avoid duplicates
|
||||
for func in mod['functions']:
|
||||
func_id = f"{mod['file_path']}::{func['name']}"
|
||||
await session.run("""
|
||||
MERGE (f:Function {func_id: $func_id})
|
||||
ON CREATE SET f.name = $name,
|
||||
f.full_name = $full_name,
|
||||
f.args = $args,
|
||||
f.params_list = $params_list,
|
||||
f.params_detailed = $params_detailed,
|
||||
f.return_type = $return_type,
|
||||
f.created_at = datetime()
|
||||
""",
|
||||
name=func['name'],
|
||||
full_name=func['full_name'],
|
||||
func_id=func_id,
|
||||
args=func['args'],
|
||||
params_list=func.get('params_list', []), # Simple format for backwards compatibility
|
||||
params_detailed=func.get('params_detailed', []), # Detailed format
|
||||
return_type=func['return_type']
|
||||
)
|
||||
nodes_created += 1
|
||||
|
||||
# Connect File to Function
|
||||
await session.run("""
|
||||
MATCH (file:File {path: $file_path})
|
||||
MATCH (func:Function {func_id: $func_id})
|
||||
MERGE (file)-[:DEFINES]->(func)
|
||||
""", file_path=mod['file_path'], func_id=func_id)
|
||||
relationships_created += 1
|
||||
|
||||
# 7. Create Import relationships
|
||||
for import_name in mod['imports']:
|
||||
# Try to find the target file
|
||||
await session.run("""
|
||||
MATCH (source:File {path: $source_path})
|
||||
OPTIONAL MATCH (target:File)
|
||||
WHERE target.module_name = $import_name OR target.module_name STARTS WITH $import_name
|
||||
WITH source, target
|
||||
WHERE target IS NOT NULL
|
||||
MERGE (source)-[:IMPORTS]->(target)
|
||||
""", source_path=mod['file_path'], import_name=import_name)
|
||||
relationships_created += 1
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f"Processed {i + 1}/{len(modules_data)} files...")
|
||||
|
||||
logger.info(f"Created {nodes_created} nodes and {relationships_created} relationships")
|
||||
|
||||
async def search_graph(self, query_type: str, **kwargs):
|
||||
"""Search the Neo4j graph directly"""
|
||||
async with self.driver.session() as session:
|
||||
if query_type == "files_importing":
|
||||
target = kwargs.get('target')
|
||||
result = await session.run("""
|
||||
MATCH (source:File)-[:IMPORTS]->(target:File)
|
||||
WHERE target.module_name CONTAINS $target
|
||||
RETURN source.path as file, target.module_name as imports
|
||||
""", target=target)
|
||||
return [{"file": record["file"], "imports": record["imports"]} async for record in result]
|
||||
|
||||
elif query_type == "classes_in_file":
|
||||
file_path = kwargs.get('file_path')
|
||||
result = await session.run("""
|
||||
MATCH (f:File {path: $file_path})-[:DEFINES]->(c:Class)
|
||||
RETURN c.name as class_name, c.full_name as full_name
|
||||
""", file_path=file_path)
|
||||
return [{"class_name": record["class_name"], "full_name": record["full_name"]} async for record in result]
|
||||
|
||||
elif query_type == "methods_of_class":
|
||||
class_name = kwargs.get('class_name')
|
||||
result = await session.run("""
|
||||
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
|
||||
WHERE c.name CONTAINS $class_name OR c.full_name CONTAINS $class_name
|
||||
RETURN m.name as method_name, m.args as args
|
||||
""", class_name=class_name)
|
||||
return [{"method_name": record["method_name"], "args": record["args"]} async for record in result]
|
||||
|
||||
|
||||
async def main():
|
||||
"""Example usage"""
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
await extractor.initialize()
|
||||
|
||||
# Analyze repository - direct Neo4j, no LLM processing!
|
||||
# repo_url = "https://github.com/pydantic/pydantic-ai.git"
|
||||
repo_url = "https://github.com/getzep/graphiti.git"
|
||||
await extractor.analyze_repository(repo_url)
|
||||
|
||||
# Direct graph queries
|
||||
print("\\n=== Direct Neo4j Queries ===")
|
||||
|
||||
# Which files import from models?
|
||||
results = await extractor.search_graph("files_importing", target="models")
|
||||
print(f"\\nFiles importing from 'models': {len(results)}")
|
||||
for result in results[:3]:
|
||||
print(f"- {result['file']} imports {result['imports']}")
|
||||
|
||||
# What classes are in a specific file?
|
||||
results = await extractor.search_graph("classes_in_file", file_path="pydantic_ai/models/openai.py")
|
||||
print(f"\\nClasses in openai.py: {len(results)}")
|
||||
for result in results:
|
||||
print(f"- {result['class_name']}")
|
||||
|
||||
# What methods does OpenAIModel have?
|
||||
results = await extractor.search_graph("methods_of_class", class_name="OpenAIModel")
|
||||
print(f"\\nMethods of OpenAIModel: {len(results)}")
|
||||
for result in results[:5]:
|
||||
print(f"- {result['method_name']}({', '.join(result['args'])})")
|
||||
|
||||
finally:
|
||||
await extractor.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,400 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Knowledge Graph Query Tool
|
||||
|
||||
Interactive script to explore what's actually stored in your Neo4j knowledge graph.
|
||||
Useful for debugging hallucination detection and understanding graph contents.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from typing import List, Dict, Any
|
||||
import argparse
|
||||
|
||||
|
||||
class KnowledgeGraphQuerier:
|
||||
"""Interactive tool to query the knowledge graph"""
|
||||
|
||||
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
|
||||
self.neo4j_uri = neo4j_uri
|
||||
self.neo4j_user = neo4j_user
|
||||
self.neo4j_password = neo4j_password
|
||||
self.driver = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Neo4j connection"""
|
||||
self.driver = AsyncGraphDatabase.driver(
|
||||
self.neo4j_uri,
|
||||
auth=(self.neo4j_user, self.neo4j_password)
|
||||
)
|
||||
print("🔗 Connected to Neo4j knowledge graph")
|
||||
|
||||
async def close(self):
|
||||
"""Close Neo4j connection"""
|
||||
if self.driver:
|
||||
await self.driver.close()
|
||||
|
||||
async def list_repositories(self):
|
||||
"""List all repositories in the knowledge graph"""
|
||||
print("\n📚 Repositories in Knowledge Graph:")
|
||||
print("=" * 50)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
query = "MATCH (r:Repository) RETURN r.name as name ORDER BY r.name"
|
||||
result = await session.run(query)
|
||||
|
||||
repos = []
|
||||
async for record in result:
|
||||
repos.append(record['name'])
|
||||
|
||||
if repos:
|
||||
for i, repo in enumerate(repos, 1):
|
||||
print(f"{i}. {repo}")
|
||||
else:
|
||||
print("No repositories found in knowledge graph.")
|
||||
|
||||
return repos
|
||||
|
||||
async def explore_repository(self, repo_name: str):
|
||||
"""Get overview of a specific repository"""
|
||||
print(f"\n🔍 Exploring Repository: {repo_name}")
|
||||
print("=" * 60)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
# Get file count
|
||||
files_query = """
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
|
||||
RETURN count(f) as file_count
|
||||
"""
|
||||
result = await session.run(files_query, repo_name=repo_name)
|
||||
file_count = (await result.single())['file_count']
|
||||
|
||||
# Get class count
|
||||
classes_query = """
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
|
||||
RETURN count(DISTINCT c) as class_count
|
||||
"""
|
||||
result = await session.run(classes_query, repo_name=repo_name)
|
||||
class_count = (await result.single())['class_count']
|
||||
|
||||
# Get function count
|
||||
functions_query = """
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
|
||||
RETURN count(DISTINCT func) as function_count
|
||||
"""
|
||||
result = await session.run(functions_query, repo_name=repo_name)
|
||||
function_count = (await result.single())['function_count']
|
||||
|
||||
print(f"📄 Files: {file_count}")
|
||||
print(f"🏗️ Classes: {class_count}")
|
||||
print(f"⚙️ Functions: {function_count}")
|
||||
|
||||
async def list_classes(self, repo_name: str = None, limit: int = 20):
|
||||
"""List classes in the knowledge graph"""
|
||||
title = f"Classes in {repo_name}" if repo_name else "All Classes"
|
||||
print(f"\n🏗️ {title} (limit {limit}):")
|
||||
print("=" * 50)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
if repo_name:
|
||||
query = """
|
||||
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
|
||||
RETURN c.name as name, c.full_name as full_name
|
||||
ORDER BY c.name
|
||||
LIMIT $limit
|
||||
"""
|
||||
result = await session.run(query, repo_name=repo_name, limit=limit)
|
||||
else:
|
||||
query = """
|
||||
MATCH (c:Class)
|
||||
RETURN c.name as name, c.full_name as full_name
|
||||
ORDER BY c.name
|
||||
LIMIT $limit
|
||||
"""
|
||||
result = await session.run(query, limit=limit)
|
||||
|
||||
classes = []
|
||||
async for record in result:
|
||||
classes.append({
|
||||
'name': record['name'],
|
||||
'full_name': record['full_name']
|
||||
})
|
||||
|
||||
if classes:
|
||||
for i, cls in enumerate(classes, 1):
|
||||
print(f"{i:2d}. {cls['name']} ({cls['full_name']})")
|
||||
else:
|
||||
print("No classes found.")
|
||||
|
||||
return classes
|
||||
|
||||
async def explore_class(self, class_name: str):
|
||||
"""Get detailed information about a specific class"""
|
||||
print(f"\n🔍 Exploring Class: {class_name}")
|
||||
print("=" * 60)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
# Find the class
|
||||
class_query = """
|
||||
MATCH (c:Class)
|
||||
WHERE c.name = $class_name OR c.full_name = $class_name
|
||||
RETURN c.name as name, c.full_name as full_name
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await session.run(class_query, class_name=class_name)
|
||||
class_record = await result.single()
|
||||
|
||||
if not class_record:
|
||||
print(f"❌ Class '{class_name}' not found in knowledge graph.")
|
||||
return None
|
||||
|
||||
actual_name = class_record['name']
|
||||
full_name = class_record['full_name']
|
||||
|
||||
print(f"📋 Name: {actual_name}")
|
||||
print(f"📋 Full Name: {full_name}")
|
||||
|
||||
# Get methods
|
||||
methods_query = """
|
||||
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
|
||||
WHERE c.name = $class_name OR c.full_name = $class_name
|
||||
RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed, m.return_type as return_type
|
||||
ORDER BY m.name
|
||||
"""
|
||||
result = await session.run(methods_query, class_name=class_name)
|
||||
|
||||
methods = []
|
||||
async for record in result:
|
||||
methods.append({
|
||||
'name': record['name'],
|
||||
'params_list': record['params_list'] or [],
|
||||
'params_detailed': record['params_detailed'] or [],
|
||||
'return_type': record['return_type'] or 'Any'
|
||||
})
|
||||
|
||||
if methods:
|
||||
print(f"\n⚙️ Methods ({len(methods)}):")
|
||||
for i, method in enumerate(methods, 1):
|
||||
# Use detailed params if available, fall back to simple params
|
||||
params_to_show = method['params_detailed'] or method['params_list']
|
||||
params = ', '.join(params_to_show) if params_to_show else ''
|
||||
print(f"{i:2d}. {method['name']}({params}) -> {method['return_type']}")
|
||||
else:
|
||||
print("\n⚙️ No methods found.")
|
||||
|
||||
# Get attributes
|
||||
attributes_query = """
|
||||
MATCH (c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
|
||||
WHERE c.name = $class_name OR c.full_name = $class_name
|
||||
RETURN a.name as name, a.type as type
|
||||
ORDER BY a.name
|
||||
"""
|
||||
result = await session.run(attributes_query, class_name=class_name)
|
||||
|
||||
attributes = []
|
||||
async for record in result:
|
||||
attributes.append({
|
||||
'name': record['name'],
|
||||
'type': record['type'] or 'Any'
|
||||
})
|
||||
|
||||
if attributes:
|
||||
print(f"\n📋 Attributes ({len(attributes)}):")
|
||||
for i, attr in enumerate(attributes, 1):
|
||||
print(f"{i:2d}. {attr['name']}: {attr['type']}")
|
||||
else:
|
||||
print("\n📋 No attributes found.")
|
||||
|
||||
return {'methods': methods, 'attributes': attributes}
|
||||
|
||||
async def search_method(self, method_name: str, class_name: str = None):
|
||||
"""Search for methods by name"""
|
||||
title = f"Method '{method_name}'"
|
||||
if class_name:
|
||||
title += f" in class '{class_name}'"
|
||||
|
||||
print(f"\n🔍 Searching for {title}:")
|
||||
print("=" * 60)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
if class_name:
|
||||
query = """
|
||||
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
|
||||
WHERE (c.name = $class_name OR c.full_name = $class_name)
|
||||
AND m.name = $method_name
|
||||
RETURN c.name as class_name, c.full_name as class_full_name,
|
||||
m.name as method_name, m.params_list as params_list,
|
||||
m.return_type as return_type, m.args as args
|
||||
"""
|
||||
result = await session.run(query, class_name=class_name, method_name=method_name)
|
||||
else:
|
||||
query = """
|
||||
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
|
||||
WHERE m.name = $method_name
|
||||
RETURN c.name as class_name, c.full_name as class_full_name,
|
||||
m.name as method_name, m.params_list as params_list,
|
||||
m.return_type as return_type, m.args as args
|
||||
ORDER BY c.name
|
||||
"""
|
||||
result = await session.run(query, method_name=method_name)
|
||||
|
||||
methods = []
|
||||
async for record in result:
|
||||
methods.append({
|
||||
'class_name': record['class_name'],
|
||||
'class_full_name': record['class_full_name'],
|
||||
'method_name': record['method_name'],
|
||||
'params_list': record['params_list'] or [],
|
||||
'return_type': record['return_type'] or 'Any',
|
||||
'args': record['args'] or []
|
||||
})
|
||||
|
||||
if methods:
|
||||
for i, method in enumerate(methods, 1):
|
||||
params = ', '.join(method['params_list']) if method['params_list'] else ''
|
||||
print(f"{i}. {method['class_full_name']}.{method['method_name']}({params}) -> {method['return_type']}")
|
||||
if method['args']:
|
||||
print(f" Legacy args: {method['args']}")
|
||||
else:
|
||||
print(f"❌ Method '{method_name}' not found.")
|
||||
|
||||
return methods
|
||||
|
||||
async def run_custom_query(self, query: str):
|
||||
"""Run a custom Cypher query"""
|
||||
print(f"\n🔍 Running Custom Query:")
|
||||
print("=" * 60)
|
||||
print(f"Query: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
async with self.driver.session() as session:
|
||||
try:
|
||||
result = await session.run(query)
|
||||
|
||||
records = []
|
||||
async for record in result:
|
||||
records.append(dict(record))
|
||||
|
||||
if records:
|
||||
for i, record in enumerate(records, 1):
|
||||
print(f"{i:2d}. {record}")
|
||||
if i >= 20: # Limit output
|
||||
print(f"... and {len(records) - 20} more records")
|
||||
break
|
||||
else:
|
||||
print("No results found.")
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Query error: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def interactive_mode(querier: KnowledgeGraphQuerier):
|
||||
"""Interactive exploration mode"""
|
||||
print("\n🚀 Welcome to Knowledge Graph Explorer!")
|
||||
print("Available commands:")
|
||||
print(" repos - List all repositories")
|
||||
print(" explore <repo> - Explore a specific repository")
|
||||
print(" classes [repo] - List classes (optionally in specific repo)")
|
||||
print(" class <name> - Explore a specific class")
|
||||
print(" method <name> [class] - Search for method")
|
||||
print(" query <cypher> - Run custom Cypher query")
|
||||
print(" quit - Exit")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = input("🔍 > ").strip()
|
||||
|
||||
if not command:
|
||||
continue
|
||||
elif command == "quit":
|
||||
break
|
||||
elif command == "repos":
|
||||
await querier.list_repositories()
|
||||
elif command.startswith("explore "):
|
||||
repo_name = command[8:].strip()
|
||||
await querier.explore_repository(repo_name)
|
||||
elif command == "classes":
|
||||
await querier.list_classes()
|
||||
elif command.startswith("classes "):
|
||||
repo_name = command[8:].strip()
|
||||
await querier.list_classes(repo_name)
|
||||
elif command.startswith("class "):
|
||||
class_name = command[6:].strip()
|
||||
await querier.explore_class(class_name)
|
||||
elif command.startswith("method "):
|
||||
parts = command[7:].strip().split()
|
||||
if len(parts) >= 2:
|
||||
await querier.search_method(parts[0], parts[1])
|
||||
else:
|
||||
await querier.search_method(parts[0])
|
||||
elif command.startswith("query "):
|
||||
query = command[6:].strip()
|
||||
await querier.run_custom_query(query)
|
||||
else:
|
||||
print("❌ Unknown command. Type 'quit' to exit.")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {str(e)}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with CLI argument support"""
|
||||
parser = argparse.ArgumentParser(description="Query the knowledge graph")
|
||||
parser.add_argument('--repos', action='store_true', help='List repositories')
|
||||
parser.add_argument('--classes', metavar='REPO', nargs='?', const='', help='List classes')
|
||||
parser.add_argument('--explore', metavar='REPO', help='Explore repository')
|
||||
parser.add_argument('--class', dest='class_name', metavar='NAME', help='Explore class')
|
||||
parser.add_argument('--method', nargs='+', metavar=('NAME', 'CLASS'), help='Search method')
|
||||
parser.add_argument('--query', metavar='CYPHER', help='Run custom query')
|
||||
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
querier = KnowledgeGraphQuerier(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
await querier.initialize()
|
||||
|
||||
# Execute commands based on arguments
|
||||
if args.repos:
|
||||
await querier.list_repositories()
|
||||
elif args.classes is not None:
|
||||
await querier.list_classes(args.classes if args.classes else None)
|
||||
elif args.explore:
|
||||
await querier.explore_repository(args.explore)
|
||||
elif args.class_name:
|
||||
await querier.explore_class(args.class_name)
|
||||
elif args.method:
|
||||
if len(args.method) >= 2:
|
||||
await querier.search_method(args.method[0], args.method[1])
|
||||
else:
|
||||
await querier.search_method(args.method[0])
|
||||
elif args.query:
|
||||
await querier.run_custom_query(args.query)
|
||||
elif args.interactive or len(sys.argv) == 1:
|
||||
await interactive_mode(querier)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
finally:
|
||||
await querier.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
asyncio.run(main())
|
||||
160
servers/mcp-crawl4ai-rag/knowledge_graphs/test_script.py
Normal file
160
servers/mcp-crawl4ai-rag/knowledge_graphs/test_script.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from dotenv import load_dotenv
|
||||
from rich.markdown import Markdown
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from graphiti_core import Graphiti
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ========== Define dependencies ==========
|
||||
@dataclass
|
||||
class GraphitiDependencies:
|
||||
"""Dependencies for the Graphiti agent."""
|
||||
graphiti_client: Graphiti
|
||||
|
||||
# ========== Helper function to get model configuration ==========
|
||||
def get_model():
|
||||
"""Configure and return the LLM model to use."""
|
||||
model_choice = os.getenv('MODEL_CHOICE', 'gpt-4.1-mini')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'no-api-key-provided')
|
||||
|
||||
return OpenAIModel(model_choice, provider=OpenAIProvider(api_key=api_key))
|
||||
|
||||
# ========== Create the Graphiti agent ==========
|
||||
graphiti_agent = Agent(
|
||||
get_model(),
|
||||
system_prompt="""You are a helpful assistant with access to a knowledge graph filled with temporal data about LLMs.
|
||||
When the user asks you a question, use your search tool to query the knowledge graph and then answer honestly.
|
||||
Be willing to admit when you didn't find the information necessary to answer the question.""",
|
||||
deps_type=GraphitiDependencies
|
||||
)
|
||||
|
||||
# ========== Define a result model for Graphiti search ==========
|
||||
class GraphitiSearchResult(BaseModel):
|
||||
"""Model representing a search result from Graphiti."""
|
||||
uuid: str = Field(description="The unique identifier for this fact")
|
||||
fact: str = Field(description="The factual statement retrieved from the knowledge graph")
|
||||
valid_at: Optional[str] = Field(None, description="When this fact became valid (if known)")
|
||||
invalid_at: Optional[str] = Field(None, description="When this fact became invalid (if known)")
|
||||
source_node_uuid: Optional[str] = Field(None, description="UUID of the source node")
|
||||
|
||||
# ========== Graphiti search tool ==========
|
||||
@graphiti_agent.tool
|
||||
async def search_graphiti(ctx: RunContext[GraphitiDependencies], query: str) -> List[GraphitiSearchResult]:
|
||||
"""Search the Graphiti knowledge graph with the given query.
|
||||
|
||||
Args:
|
||||
ctx: The run context containing dependencies
|
||||
query: The search query to find information in the knowledge graph
|
||||
|
||||
Returns:
|
||||
A list of search results containing facts that match the query
|
||||
"""
|
||||
# Access the Graphiti client from dependencies
|
||||
graphiti = ctx.deps.graphiti_client
|
||||
|
||||
try:
|
||||
# Perform the search
|
||||
results = await graphiti.search(query)
|
||||
|
||||
# Format the results
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
formatted_result = GraphitiSearchResult(
|
||||
uuid=result.uuid,
|
||||
fact=result.fact,
|
||||
source_node_uuid=result.source_node_uuid if hasattr(result, 'source_node_uuid') else None
|
||||
)
|
||||
|
||||
# Add temporal information if available
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
formatted_result.valid_at = str(result.valid_at)
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
formatted_result.invalid_at = str(result.invalid_at)
|
||||
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
# Log the error but don't close the connection since it's managed by the dependency
|
||||
print(f"Error searching Graphiti: {str(e)}")
|
||||
raise
|
||||
|
||||
# ========== Main execution function ==========
|
||||
async def main():
|
||||
"""Run the Graphiti agent with user queries."""
|
||||
print("Graphiti Agent - Powered by Pydantic AI, Graphiti, and Neo4j")
|
||||
print("Enter 'exit' to quit the program.")
|
||||
|
||||
# Neo4j connection parameters
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
# Initialize Graphiti with Neo4j connection
|
||||
graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
# Initialize the graph database with graphiti's indices if needed
|
||||
try:
|
||||
await graphiti_client.build_indices_and_constraints()
|
||||
print("Graphiti indices built successfully.")
|
||||
except Exception as e:
|
||||
print(f"Note: {str(e)}")
|
||||
print("Continuing with existing indices...")
|
||||
|
||||
console = Console()
|
||||
messages = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
user_input = input("\n[You] ")
|
||||
|
||||
# Check if user wants to exit
|
||||
if user_input.lower() in ['exit', 'quit', 'bye', 'goodbye']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
try:
|
||||
# Process the user input and output the response
|
||||
print("\n[Assistant]")
|
||||
with Live('', console=console, vertical_overflow='visible') as live:
|
||||
# Pass the Graphiti client as a dependency
|
||||
deps = GraphitiDependencies(graphiti_client=graphiti_client)
|
||||
|
||||
async with graphiti_agent.run_a_stream(
|
||||
user_input, message_history=messages, deps=deps
|
||||
) as result:
|
||||
curr_message = ""
|
||||
async for message in result.stream_text(delta=True):
|
||||
curr_message += message
|
||||
live.update(Markdown(curr_message))
|
||||
|
||||
# Add the new messages to the chat history
|
||||
messages.extend(result.all_messages())
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n[Error] An error occurred: {str(e)}")
|
||||
finally:
|
||||
# Close the Graphiti connection when done
|
||||
await graphiti_client.close()
|
||||
print("\nGraphiti connection closed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nProgram terminated by user.")
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {str(e)}")
|
||||
raise
|
||||
Reference in New Issue
Block a user