forked from rohit/spurrin-backend
2161 lines
74 KiB
Python
2161 lines
74 KiB
Python
"""
|
||
SpurrinAI - Intelligent Document Processing and Question Answering System
|
||
Copyright (c) 2024 Tech4biz. All rights reserved.
|
||
|
||
This module implements the main Flask application for the SpurrinAI system,
|
||
providing REST APIs for document processing, vector storage, and question answering
|
||
using RAG (Retrieval Augmented Generation) architecture.
|
||
|
||
Author: Tech4biz Development Team
|
||
Version: 1.0.0
|
||
Last Updated: 2024-01-19
|
||
"""
|
||
|
||
# Standard library imports
|
||
import os
|
||
import re
|
||
import sys
|
||
import json
|
||
import time
|
||
import threading
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from datetime import timedelta
|
||
from enum import Enum
|
||
|
||
# Third-party imports
|
||
import spacy
|
||
import redis
|
||
import aiomysql
|
||
from dotenv import load_dotenv
|
||
from flask import Flask, request, jsonify, Response
|
||
from flask_cors import CORS
|
||
from tqdm import tqdm
|
||
from tqdm.asyncio import tqdm as tqdm_async
|
||
from langchain_community.document_loaders import PyPDFLoader
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.embeddings import OpenAIEmbeddings
|
||
from langchain_community.vectorstores import Chroma
|
||
from langchain_community.chat_models import ChatOpenAI
|
||
from langchain.chains import RetrievalQA
|
||
from langchain.prompts import PromptTemplate
|
||
from openai import OpenAI
|
||
from rapidfuzz import process
|
||
from threading import Lock
|
||
|
||
# Local imports
|
||
from model_manager import ModelManager
|
||
|
||
# Suppress warnings
|
||
import warnings
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
# Initialize NLTK
|
||
import nltk
|
||
|
||
nltk.download("punkt")
|
||
|
||
# Configure logging
|
||
import logging
|
||
import logging.handlers
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
log_file_path = os.path.join(script_dir, "error.log")
|
||
logging.basicConfig(filename=log_file_path, level=logging.INFO)
|
||
|
||
|
||
# Configure logging
|
||
def setup_logging():
|
||
log_dir = os.path.join(script_dir, "logs")
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
main_log = os.path.join(log_dir, "app.log")
|
||
error_log = os.path.join(log_dir, "error.log")
|
||
access_log = os.path.join(log_dir, "access.log")
|
||
perf_log = os.path.join(log_dir, "performance.log")
|
||
|
||
# Create formatters
|
||
detailed_formatter = logging.Formatter(
|
||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||
)
|
||
access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||
|
||
# Main logger setup
|
||
main_handler = logging.handlers.RotatingFileHandler(
|
||
main_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
main_handler.setFormatter(detailed_formatter)
|
||
main_handler.setLevel(logging.INFO)
|
||
|
||
# Error logger setup
|
||
error_handler = logging.handlers.RotatingFileHandler(
|
||
error_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
error_handler.setFormatter(detailed_formatter)
|
||
error_handler.setLevel(logging.ERROR)
|
||
|
||
# Access logger setup
|
||
access_handler = logging.handlers.TimedRotatingFileHandler(
|
||
access_log, when="midnight", interval=1, backupCount=30
|
||
)
|
||
access_handler.setFormatter(access_formatter)
|
||
access_handler.setLevel(logging.INFO)
|
||
|
||
# Performance logger setup
|
||
perf_handler = logging.handlers.RotatingFileHandler(
|
||
perf_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
perf_handler.setFormatter(detailed_formatter)
|
||
perf_handler.setLevel(logging.INFO)
|
||
|
||
# Configure root logger
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(main_handler)
|
||
root_logger.addHandler(error_handler)
|
||
|
||
# Create specific loggers
|
||
access_logger = logging.getLogger("access")
|
||
access_logger.addHandler(access_handler)
|
||
access_logger.setLevel(logging.INFO)
|
||
|
||
perf_logger = logging.getLogger("performance")
|
||
perf_logger.addHandler(perf_handler)
|
||
perf_logger.setLevel(logging.INFO)
|
||
|
||
return root_logger, access_logger, perf_logger
|
||
|
||
|
||
# Initialize loggers
|
||
logger, access_logger, perf_logger = setup_logging()
|
||
|
||
# DB_CONFIG = {
|
||
# 'host': 'localhost',
|
||
# 'user': 'flaskuser',
|
||
# 'password': 'Flask@123',
|
||
# 'database': 'spurrinai',
|
||
# }
|
||
|
||
# DB_CONFIG = {
|
||
# 'host': 'localhost',
|
||
# 'user': 'spurrindevuser',
|
||
# 'password': 'Admin@123',
|
||
# 'database': 'spurrindev',
|
||
# }
|
||
|
||
# Redis Configuration
|
||
REDIS_CONFIG = {
|
||
"host": "localhost",
|
||
"port": 6379,
|
||
"db": 0,
|
||
"decode_responses": True, # For string operations
|
||
}
|
||
|
||
DB_CONFIG = {
|
||
"host": os.getenv("DB_HOST", "localhost"),
|
||
"user": os.getenv("DB_USER", "testuser"),
|
||
"password": os.getenv("DB_PASSWORD", "Admin@123"),
|
||
"database": os.getenv("DB_NAME", "spurrintest"),
|
||
}
|
||
|
||
# Redis connection pool
|
||
redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
|
||
redis_binary_pool = redis.ConnectionPool(
|
||
host="localhost", port=6379, db=1, decode_responses=False
|
||
)
|
||
|
||
|
||
def get_redis_client(binary=False):
|
||
"""Get Redis client from pool"""
|
||
logger.debug(f"Getting Redis client with binary={binary}")
|
||
try:
|
||
pool = redis_binary_pool if binary else redis_pool
|
||
client = redis.Redis(connection_pool=pool)
|
||
logger.debug("Redis client created successfully")
|
||
return client
|
||
except Exception as e:
|
||
logger.error(f"Failed to create Redis client: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
def fetch_cached_answer(cache_key):
|
||
logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
|
||
start_time = time.time()
|
||
try:
|
||
redis_client = get_redis_client()
|
||
cached_answer = redis_client.get(cache_key)
|
||
fetch_time = time.time() - start_time
|
||
perf_logger.info(
|
||
f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
|
||
)
|
||
return cached_answer
|
||
except Exception as e:
|
||
logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
|
||
return None
|
||
|
||
|
||
# Cache TTL configurations
|
||
CACHE_TTL = {
|
||
"vector_store": timedelta(hours=24),
|
||
"chat_completion": timedelta(hours=1),
|
||
"document_metadata": timedelta(days=7),
|
||
}
|
||
|
||
DATA_DIR = os.path.join(script_dir, "hospital_data")
|
||
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
|
||
uploads_dir = os.path.join(script_dir, "llm-uploads")
|
||
|
||
if not os.path.exists(uploads_dir):
|
||
os.makedirs(uploads_dir)
|
||
|
||
nlp = spacy.load("en_core_web_sm")
|
||
|
||
load_dotenv()
|
||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
|
||
client = OpenAI(api_key=OPENAI_API_KEY)
|
||
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
||
llm = ChatOpenAI(
|
||
model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
|
||
)
|
||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
|
||
hospital_vector_stores = {}
|
||
vector_store_lock = threading.Lock()
|
||
|
||
|
||
@dataclass
|
||
class Document:
|
||
doc_id: int
|
||
page_num: int
|
||
content: str
|
||
|
||
|
||
class DocumentStatus(Enum):
|
||
PROCESSING = "processing"
|
||
PROCESSED = "processed"
|
||
FAILED = "failed"
|
||
|
||
|
||
async def get_db_pool():
|
||
return await aiomysql.create_pool(
|
||
host=DB_CONFIG["host"],
|
||
user=DB_CONFIG["user"],
|
||
password=DB_CONFIG["password"],
|
||
db=DB_CONFIG["database"],
|
||
autocommit=True,
|
||
)
|
||
|
||
|
||
async def get_hospital_id(hospital_code):
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute(
|
||
"SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
|
||
(hospital_code,),
|
||
)
|
||
result = await cursor.fetchone()
|
||
return result["id"] if result else None
|
||
except Exception as error:
|
||
logging.error(f"Database error: {error}")
|
||
return None
|
||
finally:
|
||
pool.close()
|
||
await pool.wait_closed()
|
||
|
||
|
||
CHUNK_SIZE = 1000
|
||
CHUNK_OVERLAP = 50
|
||
BATCH_SIZE = 1000
|
||
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=CHUNK_OVERLAP,
|
||
# length_function=len,
|
||
# separators=["\n\n", "\n", ". ", " ", ""]
|
||
)
|
||
|
||
|
||
# Update the JSON_PATH to be dynamic based on hospital_id
|
||
def get_icd_json_path(hospital_id):
|
||
hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
|
||
os.makedirs(hospital_data_dir, exist_ok=True)
|
||
return os.path.join(hospital_data_dir, "icd_data.json")
|
||
|
||
|
||
def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
|
||
"""Extract and process ICD codes with optimized processing and optional JSON saving"""
|
||
try:
|
||
# Initialize pattern compilation once
|
||
pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
|
||
|
||
# Process in chunks for large content
|
||
chunk_size = 50000 # Process 50KB at a time
|
||
icd_data = []
|
||
|
||
current_code = None
|
||
current_description = []
|
||
|
||
# Split content into manageable chunks
|
||
content_chunks = [
|
||
content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
|
||
]
|
||
|
||
# Process each chunk
|
||
for chunk in content_chunks:
|
||
lines = chunk.splitlines()
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
current_code = None
|
||
current_description = []
|
||
continue
|
||
|
||
match = pattern.match(line)
|
||
if match:
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
current_code, description = match.groups()
|
||
current_description = [description.strip()]
|
||
elif current_code:
|
||
current_description.append(line)
|
||
|
||
# Add final entry if exists
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
|
||
# Save to hospital-specific JSON if requested
|
||
if save_to_json and icd_data:
|
||
try:
|
||
json_path = get_icd_json_path(hospital_id)
|
||
|
||
# Use a lock for thread safety
|
||
with threading.Lock():
|
||
if os.path.exists(json_path):
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
try:
|
||
existing_data = json.load(f)
|
||
except json.JSONDecodeError:
|
||
existing_data = []
|
||
else:
|
||
existing_data = []
|
||
|
||
# Efficient deduplication using dictionary
|
||
seen_codes = {item["code"]: item for item in existing_data}
|
||
for item in icd_data:
|
||
seen_codes[item["code"]] = item
|
||
|
||
unique_data = list(seen_codes.values())
|
||
|
||
# Write atomically using temporary file
|
||
temp_path = f"{json_path}.tmp"
|
||
with open(temp_path, "w", encoding="utf-8") as f:
|
||
json.dump(unique_data, f, indent=2, ensure_ascii=False)
|
||
os.replace(temp_path, json_path)
|
||
|
||
logging.info(
|
||
f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(
|
||
f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
|
||
)
|
||
|
||
return icd_data
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in extract_and_process_icd_data: {e}")
|
||
return []
|
||
|
||
|
||
def load_icd_entries(hospital_id):
|
||
"""Load ICD entries from hospital-specific JSON file"""
|
||
json_path = get_icd_json_path(hospital_id)
|
||
try:
|
||
if os.path.exists(json_path):
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
return []
|
||
except Exception as e:
|
||
logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
|
||
return []
|
||
|
||
|
||
# Update the process_icd_codes function to include hospital_id
|
||
async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
|
||
"""Process and store ICD codes using the optimized extraction function"""
|
||
try:
|
||
# Extract and save codes with hospital_id
|
||
extract_and_process_icd_data(content, hospital_id, save_to_json=True)
|
||
except Exception as e:
|
||
logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
|
||
|
||
|
||
async def initialize_icd_vector_store(hospital_id):
|
||
"""This function is deprecated. ICD codes are now handled through JSON search."""
|
||
logging.warning(
|
||
"initialize_icd_vector_store is deprecated - using JSON search instead"
|
||
)
|
||
return None
|
||
|
||
|
||
def extract_pdf_contents(pdf_path, hospital_id):
|
||
"""Extract PDF contents with optimized chunking and code extraction"""
|
||
try:
|
||
loader = PyPDFLoader(pdf_path)
|
||
pages = loader.load()
|
||
pages_content = []
|
||
|
||
for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
|
||
text = page.page_content.strip()
|
||
|
||
# Extract ICD codes from the page
|
||
icd_codes = extract_and_process_icd_data(
|
||
text, hospital_id
|
||
) # We'll set doc_id later
|
||
|
||
pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
|
||
|
||
return pages_content
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in extract_pdf_contents: {e}")
|
||
raise
|
||
|
||
|
||
async def insert_content_into_db(content, metadata, doc_id):
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
|
||
content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
|
||
|
||
metadata_values = [
|
||
(doc_id, key[:100], value)
|
||
for key, value in metadata.items()
|
||
if value
|
||
]
|
||
content_values = [
|
||
(doc_id, page_content["page"], page_content["text"])
|
||
for page_content in content
|
||
]
|
||
|
||
if metadata_values:
|
||
await cursor.executemany(metadata_query, metadata_values)
|
||
if content_values:
|
||
await cursor.executemany(content_query, content_values)
|
||
|
||
await conn.commit()
|
||
return {"message": "Success"}
|
||
except Exception as e:
|
||
await conn.rollback()
|
||
return {"error": str(e)}
|
||
|
||
|
||
async def initialize_or_load_vector_store(hospital_id, user_id="default"):
|
||
"""Initialize or load vector store with Redis caching and thread safety"""
|
||
store_key = f"{hospital_id}:{user_id}"
|
||
|
||
try:
|
||
# Check if we already have it loaded - with lock for thread safety
|
||
with vector_store_lock:
|
||
if store_key in hospital_vector_stores:
|
||
return hospital_vector_stores[store_key]
|
||
|
||
# Initialize vector store
|
||
redis_client = get_redis_client(binary=True)
|
||
cache_key = f"vector_store_data:{hospital_id}:{user_id}"
|
||
hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
|
||
|
||
if os.path.exists(hospital_dir):
|
||
logging.info(
|
||
f"Loading vector store for hospital {hospital_id} and user {user_id}"
|
||
)
|
||
vector_store = await asyncio.to_thread(
|
||
lambda: Chroma(
|
||
collection_name=f"hospital_{hospital_id}",
|
||
persist_directory=hospital_dir,
|
||
embedding_function=embeddings,
|
||
)
|
||
)
|
||
else:
|
||
logging.info(f"Creating vector store for hospital {hospital_id}")
|
||
os.makedirs(hospital_dir, exist_ok=True)
|
||
vector_store = await asyncio.to_thread(
|
||
lambda: Chroma(
|
||
collection_name=f"hospital_{hospital_id}",
|
||
persist_directory=hospital_dir,
|
||
embedding_function=embeddings,
|
||
)
|
||
)
|
||
|
||
# Store with lock for thread safety
|
||
with vector_store_lock:
|
||
hospital_vector_stores[store_key] = vector_store
|
||
|
||
return vector_store
|
||
except Exception as e:
|
||
logging.error(f"Error initializing vector store: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
|
||
"""Delete all vectors associated with a specific document from ChromaDB"""
|
||
try:
|
||
# Initialize vector store for the hospital
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
# Delete vectors with matching doc_id
|
||
await asyncio.to_thread(
|
||
lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
|
||
)
|
||
|
||
# Persist changes
|
||
await asyncio.to_thread(vector_store.persist)
|
||
|
||
# Clear Redis cache for this document
|
||
redis_client = get_redis_client()
|
||
pattern = f"vector_store_data:{hospital_id}:*"
|
||
for key in redis_client.scan_iter(pattern):
|
||
redis_client.delete(key)
|
||
|
||
logging.info(
|
||
f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
|
||
)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error deleting document vectors: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
async def add_document_to_index(doc_id, hospital_id):
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
await cursor.execute(
|
||
"SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
|
||
(doc_id,),
|
||
)
|
||
rows = await cursor.fetchall()
|
||
|
||
total_pages = len(rows)
|
||
logging.info(f"Processing {total_pages} pages for document {doc_id}")
|
||
page_bar = tqdm_async(total=total_pages, desc="Processing pages")
|
||
|
||
async def process_page(page_data):
|
||
page_num, content = page_data
|
||
try:
|
||
icd_data = extract_and_process_icd_data(
|
||
content, hospital_id, save_to_json=False
|
||
)
|
||
chunks = text_splitter.split_text(content)
|
||
await asyncio.sleep(0) # Yield control
|
||
return page_num, chunks, icd_data
|
||
except Exception as e:
|
||
logging.error(f"Error processing page {page_num}: {e}")
|
||
return page_num, [], []
|
||
|
||
tasks = [asyncio.create_task(process_page(row)) for row in rows]
|
||
results = []
|
||
|
||
for coro in asyncio.as_completed(tasks):
|
||
result = await coro
|
||
results.append(result)
|
||
page_bar.update(1)
|
||
|
||
page_bar.close()
|
||
|
||
# Vector addition progress bar
|
||
all_icd_data = []
|
||
all_chunks = []
|
||
all_metadatas = []
|
||
|
||
chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
|
||
|
||
for result in results:
|
||
page_num, chunks, icd_data = result
|
||
all_icd_data.extend(icd_data)
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
all_chunks.append(chunk)
|
||
all_metadatas.append(
|
||
{
|
||
"doc_id": str(doc_id),
|
||
"hospital_id": str(hospital_id),
|
||
"page_number": str(page_num),
|
||
"chunk_index": str(i),
|
||
}
|
||
)
|
||
|
||
if len(all_chunks) >= BATCH_SIZE:
|
||
chunk_add_bar.total += len(all_chunks)
|
||
chunk_add_bar.refresh()
|
||
await asyncio.to_thread(
|
||
vector_store.add_texts,
|
||
texts=all_chunks,
|
||
metadatas=all_metadatas,
|
||
)
|
||
all_chunks = []
|
||
all_metadatas = []
|
||
chunk_add_bar.update(BATCH_SIZE)
|
||
|
||
# Final batch
|
||
if all_chunks:
|
||
chunk_add_bar.total += len(all_chunks)
|
||
chunk_add_bar.refresh()
|
||
await asyncio.to_thread(
|
||
vector_store.add_texts,
|
||
texts=all_chunks,
|
||
metadatas=all_metadatas,
|
||
)
|
||
chunk_add_bar.update(len(all_chunks))
|
||
|
||
chunk_add_bar.close()
|
||
|
||
if all_icd_data:
|
||
logging.info(f"Saving {len(all_icd_data)} ICD codes")
|
||
extract_and_process_icd_data("", hospital_id, save_to_json=True)
|
||
|
||
await asyncio.to_thread(vector_store.persist)
|
||
logging.info(f"Successfully indexed document {doc_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error adding document: {e}")
|
||
return False
|
||
|
||
|
||
def is_general_knowledge_question(
|
||
query: str, context: str, conversation_context=None
|
||
) -> bool:
|
||
"""
|
||
Determine if a question is likely a general knowledge question not covered in the documents.
|
||
Takes conversation history into account to reduce repeated confirmations.
|
||
"""
|
||
query_lower = query.lower()
|
||
context_lower = context.lower()
|
||
|
||
if conversation_context:
|
||
for interaction in conversation_context:
|
||
prev_question = interaction.get("question", "").lower()
|
||
if (
|
||
prev_question
|
||
and query_lower in prev_question
|
||
or prev_question in query_lower
|
||
):
|
||
logging.info(
|
||
f"Question is similar to previous conversation, skipping confirmation"
|
||
)
|
||
return False
|
||
|
||
stop_words = {
|
||
"search",
|
||
"query:",
|
||
"can",
|
||
"you",
|
||
"some",
|
||
"at",
|
||
"the",
|
||
"a",
|
||
"an",
|
||
"in",
|
||
"on",
|
||
"at",
|
||
"to",
|
||
"for",
|
||
"with",
|
||
"by",
|
||
"about",
|
||
"give",
|
||
"full",
|
||
"is",
|
||
"are",
|
||
"was",
|
||
"were",
|
||
"define",
|
||
"what",
|
||
"how",
|
||
"why",
|
||
"when",
|
||
"where",
|
||
"year",
|
||
"list",
|
||
"form",
|
||
"table",
|
||
"who",
|
||
"which",
|
||
"me",
|
||
"tell",
|
||
"explain",
|
||
"describe",
|
||
"of",
|
||
"and",
|
||
"or",
|
||
"there",
|
||
"their",
|
||
"please",
|
||
"could",
|
||
"would",
|
||
"various",
|
||
"different",
|
||
"type",
|
||
"types",
|
||
"kind",
|
||
"kinds",
|
||
"has",
|
||
"have",
|
||
"had",
|
||
"many",
|
||
"say",
|
||
}
|
||
|
||
key_words = [
|
||
word for word in query_lower.split() if word not in stop_words and len(word) > 2
|
||
]
|
||
logging.info(f"Key words: {key_words}")
|
||
|
||
if not key_words:
|
||
logging.info("No significant keywords found, directing to general knowledge")
|
||
return True
|
||
|
||
matches = sum(1 for word in key_words if word in context_lower)
|
||
logging.info(f"Matches: {matches} out of {len(key_words)} keywords")
|
||
|
||
match_ratio = matches / len(key_words)
|
||
logging.info(f"Match ratio: {match_ratio}")
|
||
|
||
return match_ratio < 0.4
|
||
|
||
|
||
def is_table_request(query: str) -> bool:
|
||
"""
|
||
Determine if the user is requesting a response in tabular format.
|
||
"""
|
||
table_keywords = [
|
||
"table",
|
||
"tabular",
|
||
"in a table",
|
||
"in table format",
|
||
"in tabular format",
|
||
"chart",
|
||
"data",
|
||
"comparison",
|
||
"as a table",
|
||
"table format",
|
||
"in rows and columns",
|
||
"in a grid",
|
||
"breakdown",
|
||
"spreadsheet",
|
||
"comparison table",
|
||
"data table",
|
||
"structured table",
|
||
"tabular form",
|
||
"table form",
|
||
]
|
||
|
||
query_lower = query.lower()
|
||
return any(keyword in query_lower for keyword in table_keywords)
|
||
|
||
|
||
import re
|
||
|
||
|
||
def ensure_html_response(text: str) -> str:
|
||
"""
|
||
Ensure the response is properly formatted in HTML.
|
||
This function handles plain text conversion to HTML.
|
||
"""
|
||
if "<html" in text.lower() or "<body" in text.lower():
|
||
return text
|
||
|
||
has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
|
||
|
||
if not has_html_tags:
|
||
paragraphs = text.split("\n\n")
|
||
html_parts = []
|
||
in_ordered_list = False
|
||
in_unordered_list = False
|
||
|
||
for para in paragraphs:
|
||
if para.strip():
|
||
if re.match(r"^\s*[\*\-\•]\s", para):
|
||
if not in_unordered_list:
|
||
html_parts.append("<ul>")
|
||
in_unordered_list = True
|
||
|
||
lines = para.split("\n")
|
||
for line in lines:
|
||
if line.strip():
|
||
item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
|
||
html_parts.append(f"<li>{item}</li>")
|
||
|
||
elif re.match(r"^\s*\d+\.\s", para):
|
||
if not in_ordered_list:
|
||
html_parts.append("<ol>")
|
||
in_ordered_list = True
|
||
|
||
lines = para.split("\n")
|
||
for line in lines:
|
||
match = re.match(r"^\s*\d+\.\s*(.*)", line)
|
||
if match:
|
||
html_parts.append(f"<li>{match.group(1)}</li>")
|
||
|
||
else: # Close any open lists before adding a new paragraph
|
||
if in_ordered_list:
|
||
html_parts.append("</ol>")
|
||
in_ordered_list = False
|
||
if in_unordered_list:
|
||
html_parts.append("</ul>")
|
||
in_unordered_list = False
|
||
|
||
html_parts.append(f"<p>{para}</p>")
|
||
|
||
if in_ordered_list:
|
||
html_parts.append("</ol>")
|
||
if in_unordered_list:
|
||
html_parts.append("</ul>")
|
||
|
||
return "".join(html_parts)
|
||
|
||
else:
|
||
if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
|
||
paragraphs = text.split("\n\n")
|
||
html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
|
||
return "".join(html_parts)
|
||
|
||
return text
|
||
|
||
|
||
class HybridConversationManager:
|
||
"""
|
||
Hybrid conversation manager that uses Redis for RAG-based conversations
|
||
and in-memory storage for general knowledge conversations.
|
||
"""
|
||
|
||
def __init__(self, redis_client, ttl=3600, max_history_items=5):
|
||
self.redis_client = redis_client
|
||
self.ttl = ttl
|
||
self.max_history_items = max_history_items
|
||
|
||
# For general knowledge questions (in-memory)
|
||
self.general_knowledge_histories = {}
|
||
self.lock = Lock()
|
||
|
||
def _get_redis_key(self, user_id, hospital_id, session_id=None):
|
||
"""Create Redis key for document-based conversations."""
|
||
if session_id:
|
||
return f"conv_history:{user_id}:{hospital_id}:{session_id}"
|
||
return f"conv_history:{user_id}:{hospital_id}"
|
||
|
||
def _get_memory_key(self, user_id, hospital_id, session_id=None):
|
||
"""Create memory key for general knowledge conversations."""
|
||
if session_id:
|
||
return f"{user_id}:{hospital_id}:{session_id}"
|
||
return f"{user_id}:{hospital_id}"
|
||
|
||
async def add_rag_interaction(
|
||
self, user_id, hospital_id, question, answer, session_id=None
|
||
):
|
||
"""Add document-based (RAG) interaction to Redis."""
|
||
key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
|
||
# Add new interaction
|
||
history.append(
|
||
{
|
||
"question": question,
|
||
"answer": answer,
|
||
"timestamp": time.time(),
|
||
"type": "rag", # Mark as RAG-based interaction
|
||
}
|
||
)
|
||
|
||
# Keep only last N interactions
|
||
history = history[-self.max_history_items :]
|
||
|
||
# Store updated history
|
||
try:
|
||
self.redis_client.setex(key, self.ttl, json.dumps(history))
|
||
logging.info(
|
||
f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
|
||
)
|
||
except Exception as e:
|
||
logging.error(f"Failed to store RAG interaction in Redis: {e}")
|
||
|
||
def add_general_knowledge_interaction(
|
||
self, user_id, hospital_id, question, answer, session_id=None
|
||
):
|
||
"""Add general knowledge interaction to in-memory store."""
|
||
key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
with self.lock:
|
||
if key not in self.general_knowledge_histories:
|
||
self.general_knowledge_histories[key] = []
|
||
|
||
self.general_knowledge_histories[key].append(
|
||
{
|
||
"question": question,
|
||
"answer": answer,
|
||
"timestamp": time.time(),
|
||
"type": "general", # Mark as general knowledge interaction
|
||
}
|
||
)
|
||
|
||
# Keep only the most recent interactions
|
||
if len(self.general_knowledge_histories[key]) > self.max_history_items:
|
||
self.general_knowledge_histories[key] = (
|
||
self.general_knowledge_histories[key][-self.max_history_items :]
|
||
)
|
||
|
||
logging.info(
|
||
f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}"
|
||
)
|
||
|
||
def get_rag_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get document-based (RAG) conversation history from Redis."""
|
||
key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
try:
|
||
history_data = self.redis_client.get(key)
|
||
return json.loads(history_data) if history_data else []
|
||
except Exception as e:
|
||
logging.error(f"Failed to retrieve RAG history from Redis: {e}")
|
||
return []
|
||
|
||
def get_general_knowledge_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get general knowledge conversation history from memory."""
|
||
key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
with self.lock:
|
||
return self.general_knowledge_histories.get(key, []).copy()
|
||
|
||
def get_combined_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get combined conversation history from both sources, sorted by timestamp."""
|
||
rag_history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
general_history = self.get_general_knowledge_history(
|
||
user_id, hospital_id, session_id
|
||
)
|
||
|
||
# Combine histories
|
||
combined_history = rag_history + general_history
|
||
|
||
# Sort by timestamp (newest first)
|
||
combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||
|
||
# Return most recent N items
|
||
return combined_history[: self.max_history_items]
|
||
|
||
def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
|
||
"""Get the most recent interactions for context from combined history."""
|
||
combined_history = self.get_combined_history(user_id, hospital_id, session_id)
|
||
# Sort by timestamp (oldest first) for context window
|
||
sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0))
|
||
return sorted_history[-window_size:] if sorted_history else []
|
||
|
||
def clear_history(self, user_id, hospital_id):
|
||
"""Clear conversation history from both stores."""
|
||
# Clear Redis history
|
||
redis_key = self._get_redis_key(user_id, hospital_id)
|
||
try:
|
||
self.redis_client.delete(redis_key)
|
||
except Exception as e:
|
||
logging.error(f"Failed to clear Redis history: {e}")
|
||
|
||
# Clear memory history
|
||
memory_key = self._get_memory_key(user_id, hospital_id)
|
||
with self.lock:
|
||
if memory_key in self.general_knowledge_histories:
|
||
del self.general_knowledge_histories[memory_key]
|
||
|
||
|
||
class ContextMapper:
|
||
"""Enhanced context mapping using shared model manager"""
|
||
|
||
def __init__(self):
|
||
self.model_manager = ModelManager()
|
||
self.context_cache = {}
|
||
self.similarity_threshold = 0.6
|
||
|
||
def get_semantic_similarity(self, text1, text2):
|
||
"""Get semantic similarity using global model manager"""
|
||
return self.model_manager.get_semantic_similarity(text1, text2)
|
||
|
||
def extract_key_concepts(self, text):
|
||
"""Extract key concepts using NLP techniques"""
|
||
doc = nlp(text)
|
||
concepts = []
|
||
|
||
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
||
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
||
important_words = [
|
||
token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
|
||
]
|
||
|
||
concepts.extend([e[0] for e in entities])
|
||
concepts.extend(noun_phrases)
|
||
concepts.extend(important_words)
|
||
|
||
return list(set(concepts))
|
||
|
||
def map_conversation_context(
|
||
self, current_query, conversation_history, context_window=3
|
||
):
|
||
"""Map conversation context using enhanced NLP techniques"""
|
||
if not conversation_history:
|
||
return current_query
|
||
|
||
recent_context = conversation_history[-context_window:]
|
||
context_concepts = []
|
||
|
||
# Extract concepts from recent conversations
|
||
for interaction in recent_context:
|
||
q_concepts = self.extract_key_concepts(interaction["question"])
|
||
a_concepts = self.extract_key_concepts(interaction["answer"])
|
||
context_concepts.extend(q_concepts)
|
||
context_concepts.extend(a_concepts)
|
||
|
||
# Extract concepts from current query
|
||
query_concepts = self.extract_key_concepts(current_query)
|
||
|
||
# Find related concepts
|
||
related_concepts = []
|
||
for q_concept in query_concepts:
|
||
for c_concept in context_concepts:
|
||
similarity = self.get_semantic_similarity(q_concept, c_concept)
|
||
if similarity > self.similarity_threshold:
|
||
related_concepts.append(c_concept)
|
||
|
||
# Build enhanced query
|
||
if related_concepts:
|
||
enhanced_query = (
|
||
f"{current_query} in context of {', '.join(related_concepts)}"
|
||
)
|
||
else:
|
||
enhanced_query = current_query
|
||
|
||
return enhanced_query
|
||
|
||
|
||
# Initialize the context mapper
|
||
context_mapper = ContextMapper()
|
||
|
||
|
||
async def generate_contextual_query(
|
||
question: str, user_id: str, hospital_id: int, conversation_manager
|
||
) -> str:
|
||
"""Generate enhanced contextual query"""
|
||
context_window = conversation_manager.get_context_window(user_id, hospital_id)
|
||
|
||
if not context_window:
|
||
return question
|
||
|
||
# Enhanced context mapping
|
||
last_interaction = context_window[-1]
|
||
enhanced_context = f"""
|
||
Previous question: {last_interaction['question']}
|
||
Previous answer: {last_interaction['answer']}
|
||
Current question: {question}
|
||
|
||
Please generate a detailed search query that combines the context from the previous answer
|
||
with the current question, especially if the current question uses words like 'it', 'this',
|
||
'that', or asks for more details about the previous topic.
|
||
"""
|
||
|
||
try:
|
||
response = await asyncio.to_thread(
|
||
lambda: client.chat.completions.create(
|
||
model="gpt-3.5-turbo",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "You are a context-aware query generator.",
|
||
},
|
||
{"role": "user", "content": enhanced_context},
|
||
],
|
||
temperature=0.3,
|
||
max_tokens=150,
|
||
)
|
||
)
|
||
contextual_query = response.choices[0].message.content.strip()
|
||
logging.info(f"Enhanced contextual query: {contextual_query}")
|
||
return contextual_query
|
||
except Exception as e:
|
||
logging.error(f"Error generating contextual query: {e}")
|
||
return question
|
||
|
||
|
||
def is_follow_up(current_question: str, conversation_history: list) -> bool:
|
||
"""Enhanced follow-up detection using NLP techniques"""
|
||
if not conversation_history:
|
||
return False
|
||
|
||
last_interaction = conversation_history[-1]
|
||
|
||
# Get semantic similarity with higher threshold
|
||
similarity = context_mapper.get_semantic_similarity(
|
||
current_question, f"{last_interaction['question']} {last_interaction['answer']}"
|
||
)
|
||
|
||
# Enhanced referential check
|
||
doc = nlp(current_question.lower())
|
||
has_referential = any(
|
||
token.lemma_
|
||
in [
|
||
"it",
|
||
"this",
|
||
"that",
|
||
"these",
|
||
"those",
|
||
"they",
|
||
"he",
|
||
"she",
|
||
"about",
|
||
"more",
|
||
]
|
||
for token in doc
|
||
)
|
||
|
||
# Extract concepts with improved entity detection
|
||
current_concepts = set(context_mapper.extract_key_concepts(current_question))
|
||
last_concepts = set(
|
||
context_mapper.extract_key_concepts(
|
||
f"{last_interaction['question']} {last_interaction['answer']}"
|
||
)
|
||
)
|
||
|
||
# Calculate enhanced concept overlap
|
||
concept_overlap = (
|
||
len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
|
||
if current_concepts
|
||
else 0
|
||
)
|
||
|
||
# More aggressive follow-up detection
|
||
return (
|
||
similarity > 0.3 # Lowered threshold
|
||
or has_referential
|
||
or concept_overlap > 0.2 # Lowered threshold
|
||
or any(
|
||
word in current_question.lower()
|
||
for word in ["more", "about", "elaborate", "explain"]
|
||
)
|
||
)
|
||
|
||
|
||
async def get_relevant_context(question, hospital_id, doc_id=None):
|
||
try:
|
||
cache_key = f"context:hospital_{hospital_id}"
|
||
if doc_id:
|
||
cache_key += f":doc_{doc_id}"
|
||
cache_key += f":{question.lower().strip()}"
|
||
|
||
redis_client = get_redis_client()
|
||
|
||
cached_context = redis_client.get(cache_key)
|
||
if cached_context:
|
||
logging.info(f"Cache hit for key: {cache_key}")
|
||
return (
|
||
cached_context.decode("utf-8")
|
||
if isinstance(cached_context, bytes)
|
||
else cached_context
|
||
)
|
||
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
if not vector_store:
|
||
return ""
|
||
|
||
retriever = vector_store.as_retriever(
|
||
search_type="mmr",
|
||
search_kwargs={
|
||
"k": 10,
|
||
"fetch_k": 20,
|
||
"lambda_mult": 0.6,
|
||
# "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
||
},
|
||
)
|
||
|
||
docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
|
||
if not docs:
|
||
return ""
|
||
|
||
sorted_docs = sorted(
|
||
docs,
|
||
key=lambda x: (
|
||
int(x.metadata.get("page_number", 0)),
|
||
int(x.metadata.get("chunk_index", 0)),
|
||
),
|
||
)
|
||
|
||
context_parts = [doc.page_content for doc in sorted_docs]
|
||
context = "\n\n".join(context_parts)
|
||
|
||
try:
|
||
redis_client.setex(
|
||
cache_key,
|
||
int(CACHE_TTL["vector_store"].total_seconds()),
|
||
context.encode("utf-8") if isinstance(context, str) else context,
|
||
)
|
||
logging.info(f"Cached context for key: {cache_key}")
|
||
except Exception as cache_error:
|
||
logging.error(f"Failed to cache context: {cache_error}")
|
||
|
||
return context
|
||
except Exception as e:
|
||
logging.error(f"Error getting relevant context: {e}")
|
||
return ""
|
||
|
||
|
||
def format_conversation_context(conv_history):
|
||
"""Format conversation history into a string"""
|
||
if not conv_history:
|
||
return "No previous conversation."
|
||
return "\n".join(
|
||
[
|
||
f"Q: {interaction['question']}\nA: {interaction['answer']}"
|
||
for interaction in conv_history
|
||
]
|
||
)
|
||
|
||
|
||
def get_icd_context_from_question(question, hospital_id):
|
||
"""Extract any valid ICD codes from the question and return context"""
|
||
icd_data = load_icd_entries(hospital_id)
|
||
matches = []
|
||
code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
|
||
|
||
seen = set()
|
||
for code in code_pattern:
|
||
for entry in icd_data:
|
||
if entry["code"] == code and code not in seen:
|
||
matches.append(f"{entry['code']}: {entry['description']}")
|
||
seen.add(code)
|
||
return "\n".join(matches)
|
||
|
||
|
||
def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
|
||
"""Get fuzzy matches for ICD codes from the question"""
|
||
icd_data = load_icd_entries(hospital_id)
|
||
descriptions = [entry["description"] for entry in icd_data]
|
||
matches = process.extract(
|
||
question, descriptions, limit=top_n, score_cutoff=threshold
|
||
)
|
||
|
||
matched_context = []
|
||
for desc, score, _ in matches:
|
||
for entry in icd_data:
|
||
if entry["description"] == desc:
|
||
matched_context.append(f"{entry['code']}: {entry['description']}")
|
||
break
|
||
|
||
return "\n".join(matched_context)
|
||
|
||
|
||
async def generate_answer_with_rag(
|
||
question,
|
||
hospital_id,
|
||
client,
|
||
doc_id=None,
|
||
user_id="default",
|
||
conversation_manager=None,
|
||
session_id=None,
|
||
):
|
||
"""Generate an answer using RAG with improved conversation flow"""
|
||
try:
|
||
# Continue with regular RAG processing if not an ICD code or if no ICD match found
|
||
html_instruction = """
|
||
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
||
- Use <p> tags for paragraphs
|
||
- Use <h2>, <h3> tags for headings and subheadings
|
||
- Use <ul>, <li> tags for bullet points
|
||
- Use <ol>, <li> tags for numbered lists
|
||
- Use <blockquote> for quoted text
|
||
- Use <strong> for bold text and <em> for emphasis
|
||
"""
|
||
|
||
table_instruction = """
|
||
- For tables, use proper HTML table structure:
|
||
<table border="1">
|
||
<thead>
|
||
<tr>
|
||
<th colspan="{total_columns}">{table_title}</th>
|
||
</tr>
|
||
<tr>
|
||
{table_headers}
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
{table_rows}
|
||
</tbody>
|
||
</table>
|
||
"""
|
||
# Get conversation history first
|
||
conv_history = (
|
||
conversation_manager.get_context_window(user_id, hospital_id, session_id)
|
||
if conversation_manager
|
||
else []
|
||
)
|
||
|
||
# Get contextual query and relevant context first
|
||
contextual_query = await generate_contextual_query(
|
||
question, user_id, hospital_id, conversation_manager
|
||
)
|
||
# Track ICD context across conversation
|
||
icd_context = {}
|
||
if conv_history:
|
||
# Extract ICD code from previous interaction
|
||
last_answer = conv_history[-1].get("answer", "")
|
||
icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer)
|
||
if icd_codes:
|
||
icd_context["last_code"] = icd_codes[0]
|
||
|
||
# Check if current question is about a previously discussed ICD code
|
||
is_icd_followup = False
|
||
if icd_context.get("last_code"):
|
||
followup_indicators = [
|
||
"what causes",
|
||
"what is causing",
|
||
"why",
|
||
"how",
|
||
"symptoms",
|
||
"treatment",
|
||
"diagnosis",
|
||
"causes",
|
||
"effects",
|
||
"complications",
|
||
"risk factors",
|
||
"prevention",
|
||
"prognosis",
|
||
"this",
|
||
"disease",
|
||
"that",
|
||
"it",
|
||
]
|
||
is_icd_followup = any(
|
||
indicator in question.lower() for indicator in followup_indicators
|
||
)
|
||
|
||
if is_icd_followup:
|
||
# Add the previous ICD code context to the current question
|
||
icd_exact_context = get_icd_context_from_question(
|
||
icd_context["last_code"], hospital_id
|
||
)
|
||
icd_fuzzy_context = get_fuzzy_icd_context(
|
||
f"{icd_context['last_code']} {question}", hospital_id
|
||
)
|
||
else:
|
||
icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
else:
|
||
icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
|
||
# Get contextual query and relevant context
|
||
contextual_query = await generate_contextual_query(
|
||
question, user_id, hospital_id, conversation_manager
|
||
)
|
||
doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id)
|
||
|
||
# Combine context with priority for ICD information
|
||
context_parts = []
|
||
if is_icd_followup:
|
||
context_parts.append(
|
||
f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}"
|
||
)
|
||
if icd_exact_context:
|
||
context_parts.append("## ICD Code Match\n" + icd_exact_context)
|
||
if icd_fuzzy_context:
|
||
context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context)
|
||
if doc_context:
|
||
context_parts.append("## Document Context\n" + doc_context)
|
||
|
||
context = "\n\n".join(context_parts)
|
||
|
||
# Initialize follow-up detection
|
||
is_follow_up = False
|
||
|
||
# Check if this is a follow-up question
|
||
if conv_history:
|
||
last_interaction = conv_history[-1]
|
||
last_question = last_interaction["question"].lower()
|
||
last_answer = last_interaction.get("answer", "").lower()
|
||
current_question = question.lower()
|
||
|
||
# Define meaningful keywords that indicate entity-related follow-ups
|
||
entity_related_keywords = {
|
||
"achievements",
|
||
"awards",
|
||
"accomplishments",
|
||
"work",
|
||
"contributions",
|
||
"career",
|
||
"company",
|
||
"products",
|
||
"life",
|
||
"background",
|
||
"education",
|
||
"role",
|
||
"experience",
|
||
"history",
|
||
"details",
|
||
"places",
|
||
"place",
|
||
"information",
|
||
"facts",
|
||
"about",
|
||
"birth",
|
||
"death",
|
||
"family",
|
||
"books",
|
||
"projects",
|
||
"population",
|
||
}
|
||
|
||
# Check if question is asking about attributes/achievements of previously discussed entity
|
||
has_entity_attribute = any(
|
||
word in current_question.split() for word in entity_related_keywords
|
||
)
|
||
|
||
# Extract entities from last answer to maintain context
|
||
def extract_entities(text):
|
||
# Split into words and get potential entities (capitalized words)
|
||
words = text.split()
|
||
entities = set()
|
||
current_entity = []
|
||
|
||
for word in words:
|
||
if word[0].isupper():
|
||
current_entity.append(word)
|
||
elif current_entity:
|
||
if len(current_entity) > 0:
|
||
entities.add(" ".join(current_entity))
|
||
current_entity = []
|
||
|
||
if current_entity:
|
||
entities.add(" ".join(current_entity))
|
||
return entities
|
||
|
||
last_entities = extract_entities(last_answer)
|
||
|
||
# Check for referential words
|
||
referential_words = {
|
||
"it",
|
||
"this",
|
||
"that",
|
||
"these",
|
||
"those",
|
||
"they",
|
||
"their",
|
||
"he",
|
||
"she",
|
||
"him",
|
||
"her",
|
||
"his",
|
||
"hers",
|
||
"them",
|
||
"there",
|
||
"such",
|
||
"its",
|
||
}
|
||
has_referential = any(
|
||
word in referential_words for word in current_question.split()
|
||
)
|
||
|
||
# Calculate term overlap with both question and answer context
|
||
def get_significant_terms(text):
|
||
stop_words = {
|
||
"what",
|
||
"when",
|
||
"where",
|
||
"who",
|
||
"why",
|
||
"how",
|
||
"is",
|
||
"are",
|
||
"was",
|
||
"were",
|
||
"be",
|
||
"been",
|
||
"the",
|
||
"a",
|
||
"an",
|
||
"in",
|
||
"on",
|
||
"at",
|
||
"to",
|
||
"for",
|
||
"of",
|
||
"with",
|
||
"by",
|
||
"about",
|
||
"as",
|
||
"tell",
|
||
"me",
|
||
"please",
|
||
}
|
||
return set(
|
||
word
|
||
for word in text.split()
|
||
if len(word) > 2 and word.lower() not in stop_words
|
||
)
|
||
|
||
current_terms = get_significant_terms(current_question)
|
||
last_terms = get_significant_terms(last_question)
|
||
answer_terms = get_significant_terms(last_answer)
|
||
|
||
# Include terms from both question and answer in context
|
||
all_prev_terms = last_terms | answer_terms
|
||
term_overlap = len(current_terms & all_prev_terms)
|
||
total_terms = len(current_terms | all_prev_terms)
|
||
term_similarity = term_overlap / total_terms if total_terms > 0 else 0
|
||
|
||
# Enhanced follow-up detection combining multiple signals
|
||
is_follow_up = (
|
||
has_referential
|
||
or term_similarity
|
||
>= 0.2 # Lower threshold when including answer context
|
||
or (
|
||
has_entity_attribute and bool(last_entities)
|
||
) # Check if asking about attributes of known entity
|
||
or (
|
||
last_interaction.get("type") == "general"
|
||
and term_similarity >= 0.15
|
||
)
|
||
)
|
||
|
||
logging.info(f"Follow-up analysis enhanced:")
|
||
logging.info(f"- Referential words: {has_referential}")
|
||
logging.info(f"- Term similarity: {term_similarity:.2f}")
|
||
logging.info(f"- Entity attribute question: {has_entity_attribute}")
|
||
logging.info(f"- Last entities found: {last_entities}")
|
||
logging.info(f"- Is follow-up: {is_follow_up}")
|
||
|
||
# For entirely new topics (not follow-ups), use is_general_knowledge_question
|
||
if not is_follow_up:
|
||
is_general = is_general_knowledge_question(question, context, conv_history)
|
||
if is_general:
|
||
confirmation_prompt = f"""
|
||
<p>Reviewed the hospital’s documentation, but this particular question does not seem to be covered.</p>
|
||
"""
|
||
logging.info("General knowledge question detected")
|
||
return {
|
||
"answer": confirmation_prompt,
|
||
"requires_confirmation": True,
|
||
}, 200
|
||
|
||
# # For follow-up questions to general knowledge, handle directly
|
||
# if is_follow_up and conv_history and conv_history[-1].get("type") == "general":
|
||
# logging.info("Handling follow-up to general knowledge question")
|
||
# answer = await generate_general_knowledge_answer(
|
||
# question,
|
||
# client,
|
||
# user_id,
|
||
# hospital_id,
|
||
# conversation_manager,
|
||
# is_table_request(question),
|
||
# session_id=session_id,
|
||
# )
|
||
# return {"answer": answer}, 200
|
||
|
||
# Generate RAG answer with enhanced context
|
||
prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question.
|
||
Previous conversation:
|
||
{format_conversation_context(conv_history)}
|
||
|
||
Context from documents:
|
||
{context}
|
||
|
||
Current question: {question}
|
||
|
||
Instructions:
|
||
1. When providing medical codes (ICD, CPT, etc.):
|
||
- Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context.
|
||
- Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context.
|
||
- If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format.
|
||
- Remove all decimal points (e.g., use 'A150' instead of 'A15.0')
|
||
- Format the response as: '<p>The medical code for [condition] is [code]</p>'
|
||
2. Address the current question while maintaining conversation continuity
|
||
3. Resolve any ambiguous references using conversation history
|
||
4. Format the response in clear HTML
|
||
|
||
{html_instruction}
|
||
{table_instruction if is_table_request(question) else ""}
|
||
"""
|
||
|
||
response = await asyncio.to_thread(
|
||
lambda: client.chat.completions.create(
|
||
model="gpt-3.5-turbo-16k",
|
||
messages=[
|
||
{"role": "system", "content": prompt_template},
|
||
{"role": "user", "content": question},
|
||
],
|
||
temperature=0.3,
|
||
max_tokens=1000,
|
||
)
|
||
)
|
||
|
||
answer = ensure_html_response(response.choices[0].message.content)
|
||
logging.info(f"Generated RAG answer for question: {question}")
|
||
|
||
# Store interaction in history
|
||
if conversation_manager:
|
||
await conversation_manager.add_rag_interaction(
|
||
user_id, hospital_id, question, answer, session_id
|
||
)
|
||
|
||
return {"answer": answer}, 200
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in generate_answer_with_rag: {e}")
|
||
return {"answer": f"<p>Error: {str(e)}</p>"}, 500
|
||
|
||
|
||
# async def generate_general_knowledge_answer(
|
||
# query: str,
|
||
# client,
|
||
# user_id: str,
|
||
# hospital_id: int,
|
||
# conversation_manager,
|
||
# wants_table: bool = False,
|
||
# session_id=None,
|
||
# ) -> str:
|
||
# """Generate an answer from general knowledge with enhanced context awareness"""
|
||
# html_instruction = """
|
||
# IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
||
# - Use <p> tags for paragraphs
|
||
# - Use <h2>, <h3> tags for headings and subheadings
|
||
# - Use <ul>, <li> tags for bullet points
|
||
# - Use <ol>, <li> tags for numbered lists
|
||
# - Use <blockquote> for quoted text
|
||
# - Use <strong> for bold text and <em> for emphasis
|
||
# """
|
||
|
||
# table_instruction = """
|
||
# - For tables, use proper HTML table structure:
|
||
# <table border="1">
|
||
# <thead>
|
||
# <tr>
|
||
# <th colspan="{total_columns}">{table_title}</th>
|
||
# </tr>
|
||
# <tr>
|
||
# {table_headers}
|
||
# </tr>
|
||
# </thead>
|
||
# <tbody>
|
||
# {table_rows}
|
||
# </tbody>
|
||
# </table>
|
||
# """
|
||
# # Get conversation history for context
|
||
# conv_history = conversation_manager.get_context_window(
|
||
# user_id, hospital_id, session_id
|
||
# )
|
||
|
||
# # Extract entities from previous answer if available
|
||
# entities = []
|
||
# if conv_history:
|
||
# last_answer = conv_history[-1].get("answer", "")
|
||
# words = last_answer.split()
|
||
# current_entity = []
|
||
# for word in words:
|
||
# if word[0].isupper() and len(word) > 2:
|
||
# current_entity.append(word)
|
||
# elif current_entity:
|
||
# entities.append(" ".join(current_entity))
|
||
# current_entity = []
|
||
# if current_entity:
|
||
# entities.append(" ".join(current_entity))
|
||
|
||
# conversation_context = "\n".join(
|
||
# [
|
||
# f"Q: {interaction['question']}\nA: {interaction['answer']}"
|
||
# for interaction in conv_history
|
||
# ]
|
||
# )
|
||
|
||
# # Enhanced system prompt with entity awareness
|
||
# system_prompt = f"""You are a helpful assistant providing information from general knowledge.
|
||
# Previous conversation:
|
||
# {conversation_context}
|
||
|
||
# {'Previous answer mentioned these entities: ' + ', '.join(entities) if entities else ''}
|
||
|
||
# 1. If the question seems incomplete or asks about attributes (like population, details, etc.),
|
||
# assume it's referring to the most recently mentioned entity ({entities[0] if entities else 'none'}).
|
||
# 2. Present information in a clear, organized manner.
|
||
# 3. If you're uncertain about any information, acknowledge that limitation.
|
||
# 4. If the question requires specialized expertise, note that the user may want to consult an expert.
|
||
# 5. Format the response in clear HTML.
|
||
|
||
# {html_instruction}
|
||
# {table_instruction if wants_table else ""}
|
||
# """
|
||
|
||
# try:
|
||
# logging.info(f"Generating general knowledge answer for: {query}")
|
||
# response = client.chat.completions.create(
|
||
# model="gpt-3.5-turbo",
|
||
# messages=[
|
||
# {"role": "system", "content": system_prompt},
|
||
# {"role": "user", "content": query},
|
||
# ],
|
||
# temperature=0.3,
|
||
# stream=True,
|
||
# )
|
||
|
||
# response_text = ""
|
||
# for chunk in response:
|
||
# if chunk.choices[0].delta.content:
|
||
# response_text += chunk.choices[0].delta.content
|
||
|
||
# formatted_response = ensure_html_response(response_text.strip())
|
||
|
||
# # Store the interaction in conversation history
|
||
# conversation_manager.add_general_knowledge_interaction(
|
||
# user_id, hospital_id, query, formatted_response, session_id
|
||
# )
|
||
|
||
# return formatted_response
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error generating general knowledge answer: {e}")
|
||
# return f"<p>Error generating response: {str(e)}</p>"
|
||
|
||
|
||
async def load_existing_vector_stores():
|
||
"""Load existing Chroma vector stores for each hospital"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
await cursor.execute("SELECT DISTINCT id FROM hospitals")
|
||
hospital_ids = [row[0] for row in await cursor.fetchall()]
|
||
|
||
for hospital_id in hospital_ids:
|
||
try:
|
||
await initialize_or_load_vector_store(hospital_id)
|
||
except Exception as e:
|
||
logging.error(
|
||
f"Failed to load vector store for hospital {hospital_id}: {e}"
|
||
)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error loading vector stores: {e}")
|
||
|
||
|
||
async def get_failed_page(doc_id):
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
await cursor.execute(
|
||
"SELECT failed_page FROM documents WHERE id = %s", (doc_id,)
|
||
)
|
||
result = await cursor.fetchone()
|
||
return result[0] if result and result[0] else None
|
||
except Exception as e:
|
||
logging.error(f"Database error checking failed_page: {e}")
|
||
return None
|
||
|
||
|
||
async def update_document_status(doc_id, status, failed_page=None):
|
||
"""Update document status with enum validation"""
|
||
if isinstance(status, str):
|
||
status = DocumentStatus[status.upper()].value
|
||
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
if failed_page:
|
||
await cursor.execute(
|
||
"UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
|
||
(status, failed_page, doc_id),
|
||
)
|
||
else:
|
||
await cursor.execute(
|
||
"UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
|
||
(status, doc_id),
|
||
)
|
||
await conn.commit()
|
||
return True
|
||
except Exception as e:
|
||
logging.error(f"Database update error: {e}")
|
||
return False
|
||
|
||
|
||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||
|
||
|
||
def async_to_sync(coroutine):
|
||
"""Helper function to run async code in sync context"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(coroutine)
|
||
finally:
|
||
loop.close()
|
||
|
||
|
||
@app.route("/flask-api", methods=["GET"])
|
||
def health_check():
|
||
"""Health check endpoint"""
|
||
access_logger.info(f"Health check request received from {request.remote_addr}")
|
||
return jsonify({"status": "ok"}), 200
|
||
|
||
|
||
@app.route("/flask-api/process-pdf", methods=["POST"])
|
||
def process_pdf():
|
||
access_logger.info(f"PDF processing request received from {request.remote_addr}")
|
||
file_path = None
|
||
try:
|
||
file = request.files.get("pdf")
|
||
hospital_id = request.form.get("hospital_id")
|
||
doc_id = request.form.get("doc_id")
|
||
|
||
logging.info(
|
||
f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
|
||
)
|
||
|
||
if not all([file, hospital_id, doc_id]):
|
||
return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
def process_in_background():
|
||
nonlocal file_path
|
||
try:
|
||
async_to_sync(update_document_status(doc_id, "processing"))
|
||
|
||
# Add progress logging
|
||
logging.info(f"Starting processing of document {doc_id}")
|
||
|
||
filename = f"doc_{doc_id}_{file.filename}"
|
||
file_path = os.path.join(uploads_dir, filename)
|
||
|
||
with open(file_path, "wb") as f:
|
||
file.save(f)
|
||
|
||
logging.info("Extracting PDF contents...")
|
||
content = extract_pdf_contents(file_path, int(hospital_id))
|
||
|
||
logging.info("Inserting content into database...")
|
||
metadata = {"filename": filename}
|
||
result = async_to_sync(
|
||
insert_content_into_db(content, metadata, doc_id)
|
||
)
|
||
|
||
if "error" in result:
|
||
async_to_sync(update_document_status(doc_id, "failed", 1))
|
||
return False
|
||
|
||
logging.info("Creating embeddings and indexing...")
|
||
success = async_to_sync(add_document_to_index(doc_id, hospital_id))
|
||
|
||
if success:
|
||
logging.info("Document processing completed successfully")
|
||
async_to_sync(update_document_status(doc_id, "processed"))
|
||
return True
|
||
else:
|
||
logging.error("Document processing failed during indexing")
|
||
async_to_sync(update_document_status(doc_id, "failed"))
|
||
return False
|
||
|
||
except Exception as e:
|
||
logging.error(f"Processing error: {e}")
|
||
async_to_sync(update_document_status(doc_id, "failed"))
|
||
return False
|
||
finally:
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
logging.error(f"Error removing temporary file: {e}")
|
||
|
||
# Execute processing and wait for result
|
||
future = thread_pool.submit(process_in_background)
|
||
success = future.result()
|
||
|
||
if success:
|
||
return jsonify({"message": "Document processed successfully"}), 200
|
||
else:
|
||
return jsonify({"error": "Document processing failed"}), 500
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {e}")
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
os.remove(file_path)
|
||
except Exception as file_e:
|
||
logging.error(f"Error removing temporary file: {file_e}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# Initialize the hybrid conversation manager
|
||
redis_client = get_redis_client()
|
||
conversation_manager = HybridConversationManager(redis_client)
|
||
|
||
|
||
@app.route("/flask-api/generate-answer", methods=["POST"])
|
||
def rag_answer_api():
|
||
"""Sync API endpoint for RAG-based question answering with conversation history."""
|
||
access_logger.info(f"Generate answer request received from {request.remote_addr}")
|
||
try:
|
||
data = request.json
|
||
question = data.get("question", "").strip().lower()
|
||
hospital_code = data.get("hospital_code")
|
||
doc_id = data.get("doc_id")
|
||
user_id = data.get("user_id", "default")
|
||
session_id = data.get("session_id", None)
|
||
|
||
logging.info(f"Received question from user {user_id}: {question}")
|
||
logging.info(f"Received hospital code: {hospital_code}")
|
||
logging.info(f"Received session_id: {session_id}")
|
||
|
||
# is_confirmation_response = data.get("is_confirmation_response", False)
|
||
original_query = data.get("original_query", "")
|
||
|
||
def process_rag_answer():
|
||
try:
|
||
hospital_id = async_to_sync(get_hospital_id(hospital_code))
|
||
logging.info(f"Resolved hospital ID: {hospital_id}")
|
||
|
||
if not hospital_id:
|
||
return {
|
||
"error": "Invalid or missing 'hospital_code' in request"
|
||
}, 400
|
||
|
||
# if question == "yes" and original_query:
|
||
# # User confirmed they want a general knowledge answer
|
||
# answer = async_to_sync(
|
||
# generate_general_knowledge_answer(
|
||
# original_query,
|
||
# client,
|
||
# user_id,
|
||
# hospital_id,
|
||
# conversation_manager, # Pass the hybrid manager
|
||
# is_table_request(original_query),
|
||
# session_id=session_id,
|
||
# )
|
||
# )
|
||
# return {"answer": answer}, 200
|
||
|
||
if original_query:
|
||
response_message = """
|
||
<p>I can only answer questions based on information found in the hospital documents.</p>
|
||
<p>The question you asked doesn't seem to be covered in the available documents.</p>
|
||
<p>You can try rephrasing your question or asking about a different topic.</p>
|
||
"""
|
||
return {"answer": response_message}, 200
|
||
|
||
else:
|
||
# Regular RAG answer
|
||
return async_to_sync(
|
||
generate_answer_with_rag(
|
||
question=question,
|
||
hospital_id=hospital_id,
|
||
client=client,
|
||
doc_id=doc_id,
|
||
user_id=user_id,
|
||
conversation_manager=conversation_manager, # Pass the hybrid manager
|
||
session_id=session_id,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logging.error(f"Thread processing error: {str(e)}")
|
||
return {"error": str(e)}, 500
|
||
|
||
if not question:
|
||
return jsonify({"error": "Missing 'question' in request"}), 400
|
||
|
||
future = thread_pool.submit(process_rag_answer)
|
||
result, status_code = future.result()
|
||
|
||
return jsonify(result), status_code
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
@app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
|
||
def delete_document_vectors_endpoint():
|
||
"""Endpoint to delete document vectors from ChromaDB"""
|
||
try:
|
||
data = request.json
|
||
hospital_id = data.get("hospital_id")
|
||
doc_id = data.get("doc_id")
|
||
|
||
if not all([hospital_id, doc_id]):
|
||
return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
logging.info(
|
||
f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
|
||
)
|
||
|
||
def process_deletion():
|
||
try:
|
||
success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
|
||
if success:
|
||
return {"message": "Document vectors deleted successfully"}, 200
|
||
else:
|
||
return {"error": "Failed to delete document vectors"}, 500
|
||
except Exception as e:
|
||
logging.error(f"Error in vector deletion process: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
future = thread_pool.submit(process_deletion)
|
||
result, status_code = future.result()
|
||
|
||
return jsonify(result), status_code
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
@app.route("/flask-api/get-chroma-content", methods=["GET"])
|
||
def get_chroma_content_endpoint():
|
||
"""API endpoint to get ChromaDB content by hospital_id"""
|
||
try:
|
||
hospital_id = request.args.get("hospital_id")
|
||
limit = int(request.args.get("limit", 30000))
|
||
|
||
if not hospital_id:
|
||
return jsonify({"error": "Missing required parameter: hospital_id"}), 400
|
||
|
||
def process_fetch():
|
||
try:
|
||
result, status_code = async_to_sync(
|
||
get_chroma_content_by_hospital(
|
||
hospital_id=int(hospital_id), limit=limit
|
||
)
|
||
)
|
||
return result, status_code
|
||
except Exception as e:
|
||
logging.error(f"Error in ChromaDB fetch process: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
future = thread_pool.submit(process_fetch)
|
||
result, status_code = future.result()
|
||
|
||
return jsonify(result), status_code
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
|
||
"""Fetch content from ChromaDB for a specific hospital"""
|
||
try:
|
||
# Initialize vector store
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
if not vector_store:
|
||
return {"error": "Vector store not found"}, 404
|
||
|
||
# Get collection
|
||
collection = vector_store._collection
|
||
|
||
# Query the collection with hospital_id filter
|
||
results = await asyncio.to_thread(
|
||
lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
|
||
)
|
||
|
||
if not results or not results["ids"]:
|
||
return {"data": [], "count": 0}, 200
|
||
|
||
# Format the response
|
||
formatted_results = []
|
||
for i in range(len(results["ids"])):
|
||
formatted_results.append(
|
||
{
|
||
"id": results["ids"][i],
|
||
"content": results["documents"][i],
|
||
"metadata": results["metadatas"][i],
|
||
}
|
||
)
|
||
|
||
return {"data": formatted_results, "count": len(formatted_results)}, 200
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error fetching ChromaDB content: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
|
||
@app.before_request
|
||
def before_request():
|
||
request._start_time = time.time()
|
||
|
||
|
||
@app.after_request
|
||
def after_request(response):
|
||
if hasattr(request, "_start_time"):
|
||
duration = time.time() - request._start_time
|
||
access_logger.info(
|
||
f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
|
||
f"IP: {request.remote_addr}"
|
||
)
|
||
return response
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logger.info("Starting SpurrinAI application")
|
||
logger.info(f"Python version: {sys.version}")
|
||
logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
|
||
|
||
try:
|
||
model_manager = ModelManager()
|
||
logger.info("Model manager initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize model manager: {e}")
|
||
sys.exit(1)
|
||
|
||
# Initialize directories
|
||
os.makedirs(DATA_DIR, exist_ok=True)
|
||
os.makedirs(CHROMA_DIR, exist_ok=True)
|
||
logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
|
||
|
||
# Clear Redis cache
|
||
redis_client = get_redis_client()
|
||
cleared_keys = 0
|
||
for key in redis_client.scan_iter("vector_store_data:*"):
|
||
redis_client.delete(key)
|
||
cleared_keys += 1
|
||
logger.info(f"Cleared {cleared_keys} Redis cache keys")
|
||
|
||
# Load vector stores
|
||
logger.info("Loading existing vector stores...")
|
||
async_to_sync(load_existing_vector_stores())
|
||
logger.info("Vector stores loaded successfully")
|
||
|
||
# Start application
|
||
logger.info("Starting Flask application on port 5000")
|
||
app.run(port=5000, debug=False) |