forked from rohit/spurrin-backend
1621 lines
55 KiB
Python
1621 lines
55 KiB
Python
"""
|
|
SpurrinAI - Intelligent Document Processing and Question Answering System
|
|
Copyright (c) 2024 Tech4biz. All rights reserved.
|
|
|
|
This module implements the main Flask application for the SpurrinAI system,
|
|
providing REST APIs for document processing, vector storage, and question answering
|
|
using RAG (Retrieval Augmented Generation) architecture.
|
|
|
|
Author: Tech4biz Development Team
|
|
Version: 1.0.0
|
|
Last Updated: 2024-01-19
|
|
"""
|
|
|
|
# Standard library imports
|
|
import os
|
|
import re
|
|
import sys
|
|
import json
|
|
import time
|
|
import threading
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from datetime import timedelta
|
|
from enum import Enum
|
|
|
|
# Third-party imports
|
|
import spacy
|
|
import redis
|
|
import aiomysql
|
|
from dotenv import load_dotenv
|
|
from flask import Flask, request, jsonify, Response
|
|
from flask_cors import CORS
|
|
from tqdm import tqdm
|
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
from langchain_community.document_loaders import PyPDFLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.embeddings import OpenAIEmbeddings
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.prompts import PromptTemplate
|
|
from openai import OpenAI
|
|
from rapidfuzz import process
|
|
from threading import Lock
|
|
|
|
# Local imports
|
|
from model_manager import ModelManager
|
|
|
|
# Suppress warnings
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
# Initialize NLTK
|
|
import nltk
|
|
|
|
nltk.download("punkt")
|
|
|
|
# Configure logging
|
|
import logging
|
|
import logging.handlers
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
log_file_path = os.path.join(script_dir, "error.log")
|
|
logging.basicConfig(filename=log_file_path, level=logging.INFO)
|
|
|
|
|
|
# Configure logging
|
|
def setup_logging():
|
|
log_dir = os.path.join(script_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
main_log = os.path.join(log_dir, "app.log")
|
|
error_log = os.path.join(log_dir, "error.log")
|
|
access_log = os.path.join(log_dir, "access.log")
|
|
perf_log = os.path.join(log_dir, "performance.log")
|
|
|
|
# Create formatters
|
|
detailed_formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
|
)
|
|
access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
# Main logger setup
|
|
main_handler = logging.handlers.RotatingFileHandler(
|
|
main_log, maxBytes=10485760, backupCount=5
|
|
)
|
|
main_handler.setFormatter(detailed_formatter)
|
|
main_handler.setLevel(logging.INFO)
|
|
|
|
# Error logger setup
|
|
error_handler = logging.handlers.RotatingFileHandler(
|
|
error_log, maxBytes=10485760, backupCount=5
|
|
)
|
|
error_handler.setFormatter(detailed_formatter)
|
|
error_handler.setLevel(logging.ERROR)
|
|
|
|
# Access logger setup
|
|
access_handler = logging.handlers.TimedRotatingFileHandler(
|
|
access_log, when="midnight", interval=1, backupCount=30
|
|
)
|
|
access_handler.setFormatter(access_formatter)
|
|
access_handler.setLevel(logging.INFO)
|
|
|
|
# Performance logger setup
|
|
perf_handler = logging.handlers.RotatingFileHandler(
|
|
perf_log, maxBytes=10485760, backupCount=5
|
|
)
|
|
perf_handler.setFormatter(detailed_formatter)
|
|
perf_handler.setLevel(logging.INFO)
|
|
|
|
# Configure root logger
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.INFO)
|
|
root_logger.addHandler(main_handler)
|
|
root_logger.addHandler(error_handler)
|
|
|
|
# Create specific loggers
|
|
access_logger = logging.getLogger("access")
|
|
access_logger.addHandler(access_handler)
|
|
access_logger.setLevel(logging.INFO)
|
|
|
|
perf_logger = logging.getLogger("performance")
|
|
perf_logger.addHandler(perf_handler)
|
|
perf_logger.setLevel(logging.INFO)
|
|
|
|
return root_logger, access_logger, perf_logger
|
|
|
|
|
|
# Initialize loggers
|
|
logger, access_logger, perf_logger = setup_logging()
|
|
|
|
# DB_CONFIG = {
|
|
# 'host': 'localhost',
|
|
# 'user': 'flaskuser',
|
|
# 'password': 'Flask@123',
|
|
# 'database': 'spurrinai',
|
|
# }
|
|
|
|
# DB_CONFIG = {
|
|
# 'host': 'localhost',
|
|
# 'user': 'spurrindevuser',
|
|
# 'password': 'Admin@123',
|
|
# 'database': 'spurrindev',
|
|
# }
|
|
|
|
# Redis Configuration
|
|
# REDIS_CONFIG = {
|
|
# "host": "localhost",
|
|
# "port": 6379,
|
|
# "db": 0,
|
|
# "decode_responses": True, # For string operations
|
|
# }
|
|
|
|
# DB_CONFIG = {
|
|
# "host": os.getenv("DB_HOST", "localhost"),
|
|
# "user": os.getenv("DB_USER", "spurrinuser"),
|
|
# "password": os.getenv("DB_PASSWORD", "Admin@123"),
|
|
# "database": os.getenv("DB_NAME", "spurrin-live"),
|
|
# }
|
|
|
|
REDIS_CONFIG = {
|
|
"host": "localhost",
|
|
"port": 6379,
|
|
"db": 0,
|
|
"decode_responses": True, # For string operations
|
|
}
|
|
|
|
DB_CONFIG = {
|
|
"host": os.getenv("DB_HOST", "localhost"),
|
|
"user": os.getenv("DB_USER", "testuser"),
|
|
"password": os.getenv("DB_PASSWORD", "Admin@123"),
|
|
"database": os.getenv("DB_NAME", "spurrintest"),
|
|
}
|
|
|
|
# Redis connection pool
|
|
redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
|
|
redis_binary_pool = redis.ConnectionPool(
|
|
host="localhost", port=6379, db=1, decode_responses=False
|
|
)
|
|
|
|
|
|
def get_redis_client(binary=False):
|
|
"""Get Redis client from pool"""
|
|
logger.debug(f"Getting Redis client with binary={binary}")
|
|
try:
|
|
pool = redis_binary_pool if binary else redis_pool
|
|
client = redis.Redis(connection_pool=pool)
|
|
logger.debug("Redis client created successfully")
|
|
return client
|
|
except Exception as e:
|
|
logger.error(f"Failed to create Redis client: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def fetch_cached_answer(cache_key):
|
|
logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
|
|
start_time = time.time()
|
|
try:
|
|
redis_client = get_redis_client()
|
|
cached_answer = redis_client.get(cache_key)
|
|
fetch_time = time.time() - start_time
|
|
perf_logger.info(
|
|
f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
|
|
)
|
|
return cached_answer
|
|
except Exception as e:
|
|
logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
|
|
return None
|
|
|
|
|
|
# Cache TTL configurations
|
|
CACHE_TTL = {
|
|
"vector_store": timedelta(hours=24),
|
|
"chat_completion": timedelta(hours=1),
|
|
"document_metadata": timedelta(days=7),
|
|
}
|
|
|
|
DATA_DIR = os.path.join(script_dir, "hospital_data")
|
|
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
|
|
uploads_dir = os.path.join(script_dir, "llm-uploads")
|
|
|
|
if not os.path.exists(uploads_dir):
|
|
os.makedirs(uploads_dir)
|
|
|
|
nlp = spacy.load("en_core_web_sm")
|
|
|
|
load_dotenv()
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
client = OpenAI(api_key=OPENAI_API_KEY)
|
|
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
|
llm = ChatOpenAI(
|
|
model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
|
|
)
|
|
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
|
|
hospital_vector_stores = {}
|
|
vector_store_lock = threading.Lock()
|
|
|
|
|
|
@dataclass
|
|
class Document:
|
|
doc_id: int
|
|
page_num: int
|
|
content: str
|
|
|
|
|
|
class DocumentStatus(Enum):
|
|
PROCESSING = "processing"
|
|
PROCESSED = "processed"
|
|
FAILED = "failed"
|
|
|
|
|
|
async def get_db_pool():
|
|
return await aiomysql.create_pool(
|
|
host=DB_CONFIG["host"],
|
|
user=DB_CONFIG["user"],
|
|
password=DB_CONFIG["password"],
|
|
db=DB_CONFIG["database"],
|
|
autocommit=True,
|
|
)
|
|
|
|
|
|
async def get_hospital_id(hospital_code):
|
|
try:
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
await cursor.execute(
|
|
"SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
|
|
(hospital_code,),
|
|
)
|
|
result = await cursor.fetchone()
|
|
return result["id"] if result else None
|
|
except Exception as error:
|
|
logging.error(f"Database error: {error}")
|
|
return None
|
|
finally:
|
|
pool.close()
|
|
await pool.wait_closed()
|
|
|
|
|
|
CHUNK_SIZE = 1000
|
|
CHUNK_OVERLAP = 50
|
|
BATCH_SIZE = 1000
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=CHUNK_SIZE,
|
|
chunk_overlap=CHUNK_OVERLAP,
|
|
# length_function=len,
|
|
# separators=["\n\n", "\n", ". ", " ", ""]
|
|
)
|
|
|
|
|
|
# Update the JSON_PATH to be dynamic based on hospital_id
|
|
def get_icd_json_path(hospital_id):
|
|
hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
|
|
os.makedirs(hospital_data_dir, exist_ok=True)
|
|
return os.path.join(hospital_data_dir, "icd_data.json")
|
|
|
|
|
|
def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
|
|
"""Extract and process ICD codes with optimized processing and optional JSON saving"""
|
|
try:
|
|
# Initialize pattern compilation once
|
|
pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
|
|
|
|
# Process in chunks for large content
|
|
chunk_size = 50000 # Process 50KB at a time
|
|
icd_data = []
|
|
|
|
current_code = None
|
|
current_description = []
|
|
|
|
# Split content into manageable chunks
|
|
content_chunks = [
|
|
content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
|
|
]
|
|
|
|
# Process each chunk
|
|
for chunk in content_chunks:
|
|
lines = chunk.splitlines()
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
if current_code and current_description:
|
|
icd_data.append(
|
|
{
|
|
"code": current_code,
|
|
"description": " ".join(current_description).strip(),
|
|
}
|
|
)
|
|
current_code = None
|
|
current_description = []
|
|
continue
|
|
|
|
match = pattern.match(line)
|
|
if match:
|
|
if current_code and current_description:
|
|
icd_data.append(
|
|
{
|
|
"code": current_code,
|
|
"description": " ".join(current_description).strip(),
|
|
}
|
|
)
|
|
current_code, description = match.groups()
|
|
current_description = [description.strip()]
|
|
elif current_code:
|
|
current_description.append(line)
|
|
|
|
# Add final entry if exists
|
|
if current_code and current_description:
|
|
icd_data.append(
|
|
{
|
|
"code": current_code,
|
|
"description": " ".join(current_description).strip(),
|
|
}
|
|
)
|
|
|
|
# Save to hospital-specific JSON if requested
|
|
if save_to_json and icd_data:
|
|
try:
|
|
json_path = get_icd_json_path(hospital_id)
|
|
|
|
# Use a lock for thread safety
|
|
with threading.Lock():
|
|
if os.path.exists(json_path):
|
|
with open(json_path, "r", encoding="utf-8") as f:
|
|
try:
|
|
existing_data = json.load(f)
|
|
except json.JSONDecodeError:
|
|
existing_data = []
|
|
else:
|
|
existing_data = []
|
|
|
|
# Efficient deduplication using dictionary
|
|
seen_codes = {item["code"]: item for item in existing_data}
|
|
for item in icd_data:
|
|
seen_codes[item["code"]] = item
|
|
|
|
unique_data = list(seen_codes.values())
|
|
|
|
# Write atomically using temporary file
|
|
temp_path = f"{json_path}.tmp"
|
|
with open(temp_path, "w", encoding="utf-8") as f:
|
|
json.dump(unique_data, f, indent=2, ensure_ascii=False)
|
|
os.replace(temp_path, json_path)
|
|
|
|
logging.info(
|
|
f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
|
|
)
|
|
|
|
return icd_data
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in extract_and_process_icd_data: {e}")
|
|
return []
|
|
|
|
|
|
def load_icd_entries(hospital_id):
|
|
"""Load ICD entries from hospital-specific JSON file"""
|
|
json_path = get_icd_json_path(hospital_id)
|
|
try:
|
|
if os.path.exists(json_path):
|
|
with open(json_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
return []
|
|
except Exception as e:
|
|
logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
|
|
return []
|
|
|
|
|
|
# Update the process_icd_codes function to include hospital_id
|
|
async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
|
|
"""Process and store ICD codes using the optimized extraction function"""
|
|
try:
|
|
# Extract and save codes with hospital_id
|
|
extract_and_process_icd_data(content, hospital_id, save_to_json=True)
|
|
except Exception as e:
|
|
logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
|
|
|
|
|
|
async def initialize_icd_vector_store(hospital_id):
|
|
"""This function is deprecated. ICD codes are now handled through JSON search."""
|
|
logging.warning(
|
|
"initialize_icd_vector_store is deprecated - using JSON search instead"
|
|
)
|
|
return None
|
|
|
|
|
|
def extract_pdf_contents(pdf_path, hospital_id):
|
|
"""Extract PDF contents with optimized chunking and code extraction"""
|
|
try:
|
|
loader = PyPDFLoader(pdf_path)
|
|
pages = loader.load()
|
|
pages_content = []
|
|
|
|
for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
|
|
text = page.page_content.strip()
|
|
|
|
# Extract ICD codes from the page
|
|
icd_codes = extract_and_process_icd_data(
|
|
text, hospital_id
|
|
) # We'll set doc_id later
|
|
|
|
pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
|
|
|
|
return pages_content
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in extract_pdf_contents: {e}")
|
|
raise
|
|
|
|
|
|
async def insert_content_into_db(content, metadata, doc_id):
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cursor:
|
|
try:
|
|
metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
|
|
content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
|
|
|
|
metadata_values = [
|
|
(doc_id, key[:100], value)
|
|
for key, value in metadata.items()
|
|
if value
|
|
]
|
|
content_values = [
|
|
(doc_id, page_content["page"], page_content["text"])
|
|
for page_content in content
|
|
]
|
|
|
|
if metadata_values:
|
|
await cursor.executemany(metadata_query, metadata_values)
|
|
if content_values:
|
|
await cursor.executemany(content_query, content_values)
|
|
|
|
await conn.commit()
|
|
return {"message": "Success"}
|
|
except Exception as e:
|
|
await conn.rollback()
|
|
return {"error": str(e)}
|
|
|
|
|
|
async def initialize_or_load_vector_store(hospital_id, user_id="default"):
|
|
"""Initialize or load vector store with Redis caching and thread safety"""
|
|
store_key = f"{hospital_id}:{user_id}"
|
|
|
|
try:
|
|
# Check if we already have it loaded - with lock for thread safety
|
|
with vector_store_lock:
|
|
if store_key in hospital_vector_stores:
|
|
return hospital_vector_stores[store_key]
|
|
|
|
# Initialize vector store
|
|
redis_client = get_redis_client(binary=True)
|
|
cache_key = f"vector_store_data:{hospital_id}:{user_id}"
|
|
hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
|
|
|
|
if os.path.exists(hospital_dir):
|
|
logging.info(
|
|
f"Loading vector store for hospital {hospital_id} and user {user_id}"
|
|
)
|
|
vector_store = await asyncio.to_thread(
|
|
lambda: Chroma(
|
|
collection_name=f"hospital_{hospital_id}",
|
|
persist_directory=hospital_dir,
|
|
embedding_function=embeddings,
|
|
)
|
|
)
|
|
else:
|
|
logging.info(f"Creating vector store for hospital {hospital_id}")
|
|
os.makedirs(hospital_dir, exist_ok=True)
|
|
vector_store = await asyncio.to_thread(
|
|
lambda: Chroma(
|
|
collection_name=f"hospital_{hospital_id}",
|
|
persist_directory=hospital_dir,
|
|
embedding_function=embeddings,
|
|
)
|
|
)
|
|
|
|
# Store with lock for thread safety
|
|
with vector_store_lock:
|
|
hospital_vector_stores[store_key] = vector_store
|
|
|
|
return vector_store
|
|
except Exception as e:
|
|
logging.error(f"Error initializing vector store: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
|
|
"""Delete all vectors associated with a specific document from ChromaDB"""
|
|
try:
|
|
# Initialize vector store for the hospital
|
|
vector_store = await initialize_or_load_vector_store(hospital_id)
|
|
|
|
# Delete vectors with matching doc_id
|
|
await asyncio.to_thread(
|
|
lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
|
|
)
|
|
|
|
# Persist changes
|
|
await asyncio.to_thread(vector_store.persist)
|
|
|
|
# Clear Redis cache for this document
|
|
redis_client = get_redis_client()
|
|
pattern = f"vector_store_data:{hospital_id}:*"
|
|
for key in redis_client.scan_iter(pattern):
|
|
redis_client.delete(key)
|
|
|
|
logging.info(
|
|
f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
|
|
)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error deleting document vectors: {e}", exc_info=True)
|
|
return False
|
|
|
|
|
|
async def add_document_to_index(doc_id, hospital_id):
|
|
try:
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cursor:
|
|
vector_store = await initialize_or_load_vector_store(hospital_id)
|
|
|
|
await cursor.execute(
|
|
"SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
|
|
(doc_id,),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
|
|
total_pages = len(rows)
|
|
logging.info(f"Processing {total_pages} pages for document {doc_id}")
|
|
page_bar = tqdm_async(total=total_pages, desc="Processing pages")
|
|
|
|
async def process_page(page_data):
|
|
page_num, content = page_data
|
|
try:
|
|
icd_data = extract_and_process_icd_data(
|
|
content, hospital_id, save_to_json=False
|
|
)
|
|
chunks = text_splitter.split_text(content)
|
|
await asyncio.sleep(0) # Yield control
|
|
return page_num, chunks, icd_data
|
|
except Exception as e:
|
|
logging.error(f"Error processing page {page_num}: {e}")
|
|
return page_num, [], []
|
|
|
|
tasks = [asyncio.create_task(process_page(row)) for row in rows]
|
|
results = []
|
|
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
results.append(result)
|
|
page_bar.update(1)
|
|
|
|
page_bar.close()
|
|
|
|
# Vector addition progress bar
|
|
all_icd_data = []
|
|
all_chunks = []
|
|
all_metadatas = []
|
|
|
|
chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
|
|
|
|
for result in results:
|
|
page_num, chunks, icd_data = result
|
|
all_icd_data.extend(icd_data)
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
all_chunks.append(chunk)
|
|
all_metadatas.append(
|
|
{
|
|
"doc_id": str(doc_id),
|
|
"hospital_id": str(hospital_id),
|
|
"page_number": str(page_num),
|
|
"chunk_index": str(i),
|
|
}
|
|
)
|
|
|
|
if len(all_chunks) >= BATCH_SIZE:
|
|
chunk_add_bar.total += len(all_chunks)
|
|
chunk_add_bar.refresh()
|
|
await asyncio.to_thread(
|
|
vector_store.add_texts,
|
|
texts=all_chunks,
|
|
metadatas=all_metadatas,
|
|
)
|
|
all_chunks = []
|
|
all_metadatas = []
|
|
chunk_add_bar.update(BATCH_SIZE)
|
|
|
|
# Final batch
|
|
if all_chunks:
|
|
chunk_add_bar.total += len(all_chunks)
|
|
chunk_add_bar.refresh()
|
|
await asyncio.to_thread(
|
|
vector_store.add_texts,
|
|
texts=all_chunks,
|
|
metadatas=all_metadatas,
|
|
)
|
|
chunk_add_bar.update(len(all_chunks))
|
|
|
|
chunk_add_bar.close()
|
|
|
|
if all_icd_data:
|
|
logging.info(f"Saving {len(all_icd_data)} ICD codes")
|
|
extract_and_process_icd_data("", hospital_id, save_to_json=True)
|
|
|
|
await asyncio.to_thread(vector_store.persist)
|
|
logging.info(f"Successfully indexed document {doc_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error adding document: {e}")
|
|
return False
|
|
|
|
|
|
def is_table_request(query: str) -> bool:
|
|
"""
|
|
Determine if the user is requesting a response in tabular format.
|
|
"""
|
|
table_keywords = [
|
|
"table",
|
|
"tabular",
|
|
"in a table",
|
|
"in table format",
|
|
"in tabular format",
|
|
"chart",
|
|
"data",
|
|
"comparison",
|
|
"as a table",
|
|
"table format",
|
|
"in rows and columns",
|
|
"in a grid",
|
|
"breakdown",
|
|
"spreadsheet",
|
|
"comparison table",
|
|
"data table",
|
|
"structured table",
|
|
"tabular form",
|
|
"table form",
|
|
]
|
|
|
|
query_lower = query.lower()
|
|
return any(keyword in query_lower for keyword in table_keywords)
|
|
|
|
|
|
import re
|
|
|
|
|
|
def ensure_html_response(text: str) -> str:
|
|
"""
|
|
Ensure the response is properly formatted in HTML.
|
|
This function handles plain text conversion to HTML.
|
|
"""
|
|
if "<html" in text.lower() or "<body" in text.lower():
|
|
return text
|
|
|
|
has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
|
|
|
|
if not has_html_tags:
|
|
paragraphs = text.split("\n\n")
|
|
html_parts = []
|
|
in_ordered_list = False
|
|
in_unordered_list = False
|
|
|
|
for para in paragraphs:
|
|
if para.strip():
|
|
if re.match(r"^\s*[\*\-\•]\s", para):
|
|
if not in_unordered_list:
|
|
html_parts.append("<ul>")
|
|
in_unordered_list = True
|
|
|
|
lines = para.split("\n")
|
|
for line in lines:
|
|
if line.strip():
|
|
item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
|
|
html_parts.append(f"<li>{item}</li>")
|
|
|
|
elif re.match(r"^\s*\d+\.\s", para):
|
|
if not in_ordered_list:
|
|
html_parts.append("<ol>")
|
|
in_ordered_list = True
|
|
|
|
lines = para.split("\n")
|
|
for line in lines:
|
|
match = re.match(r"^\s*\d+\.\s*(.*)", line)
|
|
if match:
|
|
html_parts.append(f"<li>{match.group(1)}</li>")
|
|
|
|
else: # Close any open lists before adding a new paragraph
|
|
if in_ordered_list:
|
|
html_parts.append("</ol>")
|
|
in_ordered_list = False
|
|
if in_unordered_list:
|
|
html_parts.append("</ul>")
|
|
in_unordered_list = False
|
|
|
|
html_parts.append(f"<p>{para}</p>")
|
|
|
|
if in_ordered_list:
|
|
html_parts.append("</ol>")
|
|
if in_unordered_list:
|
|
html_parts.append("</ul>")
|
|
|
|
return "".join(html_parts)
|
|
|
|
else:
|
|
if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
|
|
paragraphs = text.split("\n\n")
|
|
html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
|
|
return "".join(html_parts)
|
|
|
|
return text
|
|
|
|
|
|
class RAGConversationManager:
|
|
"""
|
|
Conversation manager that uses Redis for RAG-based conversations only.
|
|
"""
|
|
|
|
def __init__(self, redis_client, ttl=3600, max_history_items=5):
|
|
self.redis_client = redis_client
|
|
self.ttl = ttl
|
|
self.max_history_items = max_history_items
|
|
self.lock = Lock()
|
|
|
|
def _get_redis_key(self, user_id, hospital_id, session_id=None):
|
|
"""Create Redis key for document-based conversations."""
|
|
if session_id:
|
|
return f"conv_history:{user_id}:{hospital_id}:{session_id}"
|
|
return f"conv_history:{user_id}:{hospital_id}"
|
|
|
|
async def add_rag_interaction(
|
|
self, user_id, hospital_id, question, answer, session_id=None
|
|
):
|
|
"""Add document-based (RAG) interaction to Redis."""
|
|
key = self._get_redis_key(user_id, hospital_id, session_id)
|
|
history = self.get_rag_history(user_id, hospital_id, session_id)
|
|
|
|
# Add new interaction
|
|
history.append(
|
|
{
|
|
"question": question,
|
|
"answer": answer,
|
|
"timestamp": time.time(),
|
|
"type": "rag", # Mark as RAG-based interaction
|
|
}
|
|
)
|
|
|
|
# Keep only last N interactions
|
|
history = history[-self.max_history_items :]
|
|
|
|
# Store updated history
|
|
try:
|
|
self.redis_client.setex(key, self.ttl, json.dumps(history))
|
|
logging.info(
|
|
f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to store RAG interaction in Redis: {e}")
|
|
|
|
def get_rag_history(self, user_id, hospital_id, session_id=None):
|
|
"""Get document-based (RAG) conversation history from Redis."""
|
|
key = self._get_redis_key(user_id, hospital_id, session_id)
|
|
try:
|
|
history_data = self.redis_client.get(key)
|
|
return json.loads(history_data) if history_data else []
|
|
except Exception as e:
|
|
logging.error(f"Failed to retrieve RAG history from Redis: {e}")
|
|
return []
|
|
|
|
def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
|
|
"""Get the most recent interactions for context."""
|
|
history = self.get_rag_history(user_id, hospital_id, session_id)
|
|
# Sort by timestamp (oldest first) for context window
|
|
sorted_history = sorted(history, key=lambda x: x.get("timestamp", 0))
|
|
return sorted_history[-window_size:] if sorted_history else []
|
|
|
|
def clear_history(self, user_id, hospital_id):
|
|
"""Clear conversation history."""
|
|
redis_key = self._get_redis_key(user_id, hospital_id)
|
|
try:
|
|
self.redis_client.delete(redis_key)
|
|
except Exception as e:
|
|
logging.error(f"Failed to clear Redis history: {e}")
|
|
|
|
|
|
class ContextMapper:
|
|
"""Enhanced context mapping using shared model manager"""
|
|
|
|
def __init__(self):
|
|
self.model_manager = ModelManager()
|
|
self.context_cache = {}
|
|
self.similarity_threshold = 0.6
|
|
|
|
def get_semantic_similarity(self, text1, text2):
|
|
"""Get semantic similarity using global model manager"""
|
|
return self.model_manager.get_semantic_similarity(text1, text2)
|
|
|
|
def extract_key_concepts(self, text):
|
|
"""Extract key concepts using NLP techniques"""
|
|
doc = nlp(text)
|
|
concepts = []
|
|
|
|
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
|
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
|
important_words = [
|
|
token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
|
|
]
|
|
|
|
concepts.extend([e[0] for e in entities])
|
|
concepts.extend(noun_phrases)
|
|
concepts.extend(important_words)
|
|
|
|
return list(set(concepts))
|
|
|
|
def map_conversation_context(
|
|
self, current_query, conversation_history, context_window=3
|
|
):
|
|
"""Map conversation context using enhanced NLP techniques"""
|
|
if not conversation_history:
|
|
return current_query
|
|
|
|
recent_context = conversation_history[-context_window:]
|
|
context_concepts = []
|
|
|
|
# Extract concepts from recent conversations
|
|
for interaction in recent_context:
|
|
q_concepts = self.extract_key_concepts(interaction["question"])
|
|
a_concepts = self.extract_key_concepts(interaction["answer"])
|
|
context_concepts.extend(q_concepts)
|
|
context_concepts.extend(a_concepts)
|
|
|
|
# Extract concepts from current query
|
|
query_concepts = self.extract_key_concepts(current_query)
|
|
|
|
# Find related concepts
|
|
related_concepts = []
|
|
for q_concept in query_concepts:
|
|
for c_concept in context_concepts:
|
|
similarity = self.get_semantic_similarity(q_concept, c_concept)
|
|
if similarity > self.similarity_threshold:
|
|
related_concepts.append(c_concept)
|
|
|
|
# Build enhanced query
|
|
if related_concepts:
|
|
enhanced_query = (
|
|
f"{current_query} in context of {', '.join(related_concepts)}"
|
|
)
|
|
else:
|
|
enhanced_query = current_query
|
|
|
|
return enhanced_query
|
|
|
|
|
|
# Initialize the context mapper
|
|
context_mapper = ContextMapper()
|
|
|
|
|
|
async def generate_contextual_query(
|
|
question: str, user_id: str, hospital_id: int, conversation_manager
|
|
) -> str:
|
|
"""Generate enhanced contextual query"""
|
|
context_window = conversation_manager.get_context_window(user_id, hospital_id)
|
|
|
|
if not context_window:
|
|
return question
|
|
|
|
# Enhanced context mapping
|
|
last_interaction = context_window[-1]
|
|
enhanced_context = f"""
|
|
Previous question: {last_interaction['question']}
|
|
Previous answer: {last_interaction['answer']}
|
|
Current question: {question}
|
|
|
|
Please generate a detailed search query that combines the context from the previous answer
|
|
with the current question, especially if the current question uses words like 'it', 'this',
|
|
'that', or asks for more details about the previous topic.
|
|
"""
|
|
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
lambda: client.chat.completions.create(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a context-aware query generator.",
|
|
},
|
|
{"role": "user", "content": enhanced_context},
|
|
],
|
|
temperature=0.3,
|
|
max_tokens=150,
|
|
)
|
|
)
|
|
contextual_query = response.choices[0].message.content.strip()
|
|
logging.info(f"Enhanced contextual query: {contextual_query}")
|
|
return contextual_query
|
|
except Exception as e:
|
|
logging.error(f"Error generating contextual query: {e}")
|
|
return question
|
|
|
|
|
|
def is_follow_up(current_question: str, conversation_history: list) -> bool:
|
|
"""Enhanced follow-up detection using NLP techniques"""
|
|
if not conversation_history:
|
|
return False
|
|
|
|
last_interaction = conversation_history[-1]
|
|
|
|
# Get semantic similarity with higher threshold
|
|
similarity = context_mapper.get_semantic_similarity(
|
|
current_question, f"{last_interaction['question']} {last_interaction['answer']}"
|
|
)
|
|
|
|
# Enhanced referential check
|
|
doc = nlp(current_question.lower())
|
|
has_referential = any(
|
|
token.lemma_
|
|
in [
|
|
"it",
|
|
"this",
|
|
"that",
|
|
"these",
|
|
"those",
|
|
"they",
|
|
"he",
|
|
"she",
|
|
"about",
|
|
"more",
|
|
]
|
|
for token in doc
|
|
)
|
|
|
|
# Extract concepts with improved entity detection
|
|
current_concepts = set(context_mapper.extract_key_concepts(current_question))
|
|
last_concepts = set(
|
|
context_mapper.extract_key_concepts(
|
|
f"{last_interaction['question']} {last_interaction['answer']}"
|
|
)
|
|
)
|
|
|
|
# Calculate enhanced concept overlap
|
|
concept_overlap = (
|
|
len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
|
|
if current_concepts
|
|
else 0
|
|
)
|
|
|
|
# More aggressive follow-up detection
|
|
return (
|
|
similarity > 0.3 # Lowered threshold
|
|
or has_referential
|
|
or concept_overlap > 0.2 # Lowered threshold
|
|
or any(
|
|
word in current_question.lower()
|
|
for word in ["more", "about", "elaborate", "explain"]
|
|
)
|
|
)
|
|
|
|
|
|
async def get_relevant_context(question, hospital_id, doc_id=None):
|
|
try:
|
|
cache_key = f"context:hospital_{hospital_id}"
|
|
if doc_id:
|
|
cache_key += f":doc_{doc_id}"
|
|
cache_key += f":{question.lower().strip()}"
|
|
|
|
redis_client = get_redis_client()
|
|
|
|
cached_context = redis_client.get(cache_key)
|
|
if cached_context:
|
|
logging.info(f"Cache hit for key: {cache_key}")
|
|
return (
|
|
cached_context.decode("utf-8")
|
|
if isinstance(cached_context, bytes)
|
|
else cached_context
|
|
)
|
|
|
|
vector_store = await initialize_or_load_vector_store(hospital_id)
|
|
if not vector_store:
|
|
return ""
|
|
|
|
retriever = vector_store.as_retriever(
|
|
search_type="mmr",
|
|
search_kwargs={
|
|
"k": 10,
|
|
"fetch_k": 20,
|
|
"lambda_mult": 0.6,
|
|
# "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
|
},
|
|
)
|
|
|
|
docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
|
|
if not docs:
|
|
return ""
|
|
|
|
sorted_docs = sorted(
|
|
docs,
|
|
key=lambda x: (
|
|
int(x.metadata.get("page_number", 0)),
|
|
int(x.metadata.get("chunk_index", 0)),
|
|
),
|
|
)
|
|
|
|
context_parts = [doc.page_content for doc in sorted_docs]
|
|
context = "\n\n".join(context_parts)
|
|
|
|
try:
|
|
redis_client.setex(
|
|
cache_key,
|
|
int(CACHE_TTL["vector_store"].total_seconds()),
|
|
context.encode("utf-8") if isinstance(context, str) else context,
|
|
)
|
|
logging.info(f"Cached context for key: {cache_key}")
|
|
except Exception as cache_error:
|
|
logging.error(f"Failed to cache context: {cache_error}")
|
|
|
|
return context
|
|
except Exception as e:
|
|
logging.error(f"Error getting relevant context: {e}")
|
|
return ""
|
|
|
|
|
|
def format_conversation_context(conv_history):
|
|
"""Format conversation history into a string"""
|
|
if not conv_history:
|
|
return "No previous conversation."
|
|
return "\n".join(
|
|
[
|
|
f"Q: {interaction['question']}\nA: {interaction['answer']}"
|
|
for interaction in conv_history
|
|
]
|
|
)
|
|
|
|
|
|
def get_icd_context_from_question(question, hospital_id):
|
|
"""Extract any valid ICD codes from the question and return context"""
|
|
icd_data = load_icd_entries(hospital_id)
|
|
matches = []
|
|
code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
|
|
|
|
seen = set()
|
|
for code in code_pattern:
|
|
for entry in icd_data:
|
|
if entry["code"] == code and code not in seen:
|
|
matches.append(f"{entry['code']}: {entry['description']}")
|
|
seen.add(code)
|
|
return "\n".join(matches)
|
|
|
|
|
|
def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
|
|
"""Get fuzzy matches for ICD codes from the question"""
|
|
icd_data = load_icd_entries(hospital_id)
|
|
descriptions = [entry["description"] for entry in icd_data]
|
|
matches = process.extract(
|
|
question, descriptions, limit=top_n, score_cutoff=threshold
|
|
)
|
|
|
|
matched_context = []
|
|
for desc, score, _ in matches:
|
|
for entry in icd_data:
|
|
if entry["description"] == desc:
|
|
matched_context.append(f"{entry['code']}: {entry['description']}")
|
|
break
|
|
|
|
return "\n".join(matched_context)
|
|
|
|
|
|
async def generate_answer_with_rag(
|
|
question,
|
|
hospital_id,
|
|
client,
|
|
doc_id=None,
|
|
user_id="default",
|
|
conversation_manager=None,
|
|
session_id=None,
|
|
):
|
|
"""Generate an answer using RAG with improved conversation flow"""
|
|
try:
|
|
html_instruction = """
|
|
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
|
- Use <p> tags for paragraphs
|
|
- Use <h2>, <h3> tags for headings and subheadings
|
|
- Use <ul>, <li> tags for bullet points
|
|
- Use <ol>, <li> tags for numbered lists
|
|
- Use <blockquote> for quoted text
|
|
- Use <strong> for bold text and <em> for emphasis
|
|
"""
|
|
|
|
table_instruction = """
|
|
- For tables, use proper HTML table structure:
|
|
<table border="1">
|
|
<thead>
|
|
<tr>
|
|
<th colspan="{total_columns}">{table_title}</th>
|
|
</tr>
|
|
<tr>
|
|
{table_headers}
|
|
</tr>
|
|
</thead>
|
|
<tbody>
|
|
{table_rows}
|
|
</tbody>
|
|
</table>
|
|
"""
|
|
# Get conversation history first
|
|
conv_history = (
|
|
conversation_manager.get_context_window(user_id, hospital_id, session_id)
|
|
if conversation_manager
|
|
else []
|
|
)
|
|
|
|
# Get contextual query and relevant context first
|
|
contextual_query = await generate_contextual_query(
|
|
question, user_id, hospital_id, conversation_manager
|
|
)
|
|
|
|
# Get document context
|
|
doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id)
|
|
|
|
if not doc_context:
|
|
return {
|
|
"answer": """
|
|
<p>I apologize, but I couldn't find any relevant information in the hospital documents to answer your question.</p>
|
|
<p>Please try rephrasing your question or asking about a different topic that might be covered in the documents.</p>
|
|
"""
|
|
}, 200
|
|
|
|
# Generate RAG answer with enhanced context
|
|
prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question.
|
|
Previous conversation:
|
|
{format_conversation_context(conv_history)}
|
|
|
|
Context from documents:
|
|
{doc_context}
|
|
|
|
Current question: {question}
|
|
|
|
Instructions:
|
|
1. ONLY use information from the provided document context to answer the question
|
|
2. If the answer cannot be fully derived from the context, state that clearly
|
|
3. Do not use any external knowledge or make assumptions
|
|
4. When providing medical codes (ICD, CPT, etc.):
|
|
- Only use codes that appear in the provided context
|
|
- Do not invent or use codes from external knowledge
|
|
- If multiple codes are relevant, list them all
|
|
- Remove all decimal points (e.g., use 'A150' instead of 'A15.0')
|
|
5. Format the response in clear HTML with appropriate tags
|
|
6. If the context doesn't contain enough information to answer the question completely,
|
|
acknowledge this limitation and only provide the information that is available
|
|
|
|
{html_instruction}
|
|
{table_instruction if is_table_request(question) else ""}
|
|
"""
|
|
|
|
response = await asyncio.to_thread(
|
|
lambda: client.chat.completions.create(
|
|
model="gpt-3.5-turbo-16k",
|
|
messages=[
|
|
{"role": "system", "content": prompt_template},
|
|
{"role": "user", "content": question},
|
|
],
|
|
temperature=0.3,
|
|
max_tokens=1000,
|
|
)
|
|
)
|
|
|
|
answer = ensure_html_response(response.choices[0].message.content)
|
|
logging.info(f"Generated RAG answer for question: {question}")
|
|
|
|
# Store interaction in history
|
|
if conversation_manager:
|
|
await conversation_manager.add_rag_interaction(
|
|
user_id, hospital_id, question, answer, session_id
|
|
)
|
|
|
|
return {"answer": answer}, 200
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in generate_answer_with_rag: {e}")
|
|
return {"answer": f"<p>Error: {str(e)}</p>"}, 500
|
|
|
|
|
|
async def load_existing_vector_stores():
|
|
"""Load existing Chroma vector stores for each hospital"""
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cursor:
|
|
try:
|
|
await cursor.execute("SELECT DISTINCT id FROM hospitals")
|
|
hospital_ids = [row[0] for row in await cursor.fetchall()]
|
|
|
|
for hospital_id in hospital_ids:
|
|
try:
|
|
await initialize_or_load_vector_store(hospital_id)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Failed to load vector store for hospital {hospital_id}: {e}"
|
|
)
|
|
continue
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error loading vector stores: {e}")
|
|
|
|
|
|
async def get_failed_page(doc_id):
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cursor:
|
|
try:
|
|
await cursor.execute(
|
|
"SELECT failed_page FROM documents WHERE id = %s", (doc_id,)
|
|
)
|
|
result = await cursor.fetchone()
|
|
return result[0] if result and result[0] else None
|
|
except Exception as e:
|
|
logging.error(f"Database error checking failed_page: {e}")
|
|
return None
|
|
|
|
|
|
async def update_document_status(doc_id, status, failed_page=None):
|
|
"""Update document status with enum validation"""
|
|
if isinstance(status, str):
|
|
status = DocumentStatus[status.upper()].value
|
|
|
|
pool = await get_db_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cursor:
|
|
try:
|
|
if failed_page:
|
|
await cursor.execute(
|
|
"UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
|
|
(status, failed_page, doc_id),
|
|
)
|
|
else:
|
|
await cursor.execute(
|
|
"UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
|
|
(status, doc_id),
|
|
)
|
|
await conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Database update error: {e}")
|
|
return False
|
|
|
|
|
|
thread_pool = ThreadPoolExecutor(max_workers=10)
|
|
|
|
|
|
def async_to_sync(coroutine):
|
|
"""Helper function to run async code in sync context"""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
return loop.run_until_complete(coroutine)
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
@app.route("/flask-api", methods=["GET"])
|
|
def health_check():
|
|
"""Health check endpoint"""
|
|
access_logger.info(f"Health check request received from {request.remote_addr}")
|
|
return jsonify({"status": "ok"}), 200
|
|
|
|
|
|
@app.route("/flask-api/process-pdf", methods=["POST"])
|
|
def process_pdf():
|
|
access_logger.info(f"PDF processing request received from {request.remote_addr}")
|
|
file_path = None
|
|
try:
|
|
file = request.files.get("pdf")
|
|
hospital_id = request.form.get("hospital_id")
|
|
doc_id = request.form.get("doc_id")
|
|
|
|
logging.info(
|
|
f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
|
|
)
|
|
|
|
if not all([file, hospital_id, doc_id]):
|
|
return jsonify({"error": "Missing required parameters"}), 400
|
|
|
|
def process_in_background():
|
|
nonlocal file_path
|
|
try:
|
|
async_to_sync(update_document_status(doc_id, "processing"))
|
|
|
|
# Add progress logging
|
|
logging.info(f"Starting processing of document {doc_id}")
|
|
|
|
filename = f"doc_{doc_id}_{file.filename}"
|
|
file_path = os.path.join(uploads_dir, filename)
|
|
|
|
with open(file_path, "wb") as f:
|
|
file.save(f)
|
|
|
|
logging.info("Extracting PDF contents...")
|
|
content = extract_pdf_contents(file_path, int(hospital_id))
|
|
|
|
logging.info("Inserting content into database...")
|
|
metadata = {"filename": filename}
|
|
result = async_to_sync(
|
|
insert_content_into_db(content, metadata, doc_id)
|
|
)
|
|
|
|
if "error" in result:
|
|
async_to_sync(update_document_status(doc_id, "failed", 1))
|
|
return False
|
|
|
|
logging.info("Creating embeddings and indexing...")
|
|
success = async_to_sync(add_document_to_index(doc_id, hospital_id))
|
|
|
|
if success:
|
|
logging.info("Document processing completed successfully")
|
|
async_to_sync(update_document_status(doc_id, "processed"))
|
|
return True
|
|
else:
|
|
logging.error("Document processing failed during indexing")
|
|
async_to_sync(update_document_status(doc_id, "failed"))
|
|
return False
|
|
|
|
except Exception as e:
|
|
logging.error(f"Processing error: {e}")
|
|
async_to_sync(update_document_status(doc_id, "failed"))
|
|
return False
|
|
finally:
|
|
if file_path and os.path.exists(file_path):
|
|
try:
|
|
os.remove(file_path)
|
|
except Exception as e:
|
|
logging.error(f"Error removing temporary file: {e}")
|
|
|
|
# Execute processing and wait for result
|
|
future = thread_pool.submit(process_in_background)
|
|
success = future.result()
|
|
|
|
if success:
|
|
return jsonify({"message": "Document processed successfully"}), 200
|
|
else:
|
|
return jsonify({"error": "Document processing failed"}), 500
|
|
|
|
except Exception as e:
|
|
logging.error(f"API error: {e}")
|
|
if file_path and os.path.exists(file_path):
|
|
try:
|
|
os.remove(file_path)
|
|
except Exception as file_e:
|
|
logging.error(f"Error removing temporary file: {file_e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
# Initialize the RAG conversation manager
|
|
redis_client = get_redis_client()
|
|
conversation_manager = RAGConversationManager(redis_client)
|
|
|
|
|
|
@app.route("/flask-api/generate-answer", methods=["POST"])
|
|
def rag_answer_api():
|
|
"""Sync API endpoint for RAG-based question answering with conversation history."""
|
|
access_logger.info(f"Generate answer request received from {request.remote_addr}")
|
|
try:
|
|
data = request.json
|
|
question = data.get("question", "").strip().lower()
|
|
hospital_code = data.get("hospital_code")
|
|
doc_id = data.get("doc_id")
|
|
user_id = data.get("user_id", "default")
|
|
session_id = data.get("session_id", None)
|
|
|
|
logging.info(f"Received question from user {user_id}: {question}")
|
|
logging.info(f"Received hospital code: {hospital_code}")
|
|
logging.info(f"Received session_id: {session_id}")
|
|
|
|
def process_rag_answer():
|
|
try:
|
|
hospital_id = async_to_sync(get_hospital_id(hospital_code))
|
|
logging.info(f"Resolved hospital ID: {hospital_id}")
|
|
|
|
if not hospital_id:
|
|
return {
|
|
"error": "Invalid or missing 'hospital_code' in request"
|
|
}, 400
|
|
|
|
# Regular RAG answer
|
|
return async_to_sync(
|
|
generate_answer_with_rag(
|
|
question=question,
|
|
hospital_id=hospital_id,
|
|
client=client,
|
|
doc_id=doc_id,
|
|
user_id=user_id,
|
|
conversation_manager=conversation_manager,
|
|
session_id=session_id,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error in process_rag_answer: {e}")
|
|
return {"error": str(e)}, 500
|
|
|
|
if not question:
|
|
return jsonify({"error": "Missing 'question' in request"}), 400
|
|
|
|
future = thread_pool.submit(process_rag_answer)
|
|
result, status_code = future.result()
|
|
|
|
return jsonify(result), status_code
|
|
|
|
except Exception as e:
|
|
logging.error(f"API error: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
|
|
def delete_document_vectors_endpoint():
|
|
"""Endpoint to delete document vectors from ChromaDB"""
|
|
try:
|
|
data = request.json
|
|
hospital_id = data.get("hospital_id")
|
|
doc_id = data.get("doc_id")
|
|
|
|
if not all([hospital_id, doc_id]):
|
|
return jsonify({"error": "Missing required parameters"}), 400
|
|
|
|
logging.info(
|
|
f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
|
|
)
|
|
|
|
def process_deletion():
|
|
try:
|
|
success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
|
|
if success:
|
|
return {"message": "Document vectors deleted successfully"}, 200
|
|
else:
|
|
return {"error": "Failed to delete document vectors"}, 500
|
|
except Exception as e:
|
|
logging.error(f"Error in vector deletion process: {e}")
|
|
return {"error": str(e)}, 500
|
|
|
|
future = thread_pool.submit(process_deletion)
|
|
result, status_code = future.result()
|
|
|
|
return jsonify(result), status_code
|
|
|
|
except Exception as e:
|
|
logging.error(f"API error: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route("/flask-api/get-chroma-content", methods=["GET"])
|
|
def get_chroma_content_endpoint():
|
|
"""API endpoint to get ChromaDB content by hospital_id"""
|
|
try:
|
|
hospital_id = request.args.get("hospital_id")
|
|
limit = int(request.args.get("limit", 30000))
|
|
|
|
if not hospital_id:
|
|
return jsonify({"error": "Missing required parameter: hospital_id"}), 400
|
|
|
|
def process_fetch():
|
|
try:
|
|
result, status_code = async_to_sync(
|
|
get_chroma_content_by_hospital(
|
|
hospital_id=int(hospital_id), limit=limit
|
|
)
|
|
)
|
|
return result, status_code
|
|
except Exception as e:
|
|
logging.error(f"Error in ChromaDB fetch process: {e}")
|
|
return {"error": str(e)}, 500
|
|
|
|
future = thread_pool.submit(process_fetch)
|
|
result, status_code = future.result()
|
|
|
|
return jsonify(result), status_code
|
|
|
|
except Exception as e:
|
|
logging.error(f"API error: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
|
|
"""Fetch content from ChromaDB for a specific hospital"""
|
|
try:
|
|
# Initialize vector store
|
|
vector_store = await initialize_or_load_vector_store(hospital_id)
|
|
if not vector_store:
|
|
return {"error": "Vector store not found"}, 404
|
|
|
|
# Get collection
|
|
collection = vector_store._collection
|
|
|
|
# Query the collection with hospital_id filter
|
|
results = await asyncio.to_thread(
|
|
lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
|
|
)
|
|
|
|
if not results or not results["ids"]:
|
|
return {"data": [], "count": 0}, 200
|
|
|
|
# Format the response
|
|
formatted_results = []
|
|
for i in range(len(results["ids"])):
|
|
formatted_results.append(
|
|
{
|
|
"id": results["ids"][i],
|
|
"content": results["documents"][i],
|
|
"metadata": results["metadatas"][i],
|
|
}
|
|
)
|
|
|
|
return {"data": formatted_results, "count": len(formatted_results)}, 200
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error fetching ChromaDB content: {e}")
|
|
return {"error": str(e)}, 500
|
|
|
|
|
|
@app.before_request
|
|
def before_request():
|
|
request._start_time = time.time()
|
|
|
|
|
|
@app.after_request
|
|
def after_request(response):
|
|
if hasattr(request, "_start_time"):
|
|
duration = time.time() - request._start_time
|
|
access_logger.info(
|
|
f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
|
|
f"IP: {request.remote_addr}"
|
|
)
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting SpurrinAI application")
|
|
logger.info(f"Python version: {sys.version}")
|
|
logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
|
|
|
|
try:
|
|
model_manager = ModelManager()
|
|
logger.info("Model manager initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize model manager: {e}")
|
|
sys.exit(1)
|
|
|
|
# Initialize directories
|
|
os.makedirs(DATA_DIR, exist_ok=True)
|
|
os.makedirs(CHROMA_DIR, exist_ok=True)
|
|
logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
|
|
|
|
# Clear Redis cache
|
|
redis_client = get_redis_client()
|
|
cleared_keys = 0
|
|
for key in redis_client.scan_iter("vector_store_data:*"):
|
|
redis_client.delete(key)
|
|
cleared_keys += 1
|
|
logger.info(f"Cleared {cleared_keys} Redis cache keys")
|
|
|
|
# Load vector stores
|
|
logger.info("Loading existing vector stores...")
|
|
async_to_sync(load_existing_vector_stores())
|
|
logger.info("Vector stores loaded successfully")
|
|
|
|
# Start application
|
|
logger.info("Starting Flask application on port 5000")
|
|
app.run(port=5000, debug=False) |