329 lines
11 KiB
Python
329 lines
11 KiB
Python
"""
|
|
Neo4j client helpers for the AI Analysis Service.
|
|
|
|
This module wraps the official Neo4j async driver and exposes a minimal
|
|
set of convenience methods that we can reuse across the service without
|
|
sprinkling Cypher execution boilerplate everywhere.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence
|
|
|
|
from neo4j import AsyncGraphDatabase
|
|
|
|
try:
|
|
from neo4j import AsyncResult, AsyncSession # type: ignore
|
|
except ImportError: # pragma: no cover - fallback for older/newer driver versions
|
|
AsyncResult = Any # type: ignore
|
|
AsyncSession = Any # type: ignore
|
|
|
|
|
|
def _json_dumps(value: Any) -> str:
|
|
"""Serialize complex values so we can persist them as strings safely."""
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, (str, int, float, bool)):
|
|
return str(value)
|
|
try:
|
|
return json.dumps(value, default=str)
|
|
except Exception:
|
|
return str(value)
|
|
|
|
|
|
@dataclass
|
|
class Neo4jConfig:
|
|
uri: str
|
|
user: str
|
|
password: str
|
|
database: Optional[str] = None
|
|
fetch_size: int = 1000
|
|
|
|
|
|
class Neo4jGraphClient:
|
|
"""
|
|
Thin wrapper around the Neo4j async driver that provides helpers for
|
|
writing analysis artefacts into the graph and querying them back.
|
|
"""
|
|
|
|
def __init__(self, config: Neo4jConfig) -> None:
|
|
self._config = config
|
|
self._driver = AsyncGraphDatabase.driver(
|
|
config.uri,
|
|
auth=(config.user, config.password),
|
|
# Allow long running operations while the analysis progresses
|
|
max_connection_lifetime=3600,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
if self._driver:
|
|
await self._driver.close()
|
|
|
|
@asynccontextmanager
|
|
async def session(self) -> AsyncIterator[AsyncSession]:
|
|
kwargs: Dict[str, Any] = {}
|
|
if self._config.database:
|
|
kwargs["database"] = self._config.database
|
|
if self._config.fetch_size:
|
|
kwargs["fetch_size"] = self._config.fetch_size
|
|
async with self._driver.session(**kwargs) as session:
|
|
yield session
|
|
|
|
async def _run_write(self, query: str, **params: Any) -> None:
|
|
async with self.session() as session:
|
|
async def _write(tx):
|
|
result = await tx.run(query, **params)
|
|
await result.consume()
|
|
await session.execute_write(_write)
|
|
|
|
async def _run_read(self, query: str, **params: Any) -> List[Dict[str, Any]]:
|
|
async with self.session() as session:
|
|
result: AsyncResult = await session.run(query, **params)
|
|
records = await result.data()
|
|
return records
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Write helpers
|
|
# ------------------------------------------------------------------ #
|
|
|
|
async def upsert_run(self, run_id: str, repository_id: str) -> None:
|
|
await self._run_write(
|
|
"""
|
|
MERGE (r:Run {run_id: $run_id})
|
|
ON CREATE SET
|
|
r.repository_id = $repository_id,
|
|
r.created_at = datetime(),
|
|
r.updated_at = datetime()
|
|
ON MATCH SET
|
|
r.repository_id = $repository_id,
|
|
r.updated_at = datetime()
|
|
""",
|
|
run_id=run_id,
|
|
repository_id=repository_id,
|
|
)
|
|
|
|
async def clear_run(self, run_id: str) -> None:
|
|
await self._run_write(
|
|
"""
|
|
MATCH (r:Run {run_id: $run_id})
|
|
OPTIONAL MATCH (r)-[rel]-()
|
|
DETACH DELETE r
|
|
""",
|
|
run_id=run_id,
|
|
)
|
|
|
|
async def upsert_module_graph(
|
|
self,
|
|
run_id: str,
|
|
repository_id: str,
|
|
module_props: Dict[str, Any],
|
|
files: Sequence[Dict[str, Any]],
|
|
findings: Sequence[Dict[str, Any]],
|
|
dependencies: Sequence[Dict[str, Any]],
|
|
) -> None:
|
|
"""
|
|
Persist module level artefacts in a single transaction.
|
|
"""
|
|
# Ensure strings
|
|
module_props = {k: _json_dumps(v) if isinstance(v, (dict, list, tuple, set)) else v for k, v in module_props.items()}
|
|
files_payload = [
|
|
{
|
|
"path": item["path"],
|
|
"props": {
|
|
key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value
|
|
for key, value in item.get("props", {}).items()
|
|
},
|
|
}
|
|
for item in files
|
|
]
|
|
findings_payload = [
|
|
{
|
|
"id": item["id"],
|
|
"props": {
|
|
key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value
|
|
for key, value in item.get("props", {}).items()
|
|
},
|
|
"file_path": item.get("file_path"),
|
|
}
|
|
for item in findings
|
|
]
|
|
dependencies_payload = [
|
|
{
|
|
"target": dependency.get("target"),
|
|
"kind": dependency.get("kind", "depends_on"),
|
|
"metadata": _json_dumps(dependency.get("metadata", {})),
|
|
}
|
|
for dependency in dependencies
|
|
]
|
|
|
|
await self._run_write(
|
|
"""
|
|
MERGE (r:Run {run_id: $run_id})
|
|
ON CREATE SET
|
|
r.repository_id = $repository_id,
|
|
r.created_at = datetime(),
|
|
r.updated_at = datetime()
|
|
ON MATCH SET
|
|
r.repository_id = $repository_id,
|
|
r.updated_at = datetime()
|
|
|
|
MERGE (m:Module {run_id: $run_id, name: $module_name})
|
|
SET m += $module_props,
|
|
m.updated_at = datetime()
|
|
|
|
MERGE (r)-[:RUN_HAS_MODULE]->(m)
|
|
|
|
WITH m
|
|
UNWIND $files AS file_data
|
|
MERGE (f:File {run_id: $run_id, path: file_data.path})
|
|
SET f += file_data.props,
|
|
f.updated_at = datetime()
|
|
MERGE (m)-[:MODULE_INCLUDES_FILE]->(f)
|
|
|
|
WITH m
|
|
UNWIND $findings AS finding_data
|
|
MERGE (fd:Finding {run_id: $run_id, finding_id: finding_data.id})
|
|
SET fd += finding_data.props,
|
|
fd.updated_at = datetime()
|
|
MERGE (m)-[:MODULE_HAS_FINDING]->(fd)
|
|
FOREACH (fp IN CASE WHEN finding_data.file_path IS NULL THEN [] ELSE [finding_data.file_path] END |
|
|
MERGE (ff:File {run_id: $run_id, path: fp})
|
|
MERGE (fd)-[:FINDING_TOUCHES_FILE]->(ff)
|
|
)
|
|
|
|
WITH m
|
|
UNWIND $dependencies AS dependency
|
|
FOREACH (_ IN CASE WHEN dependency.target IS NULL THEN [] ELSE [1] END |
|
|
MERGE (dep:Module {run_id: $run_id, name: dependency.target})
|
|
MERGE (m)-[rel:MODULE_DEPENDENCY {kind: dependency.kind}]->(dep)
|
|
SET rel.metadata = dependency.metadata,
|
|
rel.updated_at = datetime()
|
|
)
|
|
""",
|
|
run_id=run_id,
|
|
repository_id=repository_id,
|
|
module_name=module_props.get("name"),
|
|
module_props=module_props,
|
|
files=files_payload,
|
|
findings=findings_payload,
|
|
dependencies=dependencies_payload,
|
|
)
|
|
|
|
async def upsert_run_state(self, run_id: str, state: Dict[str, Any]) -> None:
|
|
await self._run_write(
|
|
"""
|
|
MERGE (r:Run {run_id: $run_id})
|
|
SET r.analysis_state = $state,
|
|
r.state_updated_at = datetime()
|
|
""",
|
|
run_id=run_id,
|
|
state=_json_dumps(state),
|
|
)
|
|
|
|
async def upsert_synthesis(self, run_id: str, synthesis: Dict[str, Any]) -> None:
|
|
await self._run_write(
|
|
"""
|
|
MERGE (r:Run {run_id: $run_id})
|
|
SET r.synthesis_analysis = $synthesis,
|
|
r.synthesis_updated_at = datetime()
|
|
""",
|
|
run_id=run_id,
|
|
synthesis=_json_dumps(synthesis),
|
|
)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Read helpers
|
|
# ------------------------------------------------------------------ #
|
|
|
|
async def fetch_modules(self, run_id: str) -> List[Dict[str, Any]]:
|
|
records = await self._run_read(
|
|
"""
|
|
MATCH (r:Run {run_id: $run_id})-[:RUN_HAS_MODULE]->(m:Module)
|
|
OPTIONAL MATCH (m)-[:MODULE_INCLUDES_FILE]->(f:File)
|
|
OPTIONAL MATCH (m)-[:MODULE_HAS_FINDING]->(fd:Finding)
|
|
OPTIONAL MATCH (fd)-[:FINDING_TOUCHES_FILE]->(ff:File)
|
|
RETURN
|
|
m,
|
|
collect(DISTINCT properties(f)) AS files,
|
|
collect(DISTINCT properties(fd)) AS findings,
|
|
collect(DISTINCT properties(ff)) AS finding_files
|
|
""",
|
|
run_id=run_id,
|
|
)
|
|
modules: List[Dict[str, Any]] = []
|
|
for record in records:
|
|
module_node = record.get("m", {})
|
|
files = record.get("files", [])
|
|
findings = record.get("findings", [])
|
|
finding_files = record.get("finding_files", [])
|
|
modules.append(
|
|
{
|
|
"module": module_node,
|
|
"files": files,
|
|
"findings": findings,
|
|
"finding_files": finding_files,
|
|
}
|
|
)
|
|
return modules
|
|
|
|
async def fetch_run_state(self, run_id: str) -> Optional[Dict[str, Any]]:
|
|
records = await self._run_read(
|
|
"""
|
|
MATCH (r:Run {run_id: $run_id})
|
|
RETURN r.analysis_state AS analysis_state
|
|
""",
|
|
run_id=run_id,
|
|
)
|
|
if not records:
|
|
return None
|
|
raw_state = records[0].get("analysis_state")
|
|
if not raw_state:
|
|
return None
|
|
try:
|
|
return json.loads(raw_state)
|
|
except json.JSONDecodeError:
|
|
return {"raw": raw_state}
|
|
|
|
async def fetch_synthesis(self, run_id: str) -> Optional[Dict[str, Any]]:
|
|
records = await self._run_read(
|
|
"""
|
|
MATCH (r:Run {run_id: $run_id})
|
|
RETURN r.synthesis_analysis AS synthesis
|
|
""",
|
|
run_id=run_id,
|
|
)
|
|
if not records:
|
|
return None
|
|
raw_synthesis = records[0].get("synthesis")
|
|
if not raw_synthesis:
|
|
return None
|
|
try:
|
|
return json.loads(raw_synthesis)
|
|
except json.JSONDecodeError:
|
|
return {"raw": raw_synthesis}
|
|
|
|
async def fetch_run_metadata(self, run_id: str) -> Optional[Dict[str, Any]]:
|
|
records = await self._run_read(
|
|
"""
|
|
MATCH (r:Run {run_id: $run_id})
|
|
RETURN r
|
|
""",
|
|
run_id=run_id,
|
|
)
|
|
if not records:
|
|
return None
|
|
run_node = records[0].get("r")
|
|
if not run_node:
|
|
return None
|
|
metadata = dict(run_node)
|
|
if "created_at" in metadata and isinstance(metadata["created_at"], datetime):
|
|
metadata["created_at"] = metadata["created_at"].isoformat()
|
|
if "updated_at" in metadata and isinstance(metadata["updated_at"], datetime):
|
|
metadata["updated_at"] = metadata["updated_at"].isoformat()
|
|
return metadata
|
|
|