codenuk_backend_mine/services/ai-analysis-service/knowledge_graph/neo4j_client.py
2025-11-13 09:07:54 +05:30

329 lines
11 KiB
Python

"""
Neo4j client helpers for the AI Analysis Service.
This module wraps the official Neo4j async driver and exposes a minimal
set of convenience methods that we can reuse across the service without
sprinkling Cypher execution boilerplate everywhere.
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence
from neo4j import AsyncGraphDatabase
try:
from neo4j import AsyncResult, AsyncSession # type: ignore
except ImportError: # pragma: no cover - fallback for older/newer driver versions
AsyncResult = Any # type: ignore
AsyncSession = Any # type: ignore
def _json_dumps(value: Any) -> str:
"""Serialize complex values so we can persist them as strings safely."""
if value is None:
return ""
if isinstance(value, (str, int, float, bool)):
return str(value)
try:
return json.dumps(value, default=str)
except Exception:
return str(value)
@dataclass
class Neo4jConfig:
uri: str
user: str
password: str
database: Optional[str] = None
fetch_size: int = 1000
class Neo4jGraphClient:
"""
Thin wrapper around the Neo4j async driver that provides helpers for
writing analysis artefacts into the graph and querying them back.
"""
def __init__(self, config: Neo4jConfig) -> None:
self._config = config
self._driver = AsyncGraphDatabase.driver(
config.uri,
auth=(config.user, config.password),
# Allow long running operations while the analysis progresses
max_connection_lifetime=3600,
)
async def close(self) -> None:
if self._driver:
await self._driver.close()
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
kwargs: Dict[str, Any] = {}
if self._config.database:
kwargs["database"] = self._config.database
if self._config.fetch_size:
kwargs["fetch_size"] = self._config.fetch_size
async with self._driver.session(**kwargs) as session:
yield session
async def _run_write(self, query: str, **params: Any) -> None:
async with self.session() as session:
async def _write(tx):
result = await tx.run(query, **params)
await result.consume()
await session.execute_write(_write)
async def _run_read(self, query: str, **params: Any) -> List[Dict[str, Any]]:
async with self.session() as session:
result: AsyncResult = await session.run(query, **params)
records = await result.data()
return records
# ------------------------------------------------------------------ #
# Write helpers
# ------------------------------------------------------------------ #
async def upsert_run(self, run_id: str, repository_id: str) -> None:
await self._run_write(
"""
MERGE (r:Run {run_id: $run_id})
ON CREATE SET
r.repository_id = $repository_id,
r.created_at = datetime(),
r.updated_at = datetime()
ON MATCH SET
r.repository_id = $repository_id,
r.updated_at = datetime()
""",
run_id=run_id,
repository_id=repository_id,
)
async def clear_run(self, run_id: str) -> None:
await self._run_write(
"""
MATCH (r:Run {run_id: $run_id})
OPTIONAL MATCH (r)-[rel]-()
DETACH DELETE r
""",
run_id=run_id,
)
async def upsert_module_graph(
self,
run_id: str,
repository_id: str,
module_props: Dict[str, Any],
files: Sequence[Dict[str, Any]],
findings: Sequence[Dict[str, Any]],
dependencies: Sequence[Dict[str, Any]],
) -> None:
"""
Persist module level artefacts in a single transaction.
"""
# Ensure strings
module_props = {k: _json_dumps(v) if isinstance(v, (dict, list, tuple, set)) else v for k, v in module_props.items()}
files_payload = [
{
"path": item["path"],
"props": {
key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value
for key, value in item.get("props", {}).items()
},
}
for item in files
]
findings_payload = [
{
"id": item["id"],
"props": {
key: _json_dumps(value) if isinstance(value, (dict, list, tuple, set)) else value
for key, value in item.get("props", {}).items()
},
"file_path": item.get("file_path"),
}
for item in findings
]
dependencies_payload = [
{
"target": dependency.get("target"),
"kind": dependency.get("kind", "depends_on"),
"metadata": _json_dumps(dependency.get("metadata", {})),
}
for dependency in dependencies
]
await self._run_write(
"""
MERGE (r:Run {run_id: $run_id})
ON CREATE SET
r.repository_id = $repository_id,
r.created_at = datetime(),
r.updated_at = datetime()
ON MATCH SET
r.repository_id = $repository_id,
r.updated_at = datetime()
MERGE (m:Module {run_id: $run_id, name: $module_name})
SET m += $module_props,
m.updated_at = datetime()
MERGE (r)-[:RUN_HAS_MODULE]->(m)
WITH m
UNWIND $files AS file_data
MERGE (f:File {run_id: $run_id, path: file_data.path})
SET f += file_data.props,
f.updated_at = datetime()
MERGE (m)-[:MODULE_INCLUDES_FILE]->(f)
WITH m
UNWIND $findings AS finding_data
MERGE (fd:Finding {run_id: $run_id, finding_id: finding_data.id})
SET fd += finding_data.props,
fd.updated_at = datetime()
MERGE (m)-[:MODULE_HAS_FINDING]->(fd)
FOREACH (fp IN CASE WHEN finding_data.file_path IS NULL THEN [] ELSE [finding_data.file_path] END |
MERGE (ff:File {run_id: $run_id, path: fp})
MERGE (fd)-[:FINDING_TOUCHES_FILE]->(ff)
)
WITH m
UNWIND $dependencies AS dependency
FOREACH (_ IN CASE WHEN dependency.target IS NULL THEN [] ELSE [1] END |
MERGE (dep:Module {run_id: $run_id, name: dependency.target})
MERGE (m)-[rel:MODULE_DEPENDENCY {kind: dependency.kind}]->(dep)
SET rel.metadata = dependency.metadata,
rel.updated_at = datetime()
)
""",
run_id=run_id,
repository_id=repository_id,
module_name=module_props.get("name"),
module_props=module_props,
files=files_payload,
findings=findings_payload,
dependencies=dependencies_payload,
)
async def upsert_run_state(self, run_id: str, state: Dict[str, Any]) -> None:
await self._run_write(
"""
MERGE (r:Run {run_id: $run_id})
SET r.analysis_state = $state,
r.state_updated_at = datetime()
""",
run_id=run_id,
state=_json_dumps(state),
)
async def upsert_synthesis(self, run_id: str, synthesis: Dict[str, Any]) -> None:
await self._run_write(
"""
MERGE (r:Run {run_id: $run_id})
SET r.synthesis_analysis = $synthesis,
r.synthesis_updated_at = datetime()
""",
run_id=run_id,
synthesis=_json_dumps(synthesis),
)
# ------------------------------------------------------------------ #
# Read helpers
# ------------------------------------------------------------------ #
async def fetch_modules(self, run_id: str) -> List[Dict[str, Any]]:
records = await self._run_read(
"""
MATCH (r:Run {run_id: $run_id})-[:RUN_HAS_MODULE]->(m:Module)
OPTIONAL MATCH (m)-[:MODULE_INCLUDES_FILE]->(f:File)
OPTIONAL MATCH (m)-[:MODULE_HAS_FINDING]->(fd:Finding)
OPTIONAL MATCH (fd)-[:FINDING_TOUCHES_FILE]->(ff:File)
RETURN
m,
collect(DISTINCT properties(f)) AS files,
collect(DISTINCT properties(fd)) AS findings,
collect(DISTINCT properties(ff)) AS finding_files
""",
run_id=run_id,
)
modules: List[Dict[str, Any]] = []
for record in records:
module_node = record.get("m", {})
files = record.get("files", [])
findings = record.get("findings", [])
finding_files = record.get("finding_files", [])
modules.append(
{
"module": module_node,
"files": files,
"findings": findings,
"finding_files": finding_files,
}
)
return modules
async def fetch_run_state(self, run_id: str) -> Optional[Dict[str, Any]]:
records = await self._run_read(
"""
MATCH (r:Run {run_id: $run_id})
RETURN r.analysis_state AS analysis_state
""",
run_id=run_id,
)
if not records:
return None
raw_state = records[0].get("analysis_state")
if not raw_state:
return None
try:
return json.loads(raw_state)
except json.JSONDecodeError:
return {"raw": raw_state}
async def fetch_synthesis(self, run_id: str) -> Optional[Dict[str, Any]]:
records = await self._run_read(
"""
MATCH (r:Run {run_id: $run_id})
RETURN r.synthesis_analysis AS synthesis
""",
run_id=run_id,
)
if not records:
return None
raw_synthesis = records[0].get("synthesis")
if not raw_synthesis:
return None
try:
return json.loads(raw_synthesis)
except json.JSONDecodeError:
return {"raw": raw_synthesis}
async def fetch_run_metadata(self, run_id: str) -> Optional[Dict[str, Any]]:
records = await self._run_read(
"""
MATCH (r:Run {run_id: $run_id})
RETURN r
""",
run_id=run_id,
)
if not records:
return None
run_node = records[0].get("r")
if not run_node:
return None
metadata = dict(run_node)
if "created_at" in metadata and isinstance(metadata["created_at"], datetime):
metadata["created_at"] = metadata["created_at"].isoformat()
if "updated_at" in metadata and isinstance(metadata["updated_at"], datetime):
metadata["updated_at"] = metadata["updated_at"].isoformat()
return metadata