""" Neo4j client helpers for the AI Analysis Service. This module wraps the official Neo4j async driver and exposes a minimal set of convenience methods that we can reuse across the service without sprinkling Cypher execution boilerplate everywhere. """ from __future__ import annotations import json from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime from typing import Any, AsyncIterator, Dict, List, Optional, Sequence from neo4j import AsyncGraphDatabase try: from neo4j import AsyncResult, AsyncSession # type: ignore except ImportError: # pragma: no cover - fallback for older/newer driver versions AsyncResult = Any # type: ignore AsyncSession = Any # type: ignore def _json_dumps(value: Any) -> str: """Serialize complex values so we can persist them as strings safely.""" if value is None: return "" if isinstance(value, (str, int, float, bool)): return str(value) try: return json.dumps(value, default=str) except Exception: return str(value) @dataclass class Neo4jConfig: uri: str user: str password: str database: Optional[str] = None fetch_size: int = 1000 class Neo4jGraphClient: """ Thin wrapper around the Neo4j async driver that provides helpers for writing analysis artefacts into the graph and querying them back. """ def __init__(self, config: Neo4jConfig) -> None: self._config = config self._driver = AsyncGraphDatabase.driver( config.uri, auth=(config.user, config.password), # Allow long running operations while the analysis progresses max_connection_lifetime=3600, ) async def close(self) -> None: if self._driver: await self._driver.close() @asynccontextmanager async def session(self) -> AsyncIterator[AsyncSession]: kwargs: Dict[str, Any] = {} if self._config.database: kwargs["database"] = self._config.database if self._config.fetch_size: kwargs["fetch_size"] = self._config.fetch_size async with self._driver.session(**kwargs) as session: yield session async def _run_write(self, query: str, **params: Any) -> None: async with self.session() as session: async def _write(tx): result = await tx.run(query, **params) await result.consume() await session.execute_write(_write) async def _run_read(self, query: str, **params: Any) -> List[Dict[str, Any]]: async with self.session() as session: result: AsyncResult = await session.run(query, **params) records = await result.data() return records # ------------------------------------------------------------------ # # Write helpers # ------------------------------------------------------------------ # async def upsert_run(self, run_id: str, repository_id: str) -> None: await self._run_write( """ MERGE (r:Run {run_id: $run_id}) ON CREATE SET r.repository_id = $repository_id, r.created_at = datetime(), r.updated_at = datetime() ON MATCH SET r.repository_id = $repository_id, r.updated_at = datetime() """, run_id=run_id, repository_id=repository_id, ) async def clear_run(self, run_id: str) -> None: await self._run_write( """ MATCH (r:Run {run_id: $run_id}) OPTIONAL MATCH (r)-[rel]-() DETACH DELETE r """, run_id=run_id, ) async def upsert_module_graph( self, run_id: str, repository_id: str, module_props: Dict[str, Any], files: Sequence[Dict[str, Any]], findings: Sequence[Dict[str, Any]], dependencies: Sequence[Dict[str, Any]], ) -> None: """ Persist module level artefacts in a single transaction. """ # Ensure strings module_props = {k: _json_dumps(v) if isinstance(v, (dict, list, tuple, set)) else v for k, v in module_props.items()} files_payload = [ { "path": item["path"], "props": { key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value for key, value in item.get("props", {}).items() }, } for item in files ] findings_payload = [ { "id": item["id"], "props": { key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value for key, value in item.get("props", {}).items() }, "file_path": item.get("file_path"), } for item in findings ] dependencies_payload = [ { "target": dependency.get("target"), "kind": dependency.get("kind", "depends_on"), "metadata": _json_dumps(dependency.get("metadata", {})), } for dependency in dependencies ] await self._run_write( """ MERGE (r:Run {run_id: $run_id}) ON CREATE SET r.repository_id = $repository_id, r.created_at = datetime(), r.updated_at = datetime() ON MATCH SET r.repository_id = $repository_id, r.updated_at = datetime() MERGE (m:Module {run_id: $run_id, name: $module_name}) SET m += $module_props, m.updated_at = datetime() MERGE (r)-[:RUN_HAS_MODULE]->(m) WITH m UNWIND $files AS file_data MERGE (f:File {run_id: $run_id, path: file_data.path}) SET f += file_data.props, f.updated_at = datetime() MERGE (m)-[:MODULE_INCLUDES_FILE]->(f) WITH m UNWIND $findings AS finding_data MERGE (fd:Finding {run_id: $run_id, finding_id: finding_data.id}) SET fd += finding_data.props, fd.updated_at = datetime() MERGE (m)-[:MODULE_HAS_FINDING]->(fd) FOREACH (fp IN CASE WHEN finding_data.file_path IS NULL THEN [] ELSE [finding_data.file_path] END | MERGE (ff:File {run_id: $run_id, path: fp}) MERGE (fd)-[:FINDING_TOUCHES_FILE]->(ff) ) WITH m UNWIND $dependencies AS dependency FOREACH (_ IN CASE WHEN dependency.target IS NULL THEN [] ELSE [1] END | MERGE (dep:Module {run_id: $run_id, name: dependency.target}) MERGE (m)-[rel:MODULE_DEPENDENCY {kind: dependency.kind}]->(dep) SET rel.metadata = dependency.metadata, rel.updated_at = datetime() ) """, run_id=run_id, repository_id=repository_id, module_name=module_props.get("name"), module_props=module_props, files=files_payload, findings=findings_payload, dependencies=dependencies_payload, ) async def upsert_run_state(self, run_id: str, state: Dict[str, Any]) -> None: await self._run_write( """ MERGE (r:Run {run_id: $run_id}) SET r.analysis_state = $state, r.state_updated_at = datetime() """, run_id=run_id, state=_json_dumps(state), ) async def upsert_synthesis(self, run_id: str, synthesis: Dict[str, Any]) -> None: await self._run_write( """ MERGE (r:Run {run_id: $run_id}) SET r.synthesis_analysis = $synthesis, r.synthesis_updated_at = datetime() """, run_id=run_id, synthesis=_json_dumps(synthesis), ) # ------------------------------------------------------------------ # # Read helpers # ------------------------------------------------------------------ # async def fetch_modules(self, run_id: str) -> List[Dict[str, Any]]: records = await self._run_read( """ MATCH (r:Run {run_id: $run_id})-[:RUN_HAS_MODULE]->(m:Module) OPTIONAL MATCH (m)-[:MODULE_INCLUDES_FILE]->(f:File) OPTIONAL MATCH (m)-[:MODULE_HAS_FINDING]->(fd:Finding) OPTIONAL MATCH (fd)-[:FINDING_TOUCHES_FILE]->(ff:File) RETURN m, collect(DISTINCT properties(f)) AS files, collect(DISTINCT properties(fd)) AS findings, collect(DISTINCT properties(ff)) AS finding_files """, run_id=run_id, ) modules: List[Dict[str, Any]] = [] for record in records: module_node = record.get("m", {}) files = record.get("files", []) findings = record.get("findings", []) finding_files = record.get("finding_files", []) modules.append( { "module": module_node, "files": files, "findings": findings, "finding_files": finding_files, } ) return modules async def fetch_run_state(self, run_id: str) -> Optional[Dict[str, Any]]: records = await self._run_read( """ MATCH (r:Run {run_id: $run_id}) RETURN r.analysis_state AS analysis_state """, run_id=run_id, ) if not records: return None raw_state = records[0].get("analysis_state") if not raw_state: return None try: return json.loads(raw_state) except json.JSONDecodeError: return {"raw": raw_state} async def fetch_synthesis(self, run_id: str) -> Optional[Dict[str, Any]]: records = await self._run_read( """ MATCH (r:Run {run_id: $run_id}) RETURN r.synthesis_analysis AS synthesis """, run_id=run_id, ) if not records: return None raw_synthesis = records[0].get("synthesis") if not raw_synthesis: return None try: return json.loads(raw_synthesis) except json.JSONDecodeError: return {"raw": raw_synthesis} async def fetch_run_metadata(self, run_id: str) -> Optional[Dict[str, Any]]: records = await self._run_read( """ MATCH (r:Run {run_id: $run_id}) RETURN r """, run_id=run_id, ) if not records: return None run_node = records[0].get("r") if not run_node: return None metadata = dict(run_node) if "created_at" in metadata and isinstance(metadata["created_at"], datetime): metadata["created_at"] = metadata["created_at"].isoformat() if "updated_at" in metadata and isinstance(metadata["updated_at"], datetime): metadata["updated_at"] = metadata["updated_at"].isoformat() return metadata