forked from rohit/spurrin-backend
3857 lines
134 KiB
Python
3857 lines
134 KiB
Python
# """
|
||
# SpurrinAI - Intelligent Document Processing and Question Answering System
|
||
# Copyright (c) 2024 Tech4biz. All rights reserved.
|
||
|
||
# This module implements the main Flask application for the SpurrinAI system,
|
||
# providing REST APIs for document processing, vector storage, and question answering
|
||
# using RAG (Retrieval Augmented Generation) architecture.
|
||
|
||
# Author: Tech4biz Development Team
|
||
# Version: 1.0.0
|
||
# Last Updated: 2024-01-19
|
||
# """
|
||
|
||
# # Standard library imports
|
||
# import os
|
||
# import re
|
||
# import sys
|
||
# import json
|
||
# import time
|
||
# import threading
|
||
# import asyncio
|
||
# from concurrent.futures import ThreadPoolExecutor
|
||
# from dataclasses import dataclass
|
||
# from datetime import timedelta
|
||
# from enum import Enum
|
||
|
||
# # Third-party imports
|
||
# import spacy
|
||
# import redis
|
||
# import aiomysql
|
||
# from dotenv import load_dotenv
|
||
# from flask import Flask, request, jsonify, Response
|
||
# from flask_cors import CORS
|
||
# from tqdm import tqdm
|
||
# from tqdm.asyncio import tqdm as tqdm_async
|
||
# from langchain_community.document_loaders import PyPDFLoader
|
||
# from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
# from langchain_community.embeddings import OpenAIEmbeddings
|
||
# from langchain_community.vectorstores import Chroma
|
||
# from langchain_community.chat_models import ChatOpenAI
|
||
# from langchain.chains import RetrievalQA
|
||
# from langchain.prompts import PromptTemplate
|
||
# from openai import OpenAI
|
||
# from rapidfuzz import process
|
||
# from threading import Lock
|
||
|
||
# # Local imports
|
||
# from model_manager import ModelManager
|
||
|
||
# # Suppress warnings
|
||
# import warnings
|
||
|
||
# warnings.filterwarnings("ignore")
|
||
|
||
# # Initialize NLTK
|
||
# import nltk
|
||
|
||
# nltk.download("punkt")
|
||
|
||
# # Configure logging
|
||
# import logging
|
||
# import logging.handlers
|
||
|
||
# app = Flask(__name__)
|
||
# CORS(app)
|
||
|
||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
# log_file_path = os.path.join(script_dir, "error.log")
|
||
# logging.basicConfig(filename=log_file_path, level=logging.INFO)
|
||
|
||
|
||
# # Configure logging
|
||
# def setup_logging():
|
||
# log_dir = os.path.join(script_dir, "logs")
|
||
# os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# main_log = os.path.join(log_dir, "app.log")
|
||
# error_log = os.path.join(log_dir, "error.log")
|
||
# access_log = os.path.join(log_dir, "access.log")
|
||
# perf_log = os.path.join(log_dir, "performance.log")
|
||
|
||
# # Create formatters
|
||
# detailed_formatter = logging.Formatter(
|
||
# "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||
# )
|
||
# access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||
|
||
# # Main logger setup
|
||
# main_handler = logging.handlers.RotatingFileHandler(
|
||
# main_log, maxBytes=10485760, backupCount=5
|
||
# )
|
||
# main_handler.setFormatter(detailed_formatter)
|
||
# main_handler.setLevel(logging.INFO)
|
||
|
||
# # Error logger setup
|
||
# error_handler = logging.handlers.RotatingFileHandler(
|
||
# error_log, maxBytes=10485760, backupCount=5
|
||
# )
|
||
# error_handler.setFormatter(detailed_formatter)
|
||
# error_handler.setLevel(logging.ERROR)
|
||
|
||
# # Access logger setup
|
||
# access_handler = logging.handlers.TimedRotatingFileHandler(
|
||
# access_log, when="midnight", interval=1, backupCount=30
|
||
# )
|
||
# access_handler.setFormatter(access_formatter)
|
||
# access_handler.setLevel(logging.INFO)
|
||
|
||
# # Performance logger setup
|
||
# perf_handler = logging.handlers.RotatingFileHandler(
|
||
# perf_log, maxBytes=10485760, backupCount=5
|
||
# )
|
||
# perf_handler.setFormatter(detailed_formatter)
|
||
# perf_handler.setLevel(logging.INFO)
|
||
|
||
# # Configure root logger
|
||
# root_logger = logging.getLogger()
|
||
# root_logger.setLevel(logging.INFO)
|
||
# root_logger.addHandler(main_handler)
|
||
# root_logger.addHandler(error_handler)
|
||
|
||
# # Create specific loggers
|
||
# access_logger = logging.getLogger("access")
|
||
# access_logger.addHandler(access_handler)
|
||
# access_logger.setLevel(logging.INFO)
|
||
|
||
# perf_logger = logging.getLogger("performance")
|
||
# perf_logger.addHandler(perf_handler)
|
||
# perf_logger.setLevel(logging.INFO)
|
||
|
||
# return root_logger, access_logger, perf_logger
|
||
|
||
|
||
# # Initialize loggers
|
||
# logger, access_logger, perf_logger = setup_logging()
|
||
|
||
# # DB_CONFIG = {
|
||
# # 'host': 'localhost',
|
||
# # 'user': 'flaskuser',
|
||
# # 'password': 'Flask@123',
|
||
# # 'database': 'spurrinai',
|
||
# # }
|
||
|
||
# # DB_CONFIG = {
|
||
# # 'host': 'localhost',
|
||
# # 'user': 'spurrindevuser',
|
||
# # 'password': 'Admin@123',
|
||
# # 'database': 'spurrindev',
|
||
# # }
|
||
|
||
# # DB_CONFIG = {
|
||
# # 'host': 'localhost',
|
||
# # 'user': 'root',
|
||
# # 'password': 'root',
|
||
# # 'database': 'medqueryai',
|
||
# # 'port': 3307
|
||
# # }
|
||
|
||
# # Redis Configuration
|
||
# REDIS_CONFIG = {
|
||
# "host": "localhost",
|
||
# "port": 6379,
|
||
# "db": 0,
|
||
# "decode_responses": True, # For string operations
|
||
# }
|
||
|
||
# DB_CONFIG = {
|
||
# "host": os.getenv("DB_HOST", "localhost"),
|
||
# "user": os.getenv("DB_USER", "testuser"),
|
||
# "password": os.getenv("DB_PASSWORD", "Admin@123"),
|
||
# "database": os.getenv("DB_NAME", "spurrintest"),
|
||
# }
|
||
|
||
# # Redis connection pool
|
||
# redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
|
||
# redis_binary_pool = redis.ConnectionPool(
|
||
# host="localhost", port=6379, db=1, decode_responses=False
|
||
# )
|
||
|
||
|
||
# def get_redis_client(binary=False):
|
||
# """Get Redis client from pool"""
|
||
# logger.debug(f"Getting Redis client with binary={binary}")
|
||
# try:
|
||
# pool = redis_binary_pool if binary else redis_pool
|
||
# client = redis.Redis(connection_pool=pool)
|
||
# logger.debug("Redis client created successfully")
|
||
# return client
|
||
# except Exception as e:
|
||
# logger.error(f"Failed to create Redis client: {e}", exc_info=True)
|
||
# raise
|
||
|
||
|
||
# def fetch_cached_answer(cache_key):
|
||
# logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
|
||
# start_time = time.time()
|
||
# try:
|
||
# redis_client = get_redis_client()
|
||
# cached_answer = redis_client.get(cache_key)
|
||
# fetch_time = time.time() - start_time
|
||
# perf_logger.info(
|
||
# f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
|
||
# )
|
||
# return cached_answer
|
||
# except Exception as e:
|
||
# logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
|
||
# return None
|
||
|
||
|
||
# # Cache TTL configurations
|
||
# CACHE_TTL = {
|
||
# "vector_store": timedelta(hours=24),
|
||
# "chat_completion": timedelta(hours=1),
|
||
# "document_metadata": timedelta(days=7),
|
||
# }
|
||
|
||
# DATA_DIR = os.path.join(script_dir, "hospital_data")
|
||
# CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
|
||
# uploads_dir = os.path.join(script_dir, "llm-uploads")
|
||
|
||
# if not os.path.exists(uploads_dir):
|
||
# os.makedirs(uploads_dir)
|
||
|
||
# nlp = spacy.load("en_core_web_sm")
|
||
|
||
# load_dotenv()
|
||
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
|
||
# client = OpenAI(api_key=OPENAI_API_KEY)
|
||
# embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
||
# llm = ChatOpenAI(
|
||
# model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
|
||
# )
|
||
# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
|
||
# hospital_vector_stores = {}
|
||
# vector_store_lock = threading.Lock()
|
||
|
||
|
||
# @dataclass
|
||
# class Document:
|
||
# doc_id: int
|
||
# page_num: int
|
||
# content: str
|
||
|
||
|
||
# class DocumentStatus(Enum):
|
||
# PROCESSING = "processing"
|
||
# PROCESSED = "processed"
|
||
# FAILED = "failed"
|
||
|
||
|
||
# async def get_db_pool():
|
||
# return await aiomysql.create_pool(
|
||
# host=DB_CONFIG["host"],
|
||
# user=DB_CONFIG["user"],
|
||
# password=DB_CONFIG["password"],
|
||
# db=DB_CONFIG["database"],
|
||
# autocommit=True,
|
||
# )
|
||
|
||
|
||
# async def get_hospital_id(hospital_code):
|
||
# try:
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# await cursor.execute(
|
||
# "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
|
||
# (hospital_code,),
|
||
# )
|
||
# result = await cursor.fetchone()
|
||
# return result["id"] if result else None
|
||
# except Exception as error:
|
||
# logging.error(f"Database error: {error}")
|
||
# return None
|
||
# finally:
|
||
# pool.close()
|
||
# await pool.wait_closed()
|
||
|
||
|
||
# CHUNK_SIZE = 1000
|
||
# CHUNK_OVERLAP = 50
|
||
# BATCH_SIZE = 1000
|
||
|
||
# text_splitter = RecursiveCharacterTextSplitter(
|
||
# chunk_size=CHUNK_SIZE,
|
||
# chunk_overlap=CHUNK_OVERLAP,
|
||
# # length_function=len,
|
||
# # separators=["\n\n", "\n", ". ", " ", ""]
|
||
# )
|
||
|
||
|
||
# # Update the JSON_PATH to be dynamic based on hospital_id
|
||
# def get_icd_json_path(hospital_id):
|
||
# hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
|
||
# os.makedirs(hospital_data_dir, exist_ok=True)
|
||
# return os.path.join(hospital_data_dir, "icd_data.json")
|
||
|
||
|
||
# def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
|
||
# """Extract and process ICD codes with optimized processing and optional JSON saving"""
|
||
# try:
|
||
# # Initialize pattern compilation once
|
||
# pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
|
||
|
||
# # Process in chunks for large content
|
||
# chunk_size = 50000 # Process 50KB at a time
|
||
# icd_data = []
|
||
|
||
# current_code = None
|
||
# current_description = []
|
||
|
||
# # Split content into manageable chunks
|
||
# content_chunks = [
|
||
# content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
|
||
# ]
|
||
|
||
# # Process each chunk
|
||
# for chunk in content_chunks:
|
||
# lines = chunk.splitlines()
|
||
|
||
# for line in lines:
|
||
# line = line.strip()
|
||
# if not line:
|
||
# if current_code and current_description:
|
||
# icd_data.append(
|
||
# {
|
||
# "code": current_code,
|
||
# "description": " ".join(current_description).strip(),
|
||
# }
|
||
# )
|
||
# current_code = None
|
||
# current_description = []
|
||
# continue
|
||
|
||
# match = pattern.match(line)
|
||
# if match:
|
||
# if current_code and current_description:
|
||
# icd_data.append(
|
||
# {
|
||
# "code": current_code,
|
||
# "description": " ".join(current_description).strip(),
|
||
# }
|
||
# )
|
||
# current_code, description = match.groups()
|
||
# current_description = [description.strip()]
|
||
# elif current_code:
|
||
# current_description.append(line)
|
||
|
||
# # Add final entry if exists
|
||
# if current_code and current_description:
|
||
# icd_data.append(
|
||
# {
|
||
# "code": current_code,
|
||
# "description": " ".join(current_description).strip(),
|
||
# }
|
||
# )
|
||
|
||
# # Save to hospital-specific JSON if requested
|
||
# if save_to_json and icd_data:
|
||
# try:
|
||
# json_path = get_icd_json_path(hospital_id)
|
||
|
||
# # Use a lock for thread safety
|
||
# with threading.Lock():
|
||
# if os.path.exists(json_path):
|
||
# with open(json_path, "r", encoding="utf-8") as f:
|
||
# try:
|
||
# existing_data = json.load(f)
|
||
# except json.JSONDecodeError:
|
||
# existing_data = []
|
||
# else:
|
||
# existing_data = []
|
||
|
||
# # Efficient deduplication using dictionary
|
||
# seen_codes = {item["code"]: item for item in existing_data}
|
||
# for item in icd_data:
|
||
# seen_codes[item["code"]] = item
|
||
|
||
# unique_data = list(seen_codes.values())
|
||
|
||
# # Write atomically using temporary file
|
||
# temp_path = f"{json_path}.tmp"
|
||
# with open(temp_path, "w", encoding="utf-8") as f:
|
||
# json.dump(unique_data, f, indent=2, ensure_ascii=False)
|
||
# os.replace(temp_path, json_path)
|
||
|
||
# logging.info(
|
||
# f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
|
||
# )
|
||
|
||
# except Exception as e:
|
||
# logging.error(
|
||
# f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
|
||
# )
|
||
|
||
# return icd_data
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error in extract_and_process_icd_data: {e}")
|
||
# return []
|
||
|
||
|
||
# def load_icd_entries(hospital_id):
|
||
# """Load ICD entries from hospital-specific JSON file"""
|
||
# json_path = get_icd_json_path(hospital_id)
|
||
# try:
|
||
# if os.path.exists(json_path):
|
||
# with open(json_path, "r", encoding="utf-8") as f:
|
||
# return json.load(f)
|
||
# return []
|
||
# except Exception as e:
|
||
# logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
|
||
# return []
|
||
|
||
|
||
# # Update the process_icd_codes function to include hospital_id
|
||
# async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
|
||
# """Process and store ICD codes using the optimized extraction function"""
|
||
# try:
|
||
# # Extract and save codes with hospital_id
|
||
# extract_and_process_icd_data(content, hospital_id, save_to_json=True)
|
||
# except Exception as e:
|
||
# logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
|
||
|
||
|
||
# async def initialize_icd_vector_store(hospital_id):
|
||
# """This function is deprecated. ICD codes are now handled through JSON search."""
|
||
# logging.warning(
|
||
# "initialize_icd_vector_store is deprecated - using JSON search instead"
|
||
# )
|
||
# return None
|
||
|
||
|
||
# def extract_pdf_contents(pdf_path, hospital_id):
|
||
# """Extract PDF contents with optimized chunking and code extraction"""
|
||
# try:
|
||
# loader = PyPDFLoader(pdf_path)
|
||
# pages = loader.load()
|
||
# pages_content = []
|
||
|
||
# for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
|
||
# text = page.page_content.strip()
|
||
|
||
# # Extract ICD codes from the page
|
||
# icd_codes = extract_and_process_icd_data(
|
||
# text, hospital_id
|
||
# ) # We'll set doc_id later
|
||
|
||
# pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
|
||
|
||
# return pages_content
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error in extract_pdf_contents: {e}")
|
||
# raise
|
||
|
||
|
||
# async def insert_content_into_db(content, metadata, doc_id):
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor() as cursor:
|
||
# try:
|
||
# metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
|
||
# content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
|
||
|
||
# metadata_values = [
|
||
# (doc_id, key[:100], value)
|
||
# for key, value in metadata.items()
|
||
# if value
|
||
# ]
|
||
# content_values = [
|
||
# (doc_id, page_content["page"], page_content["text"])
|
||
# for page_content in content
|
||
# ]
|
||
|
||
# if metadata_values:
|
||
# await cursor.executemany(metadata_query, metadata_values)
|
||
# if content_values:
|
||
# await cursor.executemany(content_query, content_values)
|
||
|
||
# await conn.commit()
|
||
# return {"message": "Success"}
|
||
# except Exception as e:
|
||
# await conn.rollback()
|
||
# return {"error": str(e)}
|
||
|
||
|
||
# async def initialize_or_load_vector_store(hospital_id, user_id="default"):
|
||
# """Initialize or load vector store with Redis caching and thread safety"""
|
||
# store_key = f"{hospital_id}:{user_id}"
|
||
|
||
# try:
|
||
# # Check if we already have it loaded - with lock for thread safety
|
||
# with vector_store_lock:
|
||
# if store_key in hospital_vector_stores:
|
||
# return hospital_vector_stores[store_key]
|
||
|
||
# # Initialize vector store
|
||
# redis_client = get_redis_client(binary=True)
|
||
# cache_key = f"vector_store_data:{hospital_id}:{user_id}"
|
||
# hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
|
||
|
||
# if os.path.exists(hospital_dir):
|
||
# logging.info(
|
||
# f"Loading vector store for hospital {hospital_id} and user {user_id}"
|
||
# )
|
||
# vector_store = await asyncio.to_thread(
|
||
# lambda: Chroma(
|
||
# collection_name=f"hospital_{hospital_id}",
|
||
# persist_directory=hospital_dir,
|
||
# embedding_function=embeddings,
|
||
# )
|
||
# )
|
||
# else:
|
||
# logging.info(f"Creating vector store for hospital {hospital_id}")
|
||
# os.makedirs(hospital_dir, exist_ok=True)
|
||
# vector_store = await asyncio.to_thread(
|
||
# lambda: Chroma(
|
||
# collection_name=f"hospital_{hospital_id}",
|
||
# persist_directory=hospital_dir,
|
||
# embedding_function=embeddings,
|
||
# )
|
||
# )
|
||
|
||
# # Store with lock for thread safety
|
||
# with vector_store_lock:
|
||
# hospital_vector_stores[store_key] = vector_store
|
||
|
||
# return vector_store
|
||
# except Exception as e:
|
||
# logging.error(f"Error initializing vector store: {e}", exc_info=True)
|
||
# raise
|
||
|
||
|
||
# async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
|
||
# """Delete all vectors associated with a specific document from ChromaDB"""
|
||
# try:
|
||
# # Initialize vector store for the hospital
|
||
# vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
# # Delete vectors with matching doc_id
|
||
# await asyncio.to_thread(
|
||
# lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
|
||
# )
|
||
|
||
# # Persist changes
|
||
# await asyncio.to_thread(vector_store.persist)
|
||
|
||
# # Clear Redis cache for this document
|
||
# redis_client = get_redis_client()
|
||
# pattern = f"vector_store_data:{hospital_id}:*"
|
||
# for key in redis_client.scan_iter(pattern):
|
||
# redis_client.delete(key)
|
||
|
||
# logging.info(
|
||
# f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
|
||
# )
|
||
# return True
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error deleting document vectors: {e}", exc_info=True)
|
||
# return False
|
||
|
||
|
||
# async def add_document_to_index(doc_id, hospital_id):
|
||
# try:
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor() as cursor:
|
||
# vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
# await cursor.execute(
|
||
# "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
|
||
# (doc_id,),
|
||
# )
|
||
# rows = await cursor.fetchall()
|
||
|
||
# total_pages = len(rows)
|
||
# logging.info(f"Processing {total_pages} pages for document {doc_id}")
|
||
# page_bar = tqdm_async(total=total_pages, desc="Processing pages")
|
||
|
||
# async def process_page(page_data):
|
||
# page_num, content = page_data
|
||
# try:
|
||
# icd_data = extract_and_process_icd_data(
|
||
# content, hospital_id, save_to_json=False
|
||
# )
|
||
# chunks = text_splitter.split_text(content)
|
||
# await asyncio.sleep(0) # Yield control
|
||
# return page_num, chunks, icd_data
|
||
# except Exception as e:
|
||
# logging.error(f"Error processing page {page_num}: {e}")
|
||
# return page_num, [], []
|
||
|
||
# tasks = [asyncio.create_task(process_page(row)) for row in rows]
|
||
# results = []
|
||
|
||
# for coro in asyncio.as_completed(tasks):
|
||
# result = await coro
|
||
# results.append(result)
|
||
# page_bar.update(1)
|
||
|
||
# page_bar.close()
|
||
|
||
# # Vector addition progress bar
|
||
# all_icd_data = []
|
||
# all_chunks = []
|
||
# all_metadatas = []
|
||
|
||
# chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
|
||
|
||
# for result in results:
|
||
# page_num, chunks, icd_data = result
|
||
# all_icd_data.extend(icd_data)
|
||
|
||
# for i, chunk in enumerate(chunks):
|
||
# all_chunks.append(chunk)
|
||
# all_metadatas.append(
|
||
# {
|
||
# "doc_id": str(doc_id),
|
||
# "hospital_id": str(hospital_id),
|
||
# "page_number": str(page_num),
|
||
# "chunk_index": str(i),
|
||
# }
|
||
# )
|
||
|
||
# if len(all_chunks) >= BATCH_SIZE:
|
||
# chunk_add_bar.total += len(all_chunks)
|
||
# chunk_add_bar.refresh()
|
||
# await asyncio.to_thread(
|
||
# vector_store.add_texts,
|
||
# texts=all_chunks,
|
||
# metadatas=all_metadatas,
|
||
# )
|
||
# all_chunks = []
|
||
# all_metadatas = []
|
||
# chunk_add_bar.update(BATCH_SIZE)
|
||
|
||
# # Final batch
|
||
# if all_chunks:
|
||
# chunk_add_bar.total += len(all_chunks)
|
||
# chunk_add_bar.refresh()
|
||
# await asyncio.to_thread(
|
||
# vector_store.add_texts,
|
||
# texts=all_chunks,
|
||
# metadatas=all_metadatas,
|
||
# )
|
||
# chunk_add_bar.update(len(all_chunks))
|
||
|
||
# chunk_add_bar.close()
|
||
|
||
# if all_icd_data:
|
||
# logging.info(f"Saving {len(all_icd_data)} ICD codes")
|
||
# extract_and_process_icd_data("", hospital_id, save_to_json=True)
|
||
|
||
# await asyncio.to_thread(vector_store.persist)
|
||
# logging.info(f"Successfully indexed document {doc_id}")
|
||
# return True
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error adding document: {e}")
|
||
# return False
|
||
|
||
|
||
# def is_general_knowledge_question(
|
||
# query: str, context: str, conversation_context=None
|
||
# ) -> bool:
|
||
# """
|
||
# Determine if a question is likely a general knowledge question not covered in the documents.
|
||
# Takes conversation history into account to reduce repeated confirmations.
|
||
# """
|
||
# query_lower = query.lower()
|
||
# context_lower = context.lower()
|
||
|
||
# if conversation_context:
|
||
# for interaction in conversation_context:
|
||
# prev_question = interaction.get("question", "").lower()
|
||
# if (
|
||
# prev_question
|
||
# and query_lower in prev_question
|
||
# or prev_question in query_lower
|
||
# ):
|
||
# logging.info(
|
||
# f"Question is similar to previous conversation, skipping confirmation"
|
||
# )
|
||
# return False
|
||
|
||
# stop_words = {
|
||
# "search",
|
||
# "query:",
|
||
# "can",
|
||
# "you",
|
||
# "some",
|
||
# "at",
|
||
# "the",
|
||
# "a",
|
||
# "an",
|
||
# "in",
|
||
# "on",
|
||
# "at",
|
||
# "to",
|
||
# "for",
|
||
# "with",
|
||
# "by",
|
||
# "about",
|
||
# "give",
|
||
# "full",
|
||
# "is",
|
||
# "are",
|
||
# "was",
|
||
# "were",
|
||
# "define",
|
||
# "what",
|
||
# "how",
|
||
# "why",
|
||
# "when",
|
||
# "where",
|
||
# "year",
|
||
# "list",
|
||
# "form",
|
||
# "table",
|
||
# "who",
|
||
# "which",
|
||
# "me",
|
||
# "tell",
|
||
# "explain",
|
||
# "describe",
|
||
# "of",
|
||
# "and",
|
||
# "or",
|
||
# "there",
|
||
# "their",
|
||
# "please",
|
||
# "could",
|
||
# "would",
|
||
# "various",
|
||
# "different",
|
||
# "type",
|
||
# "types",
|
||
# "kind",
|
||
# "kinds",
|
||
# "has",
|
||
# "have",
|
||
# "had",
|
||
# "many",
|
||
# "say",
|
||
# "know",
|
||
# }
|
||
|
||
# key_words = [
|
||
# word for word in query_lower.split() if word not in stop_words and len(word) > 2
|
||
# ]
|
||
# logging.info(f"Key words: {key_words}")
|
||
|
||
# if not key_words:
|
||
# logging.info("No significant keywords found, directing to general knowledge")
|
||
# return True
|
||
|
||
# matches = sum(1 for word in key_words if word in context_lower)
|
||
# logging.info(f"Matches: {matches} out of {len(key_words)} keywords")
|
||
|
||
# match_ratio = matches / len(key_words)
|
||
# logging.info(f"Match ratio: {match_ratio}")
|
||
|
||
# return match_ratio < 0.4
|
||
|
||
# def is_table_request(query: str) -> bool:
|
||
# """
|
||
# Determine if the user is requesting a response in tabular format.
|
||
# """
|
||
# table_keywords = [
|
||
# "table",
|
||
# "tabular",
|
||
# "in a table",
|
||
# "in table format",
|
||
# "in tabular format",
|
||
# "chart",
|
||
# "data",
|
||
# "comparison",
|
||
# "as a table",
|
||
# "table format",
|
||
# "in rows and columns",
|
||
# "in a grid",
|
||
# "breakdown",
|
||
# "spreadsheet",
|
||
# "comparison table",
|
||
# "data table",
|
||
# "structured table",
|
||
# "tabular form",
|
||
# "table form",
|
||
# ]
|
||
|
||
# query_lower = query.lower()
|
||
# return any(keyword in query_lower for keyword in table_keywords)
|
||
|
||
|
||
# import re
|
||
|
||
|
||
# def ensure_html_response(text: str) -> str:
|
||
# """
|
||
# Ensure the response is properly formatted in HTML.
|
||
# This function handles plain text conversion to HTML.
|
||
# """
|
||
# if "<html" in text.lower() or "<body" in text.lower():
|
||
# return text
|
||
|
||
# has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
|
||
|
||
# if not has_html_tags:
|
||
# paragraphs = text.split("\n\n")
|
||
# html_parts = []
|
||
# in_ordered_list = False
|
||
# in_unordered_list = False
|
||
|
||
# for para in paragraphs:
|
||
# if para.strip():
|
||
# if re.match(r"^\s*[\*\-\•]\s", para):
|
||
# if not in_unordered_list:
|
||
# html_parts.append("<ul>")
|
||
# in_unordered_list = True
|
||
|
||
# lines = para.split("\n")
|
||
# for line in lines:
|
||
# if line.strip():
|
||
# item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
|
||
# html_parts.append(f"<li>{item}</li>")
|
||
|
||
# elif re.match(r"^\s*\d+\.\s", para):
|
||
# if not in_ordered_list:
|
||
# html_parts.append("<ol>")
|
||
# in_ordered_list = True
|
||
|
||
# lines = para.split("\n")
|
||
# for line in lines:
|
||
# match = re.match(r"^\s*\d+\.\s*(.*)", line)
|
||
# if match:
|
||
# html_parts.append(f"<li>{match.group(1)}</li>")
|
||
|
||
# else: # Close any open lists before adding a new paragraph
|
||
# if in_ordered_list:
|
||
# html_parts.append("</ol>")
|
||
# in_ordered_list = False
|
||
# if in_unordered_list:
|
||
# html_parts.append("</ul>")
|
||
# in_unordered_list = False
|
||
|
||
# html_parts.append(f"<p>{para}</p>")
|
||
|
||
# if in_ordered_list:
|
||
# html_parts.append("</ol>")
|
||
# if in_unordered_list:
|
||
# html_parts.append("</ul>")
|
||
|
||
# return "".join(html_parts)
|
||
|
||
# else:
|
||
# if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
|
||
# paragraphs = text.split("\n\n")
|
||
# html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
|
||
# return "".join(html_parts)
|
||
|
||
# return text
|
||
|
||
|
||
# class HybridConversationManager:
|
||
# """
|
||
# Hybrid conversation manager that uses Redis for RAG-based conversations
|
||
# and in-memory storage for general knowledge conversations.
|
||
# """
|
||
|
||
# def __init__(self, redis_client, ttl=3600, max_history_items=5):
|
||
# self.redis_client = redis_client
|
||
# self.ttl = ttl
|
||
# self.max_history_items = max_history_items
|
||
|
||
# # For general knowledge questions (in-memory)
|
||
# self.general_knowledge_histories = {}
|
||
# self.lock = Lock()
|
||
|
||
# def _get_redis_key(self, user_id, hospital_id, session_id=None):
|
||
# """Create Redis key for document-based conversations."""
|
||
# if session_id:
|
||
# return f"conv_history:{user_id}:{hospital_id}:{session_id}"
|
||
# return f"conv_history:{user_id}:{hospital_id}"
|
||
|
||
# def _get_memory_key(self, user_id, hospital_id, session_id=None):
|
||
# """Create memory key for general knowledge conversations."""
|
||
# if session_id:
|
||
# return f"{user_id}:{hospital_id}:{session_id}"
|
||
# return f"{user_id}:{hospital_id}"
|
||
|
||
# async def add_rag_interaction(
|
||
# self, user_id, hospital_id, question, answer, session_id=None
|
||
# ):
|
||
# """Add document-based (RAG) interaction to Redis."""
|
||
# key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
# history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
|
||
# # Add new interaction
|
||
# history.append(
|
||
# {
|
||
# "question": question,
|
||
# "answer": answer,
|
||
# "timestamp": time.time(),
|
||
# "type": "rag", # Mark as RAG-based interaction
|
||
# }
|
||
# )
|
||
|
||
# # Keep only last N interactions
|
||
# history = history[-self.max_history_items :]
|
||
|
||
# # Store updated history
|
||
# try:
|
||
# self.redis_client.setex(key, self.ttl, json.dumps(history))
|
||
# logging.info(
|
||
# f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
|
||
# )
|
||
# except Exception as e:
|
||
# logging.error(f"Failed to store RAG interaction in Redis: {e}")
|
||
|
||
# def add_general_knowledge_interaction(
|
||
# self, user_id, hospital_id, question, answer, session_id=None
|
||
# ):
|
||
# """Add general knowledge interaction to in-memory store."""
|
||
# key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
# with self.lock:
|
||
# if key not in self.general_knowledge_histories:
|
||
# self.general_knowledge_histories[key] = []
|
||
|
||
# self.general_knowledge_histories[key].append(
|
||
# {
|
||
# "question": question,
|
||
# "answer": answer,
|
||
# "timestamp": time.time(),
|
||
# "type": "general", # Mark as general knowledge interaction
|
||
# }
|
||
# )
|
||
|
||
# # Keep only the most recent interactions
|
||
# if len(self.general_knowledge_histories[key]) > self.max_history_items:
|
||
# self.general_knowledge_histories[key] = (
|
||
# self.general_knowledge_histories[key][-self.max_history_items :]
|
||
# )
|
||
|
||
# logging.info(
|
||
# f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}"
|
||
# )
|
||
|
||
# def get_rag_history(self, user_id, hospital_id, session_id=None):
|
||
# """Get document-based (RAG) conversation history from Redis."""
|
||
# key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
# try:
|
||
# history_data = self.redis_client.get(key)
|
||
# return json.loads(history_data) if history_data else []
|
||
# except Exception as e:
|
||
# logging.error(f"Failed to retrieve RAG history from Redis: {e}")
|
||
# return []
|
||
|
||
# def get_general_knowledge_history(self, user_id, hospital_id, session_id=None):
|
||
# """Get general knowledge conversation history from memory."""
|
||
# key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
# with self.lock:
|
||
# return self.general_knowledge_histories.get(key, []).copy()
|
||
|
||
# def get_combined_history(self, user_id, hospital_id, session_id=None):
|
||
# """Get combined conversation history from both sources, sorted by timestamp."""
|
||
# rag_history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
# general_history = self.get_general_knowledge_history(
|
||
# user_id, hospital_id, session_id
|
||
# )
|
||
|
||
# # Combine histories
|
||
# combined_history = rag_history + general_history
|
||
|
||
# # Sort by timestamp (newest first)
|
||
# combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||
|
||
# # Return most recent N items
|
||
# return combined_history[: self.max_history_items]
|
||
|
||
# def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
|
||
# """Get the most recent interactions for context from combined history."""
|
||
# combined_history = self.get_combined_history(user_id, hospital_id, session_id)
|
||
# # Sort by timestamp (oldest first) for context window
|
||
# sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0))
|
||
# return sorted_history[-window_size:] if sorted_history else []
|
||
|
||
# def clear_history(self, user_id, hospital_id):
|
||
# """Clear conversation history from both stores."""
|
||
# # Clear Redis history
|
||
# redis_key = self._get_redis_key(user_id, hospital_id)
|
||
# try:
|
||
# self.redis_client.delete(redis_key)
|
||
# except Exception as e:
|
||
# logging.error(f"Failed to clear Redis history: {e}")
|
||
|
||
# # Clear memory history
|
||
# memory_key = self._get_memory_key(user_id, hospital_id)
|
||
# with self.lock:
|
||
# if memory_key in self.general_knowledge_histories:
|
||
# del self.general_knowledge_histories[memory_key]
|
||
|
||
|
||
# class ContextMapper:
|
||
# """Enhanced context mapping using shared model manager"""
|
||
|
||
# def __init__(self):
|
||
# self.model_manager = ModelManager()
|
||
# self.context_cache = {}
|
||
# self.similarity_threshold = 0.6
|
||
|
||
# def get_semantic_similarity(self, text1, text2):
|
||
# """Get semantic similarity using global model manager"""
|
||
# return self.model_manager.get_semantic_similarity(text1, text2)
|
||
|
||
# def extract_key_concepts(self, text):
|
||
# """Extract key concepts using NLP techniques"""
|
||
# doc = nlp(text)
|
||
# concepts = []
|
||
|
||
# entities = [(ent.text, ent.label_) for ent in doc.ents]
|
||
# noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
||
# important_words = [
|
||
# token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
|
||
# ]
|
||
|
||
# concepts.extend([e[0] for e in entities])
|
||
# concepts.extend(noun_phrases)
|
||
# concepts.extend(important_words)
|
||
|
||
# return list(set(concepts))
|
||
|
||
# def map_conversation_context(
|
||
# self, current_query, conversation_history, context_window=3
|
||
# ):
|
||
# """Map conversation context using enhanced NLP techniques"""
|
||
# if not conversation_history:
|
||
# return current_query
|
||
|
||
# recent_context = conversation_history[-context_window:]
|
||
# context_concepts = []
|
||
|
||
# # Extract concepts from recent conversations
|
||
# for interaction in recent_context:
|
||
# q_concepts = self.extract_key_concepts(interaction["question"])
|
||
# a_concepts = self.extract_key_concepts(interaction["answer"])
|
||
# context_concepts.extend(q_concepts)
|
||
# context_concepts.extend(a_concepts)
|
||
|
||
# # Extract concepts from current query
|
||
# query_concepts = self.extract_key_concepts(current_query)
|
||
|
||
# # Find related concepts
|
||
# related_concepts = []
|
||
# for q_concept in query_concepts:
|
||
# for c_concept in context_concepts:
|
||
# similarity = self.get_semantic_similarity(q_concept, c_concept)
|
||
# if similarity > self.similarity_threshold:
|
||
# related_concepts.append(c_concept)
|
||
|
||
# # Build enhanced query
|
||
# if related_concepts:
|
||
# enhanced_query = (
|
||
# f"{current_query} in context of {', '.join(related_concepts)}"
|
||
# )
|
||
# else:
|
||
# enhanced_query = current_query
|
||
|
||
# return enhanced_query
|
||
|
||
|
||
# # Initialize the context mapper
|
||
# context_mapper = ContextMapper()
|
||
|
||
|
||
# async def generate_contextual_query(
|
||
# question: str, user_id: str, hospital_id: int, conversation_manager
|
||
# ) -> str:
|
||
# """Generate enhanced contextual query"""
|
||
# context_window = conversation_manager.get_context_window(user_id, hospital_id)
|
||
|
||
# if not context_window:
|
||
# return question
|
||
|
||
# # Enhanced context mapping
|
||
# last_interaction = context_window[-1]
|
||
# enhanced_context = f"""
|
||
# Previous question: {last_interaction['question']}
|
||
# Previous answer: {last_interaction['answer']}
|
||
# Current question: {question}
|
||
|
||
# Please generate a detailed search query that combines the context from the previous answer
|
||
# with the current question, especially if the current question uses words like 'it', 'this',
|
||
# 'that', or asks for more details about the previous topic.
|
||
# """
|
||
|
||
# try:
|
||
# response = await asyncio.to_thread(
|
||
# lambda: client.chat.completions.create(
|
||
# model="gpt-3.5-turbo",
|
||
# messages=[
|
||
# {
|
||
# "role": "system",
|
||
# "content": "You are a context-aware query generator.",
|
||
# },
|
||
# {"role": "user", "content": enhanced_context},
|
||
# ],
|
||
# temperature=0.3,
|
||
# max_tokens=150,
|
||
# )
|
||
# )
|
||
# contextual_query = response.choices[0].message.content.strip()
|
||
# logging.info(f"Enhanced contextual query: {contextual_query}")
|
||
# return contextual_query
|
||
# except Exception as e:
|
||
# logging.error(f"Error generating contextual query: {e}")
|
||
# return question
|
||
|
||
|
||
# def is_follow_up(current_question: str, conversation_history: list) -> bool:
|
||
# """Enhanced follow-up detection using NLP techniques"""
|
||
# if not conversation_history:
|
||
# return False
|
||
|
||
# last_interaction = conversation_history[-1]
|
||
|
||
# # Get semantic similarity with higher threshold
|
||
# similarity = context_mapper.get_semantic_similarity(
|
||
# current_question, f"{last_interaction['question']} {last_interaction['answer']}"
|
||
# )
|
||
|
||
# # Enhanced referential check
|
||
# doc = nlp(current_question.lower())
|
||
# has_referential = any(
|
||
# token.lemma_
|
||
# in [
|
||
# "it",
|
||
# "this",
|
||
# "that",
|
||
# "these",
|
||
# "those",
|
||
# "they",
|
||
# "he",
|
||
# "she",
|
||
# "about",
|
||
# "more",
|
||
# ]
|
||
# for token in doc
|
||
# )
|
||
|
||
# # Extract concepts with improved entity detection
|
||
# current_concepts = set(context_mapper.extract_key_concepts(current_question))
|
||
# last_concepts = set(
|
||
# context_mapper.extract_key_concepts(
|
||
# f"{last_interaction['question']} {last_interaction['answer']}"
|
||
# )
|
||
# )
|
||
|
||
# # Calculate enhanced concept overlap
|
||
# concept_overlap = (
|
||
# len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
|
||
# if current_concepts
|
||
# else 0
|
||
# )
|
||
|
||
# # More aggressive follow-up detection
|
||
# return (
|
||
# similarity > 0.3 # Lowered threshold
|
||
# or has_referential
|
||
# or concept_overlap > 0.2 # Lowered threshold
|
||
# or any(
|
||
# word in current_question.lower()
|
||
# for word in ["more", "about", "elaborate", "explain"]
|
||
# )
|
||
# )
|
||
|
||
|
||
# async def get_relevant_context(question, hospital_id, doc_id=None):
|
||
# try:
|
||
# cache_key = f"context:hospital_{hospital_id}"
|
||
# if doc_id:
|
||
# cache_key += f":doc_{doc_id}"
|
||
# cache_key += f":{question.lower().strip()}"
|
||
|
||
# redis_client = get_redis_client()
|
||
|
||
# cached_context = redis_client.get(cache_key)
|
||
# if cached_context:
|
||
# logging.info(f"Cache hit for key: {cache_key}")
|
||
# return (
|
||
# cached_context.decode("utf-8")
|
||
# if isinstance(cached_context, bytes)
|
||
# else cached_context
|
||
# )
|
||
|
||
# vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
# if not vector_store:
|
||
# return ""
|
||
|
||
# retriever = vector_store.as_retriever(
|
||
# search_type="mmr",
|
||
# search_kwargs={
|
||
# "k": 10,
|
||
# "fetch_k": 20,
|
||
# "lambda_mult": 0.6,
|
||
# # "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
||
# },
|
||
# )
|
||
|
||
# docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
|
||
# if not docs:
|
||
# return ""
|
||
|
||
# sorted_docs = sorted(
|
||
# docs,
|
||
# key=lambda x: (
|
||
# int(x.metadata.get("page_number", 0)),
|
||
# int(x.metadata.get("chunk_index", 0)),
|
||
# ),
|
||
# )
|
||
|
||
# context_parts = [doc.page_content for doc in sorted_docs]
|
||
# context = "\n\n".join(context_parts)
|
||
|
||
# try:
|
||
# redis_client.setex(
|
||
# cache_key,
|
||
# int(CACHE_TTL["vector_store"].total_seconds()),
|
||
# context.encode("utf-8") if isinstance(context, str) else context,
|
||
# )
|
||
# logging.info(f"Cached context for key: {cache_key}")
|
||
# except Exception as cache_error:
|
||
# logging.error(f"Failed to cache context: {cache_error}")
|
||
|
||
# return context
|
||
# except Exception as e:
|
||
# logging.error(f"Error getting relevant context: {e}")
|
||
# return ""
|
||
|
||
|
||
# def format_conversation_context(conv_history):
|
||
# """Format conversation history into a string"""
|
||
# if not conv_history:
|
||
# return "No previous conversation."
|
||
# return "\n".join(
|
||
# [
|
||
# f"Q: {interaction['question']}\nA: {interaction['answer']}"
|
||
# for interaction in conv_history
|
||
# ]
|
||
# )
|
||
|
||
|
||
# def get_icd_context_from_question(question, hospital_id):
|
||
# """Extract any valid ICD codes from the question and return context"""
|
||
# icd_data = load_icd_entries(hospital_id)
|
||
# matches = []
|
||
# code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
|
||
|
||
# seen = set()
|
||
# for code in code_pattern:
|
||
# for entry in icd_data:
|
||
# if entry["code"] == code and code not in seen:
|
||
# matches.append(f"{entry['code']}: {entry['description']}")
|
||
# seen.add(code)
|
||
# return "\n".join(matches)
|
||
|
||
|
||
# def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
|
||
# """Get fuzzy matches for ICD codes from the question"""
|
||
# icd_data = load_icd_entries(hospital_id)
|
||
# descriptions = [entry["description"] for entry in icd_data]
|
||
# matches = process.extract(
|
||
# question, descriptions, limit=top_n, score_cutoff=threshold
|
||
# )
|
||
|
||
# matched_context = []
|
||
# for desc, score, _ in matches:
|
||
# for entry in icd_data:
|
||
# if entry["description"] == desc:
|
||
# matched_context.append(f"{entry['code']}: {entry['description']}")
|
||
# break
|
||
|
||
# return "\n".join(matched_context)
|
||
|
||
|
||
# async def generate_answer_with_rag(
|
||
# question,
|
||
# hospital_id,
|
||
# client,
|
||
# doc_id=None,
|
||
# user_id="default",
|
||
# conversation_manager=None,
|
||
# session_id=None,
|
||
# ):
|
||
# """Generate an answer using RAG with improved conversation flow"""
|
||
# try:
|
||
# # Continue with regular RAG processing if not an ICD code or if no ICD match found
|
||
# html_instruction = """
|
||
# IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
||
# - Use <p> tags for paragraphs
|
||
# - Use <h2>, <h3> tags for headings and subheadings
|
||
# - Use <ul>, <li> tags for bullet points
|
||
# - Use <ol>, <li> tags for numbered lists
|
||
# - Use <blockquote> for quoted text
|
||
# - Use <strong> for bold text and <em> for emphasis
|
||
# """
|
||
|
||
# table_instruction = """
|
||
# - For tables, use proper HTML table structure:
|
||
# <table border="1">
|
||
# <thead>
|
||
# <tr>
|
||
# <th colspan="{total_columns}">{table_title}</th>
|
||
# </tr>
|
||
# <tr>
|
||
# {table_headers}
|
||
# </tr>
|
||
# </thead>
|
||
# <tbody>
|
||
# {table_rows}
|
||
# </tbody>
|
||
# </table>
|
||
# """
|
||
# # Get conversation history first
|
||
# conv_history = (
|
||
# conversation_manager.get_context_window(user_id, hospital_id, session_id)
|
||
# if conversation_manager
|
||
# else []
|
||
# )
|
||
|
||
# # Get contextual query and relevant context first
|
||
# contextual_query = await generate_contextual_query(
|
||
# question, user_id, hospital_id, conversation_manager
|
||
# )
|
||
# # Track ICD context across conversation
|
||
# icd_context = {}
|
||
# if conv_history:
|
||
# # Extract ICD code from previous interaction
|
||
# last_answer = conv_history[-1].get("answer", "")
|
||
# icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer)
|
||
# if icd_codes:
|
||
# icd_context["last_code"] = icd_codes[0]
|
||
|
||
# # Check if current question is about a previously discussed ICD code
|
||
# is_icd_followup = False
|
||
# if icd_context.get("last_code"):
|
||
# followup_indicators = [
|
||
# "what causes",
|
||
# "what is causing",
|
||
# "why",
|
||
# "how",
|
||
# "symptoms",
|
||
# "treatment",
|
||
# "diagnosis",
|
||
# "causes",
|
||
# "effects",
|
||
# "complications",
|
||
# "risk factors",
|
||
# "prevention",
|
||
# "prognosis",
|
||
# "this",
|
||
# "disease",
|
||
# "that",
|
||
# "it",
|
||
# ]
|
||
# is_icd_followup = any(
|
||
# indicator in question.lower() for indicator in followup_indicators
|
||
# )
|
||
|
||
# if is_icd_followup:
|
||
# # Add the previous ICD code context to the current question
|
||
# icd_exact_context = get_icd_context_from_question(
|
||
# icd_context["last_code"], hospital_id
|
||
# )
|
||
# icd_fuzzy_context = get_fuzzy_icd_context(
|
||
# f"{icd_context['last_code']} {question}", hospital_id
|
||
# )
|
||
# else:
|
||
# icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
# icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
# else:
|
||
# icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
# icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
|
||
# # Get contextual query and relevant context
|
||
# contextual_query = await generate_contextual_query(
|
||
# question, user_id, hospital_id, conversation_manager
|
||
# )
|
||
# doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id)
|
||
|
||
# # Combine context with priority for ICD information
|
||
# context_parts = []
|
||
# if is_icd_followup:
|
||
# context_parts.append(
|
||
# f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}"
|
||
# )
|
||
# if icd_exact_context:
|
||
# context_parts.append("## ICD Code Match\n" + icd_exact_context)
|
||
# if icd_fuzzy_context:
|
||
# context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context)
|
||
# if doc_context:
|
||
# context_parts.append("## Document Context\n" + doc_context)
|
||
|
||
# context = "\n\n".join(context_parts)
|
||
|
||
# # Initialize follow-up detection
|
||
# is_follow_up = False
|
||
|
||
# # Check if this is a follow-up question
|
||
# if conv_history:
|
||
# last_interaction = conv_history[-1]
|
||
# last_question = last_interaction["question"].lower()
|
||
# last_answer = last_interaction.get("answer", "").lower()
|
||
# current_question = question.lower()
|
||
|
||
# # Define meaningful keywords that indicate entity-related follow-ups
|
||
# entity_related_keywords = {
|
||
# "achievements",
|
||
# "awards",
|
||
# "accomplishments",
|
||
# "work",
|
||
# "contributions",
|
||
# "career",
|
||
# "company",
|
||
# "products",
|
||
# "life",
|
||
# "background",
|
||
# "education",
|
||
# "role",
|
||
# "experience",
|
||
# "history",
|
||
# "details",
|
||
# "places",
|
||
# "place",
|
||
# "information",
|
||
# "facts",
|
||
# "about",
|
||
# "birth",
|
||
# "death",
|
||
# "family",
|
||
# "books",
|
||
# "projects",
|
||
# "population",
|
||
# }
|
||
|
||
# # Check if question is asking about attributes/achievements of previously discussed entity
|
||
# has_entity_attribute = any(
|
||
# word in current_question.split() for word in entity_related_keywords
|
||
# )
|
||
|
||
# # Extract entities from last answer to maintain context
|
||
# def extract_entities(text):
|
||
# # Split into words and get potential entities (capitalized words)
|
||
# words = text.split()
|
||
# entities = set()
|
||
# current_entity = []
|
||
|
||
# for word in words:
|
||
# if word[0].isupper():
|
||
# current_entity.append(word)
|
||
# elif current_entity:
|
||
# if len(current_entity) > 0:
|
||
# entities.add(" ".join(current_entity))
|
||
# current_entity = []
|
||
|
||
# if current_entity:
|
||
# entities.add(" ".join(current_entity))
|
||
# return entities
|
||
|
||
# last_entities = extract_entities(last_answer)
|
||
|
||
# # Check for referential words
|
||
# referential_words = {
|
||
# "it",
|
||
# "this",
|
||
# "that",
|
||
# "these",
|
||
# "those",
|
||
# "they",
|
||
# "their",
|
||
# "he",
|
||
# "she",
|
||
# "him",
|
||
# "her",
|
||
# "his",
|
||
# "hers",
|
||
# "them",
|
||
# "there",
|
||
# "such",
|
||
# "its",
|
||
# }
|
||
# has_referential = any(
|
||
# word in referential_words for word in current_question.split()
|
||
# )
|
||
|
||
# # Calculate term overlap with both question and answer context
|
||
# def get_significant_terms(text):
|
||
# stop_words = {
|
||
# "what",
|
||
# "when",
|
||
# "where",
|
||
# "who",
|
||
# "why",
|
||
# "how",
|
||
# "is",
|
||
# "are",
|
||
# "was",
|
||
# "were",
|
||
# "be",
|
||
# "been",
|
||
# "the",
|
||
# "a",
|
||
# "an",
|
||
# "in",
|
||
# "on",
|
||
# "at",
|
||
# "to",
|
||
# "for",
|
||
# "of",
|
||
# "with",
|
||
# "by",
|
||
# "about",
|
||
# "as",
|
||
# "tell",
|
||
# "me",
|
||
# "please",
|
||
# }
|
||
# return set(
|
||
# word
|
||
# for word in text.split()
|
||
# if len(word) > 2 and word.lower() not in stop_words
|
||
# )
|
||
|
||
# current_terms = get_significant_terms(current_question)
|
||
# last_terms = get_significant_terms(last_question)
|
||
# answer_terms = get_significant_terms(last_answer)
|
||
|
||
# # Include terms from both question and answer in context
|
||
# all_prev_terms = last_terms | answer_terms
|
||
# term_overlap = len(current_terms & all_prev_terms)
|
||
# total_terms = len(current_terms | all_prev_terms)
|
||
# term_similarity = term_overlap / total_terms if total_terms > 0 else 0
|
||
|
||
# # Enhanced follow-up detection combining multiple signals
|
||
# is_follow_up = (
|
||
# has_referential
|
||
# or term_similarity
|
||
# >= 0.2 # Lower threshold when including answer context
|
||
# or (
|
||
# has_entity_attribute and bool(last_entities)
|
||
# ) # Check if asking about attributes of known entity
|
||
# or (
|
||
# last_interaction.get("type") == "general"
|
||
# and term_similarity >= 0.15
|
||
# )
|
||
# )
|
||
|
||
# logging.info(f"Follow-up analysis enhanced:")
|
||
# logging.info(f"- Referential words: {has_referential}")
|
||
# logging.info(f"- Term similarity: {term_similarity:.2f}")
|
||
# logging.info(f"- Entity attribute question: {has_entity_attribute}")
|
||
# logging.info(f"- Last entities found: {last_entities}")
|
||
# logging.info(f"- Is follow-up: {is_follow_up}")
|
||
|
||
# # For entirely new topics (not follow-ups), use is_general_knowledge_question
|
||
# if not is_follow_up:
|
||
# is_general = is_general_knowledge_question(question, context, conv_history)
|
||
# if is_general:
|
||
# confirmation_prompt = f"""
|
||
# <p>Reviewed the hospital’s documentation, but this particular question does not seem to be covered.</p>
|
||
# """
|
||
# logging.info("General knowledge question detected")
|
||
# return {
|
||
# "answer": confirmation_prompt,
|
||
# "requires_confirmation": True,
|
||
# }, 200
|
||
|
||
|
||
# prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question.
|
||
# Previous conversation:
|
||
# {format_conversation_context(conv_history)}
|
||
|
||
# Context from documents:
|
||
# {context}
|
||
|
||
# Current question: {question}
|
||
|
||
# Instructions:
|
||
# 1. When providing medical codes (ICD, CPT, etc.):
|
||
# - Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context.
|
||
# - Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context.
|
||
# - If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format.
|
||
# - Remove all decimal points (e.g., use 'A150' instead of 'A15.0')
|
||
# - Format the response as: '<p>The medical code for [condition] is [code]</p>'
|
||
# 2. Address the current question while maintaining conversation continuity
|
||
# 3. Resolve any ambiguous references using conversation history
|
||
# 4. Format the response in clear HTML.
|
||
# 5. Strictly it should answer only from the {context} provided, do not invent or assume and give the information and it should not be general knowledge also. Purely 100% RAG-based response and only from documents.
|
||
|
||
# {html_instruction}
|
||
# {table_instruction if is_table_request(question) else ""}
|
||
# """
|
||
|
||
# response = await asyncio.to_thread(
|
||
# lambda: client.chat.completions.create(
|
||
# model="gpt-3.5-turbo-16k",
|
||
# messages=[
|
||
# {"role": "system", "content": prompt_template},
|
||
# {"role": "user", "content": question},
|
||
# ],
|
||
# temperature=0.2,
|
||
# max_tokens=1000,
|
||
# )
|
||
# )
|
||
|
||
# answer = ensure_html_response(response.choices[0].message.content)
|
||
# logging.info(f"Generated RAG answer for question: {question}")
|
||
|
||
# # Store interaction in history
|
||
# if conversation_manager:
|
||
# await conversation_manager.add_rag_interaction(
|
||
# user_id, hospital_id, question, answer, session_id
|
||
# )
|
||
|
||
# return {"answer": answer}, 200
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error in generate_answer_with_rag: {e}")
|
||
# return {"answer": f"<p>Error: {str(e)}</p>"}, 500
|
||
|
||
# async def load_existing_vector_stores():
|
||
# """Load existing Chroma vector stores for each hospital"""
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor() as cursor:
|
||
# try:
|
||
# await cursor.execute("SELECT DISTINCT id FROM hospitals")
|
||
# hospital_ids = [row[0] for row in await cursor.fetchall()]
|
||
|
||
# for hospital_id in hospital_ids:
|
||
# try:
|
||
# await initialize_or_load_vector_store(hospital_id)
|
||
# except Exception as e:
|
||
# logging.error(
|
||
# f"Failed to load vector store for hospital {hospital_id}: {e}"
|
||
# )
|
||
# continue
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error loading vector stores: {e}")
|
||
|
||
|
||
# async def get_failed_page(doc_id):
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor() as cursor:
|
||
# try:
|
||
# await cursor.execute(
|
||
# "SELECT failed_page FROM documents WHERE id = %s", (doc_id,)
|
||
# )
|
||
# result = await cursor.fetchone()
|
||
# return result[0] if result and result[0] else None
|
||
# except Exception as e:
|
||
# logging.error(f"Database error checking failed_page: {e}")
|
||
# return None
|
||
|
||
|
||
# async def update_document_status(doc_id, status, failed_page=None):
|
||
# """Update document status with enum validation"""
|
||
# if isinstance(status, str):
|
||
# status = DocumentStatus[status.upper()].value
|
||
|
||
# pool = await get_db_pool()
|
||
# async with pool.acquire() as conn:
|
||
# async with conn.cursor() as cursor:
|
||
# try:
|
||
# if failed_page:
|
||
# await cursor.execute(
|
||
# "UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
|
||
# (status, failed_page, doc_id),
|
||
# )
|
||
# else:
|
||
# await cursor.execute(
|
||
# "UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
|
||
# (status, doc_id),
|
||
# )
|
||
# await conn.commit()
|
||
# return True
|
||
# except Exception as e:
|
||
# logging.error(f"Database update error: {e}")
|
||
# return False
|
||
|
||
|
||
# thread_pool = ThreadPoolExecutor(max_workers=10)
|
||
|
||
|
||
# def async_to_sync(coroutine):
|
||
# """Helper function to run async code in sync context"""
|
||
# loop = asyncio.new_event_loop()
|
||
# asyncio.set_event_loop(loop)
|
||
# try:
|
||
# return loop.run_until_complete(coroutine)
|
||
# finally:
|
||
# loop.close()
|
||
|
||
|
||
# @app.route("/flask-api", methods=["GET"])
|
||
# def health_check():
|
||
# """Health check endpoint"""
|
||
# access_logger.info(f"Health check request received from {request.remote_addr}")
|
||
# return jsonify({"status": "ok"}), 200
|
||
|
||
|
||
# @app.route("/flask-api/process-pdf", methods=["POST"])
|
||
# def process_pdf():
|
||
# access_logger.info(f"PDF processing request received from {request.remote_addr}")
|
||
# file_path = None
|
||
# try:
|
||
# file = request.files.get("pdf")
|
||
# hospital_id = request.form.get("hospital_id")
|
||
# doc_id = request.form.get("doc_id")
|
||
|
||
# logging.info(
|
||
# f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
|
||
# )
|
||
|
||
# if not all([file, hospital_id, doc_id]):
|
||
# return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
# def process_in_background():
|
||
# nonlocal file_path
|
||
# try:
|
||
# async_to_sync(update_document_status(doc_id, "processing"))
|
||
|
||
# # Add progress logging
|
||
# logging.info(f"Starting processing of document {doc_id}")
|
||
|
||
# filename = f"doc_{doc_id}_{file.filename}"
|
||
# file_path = os.path.join(uploads_dir, filename)
|
||
|
||
# with open(file_path, "wb") as f:
|
||
# file.save(f)
|
||
|
||
# logging.info("Extracting PDF contents...")
|
||
# content = extract_pdf_contents(file_path, int(hospital_id))
|
||
|
||
# logging.info("Inserting content into database...")
|
||
# metadata = {"filename": filename}
|
||
# result = async_to_sync(
|
||
# insert_content_into_db(content, metadata, doc_id)
|
||
# )
|
||
|
||
# if "error" in result:
|
||
# async_to_sync(update_document_status(doc_id, "failed", 1))
|
||
# return False
|
||
|
||
# logging.info("Creating embeddings and indexing...")
|
||
# success = async_to_sync(add_document_to_index(doc_id, hospital_id))
|
||
|
||
# if success:
|
||
# logging.info("Document processing completed successfully")
|
||
# async_to_sync(update_document_status(doc_id, "processed"))
|
||
# return True
|
||
# else:
|
||
# logging.error("Document processing failed during indexing")
|
||
# async_to_sync(update_document_status(doc_id, "failed"))
|
||
# return False
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Processing error: {e}")
|
||
# async_to_sync(update_document_status(doc_id, "failed"))
|
||
# return False
|
||
# finally:
|
||
# if file_path and os.path.exists(file_path):
|
||
# try:
|
||
# os.remove(file_path)
|
||
# except Exception as e:
|
||
# logging.error(f"Error removing temporary file: {e}")
|
||
|
||
# # Execute processing and wait for result
|
||
# future = thread_pool.submit(process_in_background)
|
||
# success = future.result()
|
||
|
||
# if success:
|
||
# return jsonify({"message": "Document processed successfully"}), 200
|
||
# else:
|
||
# return jsonify({"error": "Document processing failed"}), 500
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"API error: {e}")
|
||
# if file_path and os.path.exists(file_path):
|
||
# try:
|
||
# os.remove(file_path)
|
||
# except Exception as file_e:
|
||
# logging.error(f"Error removing temporary file: {file_e}")
|
||
# return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# # Initialize the hybrid conversation manager
|
||
# redis_client = get_redis_client()
|
||
# conversation_manager = HybridConversationManager(redis_client)
|
||
|
||
|
||
# @app.route("/flask-api/generate-answer", methods=["POST"])
|
||
# def rag_answer_api():
|
||
# """Sync API endpoint for RAG-based question answering with conversation history."""
|
||
# access_logger.info(f"Generate answer request received from {request.remote_addr}")
|
||
# try:
|
||
# data = request.json
|
||
# question = data.get("question", "").strip().lower()
|
||
# hospital_code = data.get("hospital_code")
|
||
# doc_id = data.get("doc_id")
|
||
# user_id = data.get("user_id", "default")
|
||
# session_id = data.get("session_id", None)
|
||
|
||
# logging.info(f"Received question from user {user_id}: {question}")
|
||
# logging.info(f"Received hospital code: {hospital_code}")
|
||
# logging.info(f"Received session_id: {session_id}")
|
||
|
||
# # is_confirmation_response = data.get("is_confirmation_response", False)
|
||
# original_query = data.get("original_query", "")
|
||
|
||
# def process_rag_answer():
|
||
# try:
|
||
# hospital_id = async_to_sync(get_hospital_id(hospital_code))
|
||
# logging.info(f"Resolved hospital ID: {hospital_id}")
|
||
|
||
# if not hospital_id:
|
||
# return {
|
||
# "error": "Invalid or missing 'hospital_code' in request"
|
||
# }, 400
|
||
|
||
# # if question == "yes" and original_query:
|
||
# # # User confirmed they want a general knowledge answer
|
||
# # answer = async_to_sync(
|
||
# # generate_general_knowledge_answer(
|
||
# # original_query,
|
||
# # client,
|
||
# # user_id,
|
||
# # hospital_id,
|
||
# # conversation_manager, # Pass the hybrid manager
|
||
# # is_table_request(original_query),
|
||
# # session_id=session_id,
|
||
# # )
|
||
# # )
|
||
# # return {"answer": answer}, 200
|
||
|
||
# if original_query:
|
||
# response_message = """
|
||
# <p>I can only answer questions based on information found in the hospital documents.</p>
|
||
# <p>The question you asked doesn't seem to be covered in the available documents.</p>
|
||
# <p>You can try rephrasing your question or asking about a different topic.</p>
|
||
# """
|
||
# return {"answer": response_message}, 200
|
||
|
||
# else:
|
||
# # Regular RAG answer
|
||
# return async_to_sync(
|
||
# generate_answer_with_rag(
|
||
# question=question,
|
||
# hospital_id=hospital_id,
|
||
# client=client,
|
||
# doc_id=doc_id,
|
||
# user_id=user_id,
|
||
# conversation_manager=conversation_manager, # Pass the hybrid manager
|
||
# session_id=session_id,
|
||
# )
|
||
# )
|
||
# except Exception as e:
|
||
# logging.error(f"Thread processing error: {str(e)}")
|
||
# return {"error": str(e)}, 500
|
||
|
||
# if not question:
|
||
# return jsonify({"error": "Missing 'question' in request"}), 400
|
||
|
||
# future = thread_pool.submit(process_rag_answer)
|
||
# result, status_code = future.result()
|
||
|
||
# return jsonify(result), status_code
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"API error: {str(e)}")
|
||
# return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# @app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
|
||
# def delete_document_vectors_endpoint():
|
||
# """Endpoint to delete document vectors from ChromaDB"""
|
||
# try:
|
||
# data = request.json
|
||
# hospital_id = data.get("hospital_id")
|
||
# doc_id = data.get("doc_id")
|
||
|
||
# if not all([hospital_id, doc_id]):
|
||
# return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
# logging.info(
|
||
# f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
|
||
# )
|
||
|
||
# def process_deletion():
|
||
# try:
|
||
# success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
|
||
# if success:
|
||
# return {"message": "Document vectors deleted successfully"}, 200
|
||
# else:
|
||
# return {"error": "Failed to delete document vectors"}, 500
|
||
# except Exception as e:
|
||
# logging.error(f"Error in vector deletion process: {e}")
|
||
# return {"error": str(e)}, 500
|
||
|
||
# future = thread_pool.submit(process_deletion)
|
||
# result, status_code = future.result()
|
||
|
||
# return jsonify(result), status_code
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"API error: {str(e)}")
|
||
# return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# @app.route("/flask-api/get-chroma-content", methods=["GET"])
|
||
# def get_chroma_content_endpoint():
|
||
# """API endpoint to get ChromaDB content by hospital_id"""
|
||
# try:
|
||
# hospital_id = request.args.get("hospital_id")
|
||
# limit = int(request.args.get("limit", 30000))
|
||
|
||
# if not hospital_id:
|
||
# return jsonify({"error": "Missing required parameter: hospital_id"}), 400
|
||
|
||
# def process_fetch():
|
||
# try:
|
||
# result, status_code = async_to_sync(
|
||
# get_chroma_content_by_hospital(
|
||
# hospital_id=int(hospital_id), limit=limit
|
||
# )
|
||
# )
|
||
# return result, status_code
|
||
# except Exception as e:
|
||
# logging.error(f"Error in ChromaDB fetch process: {e}")
|
||
# return {"error": str(e)}, 500
|
||
|
||
# future = thread_pool.submit(process_fetch)
|
||
# result, status_code = future.result()
|
||
|
||
# return jsonify(result), status_code
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"API error: {str(e)}")
|
||
# return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
|
||
# """Fetch content from ChromaDB for a specific hospital"""
|
||
# try:
|
||
# # Initialize vector store
|
||
# vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
# if not vector_store:
|
||
# return {"error": "Vector store not found"}, 404
|
||
|
||
# # Get collection
|
||
# collection = vector_store._collection
|
||
|
||
# # Query the collection with hospital_id filter
|
||
# results = await asyncio.to_thread(
|
||
# lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
|
||
# )
|
||
|
||
# if not results or not results["ids"]:
|
||
# return {"data": [], "count": 0}, 200
|
||
|
||
# # Format the response
|
||
# formatted_results = []
|
||
# for i in range(len(results["ids"])):
|
||
# formatted_results.append(
|
||
# {
|
||
# "id": results["ids"][i],
|
||
# "content": results["documents"][i],
|
||
# "metadata": results["metadatas"][i],
|
||
# }
|
||
# )
|
||
|
||
# return {"data": formatted_results, "count": len(formatted_results)}, 200
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error fetching ChromaDB content: {e}")
|
||
# return {"error": str(e)}, 500
|
||
|
||
|
||
# @app.before_request
|
||
# def before_request():
|
||
# request._start_time = time.time()
|
||
|
||
|
||
# @app.after_request
|
||
# def after_request(response):
|
||
# if hasattr(request, "_start_time"):
|
||
# duration = time.time() - request._start_time
|
||
# access_logger.info(
|
||
# f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
|
||
# f"IP: {request.remote_addr}"
|
||
# )
|
||
# return response
|
||
|
||
|
||
# if __name__ == "__main__":
|
||
# logger.info("Starting SpurrinAI application")
|
||
# logger.info(f"Python version: {sys.version}")
|
||
# logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
|
||
|
||
# try:
|
||
# model_manager = ModelManager()
|
||
# logger.info("Model manager initialized successfully")
|
||
# except Exception as e:
|
||
# logger.error(f"Failed to initialize model manager: {e}")
|
||
# sys.exit(1)
|
||
|
||
# # Initialize directories
|
||
# os.makedirs(DATA_DIR, exist_ok=True)
|
||
# os.makedirs(CHROMA_DIR, exist_ok=True)
|
||
# logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
|
||
|
||
# # Clear Redis cache
|
||
# redis_client = get_redis_client()
|
||
# cleared_keys = 0
|
||
# for key in redis_client.scan_iter("vector_store_data:*"):
|
||
# redis_client.delete(key)
|
||
# cleared_keys += 1
|
||
# logger.info(f"Cleared {cleared_keys} Redis cache keys")
|
||
|
||
# # Load vector stores
|
||
# logger.info("Loading existing vector stores...")
|
||
# async_to_sync(load_existing_vector_stores())
|
||
# logger.info("Vector stores loaded successfully")
|
||
|
||
# # Start application
|
||
# logger.info("Starting Flask application on port 5000")
|
||
# app.run(port=5000, debug=False)
|
||
|
||
"""
|
||
SpurrinAI - Intelligent Document Processing and Question Answering System
|
||
Copyright (c) 2024 Tech4biz. All rights reserved.
|
||
|
||
This module implements the main Flask application for the SpurrinAI system,
|
||
providing REST APIs for document processing, vector storage, and question answering
|
||
using RAG (Retrieval Augmented Generation) architecture.
|
||
|
||
Author: Tech4biz Development Team
|
||
Version: 1.0.0
|
||
Last Updated: 2024-01-19
|
||
"""
|
||
|
||
# Standard library imports
|
||
import os
|
||
import re
|
||
import sys
|
||
import json
|
||
import time
|
||
import threading
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from datetime import timedelta
|
||
from enum import Enum
|
||
|
||
# Third-party imports
|
||
import spacy
|
||
import redis
|
||
import aiomysql
|
||
from dotenv import load_dotenv
|
||
from flask import Flask, request, jsonify, Response
|
||
from flask_cors import CORS
|
||
from tqdm import tqdm
|
||
from tqdm.asyncio import tqdm as tqdm_async
|
||
from langchain_community.document_loaders import PyPDFLoader
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.embeddings import OpenAIEmbeddings
|
||
from langchain_community.vectorstores import Chroma
|
||
from langchain_community.chat_models import ChatOpenAI
|
||
from langchain.chains import RetrievalQA
|
||
from langchain.prompts import PromptTemplate
|
||
from openai import OpenAI
|
||
from rapidfuzz import process
|
||
from threading import Lock
|
||
|
||
# Local imports
|
||
from model_manager import ModelManager
|
||
|
||
# Suppress warnings
|
||
import warnings
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
# Initialize NLTK
|
||
import nltk
|
||
|
||
nltk.download("punkt")
|
||
|
||
# Configure logging
|
||
import logging
|
||
import logging.handlers
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
log_file_path = os.path.join(script_dir, "error.log")
|
||
logging.basicConfig(filename=log_file_path, level=logging.INFO)
|
||
|
||
|
||
# Configure logging
|
||
def setup_logging():
|
||
log_dir = os.path.join(script_dir, "logs")
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
main_log = os.path.join(log_dir, "app.log")
|
||
error_log = os.path.join(log_dir, "error.log")
|
||
access_log = os.path.join(log_dir, "access.log")
|
||
perf_log = os.path.join(log_dir, "performance.log")
|
||
|
||
# Create formatters
|
||
detailed_formatter = logging.Formatter(
|
||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||
)
|
||
access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||
|
||
# Main logger setup
|
||
main_handler = logging.handlers.RotatingFileHandler(
|
||
main_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
main_handler.setFormatter(detailed_formatter)
|
||
main_handler.setLevel(logging.INFO)
|
||
|
||
# Error logger setup
|
||
error_handler = logging.handlers.RotatingFileHandler(
|
||
error_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
error_handler.setFormatter(detailed_formatter)
|
||
error_handler.setLevel(logging.ERROR)
|
||
|
||
# Access logger setup
|
||
access_handler = logging.handlers.TimedRotatingFileHandler(
|
||
access_log, when="midnight", interval=1, backupCount=30
|
||
)
|
||
access_handler.setFormatter(access_formatter)
|
||
access_handler.setLevel(logging.INFO)
|
||
|
||
# Performance logger setup
|
||
perf_handler = logging.handlers.RotatingFileHandler(
|
||
perf_log, maxBytes=10485760, backupCount=5
|
||
)
|
||
perf_handler.setFormatter(detailed_formatter)
|
||
perf_handler.setLevel(logging.INFO)
|
||
|
||
# Configure root logger
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(main_handler)
|
||
root_logger.addHandler(error_handler)
|
||
|
||
# Create specific loggers
|
||
access_logger = logging.getLogger("access")
|
||
access_logger.addHandler(access_handler)
|
||
access_logger.setLevel(logging.INFO)
|
||
|
||
perf_logger = logging.getLogger("performance")
|
||
perf_logger.addHandler(perf_handler)
|
||
perf_logger.setLevel(logging.INFO)
|
||
|
||
return root_logger, access_logger, perf_logger
|
||
|
||
|
||
# Initialize loggers
|
||
logger, access_logger, perf_logger = setup_logging()
|
||
|
||
# DB_CONFIG = {
|
||
# 'host': 'localhost',
|
||
# 'user': 'flaskuser',
|
||
# 'password': 'Flask@123',
|
||
# 'database': 'spurrinai',
|
||
# }
|
||
|
||
# DB_CONFIG = {
|
||
# 'host': 'localhost',
|
||
# 'user': 'spurrindevuser',
|
||
# 'password': 'Admin@123',
|
||
# 'database': 'spurrindev',
|
||
# }
|
||
|
||
# DB_CONFIG = {
|
||
# 'host': 'localhost',
|
||
# 'user': 'root',
|
||
# 'password': 'root',
|
||
# 'database': 'medqueryai',
|
||
# 'port': 3307
|
||
# }
|
||
|
||
# Redis Configuration
|
||
REDIS_CONFIG = {
|
||
"host": "localhost",
|
||
"port": 6379,
|
||
"db": 0,
|
||
"decode_responses": True, # For string operations
|
||
}
|
||
|
||
DB_CONFIG = {
|
||
"host": os.getenv("DB_HOST", "localhost"),
|
||
"user": os.getenv("DB_USER", "testuser"),
|
||
"password": os.getenv("DB_PASSWORD", "Admin@123"),
|
||
"database": os.getenv("DB_NAME", "spurrintest"),
|
||
}
|
||
|
||
# Redis connection pool
|
||
redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
|
||
redis_binary_pool = redis.ConnectionPool(
|
||
host="localhost", port=6379, db=1, decode_responses=False
|
||
)
|
||
|
||
|
||
def get_redis_client(binary=False):
|
||
"""Get Redis client from pool"""
|
||
logger.debug(f"Getting Redis client with binary={binary}")
|
||
try:
|
||
pool = redis_binary_pool if binary else redis_pool
|
||
client = redis.Redis(connection_pool=pool)
|
||
logger.debug("Redis client created successfully")
|
||
return client
|
||
except Exception as e:
|
||
logger.error(f"Failed to create Redis client: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
def fetch_cached_answer(cache_key):
|
||
logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
|
||
start_time = time.time()
|
||
try:
|
||
redis_client = get_redis_client()
|
||
cached_answer = redis_client.get(cache_key)
|
||
fetch_time = time.time() - start_time
|
||
perf_logger.info(
|
||
f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
|
||
)
|
||
return cached_answer
|
||
except Exception as e:
|
||
logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
|
||
return None
|
||
|
||
|
||
# Cache TTL configurations
|
||
CACHE_TTL = {
|
||
"vector_store": timedelta(hours=24),
|
||
"chat_completion": timedelta(hours=1),
|
||
"document_metadata": timedelta(days=7),
|
||
}
|
||
|
||
DATA_DIR = os.path.join(script_dir, "hospital_data")
|
||
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
|
||
uploads_dir = os.path.join(script_dir, "llm-uploads")
|
||
|
||
if not os.path.exists(uploads_dir):
|
||
os.makedirs(uploads_dir)
|
||
|
||
nlp = spacy.load("en_core_web_sm")
|
||
|
||
load_dotenv()
|
||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
|
||
client = OpenAI(api_key=OPENAI_API_KEY)
|
||
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
||
llm = ChatOpenAI(
|
||
model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
|
||
)
|
||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
|
||
hospital_vector_stores = {}
|
||
vector_store_lock = threading.Lock()
|
||
|
||
|
||
@dataclass
|
||
class Document:
|
||
doc_id: int
|
||
page_num: int
|
||
content: str
|
||
|
||
|
||
class DocumentStatus(Enum):
|
||
PROCESSING = "processing"
|
||
PROCESSED = "processed"
|
||
FAILED = "failed"
|
||
|
||
|
||
async def get_db_pool():
|
||
return await aiomysql.create_pool(
|
||
host=DB_CONFIG["host"],
|
||
user=DB_CONFIG["user"],
|
||
password=DB_CONFIG["password"],
|
||
db=DB_CONFIG["database"],
|
||
autocommit=True,
|
||
)
|
||
|
||
|
||
async def get_hospital_id(hospital_code):
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute(
|
||
"SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
|
||
(hospital_code,),
|
||
)
|
||
result = await cursor.fetchone()
|
||
return result["id"] if result else None
|
||
except Exception as error:
|
||
logging.error(f"Database error: {error}")
|
||
return None
|
||
finally:
|
||
pool.close()
|
||
await pool.wait_closed()
|
||
|
||
|
||
CHUNK_SIZE = 1000
|
||
CHUNK_OVERLAP = 50
|
||
BATCH_SIZE = 1000
|
||
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=CHUNK_OVERLAP,
|
||
# length_function=len,
|
||
# separators=["\n\n", "\n", ". ", " ", ""]
|
||
)
|
||
|
||
|
||
# Update the JSON_PATH to be dynamic based on hospital_id
|
||
def get_icd_json_path(hospital_id):
|
||
hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
|
||
os.makedirs(hospital_data_dir, exist_ok=True)
|
||
return os.path.join(hospital_data_dir, "icd_data.json")
|
||
|
||
|
||
def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
|
||
"""Extract and process ICD codes with optimized processing and optional JSON saving"""
|
||
try:
|
||
# Initialize pattern compilation once
|
||
pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
|
||
|
||
# Process in chunks for large content
|
||
chunk_size = 50000 # Process 50KB at a time
|
||
icd_data = []
|
||
|
||
current_code = None
|
||
current_description = []
|
||
|
||
# Split content into manageable chunks
|
||
content_chunks = [
|
||
content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
|
||
]
|
||
|
||
# Process each chunk
|
||
for chunk in content_chunks:
|
||
lines = chunk.splitlines()
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
current_code = None
|
||
current_description = []
|
||
continue
|
||
|
||
match = pattern.match(line)
|
||
if match:
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
current_code, description = match.groups()
|
||
current_description = [description.strip()]
|
||
elif current_code:
|
||
current_description.append(line)
|
||
|
||
# Add final entry if exists
|
||
if current_code and current_description:
|
||
icd_data.append(
|
||
{
|
||
"code": current_code,
|
||
"description": " ".join(current_description).strip(),
|
||
}
|
||
)
|
||
|
||
# Save to hospital-specific JSON if requested
|
||
if save_to_json and icd_data:
|
||
try:
|
||
json_path = get_icd_json_path(hospital_id)
|
||
|
||
# Use a lock for thread safety
|
||
with threading.Lock():
|
||
if os.path.exists(json_path):
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
try:
|
||
existing_data = json.load(f)
|
||
except json.JSONDecodeError:
|
||
existing_data = []
|
||
else:
|
||
existing_data = []
|
||
|
||
# Efficient deduplication using dictionary
|
||
seen_codes = {item["code"]: item for item in existing_data}
|
||
for item in icd_data:
|
||
seen_codes[item["code"]] = item
|
||
|
||
unique_data = list(seen_codes.values())
|
||
|
||
# Write atomically using temporary file
|
||
temp_path = f"{json_path}.tmp"
|
||
with open(temp_path, "w", encoding="utf-8") as f:
|
||
json.dump(unique_data, f, indent=2, ensure_ascii=False)
|
||
os.replace(temp_path, json_path)
|
||
|
||
logging.info(
|
||
f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(
|
||
f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
|
||
)
|
||
|
||
return icd_data
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in extract_and_process_icd_data: {e}")
|
||
return []
|
||
|
||
|
||
def load_icd_entries(hospital_id):
|
||
"""Load ICD entries from hospital-specific JSON file"""
|
||
json_path = get_icd_json_path(hospital_id)
|
||
try:
|
||
if os.path.exists(json_path):
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
return []
|
||
except Exception as e:
|
||
logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
|
||
return []
|
||
|
||
|
||
# Update the process_icd_codes function to include hospital_id
|
||
async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
|
||
"""Process and store ICD codes using the optimized extraction function"""
|
||
try:
|
||
# Extract and save codes with hospital_id
|
||
extract_and_process_icd_data(content, hospital_id, save_to_json=True)
|
||
except Exception as e:
|
||
logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
|
||
|
||
|
||
async def initialize_icd_vector_store(hospital_id):
|
||
"""This function is deprecated. ICD codes are now handled through JSON search."""
|
||
logging.warning(
|
||
"initialize_icd_vector_store is deprecated - using JSON search instead"
|
||
)
|
||
return None
|
||
|
||
|
||
def extract_pdf_contents(pdf_path, hospital_id):
|
||
"""Extract PDF contents with optimized chunking and code extraction"""
|
||
try:
|
||
loader = PyPDFLoader(pdf_path)
|
||
pages = loader.load()
|
||
pages_content = []
|
||
|
||
for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
|
||
text = page.page_content.strip()
|
||
|
||
# Extract ICD codes from the page
|
||
icd_codes = extract_and_process_icd_data(
|
||
text, hospital_id
|
||
) # We'll set doc_id later
|
||
|
||
pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
|
||
|
||
return pages_content
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in extract_pdf_contents: {e}")
|
||
raise
|
||
|
||
|
||
async def insert_content_into_db(content, metadata, doc_id):
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
|
||
content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
|
||
|
||
metadata_values = [
|
||
(doc_id, key[:100], value)
|
||
for key, value in metadata.items()
|
||
if value
|
||
]
|
||
content_values = [
|
||
(doc_id, page_content["page"], page_content["text"])
|
||
for page_content in content
|
||
]
|
||
|
||
if metadata_values:
|
||
await cursor.executemany(metadata_query, metadata_values)
|
||
if content_values:
|
||
await cursor.executemany(content_query, content_values)
|
||
|
||
await conn.commit()
|
||
return {"message": "Success"}
|
||
except Exception as e:
|
||
await conn.rollback()
|
||
return {"error": str(e)}
|
||
|
||
|
||
async def initialize_or_load_vector_store(hospital_id, user_id="default"):
|
||
"""Initialize or load vector store with Redis caching and thread safety"""
|
||
store_key = f"{hospital_id}:{user_id}"
|
||
|
||
try:
|
||
# Check if we already have it loaded - with lock for thread safety
|
||
with vector_store_lock:
|
||
if store_key in hospital_vector_stores:
|
||
return hospital_vector_stores[store_key]
|
||
|
||
# Initialize vector store
|
||
redis_client = get_redis_client(binary=True)
|
||
cache_key = f"vector_store_data:{hospital_id}:{user_id}"
|
||
hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
|
||
|
||
if os.path.exists(hospital_dir):
|
||
logging.info(
|
||
f"Loading vector store for hospital {hospital_id} and user {user_id}"
|
||
)
|
||
vector_store = await asyncio.to_thread(
|
||
lambda: Chroma(
|
||
collection_name=f"hospital_{hospital_id}",
|
||
persist_directory=hospital_dir,
|
||
embedding_function=embeddings,
|
||
)
|
||
)
|
||
else:
|
||
logging.info(f"Creating vector store for hospital {hospital_id}")
|
||
os.makedirs(hospital_dir, exist_ok=True)
|
||
vector_store = await asyncio.to_thread(
|
||
lambda: Chroma(
|
||
collection_name=f"hospital_{hospital_id}",
|
||
persist_directory=hospital_dir,
|
||
embedding_function=embeddings,
|
||
)
|
||
)
|
||
|
||
# Store with lock for thread safety
|
||
with vector_store_lock:
|
||
hospital_vector_stores[store_key] = vector_store
|
||
|
||
return vector_store
|
||
except Exception as e:
|
||
logging.error(f"Error initializing vector store: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
|
||
"""Delete all vectors associated with a specific document from ChromaDB"""
|
||
try:
|
||
# Initialize vector store for the hospital
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
# Delete vectors with matching doc_id
|
||
await asyncio.to_thread(
|
||
lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
|
||
)
|
||
|
||
# Persist changes
|
||
await asyncio.to_thread(vector_store.persist)
|
||
|
||
# Clear Redis cache for this document
|
||
redis_client = get_redis_client()
|
||
pattern = f"vector_store_data:{hospital_id}:*"
|
||
for key in redis_client.scan_iter(pattern):
|
||
redis_client.delete(key)
|
||
|
||
logging.info(
|
||
f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
|
||
)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error deleting document vectors: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
async def add_document_to_index(doc_id, hospital_id):
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
|
||
await cursor.execute(
|
||
"SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
|
||
(doc_id,),
|
||
)
|
||
rows = await cursor.fetchall()
|
||
|
||
total_pages = len(rows)
|
||
logging.info(f"Processing {total_pages} pages for document {doc_id}")
|
||
page_bar = tqdm_async(total=total_pages, desc="Processing pages")
|
||
|
||
async def process_page(page_data):
|
||
page_num, content = page_data
|
||
try:
|
||
icd_data = extract_and_process_icd_data(
|
||
content, hospital_id, save_to_json=False
|
||
)
|
||
chunks = text_splitter.split_text(content)
|
||
await asyncio.sleep(0) # Yield control
|
||
return page_num, chunks, icd_data
|
||
except Exception as e:
|
||
logging.error(f"Error processing page {page_num}: {e}")
|
||
return page_num, [], []
|
||
|
||
tasks = [asyncio.create_task(process_page(row)) for row in rows]
|
||
results = []
|
||
|
||
for coro in asyncio.as_completed(tasks):
|
||
result = await coro
|
||
results.append(result)
|
||
page_bar.update(1)
|
||
|
||
page_bar.close()
|
||
|
||
# Vector addition progress bar
|
||
all_icd_data = []
|
||
all_chunks = []
|
||
all_metadatas = []
|
||
|
||
chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
|
||
|
||
for result in results:
|
||
page_num, chunks, icd_data = result
|
||
all_icd_data.extend(icd_data)
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
all_chunks.append(chunk)
|
||
all_metadatas.append(
|
||
{
|
||
"doc_id": str(doc_id),
|
||
"hospital_id": str(hospital_id),
|
||
"page_number": str(page_num),
|
||
"chunk_index": str(i),
|
||
}
|
||
)
|
||
|
||
if len(all_chunks) >= BATCH_SIZE:
|
||
chunk_add_bar.total += len(all_chunks)
|
||
chunk_add_bar.refresh()
|
||
await asyncio.to_thread(
|
||
vector_store.add_texts,
|
||
texts=all_chunks,
|
||
metadatas=all_metadatas,
|
||
)
|
||
all_chunks = []
|
||
all_metadatas = []
|
||
chunk_add_bar.update(BATCH_SIZE)
|
||
|
||
# Final batch
|
||
if all_chunks:
|
||
chunk_add_bar.total += len(all_chunks)
|
||
chunk_add_bar.refresh()
|
||
await asyncio.to_thread(
|
||
vector_store.add_texts,
|
||
texts=all_chunks,
|
||
metadatas=all_metadatas,
|
||
)
|
||
chunk_add_bar.update(len(all_chunks))
|
||
|
||
chunk_add_bar.close()
|
||
|
||
if all_icd_data:
|
||
logging.info(f"Saving {len(all_icd_data)} ICD codes")
|
||
extract_and_process_icd_data("", hospital_id, save_to_json=True)
|
||
|
||
await asyncio.to_thread(vector_store.persist)
|
||
logging.info(f"Successfully indexed document {doc_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error adding document: {e}")
|
||
return False
|
||
|
||
|
||
def is_general_knowledge_question(
|
||
query: str, context: str, conversation_context=None
|
||
) -> bool:
|
||
"""
|
||
Determine if a question is likely a general knowledge question not covered in the documents.
|
||
Takes conversation history into account to reduce repeated confirmations.
|
||
"""
|
||
query_lower = query.lower()
|
||
context_lower = context.lower()
|
||
|
||
if conversation_context:
|
||
for interaction in conversation_context:
|
||
prev_question = interaction.get("question", "").lower()
|
||
if (
|
||
prev_question
|
||
and query_lower in prev_question
|
||
or prev_question in query_lower
|
||
):
|
||
logging.info(
|
||
f"Question is similar to previous conversation, skipping confirmation"
|
||
)
|
||
return False
|
||
|
||
stop_words = {
|
||
"search",
|
||
"query:",
|
||
"can",
|
||
"you",
|
||
"some",
|
||
"at",
|
||
"the",
|
||
"a",
|
||
"an",
|
||
"in",
|
||
"on",
|
||
"at",
|
||
"to",
|
||
"for",
|
||
"with",
|
||
"by",
|
||
"about",
|
||
"give",
|
||
"full",
|
||
"is",
|
||
"are",
|
||
"was",
|
||
"were",
|
||
"define",
|
||
"what",
|
||
"how",
|
||
"why",
|
||
"when",
|
||
"where",
|
||
"year",
|
||
"list",
|
||
"form",
|
||
"table",
|
||
"who",
|
||
"which",
|
||
"me",
|
||
"tell",
|
||
"explain",
|
||
"describe",
|
||
"of",
|
||
"and",
|
||
"or",
|
||
"there",
|
||
"their",
|
||
"please",
|
||
"could",
|
||
"would",
|
||
"various",
|
||
"different",
|
||
"type",
|
||
"types",
|
||
"kind",
|
||
"kinds",
|
||
"has",
|
||
"have",
|
||
"had",
|
||
"many",
|
||
"say",
|
||
"know",
|
||
}
|
||
|
||
key_words = [
|
||
word for word in query_lower.split() if word not in stop_words and len(word) > 2
|
||
]
|
||
logging.info(f"Key words: {key_words}")
|
||
|
||
if not key_words:
|
||
logging.info("No significant keywords found, directing to general knowledge")
|
||
return True
|
||
|
||
matches = sum(1 for word in key_words if word in context_lower)
|
||
logging.info(f"Matches: {matches} out of {len(key_words)} keywords")
|
||
|
||
match_ratio = matches / len(key_words)
|
||
logging.info(f"Match ratio: {match_ratio}")
|
||
|
||
return match_ratio < 0.4
|
||
|
||
def is_table_request(query: str) -> bool:
|
||
"""
|
||
Determine if the user is requesting a response in tabular format.
|
||
"""
|
||
table_keywords = [
|
||
"table",
|
||
"tabular",
|
||
"in a table",
|
||
"in table format",
|
||
"in tabular format",
|
||
"chart",
|
||
"data",
|
||
"comparison",
|
||
"as a table",
|
||
"table format",
|
||
"in rows and columns",
|
||
"in a grid",
|
||
"breakdown",
|
||
"spreadsheet",
|
||
"comparison table",
|
||
"data table",
|
||
"structured table",
|
||
"tabular form",
|
||
"table form",
|
||
]
|
||
|
||
query_lower = query.lower()
|
||
return any(keyword in query_lower for keyword in table_keywords)
|
||
|
||
|
||
import re
|
||
|
||
|
||
def ensure_html_response(text: str) -> str:
|
||
"""
|
||
Ensure the response is properly formatted in HTML.
|
||
This function handles plain text conversion to HTML.
|
||
"""
|
||
if "<html" in text.lower() or "<body" in text.lower():
|
||
return text
|
||
|
||
has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
|
||
|
||
if not has_html_tags:
|
||
paragraphs = text.split("\n\n")
|
||
html_parts = []
|
||
in_ordered_list = False
|
||
in_unordered_list = False
|
||
|
||
for para in paragraphs:
|
||
if para.strip():
|
||
if re.match(r"^\s*[\*\-\•]\s", para):
|
||
if not in_unordered_list:
|
||
html_parts.append("<ul>")
|
||
in_unordered_list = True
|
||
|
||
lines = para.split("\n")
|
||
for line in lines:
|
||
if line.strip():
|
||
item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
|
||
html_parts.append(f"<li>{item}</li>")
|
||
|
||
elif re.match(r"^\s*\d+\.\s", para):
|
||
if not in_ordered_list:
|
||
html_parts.append("<ol>")
|
||
in_ordered_list = True
|
||
|
||
lines = para.split("\n")
|
||
for line in lines:
|
||
match = re.match(r"^\s*\d+\.\s*(.*)", line)
|
||
if match:
|
||
html_parts.append(f"<li>{match.group(1)}</li>")
|
||
|
||
else: # Close any open lists before adding a new paragraph
|
||
if in_ordered_list:
|
||
html_parts.append("</ol>")
|
||
in_ordered_list = False
|
||
if in_unordered_list:
|
||
html_parts.append("</ul>")
|
||
in_unordered_list = False
|
||
|
||
html_parts.append(f"<p>{para}</p>")
|
||
|
||
if in_ordered_list:
|
||
html_parts.append("</ol>")
|
||
if in_unordered_list:
|
||
html_parts.append("</ul>")
|
||
|
||
return "".join(html_parts)
|
||
|
||
else:
|
||
if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
|
||
paragraphs = text.split("\n\n")
|
||
html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
|
||
return "".join(html_parts)
|
||
|
||
return text
|
||
|
||
|
||
class HybridConversationManager:
|
||
"""
|
||
Hybrid conversation manager that uses Redis for RAG-based conversations
|
||
and in-memory storage for general knowledge conversations.
|
||
"""
|
||
|
||
def __init__(self, redis_client, ttl=3600, max_history_items=5):
|
||
self.redis_client = redis_client
|
||
self.ttl = ttl
|
||
self.max_history_items = max_history_items
|
||
|
||
# For general knowledge questions (in-memory)
|
||
self.general_knowledge_histories = {}
|
||
self.lock = Lock()
|
||
|
||
def _get_redis_key(self, user_id, hospital_id, session_id=None):
|
||
"""Create Redis key for document-based conversations."""
|
||
if session_id:
|
||
return f"conv_history:{user_id}:{hospital_id}:{session_id}"
|
||
return f"conv_history:{user_id}:{hospital_id}"
|
||
|
||
def _get_memory_key(self, user_id, hospital_id, session_id=None):
|
||
"""Create memory key for general knowledge conversations."""
|
||
if session_id:
|
||
return f"{user_id}:{hospital_id}:{session_id}"
|
||
return f"{user_id}:{hospital_id}"
|
||
|
||
async def add_rag_interaction(
|
||
self, user_id, hospital_id, question, answer, session_id=None
|
||
):
|
||
"""Add document-based (RAG) interaction to Redis."""
|
||
key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
|
||
# Add new interaction
|
||
history.append(
|
||
{
|
||
"question": question,
|
||
"answer": answer,
|
||
"timestamp": time.time(),
|
||
"type": "rag", # Mark as RAG-based interaction
|
||
}
|
||
)
|
||
|
||
# Keep only last N interactions
|
||
history = history[-self.max_history_items :]
|
||
|
||
# Store updated history
|
||
try:
|
||
self.redis_client.setex(key, self.ttl, json.dumps(history))
|
||
logging.info(
|
||
f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
|
||
)
|
||
except Exception as e:
|
||
logging.error(f"Failed to store RAG interaction in Redis: {e}")
|
||
|
||
def add_general_knowledge_interaction(
|
||
self, user_id, hospital_id, question, answer, session_id=None
|
||
):
|
||
"""Add general knowledge interaction to in-memory store."""
|
||
key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
with self.lock:
|
||
if key not in self.general_knowledge_histories:
|
||
self.general_knowledge_histories[key] = []
|
||
|
||
self.general_knowledge_histories[key].append(
|
||
{
|
||
"question": question,
|
||
"answer": answer,
|
||
"timestamp": time.time(),
|
||
"type": "general", # Mark as general knowledge interaction
|
||
}
|
||
)
|
||
|
||
# Keep only the most recent interactions
|
||
if len(self.general_knowledge_histories[key]) > self.max_history_items:
|
||
self.general_knowledge_histories[key] = (
|
||
self.general_knowledge_histories[key][-self.max_history_items :]
|
||
)
|
||
|
||
logging.info(
|
||
f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}"
|
||
)
|
||
|
||
def get_rag_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get document-based (RAG) conversation history from Redis."""
|
||
key = self._get_redis_key(user_id, hospital_id, session_id)
|
||
try:
|
||
history_data = self.redis_client.get(key)
|
||
return json.loads(history_data) if history_data else []
|
||
except Exception as e:
|
||
logging.error(f"Failed to retrieve RAG history from Redis: {e}")
|
||
return []
|
||
|
||
def get_general_knowledge_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get general knowledge conversation history from memory."""
|
||
key = self._get_memory_key(user_id, hospital_id, session_id)
|
||
|
||
with self.lock:
|
||
return self.general_knowledge_histories.get(key, []).copy()
|
||
|
||
def get_combined_history(self, user_id, hospital_id, session_id=None):
|
||
"""Get combined conversation history from both sources, sorted by timestamp."""
|
||
rag_history = self.get_rag_history(user_id, hospital_id, session_id)
|
||
general_history = self.get_general_knowledge_history(
|
||
user_id, hospital_id, session_id
|
||
)
|
||
|
||
# Combine histories
|
||
combined_history = rag_history + general_history
|
||
|
||
# Sort by timestamp (newest first)
|
||
combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||
|
||
# Return most recent N items
|
||
return combined_history[: self.max_history_items]
|
||
|
||
def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
|
||
"""Get the most recent interactions for context from combined history."""
|
||
combined_history = self.get_combined_history(user_id, hospital_id, session_id)
|
||
# Sort by timestamp (oldest first) for context window
|
||
sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0))
|
||
return sorted_history[-window_size:] if sorted_history else []
|
||
|
||
def clear_history(self, user_id, hospital_id):
|
||
"""Clear conversation history from both stores."""
|
||
# Clear Redis history
|
||
redis_key = self._get_redis_key(user_id, hospital_id)
|
||
try:
|
||
self.redis_client.delete(redis_key)
|
||
except Exception as e:
|
||
logging.error(f"Failed to clear Redis history: {e}")
|
||
|
||
# Clear memory history
|
||
memory_key = self._get_memory_key(user_id, hospital_id)
|
||
with self.lock:
|
||
if memory_key in self.general_knowledge_histories:
|
||
del self.general_knowledge_histories[memory_key]
|
||
|
||
|
||
class ContextMapper:
|
||
"""Enhanced context mapping using shared model manager"""
|
||
|
||
def __init__(self):
|
||
self.model_manager = ModelManager()
|
||
self.context_cache = {}
|
||
self.similarity_threshold = 0.6
|
||
|
||
def get_semantic_similarity(self, text1, text2):
|
||
"""Get semantic similarity using global model manager"""
|
||
return self.model_manager.get_semantic_similarity(text1, text2)
|
||
|
||
def extract_key_concepts(self, text):
|
||
"""Extract key concepts using NLP techniques"""
|
||
doc = nlp(text)
|
||
concepts = []
|
||
|
||
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
||
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
||
important_words = [
|
||
token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
|
||
]
|
||
|
||
concepts.extend([e[0] for e in entities])
|
||
concepts.extend(noun_phrases)
|
||
concepts.extend(important_words)
|
||
|
||
return list(set(concepts))
|
||
|
||
def map_conversation_context(
|
||
self, current_query, conversation_history, context_window=3
|
||
):
|
||
"""Map conversation context using enhanced NLP techniques"""
|
||
if not conversation_history:
|
||
return current_query
|
||
|
||
recent_context = conversation_history[-context_window:]
|
||
context_concepts = []
|
||
|
||
# Extract concepts from recent conversations
|
||
for interaction in recent_context:
|
||
q_concepts = self.extract_key_concepts(interaction["question"])
|
||
a_concepts = self.extract_key_concepts(interaction["answer"])
|
||
context_concepts.extend(q_concepts)
|
||
context_concepts.extend(a_concepts)
|
||
|
||
# Extract concepts from current query
|
||
query_concepts = self.extract_key_concepts(current_query)
|
||
|
||
# Find related concepts
|
||
related_concepts = []
|
||
for q_concept in query_concepts:
|
||
for c_concept in context_concepts:
|
||
similarity = self.get_semantic_similarity(q_concept, c_concept)
|
||
if similarity > self.similarity_threshold:
|
||
related_concepts.append(c_concept)
|
||
|
||
# Build enhanced query
|
||
if related_concepts:
|
||
enhanced_query = (
|
||
f"{current_query} in context of {', '.join(related_concepts)}"
|
||
)
|
||
else:
|
||
enhanced_query = current_query
|
||
|
||
return enhanced_query
|
||
|
||
|
||
# Initialize the context mapper
|
||
context_mapper = ContextMapper()
|
||
|
||
|
||
async def generate_contextual_query(
|
||
question: str, user_id: str, hospital_id: int, conversation_manager
|
||
) -> str:
|
||
"""Generate enhanced contextual query"""
|
||
context_window = conversation_manager.get_context_window(user_id, hospital_id)
|
||
|
||
if not context_window:
|
||
return question
|
||
|
||
# Enhanced context mapping
|
||
last_interaction = context_window[-1]
|
||
enhanced_context = f"""
|
||
Previous question: {last_interaction['question']}
|
||
Previous answer: {last_interaction['answer']}
|
||
Current question: {question}
|
||
|
||
Please generate a detailed search query that combines the context from the previous answer
|
||
with the current question, especially if the current question uses words like 'it', 'this',
|
||
'that', or asks for more details about the previous topic.
|
||
"""
|
||
|
||
try:
|
||
response = await asyncio.to_thread(
|
||
lambda: client.chat.completions.create(
|
||
model="gpt-3.5-turbo",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "You are a context-aware query generator.",
|
||
},
|
||
{"role": "user", "content": enhanced_context},
|
||
],
|
||
temperature=0.3,
|
||
max_tokens=150,
|
||
)
|
||
)
|
||
contextual_query = response.choices[0].message.content.strip()
|
||
logging.info(f"Enhanced contextual query: {contextual_query}")
|
||
return contextual_query
|
||
except Exception as e:
|
||
logging.error(f"Error generating contextual query: {e}")
|
||
return question
|
||
|
||
|
||
def is_follow_up(current_question: str, conversation_history: list) -> bool:
|
||
"""Enhanced follow-up detection using NLP techniques"""
|
||
if not conversation_history:
|
||
return False
|
||
|
||
last_interaction = conversation_history[-1]
|
||
|
||
# Get semantic similarity with higher threshold
|
||
similarity = context_mapper.get_semantic_similarity(
|
||
current_question, f"{last_interaction['question']} {last_interaction['answer']}"
|
||
)
|
||
|
||
# Enhanced referential check
|
||
doc = nlp(current_question.lower())
|
||
has_referential = any(
|
||
token.lemma_
|
||
in [
|
||
"it",
|
||
"this",
|
||
"that",
|
||
"these",
|
||
"those",
|
||
"they",
|
||
"he",
|
||
"she",
|
||
"about",
|
||
"more",
|
||
]
|
||
for token in doc
|
||
)
|
||
|
||
# Extract concepts with improved entity detection
|
||
current_concepts = set(context_mapper.extract_key_concepts(current_question))
|
||
last_concepts = set(
|
||
context_mapper.extract_key_concepts(
|
||
f"{last_interaction['question']} {last_interaction['answer']}"
|
||
)
|
||
)
|
||
|
||
# Calculate enhanced concept overlap
|
||
concept_overlap = (
|
||
len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
|
||
if current_concepts
|
||
else 0
|
||
)
|
||
|
||
# More aggressive follow-up detection
|
||
return (
|
||
similarity > 0.3 # Lowered threshold
|
||
or has_referential
|
||
or concept_overlap > 0.2 # Lowered threshold
|
||
or any(
|
||
word in current_question.lower()
|
||
for word in ["more", "about", "elaborate", "explain"]
|
||
)
|
||
)
|
||
|
||
|
||
async def get_relevant_context(question, hospital_id, doc_id=None):
|
||
try:
|
||
cache_key = f"context:hospital_{hospital_id}"
|
||
if doc_id:
|
||
cache_key += f":doc_{doc_id}"
|
||
cache_key += f":{question.lower().strip()}"
|
||
|
||
redis_client = get_redis_client()
|
||
|
||
cached_context = redis_client.get(cache_key)
|
||
if cached_context:
|
||
logging.info(f"Cache hit for key: {cache_key}")
|
||
return (
|
||
cached_context.decode("utf-8")
|
||
if isinstance(cached_context, bytes)
|
||
else cached_context
|
||
)
|
||
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
if not vector_store:
|
||
return ""
|
||
|
||
retriever = vector_store.as_retriever(
|
||
search_type="mmr",
|
||
search_kwargs={
|
||
"k": 10,
|
||
"fetch_k": 20,
|
||
"lambda_mult": 0.6,
|
||
# "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
||
},
|
||
)
|
||
|
||
docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
|
||
if not docs:
|
||
return ""
|
||
|
||
sorted_docs = sorted(
|
||
docs,
|
||
key=lambda x: (
|
||
int(x.metadata.get("page_number", 0)),
|
||
int(x.metadata.get("chunk_index", 0)),
|
||
),
|
||
)
|
||
|
||
context_parts = [doc.page_content for doc in sorted_docs]
|
||
context = "\n\n".join(context_parts)
|
||
|
||
try:
|
||
redis_client.setex(
|
||
cache_key,
|
||
int(CACHE_TTL["vector_store"].total_seconds()),
|
||
context.encode("utf-8") if isinstance(context, str) else context,
|
||
)
|
||
logging.info(f"Cached context for key: {cache_key}")
|
||
except Exception as cache_error:
|
||
logging.error(f"Failed to cache context: {cache_error}")
|
||
|
||
return context
|
||
except Exception as e:
|
||
logging.error(f"Error getting relevant context: {e}")
|
||
return ""
|
||
|
||
|
||
def format_conversation_context(conv_history):
|
||
"""Format conversation history into a string"""
|
||
if not conv_history:
|
||
return "No previous conversation."
|
||
return "\n".join(
|
||
[
|
||
f"Q: {interaction['question']}\nA: {interaction['answer']}"
|
||
for interaction in conv_history
|
||
]
|
||
)
|
||
|
||
|
||
def get_icd_context_from_question(question, hospital_id):
|
||
"""Extract any valid ICD codes from the question and return context"""
|
||
icd_data = load_icd_entries(hospital_id)
|
||
matches = []
|
||
code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
|
||
|
||
seen = set()
|
||
for code in code_pattern:
|
||
for entry in icd_data:
|
||
if entry["code"] == code and code not in seen:
|
||
matches.append(f"{entry['code']}: {entry['description']}")
|
||
seen.add(code)
|
||
return "\n".join(matches)
|
||
|
||
|
||
def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
|
||
"""Get fuzzy matches for ICD codes from the question"""
|
||
icd_data = load_icd_entries(hospital_id)
|
||
descriptions = [entry["description"] for entry in icd_data]
|
||
matches = process.extract(
|
||
question, descriptions, limit=top_n, score_cutoff=threshold
|
||
)
|
||
|
||
matched_context = []
|
||
for desc, score, _ in matches:
|
||
for entry in icd_data:
|
||
if entry["description"] == desc:
|
||
matched_context.append(f"{entry['code']}: {entry['description']}")
|
||
break
|
||
|
||
return "\n".join(matched_context)
|
||
|
||
|
||
async def generate_answer_with_rag(question, hospital_id, client, doc_id=None):
|
||
"""Generate answer using strict RAG approach - only using document content"""
|
||
try:
|
||
html_instruction = """
|
||
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
||
- Use <p> tags for paragraphs
|
||
- Use <h2>, <h3> tags for headings and subheadings
|
||
- Use <ul>, <li> tags for bullet points
|
||
- Use <ol>, <li> tags for numbered lists
|
||
- Use <blockquote> for quoted text
|
||
- Use <strong> for bold text and <em> for emphasis
|
||
"""
|
||
|
||
table_instruction = """
|
||
- For tables, use proper HTML table structure:
|
||
<table border="1">
|
||
<thead>
|
||
<tr>
|
||
<th colspan="{total_columns}">{table_title}</th>
|
||
</tr>
|
||
<tr>
|
||
{table_headers}
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
{table_rows}
|
||
</tbody>
|
||
</table>
|
||
"""
|
||
# First, check for ICD codes in the question
|
||
icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
|
||
# Get document context
|
||
vector_store = initialize_or_load_vector_store(hospital_id)
|
||
if not vector_store:
|
||
return {"answer": "<p>No document content available.</p>"}, 404
|
||
|
||
# Create template focusing only on document content
|
||
prompt_template = PromptTemplate(
|
||
template="""Based ONLY on the provided document context and ICD codes, generate an answer to the question.
|
||
If the information is not found in the context, explicitly state that.
|
||
Do not use any external knowledge or assumptions.
|
||
|
||
{html_instruction}
|
||
{table_instruction}
|
||
|
||
ICD Code Matches:
|
||
{icd_exact_context}
|
||
|
||
Related ICD Codes:
|
||
{icd_fuzzy_context}
|
||
|
||
Context from documents: {context}
|
||
Question: {question}
|
||
|
||
Instructions:
|
||
1. Only use information explicitly present in the context and ICD codes
|
||
2. Do not make assumptions or use external knowledge
|
||
3. If an ICD code is found, include it in your response
|
||
4. If the answer is not in the context, say "This information is not found in the available documents"
|
||
5. Format response in clear HTML
|
||
|
||
Answer:""",
|
||
input_variables=["context", "question"],
|
||
partial_variables={
|
||
"html_instruction": html_instruction,
|
||
"table_instruction": table_instruction if is_table_request(question) else "",
|
||
"icd_exact_context": icd_exact_context if icd_exact_context else "No exact ICD code matches found.",
|
||
"icd_fuzzy_context": icd_fuzzy_context if icd_fuzzy_context else "No related ICD codes found."
|
||
}
|
||
)
|
||
|
||
retriever = vector_store.as_retriever(
|
||
search_type="similarity",
|
||
search_kwargs={
|
||
"k": 6,
|
||
"filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
||
}
|
||
)
|
||
|
||
qa_chain = RetrievalQA.from_chain_type(
|
||
llm=llm,
|
||
chain_type="stuff",
|
||
retriever=retriever,
|
||
chain_type_kwargs={"prompt": prompt_template},
|
||
return_source_documents=True
|
||
)
|
||
|
||
result = qa_chain({"query": question})
|
||
formatted_answer = ensure_html_response(result["result"])
|
||
|
||
if "not found in" in formatted_answer.lower():
|
||
return {"answer": formatted_answer}, 404
|
||
|
||
return {"answer": formatted_answer}, 200
|
||
|
||
except Exception as e:
|
||
return {"answer": f"<p>Error: {str(e)}</p>"}, 500
|
||
|
||
async def update_document_status(doc_id, status, failed_page=None):
|
||
"""Update document status with enum validation"""
|
||
if isinstance(status, str):
|
||
status = DocumentStatus[status.upper()].value
|
||
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
if failed_page:
|
||
await cursor.execute(
|
||
"UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
|
||
(status, failed_page, doc_id),
|
||
)
|
||
else:
|
||
await cursor.execute(
|
||
"UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
|
||
(status, doc_id),
|
||
)
|
||
await conn.commit()
|
||
return True
|
||
except Exception as e:
|
||
logging.error(f"Database update error: {e}")
|
||
return False
|
||
|
||
async def load_existing_vector_stores():
|
||
"""Load existing Chroma vector stores for each hospital"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
await cursor.execute("SELECT DISTINCT id FROM hospitals")
|
||
hospital_ids = [row[0] for row in await cursor.fetchall()]
|
||
|
||
for hospital_id in hospital_ids:
|
||
try:
|
||
await initialize_or_load_vector_store(hospital_id)
|
||
except Exception as e:
|
||
logging.error(
|
||
f"Failed to load vector store for hospital {hospital_id}: {e}"
|
||
)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error loading vector stores: {e}")
|
||
|
||
@app.route('/flask-api/generate-answer', methods=['POST'])
|
||
def generate_answer():
|
||
"""Generate answer using RAG approach."""
|
||
try:
|
||
data = request.json
|
||
question = data.get('question')
|
||
hospital_code = data.get('hospital_code')
|
||
hospital_id = async_to_sync(get_hospital_id(hospital_code))
|
||
doc_id = data.get('doc_id')
|
||
html_instruction = """
|
||
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
|
||
- Use <p> tags for paragraphs
|
||
- Use <h2>, <h3> tags for headings and subheadings
|
||
- Use <ul>, <li> tags for bullet points
|
||
- Use <ol>, <li> tags for numbered lists
|
||
- Use <blockquote> for quoted text
|
||
- Use <strong> for bold text and <em> for emphasis
|
||
"""
|
||
|
||
table_instruction = """
|
||
- For tables, use proper HTML table structure:
|
||
<table border="1">
|
||
<thead>
|
||
<tr>
|
||
<th colspan="{total_columns}">{table_title}</th>
|
||
</tr>
|
||
<tr>
|
||
{table_headers}
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
{table_rows}
|
||
</tbody>
|
||
</table>
|
||
"""
|
||
|
||
# Validate required parameters
|
||
if not question:
|
||
return jsonify({"error": "Missing 'question' in request"}), 400
|
||
if not hospital_id:
|
||
return jsonify({"error": "Missing 'hospital_code' in request"}), 400
|
||
|
||
# Get ICD context
|
||
icd_exact_context = get_icd_context_from_question(question, hospital_id)
|
||
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
|
||
|
||
vector_store = async_to_sync(initialize_or_load_vector_store(hospital_id))
|
||
if not vector_store:
|
||
return jsonify({"answer": "<p>No document content available.</p>"}), 404
|
||
|
||
# Create template focusing only on document content
|
||
prompt_template = PromptTemplate(
|
||
template="""Based ONLY on the provided document context and ICD codes, generate an answer.
|
||
Do not use any external knowledge or assumptions.
|
||
|
||
{html_instruction}
|
||
{table_instruction}
|
||
|
||
ICD Code Matches:
|
||
{icd_exact_context}
|
||
|
||
Related ICD Codes:
|
||
{icd_fuzzy_context}
|
||
|
||
Context: {context}
|
||
Question: {question}
|
||
|
||
Instructions:
|
||
1. Only use information explicitly present in the context and ICD codes
|
||
2. Do not make assumptions or use external knowledge
|
||
3. If an ICD code is found, include it in your response
|
||
4. If the answer is not in the context, say "This information is not found in the available documents"
|
||
5. Format response in clear HTML
|
||
|
||
Answer:""",
|
||
input_variables=["context", "question"],
|
||
partial_variables={
|
||
"html_instruction": html_instruction,
|
||
"table_instruction": table_instruction if is_table_request(question) else "",
|
||
"icd_exact_context": icd_exact_context if icd_exact_context else "No exact ICD code matches found.",
|
||
"icd_fuzzy_context": icd_fuzzy_context if icd_fuzzy_context else "No related ICD codes found."
|
||
}
|
||
)
|
||
|
||
retriever = vector_store.as_retriever(
|
||
search_type="similarity",
|
||
search_kwargs={
|
||
"k": 6,
|
||
"filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
|
||
}
|
||
)
|
||
|
||
qa_chain = RetrievalQA.from_chain_type(
|
||
llm=llm,
|
||
chain_type="stuff",
|
||
retriever=retriever,
|
||
chain_type_kwargs={"prompt": prompt_template},
|
||
return_source_documents=True
|
||
)
|
||
|
||
result = qa_chain({"query": question})
|
||
answer = ensure_html_response(result["result"])
|
||
|
||
if "not found in" in answer.lower():
|
||
return jsonify({
|
||
"answer": "<p>This information is not found in the available documents.</p>"
|
||
}), 404
|
||
|
||
return jsonify({"answer": answer}), 200
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error generating answer: {e}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||
|
||
|
||
def async_to_sync(coroutine):
|
||
"""Helper function to run async code in sync context"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(coroutine)
|
||
finally:
|
||
loop.close()
|
||
|
||
|
||
@app.route("/flask-api", methods=["GET"])
|
||
def health_check():
|
||
"""Health check endpoint"""
|
||
access_logger.info(f"Health check request received from {request.remote_addr}")
|
||
return jsonify({"status": "ok"}), 200
|
||
|
||
|
||
@app.route("/flask-api/process-pdf", methods=["POST"])
|
||
def process_pdf():
|
||
access_logger.info(f"PDF processing request received from {request.remote_addr}")
|
||
file_path = None
|
||
try:
|
||
file = request.files.get("pdf")
|
||
hospital_id = request.form.get("hospital_id")
|
||
doc_id = request.form.get("doc_id")
|
||
|
||
logging.info(
|
||
f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
|
||
)
|
||
|
||
if not all([file, hospital_id, doc_id]):
|
||
return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
def process_in_background():
|
||
nonlocal file_path
|
||
try:
|
||
async_to_sync(update_document_status(doc_id, "processing"))
|
||
|
||
# Add progress logging
|
||
logging.info(f"Starting processing of document {doc_id}")
|
||
|
||
filename = f"doc_{doc_id}_{file.filename}"
|
||
file_path = os.path.join(uploads_dir, filename)
|
||
|
||
with open(file_path, "wb") as f:
|
||
file.save(f)
|
||
|
||
logging.info("Extracting PDF contents...")
|
||
content = extract_pdf_contents(file_path, int(hospital_id))
|
||
|
||
logging.info("Inserting content into database...")
|
||
metadata = {"filename": filename}
|
||
result = async_to_sync(
|
||
insert_content_into_db(content, metadata, doc_id)
|
||
)
|
||
|
||
if "error" in result:
|
||
async_to_sync(update_document_status(doc_id, "failed", 1))
|
||
return False
|
||
|
||
logging.info("Creating embeddings and indexing...")
|
||
success = async_to_sync(add_document_to_index(doc_id, hospital_id))
|
||
|
||
if success:
|
||
logging.info("Document processing completed successfully")
|
||
async_to_sync(update_document_status(doc_id, "processed"))
|
||
return True
|
||
else:
|
||
logging.error("Document processing failed during indexing")
|
||
async_to_sync(update_document_status(doc_id, "failed"))
|
||
return False
|
||
|
||
except Exception as e:
|
||
logging.error(f"Processing error: {e}")
|
||
async_to_sync(update_document_status(doc_id, "failed"))
|
||
return False
|
||
finally:
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
logging.error(f"Error removing temporary file: {e}")
|
||
|
||
# Execute processing and wait for result
|
||
future = thread_pool.submit(process_in_background)
|
||
success = future.result()
|
||
|
||
if success:
|
||
return jsonify({"message": "Document processed successfully"}), 200
|
||
else:
|
||
return jsonify({"error": "Document processing failed"}), 500
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {e}")
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
os.remove(file_path)
|
||
except Exception as file_e:
|
||
logging.error(f"Error removing temporary file: {file_e}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# Initialize the hybrid conversation manager
|
||
redis_client = get_redis_client()
|
||
conversation_manager = HybridConversationManager(redis_client)
|
||
|
||
|
||
@app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
|
||
def delete_document_vectors_endpoint():
|
||
"""Endpoint to delete document vectors from ChromaDB"""
|
||
try:
|
||
data = request.json
|
||
hospital_id = data.get("hospital_id")
|
||
doc_id = data.get("doc_id")
|
||
|
||
if not all([hospital_id, doc_id]):
|
||
return jsonify({"error": "Missing required parameters"}), 400
|
||
|
||
logging.info(
|
||
f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
|
||
)
|
||
|
||
def process_deletion():
|
||
try:
|
||
success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
|
||
if success:
|
||
return {"message": "Document vectors deleted successfully"}, 200
|
||
else:
|
||
return {"error": "Failed to delete document vectors"}, 500
|
||
except Exception as e:
|
||
logging.error(f"Error in vector deletion process: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
future = thread_pool.submit(process_deletion)
|
||
result, status_code = future.result()
|
||
|
||
return jsonify(result), status_code
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
@app.route("/flask-api/get-chroma-content", methods=["GET"])
|
||
def get_chroma_content_endpoint():
|
||
"""API endpoint to get ChromaDB content by hospital_id"""
|
||
try:
|
||
hospital_id = request.args.get("hospital_id")
|
||
limit = int(request.args.get("limit", 30000))
|
||
|
||
if not hospital_id:
|
||
return jsonify({"error": "Missing required parameter: hospital_id"}), 400
|
||
|
||
def process_fetch():
|
||
try:
|
||
result, status_code = async_to_sync(
|
||
get_chroma_content_by_hospital(
|
||
hospital_id=int(hospital_id), limit=limit
|
||
)
|
||
)
|
||
return result, status_code
|
||
except Exception as e:
|
||
logging.error(f"Error in ChromaDB fetch process: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
future = thread_pool.submit(process_fetch)
|
||
result, status_code = future.result()
|
||
|
||
return jsonify(result), status_code
|
||
|
||
except Exception as e:
|
||
logging.error(f"API error: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
|
||
"""Fetch content from ChromaDB for a specific hospital"""
|
||
try:
|
||
# Initialize vector store
|
||
vector_store = await initialize_or_load_vector_store(hospital_id)
|
||
if not vector_store:
|
||
return {"error": "Vector store not found"}, 404
|
||
|
||
# Get collection
|
||
collection = vector_store._collection
|
||
|
||
# Query the collection with hospital_id filter
|
||
results = await asyncio.to_thread(
|
||
lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
|
||
)
|
||
|
||
if not results or not results["ids"]:
|
||
return {"data": [], "count": 0}, 200
|
||
|
||
# Format the response
|
||
formatted_results = []
|
||
for i in range(len(results["ids"])):
|
||
formatted_results.append(
|
||
{
|
||
"id": results["ids"][i],
|
||
"content": results["documents"][i],
|
||
"metadata": results["metadatas"][i],
|
||
}
|
||
)
|
||
|
||
return {"data": formatted_results, "count": len(formatted_results)}, 200
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error fetching ChromaDB content: {e}")
|
||
return {"error": str(e)}, 500
|
||
|
||
|
||
@app.before_request
|
||
def before_request():
|
||
request._start_time = time.time()
|
||
|
||
|
||
@app.after_request
|
||
def after_request(response):
|
||
if hasattr(request, "_start_time"):
|
||
duration = time.time() - request._start_time
|
||
access_logger.info(
|
||
f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
|
||
f"IP: {request.remote_addr}"
|
||
)
|
||
return response
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logger.info("Starting SpurrinAI application")
|
||
logger.info(f"Python version: {sys.version}")
|
||
logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
|
||
|
||
try:
|
||
model_manager = ModelManager()
|
||
logger.info("Model manager initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize model manager: {e}")
|
||
sys.exit(1)
|
||
|
||
# Initialize directories
|
||
os.makedirs(DATA_DIR, exist_ok=True)
|
||
os.makedirs(CHROMA_DIR, exist_ok=True)
|
||
logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
|
||
|
||
# Clear Redis cache
|
||
redis_client = get_redis_client()
|
||
cleared_keys = 0
|
||
for key in redis_client.scan_iter("vector_store_data:*"):
|
||
redis_client.delete(key)
|
||
cleared_keys += 1
|
||
logger.info(f"Cleared {cleared_keys} Redis cache keys")
|
||
|
||
# Load vector stores
|
||
logger.info("Loading existing vector stores...")
|
||
async_to_sync(load_existing_vector_stores())
|
||
logger.info("Vector stores loaded successfully")
|
||
|
||
# Start application
|
||
logger.info("Starting Flask application on port 5000")
|
||
app.run(port=5000, debug=False) |