spurrin-backend/chat copy 5.py
2025-06-09 11:11:52 +05:30

3857 lines
134 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# """
# SpurrinAI - Intelligent Document Processing and Question Answering System
# Copyright (c) 2024 Tech4biz. All rights reserved.
# This module implements the main Flask application for the SpurrinAI system,
# providing REST APIs for document processing, vector storage, and question answering
# using RAG (Retrieval Augmented Generation) architecture.
# Author: Tech4biz Development Team
# Version: 1.0.0
# Last Updated: 2024-01-19
# """
# # Standard library imports
# import os
# import re
# import sys
# import json
# import time
# import threading
# import asyncio
# from concurrent.futures import ThreadPoolExecutor
# from dataclasses import dataclass
# from datetime import timedelta
# from enum import Enum
# # Third-party imports
# import spacy
# import redis
# import aiomysql
# from dotenv import load_dotenv
# from flask import Flask, request, jsonify, Response
# from flask_cors import CORS
# from tqdm import tqdm
# from tqdm.asyncio import tqdm as tqdm_async
# from langchain_community.document_loaders import PyPDFLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain_community.embeddings import OpenAIEmbeddings
# from langchain_community.vectorstores import Chroma
# from langchain_community.chat_models import ChatOpenAI
# from langchain.chains import RetrievalQA
# from langchain.prompts import PromptTemplate
# from openai import OpenAI
# from rapidfuzz import process
# from threading import Lock
# # Local imports
# from model_manager import ModelManager
# # Suppress warnings
# import warnings
# warnings.filterwarnings("ignore")
# # Initialize NLTK
# import nltk
# nltk.download("punkt")
# # Configure logging
# import logging
# import logging.handlers
# app = Flask(__name__)
# CORS(app)
# script_dir = os.path.dirname(os.path.abspath(__file__))
# log_file_path = os.path.join(script_dir, "error.log")
# logging.basicConfig(filename=log_file_path, level=logging.INFO)
# # Configure logging
# def setup_logging():
# log_dir = os.path.join(script_dir, "logs")
# os.makedirs(log_dir, exist_ok=True)
# main_log = os.path.join(log_dir, "app.log")
# error_log = os.path.join(log_dir, "error.log")
# access_log = os.path.join(log_dir, "access.log")
# perf_log = os.path.join(log_dir, "performance.log")
# # Create formatters
# detailed_formatter = logging.Formatter(
# "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
# )
# access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# # Main logger setup
# main_handler = logging.handlers.RotatingFileHandler(
# main_log, maxBytes=10485760, backupCount=5
# )
# main_handler.setFormatter(detailed_formatter)
# main_handler.setLevel(logging.INFO)
# # Error logger setup
# error_handler = logging.handlers.RotatingFileHandler(
# error_log, maxBytes=10485760, backupCount=5
# )
# error_handler.setFormatter(detailed_formatter)
# error_handler.setLevel(logging.ERROR)
# # Access logger setup
# access_handler = logging.handlers.TimedRotatingFileHandler(
# access_log, when="midnight", interval=1, backupCount=30
# )
# access_handler.setFormatter(access_formatter)
# access_handler.setLevel(logging.INFO)
# # Performance logger setup
# perf_handler = logging.handlers.RotatingFileHandler(
# perf_log, maxBytes=10485760, backupCount=5
# )
# perf_handler.setFormatter(detailed_formatter)
# perf_handler.setLevel(logging.INFO)
# # Configure root logger
# root_logger = logging.getLogger()
# root_logger.setLevel(logging.INFO)
# root_logger.addHandler(main_handler)
# root_logger.addHandler(error_handler)
# # Create specific loggers
# access_logger = logging.getLogger("access")
# access_logger.addHandler(access_handler)
# access_logger.setLevel(logging.INFO)
# perf_logger = logging.getLogger("performance")
# perf_logger.addHandler(perf_handler)
# perf_logger.setLevel(logging.INFO)
# return root_logger, access_logger, perf_logger
# # Initialize loggers
# logger, access_logger, perf_logger = setup_logging()
# # DB_CONFIG = {
# # 'host': 'localhost',
# # 'user': 'flaskuser',
# # 'password': 'Flask@123',
# # 'database': 'spurrinai',
# # }
# # DB_CONFIG = {
# # 'host': 'localhost',
# # 'user': 'spurrindevuser',
# # 'password': 'Admin@123',
# # 'database': 'spurrindev',
# # }
# # DB_CONFIG = {
# # 'host': 'localhost',
# # 'user': 'root',
# # 'password': 'root',
# # 'database': 'medqueryai',
# # 'port': 3307
# # }
# # Redis Configuration
# REDIS_CONFIG = {
# "host": "localhost",
# "port": 6379,
# "db": 0,
# "decode_responses": True, # For string operations
# }
# DB_CONFIG = {
# "host": os.getenv("DB_HOST", "localhost"),
# "user": os.getenv("DB_USER", "testuser"),
# "password": os.getenv("DB_PASSWORD", "Admin@123"),
# "database": os.getenv("DB_NAME", "spurrintest"),
# }
# # Redis connection pool
# redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
# redis_binary_pool = redis.ConnectionPool(
# host="localhost", port=6379, db=1, decode_responses=False
# )
# def get_redis_client(binary=False):
# """Get Redis client from pool"""
# logger.debug(f"Getting Redis client with binary={binary}")
# try:
# pool = redis_binary_pool if binary else redis_pool
# client = redis.Redis(connection_pool=pool)
# logger.debug("Redis client created successfully")
# return client
# except Exception as e:
# logger.error(f"Failed to create Redis client: {e}", exc_info=True)
# raise
# def fetch_cached_answer(cache_key):
# logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
# start_time = time.time()
# try:
# redis_client = get_redis_client()
# cached_answer = redis_client.get(cache_key)
# fetch_time = time.time() - start_time
# perf_logger.info(
# f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
# )
# return cached_answer
# except Exception as e:
# logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
# return None
# # Cache TTL configurations
# CACHE_TTL = {
# "vector_store": timedelta(hours=24),
# "chat_completion": timedelta(hours=1),
# "document_metadata": timedelta(days=7),
# }
# DATA_DIR = os.path.join(script_dir, "hospital_data")
# CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
# uploads_dir = os.path.join(script_dir, "llm-uploads")
# if not os.path.exists(uploads_dir):
# os.makedirs(uploads_dir)
# nlp = spacy.load("en_core_web_sm")
# load_dotenv()
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# client = OpenAI(api_key=OPENAI_API_KEY)
# embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
# llm = ChatOpenAI(
# model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
# )
# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
# hospital_vector_stores = {}
# vector_store_lock = threading.Lock()
# @dataclass
# class Document:
# doc_id: int
# page_num: int
# content: str
# class DocumentStatus(Enum):
# PROCESSING = "processing"
# PROCESSED = "processed"
# FAILED = "failed"
# async def get_db_pool():
# return await aiomysql.create_pool(
# host=DB_CONFIG["host"],
# user=DB_CONFIG["user"],
# password=DB_CONFIG["password"],
# db=DB_CONFIG["database"],
# autocommit=True,
# )
# async def get_hospital_id(hospital_code):
# try:
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor(aiomysql.DictCursor) as cursor:
# await cursor.execute(
# "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
# (hospital_code,),
# )
# result = await cursor.fetchone()
# return result["id"] if result else None
# except Exception as error:
# logging.error(f"Database error: {error}")
# return None
# finally:
# pool.close()
# await pool.wait_closed()
# CHUNK_SIZE = 1000
# CHUNK_OVERLAP = 50
# BATCH_SIZE = 1000
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=CHUNK_SIZE,
# chunk_overlap=CHUNK_OVERLAP,
# # length_function=len,
# # separators=["\n\n", "\n", ". ", " ", ""]
# )
# # Update the JSON_PATH to be dynamic based on hospital_id
# def get_icd_json_path(hospital_id):
# hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
# os.makedirs(hospital_data_dir, exist_ok=True)
# return os.path.join(hospital_data_dir, "icd_data.json")
# def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
# """Extract and process ICD codes with optimized processing and optional JSON saving"""
# try:
# # Initialize pattern compilation once
# pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
# # Process in chunks for large content
# chunk_size = 50000 # Process 50KB at a time
# icd_data = []
# current_code = None
# current_description = []
# # Split content into manageable chunks
# content_chunks = [
# content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
# ]
# # Process each chunk
# for chunk in content_chunks:
# lines = chunk.splitlines()
# for line in lines:
# line = line.strip()
# if not line:
# if current_code and current_description:
# icd_data.append(
# {
# "code": current_code,
# "description": " ".join(current_description).strip(),
# }
# )
# current_code = None
# current_description = []
# continue
# match = pattern.match(line)
# if match:
# if current_code and current_description:
# icd_data.append(
# {
# "code": current_code,
# "description": " ".join(current_description).strip(),
# }
# )
# current_code, description = match.groups()
# current_description = [description.strip()]
# elif current_code:
# current_description.append(line)
# # Add final entry if exists
# if current_code and current_description:
# icd_data.append(
# {
# "code": current_code,
# "description": " ".join(current_description).strip(),
# }
# )
# # Save to hospital-specific JSON if requested
# if save_to_json and icd_data:
# try:
# json_path = get_icd_json_path(hospital_id)
# # Use a lock for thread safety
# with threading.Lock():
# if os.path.exists(json_path):
# with open(json_path, "r", encoding="utf-8") as f:
# try:
# existing_data = json.load(f)
# except json.JSONDecodeError:
# existing_data = []
# else:
# existing_data = []
# # Efficient deduplication using dictionary
# seen_codes = {item["code"]: item for item in existing_data}
# for item in icd_data:
# seen_codes[item["code"]] = item
# unique_data = list(seen_codes.values())
# # Write atomically using temporary file
# temp_path = f"{json_path}.tmp"
# with open(temp_path, "w", encoding="utf-8") as f:
# json.dump(unique_data, f, indent=2, ensure_ascii=False)
# os.replace(temp_path, json_path)
# logging.info(
# f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
# )
# except Exception as e:
# logging.error(
# f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
# )
# return icd_data
# except Exception as e:
# logging.error(f"Error in extract_and_process_icd_data: {e}")
# return []
# def load_icd_entries(hospital_id):
# """Load ICD entries from hospital-specific JSON file"""
# json_path = get_icd_json_path(hospital_id)
# try:
# if os.path.exists(json_path):
# with open(json_path, "r", encoding="utf-8") as f:
# return json.load(f)
# return []
# except Exception as e:
# logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
# return []
# # Update the process_icd_codes function to include hospital_id
# async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
# """Process and store ICD codes using the optimized extraction function"""
# try:
# # Extract and save codes with hospital_id
# extract_and_process_icd_data(content, hospital_id, save_to_json=True)
# except Exception as e:
# logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
# async def initialize_icd_vector_store(hospital_id):
# """This function is deprecated. ICD codes are now handled through JSON search."""
# logging.warning(
# "initialize_icd_vector_store is deprecated - using JSON search instead"
# )
# return None
# def extract_pdf_contents(pdf_path, hospital_id):
# """Extract PDF contents with optimized chunking and code extraction"""
# try:
# loader = PyPDFLoader(pdf_path)
# pages = loader.load()
# pages_content = []
# for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
# text = page.page_content.strip()
# # Extract ICD codes from the page
# icd_codes = extract_and_process_icd_data(
# text, hospital_id
# ) # We'll set doc_id later
# pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
# return pages_content
# except Exception as e:
# logging.error(f"Error in extract_pdf_contents: {e}")
# raise
# async def insert_content_into_db(content, metadata, doc_id):
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor() as cursor:
# try:
# metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
# content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
# metadata_values = [
# (doc_id, key[:100], value)
# for key, value in metadata.items()
# if value
# ]
# content_values = [
# (doc_id, page_content["page"], page_content["text"])
# for page_content in content
# ]
# if metadata_values:
# await cursor.executemany(metadata_query, metadata_values)
# if content_values:
# await cursor.executemany(content_query, content_values)
# await conn.commit()
# return {"message": "Success"}
# except Exception as e:
# await conn.rollback()
# return {"error": str(e)}
# async def initialize_or_load_vector_store(hospital_id, user_id="default"):
# """Initialize or load vector store with Redis caching and thread safety"""
# store_key = f"{hospital_id}:{user_id}"
# try:
# # Check if we already have it loaded - with lock for thread safety
# with vector_store_lock:
# if store_key in hospital_vector_stores:
# return hospital_vector_stores[store_key]
# # Initialize vector store
# redis_client = get_redis_client(binary=True)
# cache_key = f"vector_store_data:{hospital_id}:{user_id}"
# hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
# if os.path.exists(hospital_dir):
# logging.info(
# f"Loading vector store for hospital {hospital_id} and user {user_id}"
# )
# vector_store = await asyncio.to_thread(
# lambda: Chroma(
# collection_name=f"hospital_{hospital_id}",
# persist_directory=hospital_dir,
# embedding_function=embeddings,
# )
# )
# else:
# logging.info(f"Creating vector store for hospital {hospital_id}")
# os.makedirs(hospital_dir, exist_ok=True)
# vector_store = await asyncio.to_thread(
# lambda: Chroma(
# collection_name=f"hospital_{hospital_id}",
# persist_directory=hospital_dir,
# embedding_function=embeddings,
# )
# )
# # Store with lock for thread safety
# with vector_store_lock:
# hospital_vector_stores[store_key] = vector_store
# return vector_store
# except Exception as e:
# logging.error(f"Error initializing vector store: {e}", exc_info=True)
# raise
# async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
# """Delete all vectors associated with a specific document from ChromaDB"""
# try:
# # Initialize vector store for the hospital
# vector_store = await initialize_or_load_vector_store(hospital_id)
# # Delete vectors with matching doc_id
# await asyncio.to_thread(
# lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
# )
# # Persist changes
# await asyncio.to_thread(vector_store.persist)
# # Clear Redis cache for this document
# redis_client = get_redis_client()
# pattern = f"vector_store_data:{hospital_id}:*"
# for key in redis_client.scan_iter(pattern):
# redis_client.delete(key)
# logging.info(
# f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
# )
# return True
# except Exception as e:
# logging.error(f"Error deleting document vectors: {e}", exc_info=True)
# return False
# async def add_document_to_index(doc_id, hospital_id):
# try:
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor() as cursor:
# vector_store = await initialize_or_load_vector_store(hospital_id)
# await cursor.execute(
# "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
# (doc_id,),
# )
# rows = await cursor.fetchall()
# total_pages = len(rows)
# logging.info(f"Processing {total_pages} pages for document {doc_id}")
# page_bar = tqdm_async(total=total_pages, desc="Processing pages")
# async def process_page(page_data):
# page_num, content = page_data
# try:
# icd_data = extract_and_process_icd_data(
# content, hospital_id, save_to_json=False
# )
# chunks = text_splitter.split_text(content)
# await asyncio.sleep(0) # Yield control
# return page_num, chunks, icd_data
# except Exception as e:
# logging.error(f"Error processing page {page_num}: {e}")
# return page_num, [], []
# tasks = [asyncio.create_task(process_page(row)) for row in rows]
# results = []
# for coro in asyncio.as_completed(tasks):
# result = await coro
# results.append(result)
# page_bar.update(1)
# page_bar.close()
# # Vector addition progress bar
# all_icd_data = []
# all_chunks = []
# all_metadatas = []
# chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
# for result in results:
# page_num, chunks, icd_data = result
# all_icd_data.extend(icd_data)
# for i, chunk in enumerate(chunks):
# all_chunks.append(chunk)
# all_metadatas.append(
# {
# "doc_id": str(doc_id),
# "hospital_id": str(hospital_id),
# "page_number": str(page_num),
# "chunk_index": str(i),
# }
# )
# if len(all_chunks) >= BATCH_SIZE:
# chunk_add_bar.total += len(all_chunks)
# chunk_add_bar.refresh()
# await asyncio.to_thread(
# vector_store.add_texts,
# texts=all_chunks,
# metadatas=all_metadatas,
# )
# all_chunks = []
# all_metadatas = []
# chunk_add_bar.update(BATCH_SIZE)
# # Final batch
# if all_chunks:
# chunk_add_bar.total += len(all_chunks)
# chunk_add_bar.refresh()
# await asyncio.to_thread(
# vector_store.add_texts,
# texts=all_chunks,
# metadatas=all_metadatas,
# )
# chunk_add_bar.update(len(all_chunks))
# chunk_add_bar.close()
# if all_icd_data:
# logging.info(f"Saving {len(all_icd_data)} ICD codes")
# extract_and_process_icd_data("", hospital_id, save_to_json=True)
# await asyncio.to_thread(vector_store.persist)
# logging.info(f"Successfully indexed document {doc_id}")
# return True
# except Exception as e:
# logging.error(f"Error adding document: {e}")
# return False
# def is_general_knowledge_question(
# query: str, context: str, conversation_context=None
# ) -> bool:
# """
# Determine if a question is likely a general knowledge question not covered in the documents.
# Takes conversation history into account to reduce repeated confirmations.
# """
# query_lower = query.lower()
# context_lower = context.lower()
# if conversation_context:
# for interaction in conversation_context:
# prev_question = interaction.get("question", "").lower()
# if (
# prev_question
# and query_lower in prev_question
# or prev_question in query_lower
# ):
# logging.info(
# f"Question is similar to previous conversation, skipping confirmation"
# )
# return False
# stop_words = {
# "search",
# "query:",
# "can",
# "you",
# "some",
# "at",
# "the",
# "a",
# "an",
# "in",
# "on",
# "at",
# "to",
# "for",
# "with",
# "by",
# "about",
# "give",
# "full",
# "is",
# "are",
# "was",
# "were",
# "define",
# "what",
# "how",
# "why",
# "when",
# "where",
# "year",
# "list",
# "form",
# "table",
# "who",
# "which",
# "me",
# "tell",
# "explain",
# "describe",
# "of",
# "and",
# "or",
# "there",
# "their",
# "please",
# "could",
# "would",
# "various",
# "different",
# "type",
# "types",
# "kind",
# "kinds",
# "has",
# "have",
# "had",
# "many",
# "say",
# "know",
# }
# key_words = [
# word for word in query_lower.split() if word not in stop_words and len(word) > 2
# ]
# logging.info(f"Key words: {key_words}")
# if not key_words:
# logging.info("No significant keywords found, directing to general knowledge")
# return True
# matches = sum(1 for word in key_words if word in context_lower)
# logging.info(f"Matches: {matches} out of {len(key_words)} keywords")
# match_ratio = matches / len(key_words)
# logging.info(f"Match ratio: {match_ratio}")
# return match_ratio < 0.4
# def is_table_request(query: str) -> bool:
# """
# Determine if the user is requesting a response in tabular format.
# """
# table_keywords = [
# "table",
# "tabular",
# "in a table",
# "in table format",
# "in tabular format",
# "chart",
# "data",
# "comparison",
# "as a table",
# "table format",
# "in rows and columns",
# "in a grid",
# "breakdown",
# "spreadsheet",
# "comparison table",
# "data table",
# "structured table",
# "tabular form",
# "table form",
# ]
# query_lower = query.lower()
# return any(keyword in query_lower for keyword in table_keywords)
# import re
# def ensure_html_response(text: str) -> str:
# """
# Ensure the response is properly formatted in HTML.
# This function handles plain text conversion to HTML.
# """
# if "<html" in text.lower() or "<body" in text.lower():
# return text
# has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
# if not has_html_tags:
# paragraphs = text.split("\n\n")
# html_parts = []
# in_ordered_list = False
# in_unordered_list = False
# for para in paragraphs:
# if para.strip():
# if re.match(r"^\s*[\*\-\•]\s", para):
# if not in_unordered_list:
# html_parts.append("<ul>")
# in_unordered_list = True
# lines = para.split("\n")
# for line in lines:
# if line.strip():
# item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
# html_parts.append(f"<li>{item}</li>")
# elif re.match(r"^\s*\d+\.\s", para):
# if not in_ordered_list:
# html_parts.append("<ol>")
# in_ordered_list = True
# lines = para.split("\n")
# for line in lines:
# match = re.match(r"^\s*\d+\.\s*(.*)", line)
# if match:
# html_parts.append(f"<li>{match.group(1)}</li>")
# else: # Close any open lists before adding a new paragraph
# if in_ordered_list:
# html_parts.append("</ol>")
# in_ordered_list = False
# if in_unordered_list:
# html_parts.append("</ul>")
# in_unordered_list = False
# html_parts.append(f"<p>{para}</p>")
# if in_ordered_list:
# html_parts.append("</ol>")
# if in_unordered_list:
# html_parts.append("</ul>")
# return "".join(html_parts)
# else:
# if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
# paragraphs = text.split("\n\n")
# html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
# return "".join(html_parts)
# return text
# class HybridConversationManager:
# """
# Hybrid conversation manager that uses Redis for RAG-based conversations
# and in-memory storage for general knowledge conversations.
# """
# def __init__(self, redis_client, ttl=3600, max_history_items=5):
# self.redis_client = redis_client
# self.ttl = ttl
# self.max_history_items = max_history_items
# # For general knowledge questions (in-memory)
# self.general_knowledge_histories = {}
# self.lock = Lock()
# def _get_redis_key(self, user_id, hospital_id, session_id=None):
# """Create Redis key for document-based conversations."""
# if session_id:
# return f"conv_history:{user_id}:{hospital_id}:{session_id}"
# return f"conv_history:{user_id}:{hospital_id}"
# def _get_memory_key(self, user_id, hospital_id, session_id=None):
# """Create memory key for general knowledge conversations."""
# if session_id:
# return f"{user_id}:{hospital_id}:{session_id}"
# return f"{user_id}:{hospital_id}"
# async def add_rag_interaction(
# self, user_id, hospital_id, question, answer, session_id=None
# ):
# """Add document-based (RAG) interaction to Redis."""
# key = self._get_redis_key(user_id, hospital_id, session_id)
# history = self.get_rag_history(user_id, hospital_id, session_id)
# # Add new interaction
# history.append(
# {
# "question": question,
# "answer": answer,
# "timestamp": time.time(),
# "type": "rag", # Mark as RAG-based interaction
# }
# )
# # Keep only last N interactions
# history = history[-self.max_history_items :]
# # Store updated history
# try:
# self.redis_client.setex(key, self.ttl, json.dumps(history))
# logging.info(
# f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
# )
# except Exception as e:
# logging.error(f"Failed to store RAG interaction in Redis: {e}")
# def add_general_knowledge_interaction(
# self, user_id, hospital_id, question, answer, session_id=None
# ):
# """Add general knowledge interaction to in-memory store."""
# key = self._get_memory_key(user_id, hospital_id, session_id)
# with self.lock:
# if key not in self.general_knowledge_histories:
# self.general_knowledge_histories[key] = []
# self.general_knowledge_histories[key].append(
# {
# "question": question,
# "answer": answer,
# "timestamp": time.time(),
# "type": "general", # Mark as general knowledge interaction
# }
# )
# # Keep only the most recent interactions
# if len(self.general_knowledge_histories[key]) > self.max_history_items:
# self.general_knowledge_histories[key] = (
# self.general_knowledge_histories[key][-self.max_history_items :]
# )
# logging.info(
# f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}"
# )
# def get_rag_history(self, user_id, hospital_id, session_id=None):
# """Get document-based (RAG) conversation history from Redis."""
# key = self._get_redis_key(user_id, hospital_id, session_id)
# try:
# history_data = self.redis_client.get(key)
# return json.loads(history_data) if history_data else []
# except Exception as e:
# logging.error(f"Failed to retrieve RAG history from Redis: {e}")
# return []
# def get_general_knowledge_history(self, user_id, hospital_id, session_id=None):
# """Get general knowledge conversation history from memory."""
# key = self._get_memory_key(user_id, hospital_id, session_id)
# with self.lock:
# return self.general_knowledge_histories.get(key, []).copy()
# def get_combined_history(self, user_id, hospital_id, session_id=None):
# """Get combined conversation history from both sources, sorted by timestamp."""
# rag_history = self.get_rag_history(user_id, hospital_id, session_id)
# general_history = self.get_general_knowledge_history(
# user_id, hospital_id, session_id
# )
# # Combine histories
# combined_history = rag_history + general_history
# # Sort by timestamp (newest first)
# combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
# # Return most recent N items
# return combined_history[: self.max_history_items]
# def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
# """Get the most recent interactions for context from combined history."""
# combined_history = self.get_combined_history(user_id, hospital_id, session_id)
# # Sort by timestamp (oldest first) for context window
# sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0))
# return sorted_history[-window_size:] if sorted_history else []
# def clear_history(self, user_id, hospital_id):
# """Clear conversation history from both stores."""
# # Clear Redis history
# redis_key = self._get_redis_key(user_id, hospital_id)
# try:
# self.redis_client.delete(redis_key)
# except Exception as e:
# logging.error(f"Failed to clear Redis history: {e}")
# # Clear memory history
# memory_key = self._get_memory_key(user_id, hospital_id)
# with self.lock:
# if memory_key in self.general_knowledge_histories:
# del self.general_knowledge_histories[memory_key]
# class ContextMapper:
# """Enhanced context mapping using shared model manager"""
# def __init__(self):
# self.model_manager = ModelManager()
# self.context_cache = {}
# self.similarity_threshold = 0.6
# def get_semantic_similarity(self, text1, text2):
# """Get semantic similarity using global model manager"""
# return self.model_manager.get_semantic_similarity(text1, text2)
# def extract_key_concepts(self, text):
# """Extract key concepts using NLP techniques"""
# doc = nlp(text)
# concepts = []
# entities = [(ent.text, ent.label_) for ent in doc.ents]
# noun_phrases = [chunk.text for chunk in doc.noun_chunks]
# important_words = [
# token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
# ]
# concepts.extend([e[0] for e in entities])
# concepts.extend(noun_phrases)
# concepts.extend(important_words)
# return list(set(concepts))
# def map_conversation_context(
# self, current_query, conversation_history, context_window=3
# ):
# """Map conversation context using enhanced NLP techniques"""
# if not conversation_history:
# return current_query
# recent_context = conversation_history[-context_window:]
# context_concepts = []
# # Extract concepts from recent conversations
# for interaction in recent_context:
# q_concepts = self.extract_key_concepts(interaction["question"])
# a_concepts = self.extract_key_concepts(interaction["answer"])
# context_concepts.extend(q_concepts)
# context_concepts.extend(a_concepts)
# # Extract concepts from current query
# query_concepts = self.extract_key_concepts(current_query)
# # Find related concepts
# related_concepts = []
# for q_concept in query_concepts:
# for c_concept in context_concepts:
# similarity = self.get_semantic_similarity(q_concept, c_concept)
# if similarity > self.similarity_threshold:
# related_concepts.append(c_concept)
# # Build enhanced query
# if related_concepts:
# enhanced_query = (
# f"{current_query} in context of {', '.join(related_concepts)}"
# )
# else:
# enhanced_query = current_query
# return enhanced_query
# # Initialize the context mapper
# context_mapper = ContextMapper()
# async def generate_contextual_query(
# question: str, user_id: str, hospital_id: int, conversation_manager
# ) -> str:
# """Generate enhanced contextual query"""
# context_window = conversation_manager.get_context_window(user_id, hospital_id)
# if not context_window:
# return question
# # Enhanced context mapping
# last_interaction = context_window[-1]
# enhanced_context = f"""
# Previous question: {last_interaction['question']}
# Previous answer: {last_interaction['answer']}
# Current question: {question}
# Please generate a detailed search query that combines the context from the previous answer
# with the current question, especially if the current question uses words like 'it', 'this',
# 'that', or asks for more details about the previous topic.
# """
# try:
# response = await asyncio.to_thread(
# lambda: client.chat.completions.create(
# model="gpt-3.5-turbo",
# messages=[
# {
# "role": "system",
# "content": "You are a context-aware query generator.",
# },
# {"role": "user", "content": enhanced_context},
# ],
# temperature=0.3,
# max_tokens=150,
# )
# )
# contextual_query = response.choices[0].message.content.strip()
# logging.info(f"Enhanced contextual query: {contextual_query}")
# return contextual_query
# except Exception as e:
# logging.error(f"Error generating contextual query: {e}")
# return question
# def is_follow_up(current_question: str, conversation_history: list) -> bool:
# """Enhanced follow-up detection using NLP techniques"""
# if not conversation_history:
# return False
# last_interaction = conversation_history[-1]
# # Get semantic similarity with higher threshold
# similarity = context_mapper.get_semantic_similarity(
# current_question, f"{last_interaction['question']} {last_interaction['answer']}"
# )
# # Enhanced referential check
# doc = nlp(current_question.lower())
# has_referential = any(
# token.lemma_
# in [
# "it",
# "this",
# "that",
# "these",
# "those",
# "they",
# "he",
# "she",
# "about",
# "more",
# ]
# for token in doc
# )
# # Extract concepts with improved entity detection
# current_concepts = set(context_mapper.extract_key_concepts(current_question))
# last_concepts = set(
# context_mapper.extract_key_concepts(
# f"{last_interaction['question']} {last_interaction['answer']}"
# )
# )
# # Calculate enhanced concept overlap
# concept_overlap = (
# len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
# if current_concepts
# else 0
# )
# # More aggressive follow-up detection
# return (
# similarity > 0.3 # Lowered threshold
# or has_referential
# or concept_overlap > 0.2 # Lowered threshold
# or any(
# word in current_question.lower()
# for word in ["more", "about", "elaborate", "explain"]
# )
# )
# async def get_relevant_context(question, hospital_id, doc_id=None):
# try:
# cache_key = f"context:hospital_{hospital_id}"
# if doc_id:
# cache_key += f":doc_{doc_id}"
# cache_key += f":{question.lower().strip()}"
# redis_client = get_redis_client()
# cached_context = redis_client.get(cache_key)
# if cached_context:
# logging.info(f"Cache hit for key: {cache_key}")
# return (
# cached_context.decode("utf-8")
# if isinstance(cached_context, bytes)
# else cached_context
# )
# vector_store = await initialize_or_load_vector_store(hospital_id)
# if not vector_store:
# return ""
# retriever = vector_store.as_retriever(
# search_type="mmr",
# search_kwargs={
# "k": 10,
# "fetch_k": 20,
# "lambda_mult": 0.6,
# # "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
# },
# )
# docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
# if not docs:
# return ""
# sorted_docs = sorted(
# docs,
# key=lambda x: (
# int(x.metadata.get("page_number", 0)),
# int(x.metadata.get("chunk_index", 0)),
# ),
# )
# context_parts = [doc.page_content for doc in sorted_docs]
# context = "\n\n".join(context_parts)
# try:
# redis_client.setex(
# cache_key,
# int(CACHE_TTL["vector_store"].total_seconds()),
# context.encode("utf-8") if isinstance(context, str) else context,
# )
# logging.info(f"Cached context for key: {cache_key}")
# except Exception as cache_error:
# logging.error(f"Failed to cache context: {cache_error}")
# return context
# except Exception as e:
# logging.error(f"Error getting relevant context: {e}")
# return ""
# def format_conversation_context(conv_history):
# """Format conversation history into a string"""
# if not conv_history:
# return "No previous conversation."
# return "\n".join(
# [
# f"Q: {interaction['question']}\nA: {interaction['answer']}"
# for interaction in conv_history
# ]
# )
# def get_icd_context_from_question(question, hospital_id):
# """Extract any valid ICD codes from the question and return context"""
# icd_data = load_icd_entries(hospital_id)
# matches = []
# code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
# seen = set()
# for code in code_pattern:
# for entry in icd_data:
# if entry["code"] == code and code not in seen:
# matches.append(f"{entry['code']}: {entry['description']}")
# seen.add(code)
# return "\n".join(matches)
# def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
# """Get fuzzy matches for ICD codes from the question"""
# icd_data = load_icd_entries(hospital_id)
# descriptions = [entry["description"] for entry in icd_data]
# matches = process.extract(
# question, descriptions, limit=top_n, score_cutoff=threshold
# )
# matched_context = []
# for desc, score, _ in matches:
# for entry in icd_data:
# if entry["description"] == desc:
# matched_context.append(f"{entry['code']}: {entry['description']}")
# break
# return "\n".join(matched_context)
# async def generate_answer_with_rag(
# question,
# hospital_id,
# client,
# doc_id=None,
# user_id="default",
# conversation_manager=None,
# session_id=None,
# ):
# """Generate an answer using RAG with improved conversation flow"""
# try:
# # Continue with regular RAG processing if not an ICD code or if no ICD match found
# html_instruction = """
# IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
# - Use <p> tags for paragraphs
# - Use <h2>, <h3> tags for headings and subheadings
# - Use <ul>, <li> tags for bullet points
# - Use <ol>, <li> tags for numbered lists
# - Use <blockquote> for quoted text
# - Use <strong> for bold text and <em> for emphasis
# """
# table_instruction = """
# - For tables, use proper HTML table structure:
# <table border="1">
# <thead>
# <tr>
# <th colspan="{total_columns}">{table_title}</th>
# </tr>
# <tr>
# {table_headers}
# </tr>
# </thead>
# <tbody>
# {table_rows}
# </tbody>
# </table>
# """
# # Get conversation history first
# conv_history = (
# conversation_manager.get_context_window(user_id, hospital_id, session_id)
# if conversation_manager
# else []
# )
# # Get contextual query and relevant context first
# contextual_query = await generate_contextual_query(
# question, user_id, hospital_id, conversation_manager
# )
# # Track ICD context across conversation
# icd_context = {}
# if conv_history:
# # Extract ICD code from previous interaction
# last_answer = conv_history[-1].get("answer", "")
# icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer)
# if icd_codes:
# icd_context["last_code"] = icd_codes[0]
# # Check if current question is about a previously discussed ICD code
# is_icd_followup = False
# if icd_context.get("last_code"):
# followup_indicators = [
# "what causes",
# "what is causing",
# "why",
# "how",
# "symptoms",
# "treatment",
# "diagnosis",
# "causes",
# "effects",
# "complications",
# "risk factors",
# "prevention",
# "prognosis",
# "this",
# "disease",
# "that",
# "it",
# ]
# is_icd_followup = any(
# indicator in question.lower() for indicator in followup_indicators
# )
# if is_icd_followup:
# # Add the previous ICD code context to the current question
# icd_exact_context = get_icd_context_from_question(
# icd_context["last_code"], hospital_id
# )
# icd_fuzzy_context = get_fuzzy_icd_context(
# f"{icd_context['last_code']} {question}", hospital_id
# )
# else:
# icd_exact_context = get_icd_context_from_question(question, hospital_id)
# icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
# else:
# icd_exact_context = get_icd_context_from_question(question, hospital_id)
# icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
# # Get contextual query and relevant context
# contextual_query = await generate_contextual_query(
# question, user_id, hospital_id, conversation_manager
# )
# doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id)
# # Combine context with priority for ICD information
# context_parts = []
# if is_icd_followup:
# context_parts.append(
# f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}"
# )
# if icd_exact_context:
# context_parts.append("## ICD Code Match\n" + icd_exact_context)
# if icd_fuzzy_context:
# context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context)
# if doc_context:
# context_parts.append("## Document Context\n" + doc_context)
# context = "\n\n".join(context_parts)
# # Initialize follow-up detection
# is_follow_up = False
# # Check if this is a follow-up question
# if conv_history:
# last_interaction = conv_history[-1]
# last_question = last_interaction["question"].lower()
# last_answer = last_interaction.get("answer", "").lower()
# current_question = question.lower()
# # Define meaningful keywords that indicate entity-related follow-ups
# entity_related_keywords = {
# "achievements",
# "awards",
# "accomplishments",
# "work",
# "contributions",
# "career",
# "company",
# "products",
# "life",
# "background",
# "education",
# "role",
# "experience",
# "history",
# "details",
# "places",
# "place",
# "information",
# "facts",
# "about",
# "birth",
# "death",
# "family",
# "books",
# "projects",
# "population",
# }
# # Check if question is asking about attributes/achievements of previously discussed entity
# has_entity_attribute = any(
# word in current_question.split() for word in entity_related_keywords
# )
# # Extract entities from last answer to maintain context
# def extract_entities(text):
# # Split into words and get potential entities (capitalized words)
# words = text.split()
# entities = set()
# current_entity = []
# for word in words:
# if word[0].isupper():
# current_entity.append(word)
# elif current_entity:
# if len(current_entity) > 0:
# entities.add(" ".join(current_entity))
# current_entity = []
# if current_entity:
# entities.add(" ".join(current_entity))
# return entities
# last_entities = extract_entities(last_answer)
# # Check for referential words
# referential_words = {
# "it",
# "this",
# "that",
# "these",
# "those",
# "they",
# "their",
# "he",
# "she",
# "him",
# "her",
# "his",
# "hers",
# "them",
# "there",
# "such",
# "its",
# }
# has_referential = any(
# word in referential_words for word in current_question.split()
# )
# # Calculate term overlap with both question and answer context
# def get_significant_terms(text):
# stop_words = {
# "what",
# "when",
# "where",
# "who",
# "why",
# "how",
# "is",
# "are",
# "was",
# "were",
# "be",
# "been",
# "the",
# "a",
# "an",
# "in",
# "on",
# "at",
# "to",
# "for",
# "of",
# "with",
# "by",
# "about",
# "as",
# "tell",
# "me",
# "please",
# }
# return set(
# word
# for word in text.split()
# if len(word) > 2 and word.lower() not in stop_words
# )
# current_terms = get_significant_terms(current_question)
# last_terms = get_significant_terms(last_question)
# answer_terms = get_significant_terms(last_answer)
# # Include terms from both question and answer in context
# all_prev_terms = last_terms | answer_terms
# term_overlap = len(current_terms & all_prev_terms)
# total_terms = len(current_terms | all_prev_terms)
# term_similarity = term_overlap / total_terms if total_terms > 0 else 0
# # Enhanced follow-up detection combining multiple signals
# is_follow_up = (
# has_referential
# or term_similarity
# >= 0.2 # Lower threshold when including answer context
# or (
# has_entity_attribute and bool(last_entities)
# ) # Check if asking about attributes of known entity
# or (
# last_interaction.get("type") == "general"
# and term_similarity >= 0.15
# )
# )
# logging.info(f"Follow-up analysis enhanced:")
# logging.info(f"- Referential words: {has_referential}")
# logging.info(f"- Term similarity: {term_similarity:.2f}")
# logging.info(f"- Entity attribute question: {has_entity_attribute}")
# logging.info(f"- Last entities found: {last_entities}")
# logging.info(f"- Is follow-up: {is_follow_up}")
# # For entirely new topics (not follow-ups), use is_general_knowledge_question
# if not is_follow_up:
# is_general = is_general_knowledge_question(question, context, conv_history)
# if is_general:
# confirmation_prompt = f"""
# <p>Reviewed the hospitals documentation, but this particular question does not seem to be covered.</p>
# """
# logging.info("General knowledge question detected")
# return {
# "answer": confirmation_prompt,
# "requires_confirmation": True,
# }, 200
# prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question.
# Previous conversation:
# {format_conversation_context(conv_history)}
# Context from documents:
# {context}
# Current question: {question}
# Instructions:
# 1. When providing medical codes (ICD, CPT, etc.):
# - Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context.
# - Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context.
# - If multiple codes are relevant, return the one that best matches the users question. If unsure, return multiple options in HTML list format.
# - Remove all decimal points (e.g., use 'A150' instead of 'A15.0')
# - Format the response as: '<p>The medical code for [condition] is [code]</p>'
# 2. Address the current question while maintaining conversation continuity
# 3. Resolve any ambiguous references using conversation history
# 4. Format the response in clear HTML.
# 5. Strictly it should answer only from the {context} provided, do not invent or assume and give the information and it should not be general knowledge also. Purely 100% RAG-based response and only from documents.
# {html_instruction}
# {table_instruction if is_table_request(question) else ""}
# """
# response = await asyncio.to_thread(
# lambda: client.chat.completions.create(
# model="gpt-3.5-turbo-16k",
# messages=[
# {"role": "system", "content": prompt_template},
# {"role": "user", "content": question},
# ],
# temperature=0.2,
# max_tokens=1000,
# )
# )
# answer = ensure_html_response(response.choices[0].message.content)
# logging.info(f"Generated RAG answer for question: {question}")
# # Store interaction in history
# if conversation_manager:
# await conversation_manager.add_rag_interaction(
# user_id, hospital_id, question, answer, session_id
# )
# return {"answer": answer}, 200
# except Exception as e:
# logging.error(f"Error in generate_answer_with_rag: {e}")
# return {"answer": f"<p>Error: {str(e)}</p>"}, 500
# async def load_existing_vector_stores():
# """Load existing Chroma vector stores for each hospital"""
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor() as cursor:
# try:
# await cursor.execute("SELECT DISTINCT id FROM hospitals")
# hospital_ids = [row[0] for row in await cursor.fetchall()]
# for hospital_id in hospital_ids:
# try:
# await initialize_or_load_vector_store(hospital_id)
# except Exception as e:
# logging.error(
# f"Failed to load vector store for hospital {hospital_id}: {e}"
# )
# continue
# except Exception as e:
# logging.error(f"Error loading vector stores: {e}")
# async def get_failed_page(doc_id):
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor() as cursor:
# try:
# await cursor.execute(
# "SELECT failed_page FROM documents WHERE id = %s", (doc_id,)
# )
# result = await cursor.fetchone()
# return result[0] if result and result[0] else None
# except Exception as e:
# logging.error(f"Database error checking failed_page: {e}")
# return None
# async def update_document_status(doc_id, status, failed_page=None):
# """Update document status with enum validation"""
# if isinstance(status, str):
# status = DocumentStatus[status.upper()].value
# pool = await get_db_pool()
# async with pool.acquire() as conn:
# async with conn.cursor() as cursor:
# try:
# if failed_page:
# await cursor.execute(
# "UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
# (status, failed_page, doc_id),
# )
# else:
# await cursor.execute(
# "UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
# (status, doc_id),
# )
# await conn.commit()
# return True
# except Exception as e:
# logging.error(f"Database update error: {e}")
# return False
# thread_pool = ThreadPoolExecutor(max_workers=10)
# def async_to_sync(coroutine):
# """Helper function to run async code in sync context"""
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# return loop.run_until_complete(coroutine)
# finally:
# loop.close()
# @app.route("/flask-api", methods=["GET"])
# def health_check():
# """Health check endpoint"""
# access_logger.info(f"Health check request received from {request.remote_addr}")
# return jsonify({"status": "ok"}), 200
# @app.route("/flask-api/process-pdf", methods=["POST"])
# def process_pdf():
# access_logger.info(f"PDF processing request received from {request.remote_addr}")
# file_path = None
# try:
# file = request.files.get("pdf")
# hospital_id = request.form.get("hospital_id")
# doc_id = request.form.get("doc_id")
# logging.info(
# f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
# )
# if not all([file, hospital_id, doc_id]):
# return jsonify({"error": "Missing required parameters"}), 400
# def process_in_background():
# nonlocal file_path
# try:
# async_to_sync(update_document_status(doc_id, "processing"))
# # Add progress logging
# logging.info(f"Starting processing of document {doc_id}")
# filename = f"doc_{doc_id}_{file.filename}"
# file_path = os.path.join(uploads_dir, filename)
# with open(file_path, "wb") as f:
# file.save(f)
# logging.info("Extracting PDF contents...")
# content = extract_pdf_contents(file_path, int(hospital_id))
# logging.info("Inserting content into database...")
# metadata = {"filename": filename}
# result = async_to_sync(
# insert_content_into_db(content, metadata, doc_id)
# )
# if "error" in result:
# async_to_sync(update_document_status(doc_id, "failed", 1))
# return False
# logging.info("Creating embeddings and indexing...")
# success = async_to_sync(add_document_to_index(doc_id, hospital_id))
# if success:
# logging.info("Document processing completed successfully")
# async_to_sync(update_document_status(doc_id, "processed"))
# return True
# else:
# logging.error("Document processing failed during indexing")
# async_to_sync(update_document_status(doc_id, "failed"))
# return False
# except Exception as e:
# logging.error(f"Processing error: {e}")
# async_to_sync(update_document_status(doc_id, "failed"))
# return False
# finally:
# if file_path and os.path.exists(file_path):
# try:
# os.remove(file_path)
# except Exception as e:
# logging.error(f"Error removing temporary file: {e}")
# # Execute processing and wait for result
# future = thread_pool.submit(process_in_background)
# success = future.result()
# if success:
# return jsonify({"message": "Document processed successfully"}), 200
# else:
# return jsonify({"error": "Document processing failed"}), 500
# except Exception as e:
# logging.error(f"API error: {e}")
# if file_path and os.path.exists(file_path):
# try:
# os.remove(file_path)
# except Exception as file_e:
# logging.error(f"Error removing temporary file: {file_e}")
# return jsonify({"error": str(e)}), 500
# # Initialize the hybrid conversation manager
# redis_client = get_redis_client()
# conversation_manager = HybridConversationManager(redis_client)
# @app.route("/flask-api/generate-answer", methods=["POST"])
# def rag_answer_api():
# """Sync API endpoint for RAG-based question answering with conversation history."""
# access_logger.info(f"Generate answer request received from {request.remote_addr}")
# try:
# data = request.json
# question = data.get("question", "").strip().lower()
# hospital_code = data.get("hospital_code")
# doc_id = data.get("doc_id")
# user_id = data.get("user_id", "default")
# session_id = data.get("session_id", None)
# logging.info(f"Received question from user {user_id}: {question}")
# logging.info(f"Received hospital code: {hospital_code}")
# logging.info(f"Received session_id: {session_id}")
# # is_confirmation_response = data.get("is_confirmation_response", False)
# original_query = data.get("original_query", "")
# def process_rag_answer():
# try:
# hospital_id = async_to_sync(get_hospital_id(hospital_code))
# logging.info(f"Resolved hospital ID: {hospital_id}")
# if not hospital_id:
# return {
# "error": "Invalid or missing 'hospital_code' in request"
# }, 400
# # if question == "yes" and original_query:
# # # User confirmed they want a general knowledge answer
# # answer = async_to_sync(
# # generate_general_knowledge_answer(
# # original_query,
# # client,
# # user_id,
# # hospital_id,
# # conversation_manager, # Pass the hybrid manager
# # is_table_request(original_query),
# # session_id=session_id,
# # )
# # )
# # return {"answer": answer}, 200
# if original_query:
# response_message = """
# <p>I can only answer questions based on information found in the hospital documents.</p>
# <p>The question you asked doesn't seem to be covered in the available documents.</p>
# <p>You can try rephrasing your question or asking about a different topic.</p>
# """
# return {"answer": response_message}, 200
# else:
# # Regular RAG answer
# return async_to_sync(
# generate_answer_with_rag(
# question=question,
# hospital_id=hospital_id,
# client=client,
# doc_id=doc_id,
# user_id=user_id,
# conversation_manager=conversation_manager, # Pass the hybrid manager
# session_id=session_id,
# )
# )
# except Exception as e:
# logging.error(f"Thread processing error: {str(e)}")
# return {"error": str(e)}, 500
# if not question:
# return jsonify({"error": "Missing 'question' in request"}), 400
# future = thread_pool.submit(process_rag_answer)
# result, status_code = future.result()
# return jsonify(result), status_code
# except Exception as e:
# logging.error(f"API error: {str(e)}")
# return jsonify({"error": str(e)}), 500
# @app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
# def delete_document_vectors_endpoint():
# """Endpoint to delete document vectors from ChromaDB"""
# try:
# data = request.json
# hospital_id = data.get("hospital_id")
# doc_id = data.get("doc_id")
# if not all([hospital_id, doc_id]):
# return jsonify({"error": "Missing required parameters"}), 400
# logging.info(
# f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
# )
# def process_deletion():
# try:
# success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
# if success:
# return {"message": "Document vectors deleted successfully"}, 200
# else:
# return {"error": "Failed to delete document vectors"}, 500
# except Exception as e:
# logging.error(f"Error in vector deletion process: {e}")
# return {"error": str(e)}, 500
# future = thread_pool.submit(process_deletion)
# result, status_code = future.result()
# return jsonify(result), status_code
# except Exception as e:
# logging.error(f"API error: {str(e)}")
# return jsonify({"error": str(e)}), 500
# @app.route("/flask-api/get-chroma-content", methods=["GET"])
# def get_chroma_content_endpoint():
# """API endpoint to get ChromaDB content by hospital_id"""
# try:
# hospital_id = request.args.get("hospital_id")
# limit = int(request.args.get("limit", 30000))
# if not hospital_id:
# return jsonify({"error": "Missing required parameter: hospital_id"}), 400
# def process_fetch():
# try:
# result, status_code = async_to_sync(
# get_chroma_content_by_hospital(
# hospital_id=int(hospital_id), limit=limit
# )
# )
# return result, status_code
# except Exception as e:
# logging.error(f"Error in ChromaDB fetch process: {e}")
# return {"error": str(e)}, 500
# future = thread_pool.submit(process_fetch)
# result, status_code = future.result()
# return jsonify(result), status_code
# except Exception as e:
# logging.error(f"API error: {str(e)}")
# return jsonify({"error": str(e)}), 500
# async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
# """Fetch content from ChromaDB for a specific hospital"""
# try:
# # Initialize vector store
# vector_store = await initialize_or_load_vector_store(hospital_id)
# if not vector_store:
# return {"error": "Vector store not found"}, 404
# # Get collection
# collection = vector_store._collection
# # Query the collection with hospital_id filter
# results = await asyncio.to_thread(
# lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
# )
# if not results or not results["ids"]:
# return {"data": [], "count": 0}, 200
# # Format the response
# formatted_results = []
# for i in range(len(results["ids"])):
# formatted_results.append(
# {
# "id": results["ids"][i],
# "content": results["documents"][i],
# "metadata": results["metadatas"][i],
# }
# )
# return {"data": formatted_results, "count": len(formatted_results)}, 200
# except Exception as e:
# logging.error(f"Error fetching ChromaDB content: {e}")
# return {"error": str(e)}, 500
# @app.before_request
# def before_request():
# request._start_time = time.time()
# @app.after_request
# def after_request(response):
# if hasattr(request, "_start_time"):
# duration = time.time() - request._start_time
# access_logger.info(
# f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
# f"IP: {request.remote_addr}"
# )
# return response
# if __name__ == "__main__":
# logger.info("Starting SpurrinAI application")
# logger.info(f"Python version: {sys.version}")
# logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
# try:
# model_manager = ModelManager()
# logger.info("Model manager initialized successfully")
# except Exception as e:
# logger.error(f"Failed to initialize model manager: {e}")
# sys.exit(1)
# # Initialize directories
# os.makedirs(DATA_DIR, exist_ok=True)
# os.makedirs(CHROMA_DIR, exist_ok=True)
# logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
# # Clear Redis cache
# redis_client = get_redis_client()
# cleared_keys = 0
# for key in redis_client.scan_iter("vector_store_data:*"):
# redis_client.delete(key)
# cleared_keys += 1
# logger.info(f"Cleared {cleared_keys} Redis cache keys")
# # Load vector stores
# logger.info("Loading existing vector stores...")
# async_to_sync(load_existing_vector_stores())
# logger.info("Vector stores loaded successfully")
# # Start application
# logger.info("Starting Flask application on port 5000")
# app.run(port=5000, debug=False)
"""
SpurrinAI - Intelligent Document Processing and Question Answering System
Copyright (c) 2024 Tech4biz. All rights reserved.
This module implements the main Flask application for the SpurrinAI system,
providing REST APIs for document processing, vector storage, and question answering
using RAG (Retrieval Augmented Generation) architecture.
Author: Tech4biz Development Team
Version: 1.0.0
Last Updated: 2024-01-19
"""
# Standard library imports
import os
import re
import sys
import json
import time
import threading
import asyncio
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
# Third-party imports
import spacy
import redis
import aiomysql
from dotenv import load_dotenv
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from tqdm import tqdm
from tqdm.asyncio import tqdm as tqdm_async
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from openai import OpenAI
from rapidfuzz import process
from threading import Lock
# Local imports
from model_manager import ModelManager
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")
# Initialize NLTK
import nltk
nltk.download("punkt")
# Configure logging
import logging
import logging.handlers
app = Flask(__name__)
CORS(app)
script_dir = os.path.dirname(os.path.abspath(__file__))
log_file_path = os.path.join(script_dir, "error.log")
logging.basicConfig(filename=log_file_path, level=logging.INFO)
# Configure logging
def setup_logging():
log_dir = os.path.join(script_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
main_log = os.path.join(log_dir, "app.log")
error_log = os.path.join(log_dir, "error.log")
access_log = os.path.join(log_dir, "access.log")
perf_log = os.path.join(log_dir, "performance.log")
# Create formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
)
access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# Main logger setup
main_handler = logging.handlers.RotatingFileHandler(
main_log, maxBytes=10485760, backupCount=5
)
main_handler.setFormatter(detailed_formatter)
main_handler.setLevel(logging.INFO)
# Error logger setup
error_handler = logging.handlers.RotatingFileHandler(
error_log, maxBytes=10485760, backupCount=5
)
error_handler.setFormatter(detailed_formatter)
error_handler.setLevel(logging.ERROR)
# Access logger setup
access_handler = logging.handlers.TimedRotatingFileHandler(
access_log, when="midnight", interval=1, backupCount=30
)
access_handler.setFormatter(access_formatter)
access_handler.setLevel(logging.INFO)
# Performance logger setup
perf_handler = logging.handlers.RotatingFileHandler(
perf_log, maxBytes=10485760, backupCount=5
)
perf_handler.setFormatter(detailed_formatter)
perf_handler.setLevel(logging.INFO)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(main_handler)
root_logger.addHandler(error_handler)
# Create specific loggers
access_logger = logging.getLogger("access")
access_logger.addHandler(access_handler)
access_logger.setLevel(logging.INFO)
perf_logger = logging.getLogger("performance")
perf_logger.addHandler(perf_handler)
perf_logger.setLevel(logging.INFO)
return root_logger, access_logger, perf_logger
# Initialize loggers
logger, access_logger, perf_logger = setup_logging()
# DB_CONFIG = {
# 'host': 'localhost',
# 'user': 'flaskuser',
# 'password': 'Flask@123',
# 'database': 'spurrinai',
# }
# DB_CONFIG = {
# 'host': 'localhost',
# 'user': 'spurrindevuser',
# 'password': 'Admin@123',
# 'database': 'spurrindev',
# }
# DB_CONFIG = {
# 'host': 'localhost',
# 'user': 'root',
# 'password': 'root',
# 'database': 'medqueryai',
# 'port': 3307
# }
# Redis Configuration
REDIS_CONFIG = {
"host": "localhost",
"port": 6379,
"db": 0,
"decode_responses": True, # For string operations
}
DB_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"user": os.getenv("DB_USER", "testuser"),
"password": os.getenv("DB_PASSWORD", "Admin@123"),
"database": os.getenv("DB_NAME", "spurrintest"),
}
# Redis connection pool
redis_pool = redis.ConnectionPool(**REDIS_CONFIG)
redis_binary_pool = redis.ConnectionPool(
host="localhost", port=6379, db=1, decode_responses=False
)
def get_redis_client(binary=False):
"""Get Redis client from pool"""
logger.debug(f"Getting Redis client with binary={binary}")
try:
pool = redis_binary_pool if binary else redis_pool
client = redis.Redis(connection_pool=pool)
logger.debug("Redis client created successfully")
return client
except Exception as e:
logger.error(f"Failed to create Redis client: {e}", exc_info=True)
raise
def fetch_cached_answer(cache_key):
logger.debug(f"Attempting to fetch cached answer for key: {cache_key}")
start_time = time.time()
try:
redis_client = get_redis_client()
cached_answer = redis_client.get(cache_key)
fetch_time = time.time() - start_time
perf_logger.info(
f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}"
)
return cached_answer
except Exception as e:
logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True)
return None
# Cache TTL configurations
CACHE_TTL = {
"vector_store": timedelta(hours=24),
"chat_completion": timedelta(hours=1),
"document_metadata": timedelta(days=7),
}
DATA_DIR = os.path.join(script_dir, "hospital_data")
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")
uploads_dir = os.path.join(script_dir, "llm-uploads")
if not os.path.exists(uploads_dir):
os.makedirs(uploads_dir)
nlp = spacy.load("en_core_web_sm")
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
llm = ChatOpenAI(
model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY
)
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
hospital_vector_stores = {}
vector_store_lock = threading.Lock()
@dataclass
class Document:
doc_id: int
page_num: int
content: str
class DocumentStatus(Enum):
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
async def get_db_pool():
return await aiomysql.create_pool(
host=DB_CONFIG["host"],
user=DB_CONFIG["user"],
password=DB_CONFIG["password"],
db=DB_CONFIG["database"],
autocommit=True,
)
async def get_hospital_id(hospital_code):
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(
"SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1",
(hospital_code,),
)
result = await cursor.fetchone()
return result["id"] if result else None
except Exception as error:
logging.error(f"Database error: {error}")
return None
finally:
pool.close()
await pool.wait_closed()
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 50
BATCH_SIZE = 1000
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
# length_function=len,
# separators=["\n\n", "\n", ". ", " ", ""]
)
# Update the JSON_PATH to be dynamic based on hospital_id
def get_icd_json_path(hospital_id):
hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}")
os.makedirs(hospital_data_dir, exist_ok=True)
return os.path.join(hospital_data_dir, "icd_data.json")
def extract_and_process_icd_data(content, hospital_id, save_to_json=True):
"""Extract and process ICD codes with optimized processing and optional JSON saving"""
try:
# Initialize pattern compilation once
pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE)
# Process in chunks for large content
chunk_size = 50000 # Process 50KB at a time
icd_data = []
current_code = None
current_description = []
# Split content into manageable chunks
content_chunks = [
content[i : i + chunk_size] for i in range(0, len(content), chunk_size)
]
# Process each chunk
for chunk in content_chunks:
lines = chunk.splitlines()
for line in lines:
line = line.strip()
if not line:
if current_code and current_description:
icd_data.append(
{
"code": current_code,
"description": " ".join(current_description).strip(),
}
)
current_code = None
current_description = []
continue
match = pattern.match(line)
if match:
if current_code and current_description:
icd_data.append(
{
"code": current_code,
"description": " ".join(current_description).strip(),
}
)
current_code, description = match.groups()
current_description = [description.strip()]
elif current_code:
current_description.append(line)
# Add final entry if exists
if current_code and current_description:
icd_data.append(
{
"code": current_code,
"description": " ".join(current_description).strip(),
}
)
# Save to hospital-specific JSON if requested
if save_to_json and icd_data:
try:
json_path = get_icd_json_path(hospital_id)
# Use a lock for thread safety
with threading.Lock():
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
try:
existing_data = json.load(f)
except json.JSONDecodeError:
existing_data = []
else:
existing_data = []
# Efficient deduplication using dictionary
seen_codes = {item["code"]: item for item in existing_data}
for item in icd_data:
seen_codes[item["code"]] = item
unique_data = list(seen_codes.values())
# Write atomically using temporary file
temp_path = f"{json_path}.tmp"
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(unique_data, f, indent=2, ensure_ascii=False)
os.replace(temp_path, json_path)
logging.info(
f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}"
)
except Exception as e:
logging.error(
f"Error saving ICD data to JSON for hospital {hospital_id}: {e}"
)
return icd_data
except Exception as e:
logging.error(f"Error in extract_and_process_icd_data: {e}")
return []
def load_icd_entries(hospital_id):
"""Load ICD entries from hospital-specific JSON file"""
json_path = get_icd_json_path(hospital_id)
try:
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
return []
except Exception as e:
logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}")
return []
# Update the process_icd_codes function to include hospital_id
async def process_icd_codes(content, doc_id, hospital_id, batch_size=256):
"""Process and store ICD codes using the optimized extraction function"""
try:
# Extract and save codes with hospital_id
extract_and_process_icd_data(content, hospital_id, save_to_json=True)
except Exception as e:
logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}")
async def initialize_icd_vector_store(hospital_id):
"""This function is deprecated. ICD codes are now handled through JSON search."""
logging.warning(
"initialize_icd_vector_store is deprecated - using JSON search instead"
)
return None
def extract_pdf_contents(pdf_path, hospital_id):
"""Extract PDF contents with optimized chunking and code extraction"""
try:
loader = PyPDFLoader(pdf_path)
pages = loader.load()
pages_content = []
for i, page in enumerate(tqdm(pages, desc="Extracting pages")):
text = page.page_content.strip()
# Extract ICD codes from the page
icd_codes = extract_and_process_icd_data(
text, hospital_id
) # We'll set doc_id later
pages_content.append({"page": i + 1, "text": text, "codes": icd_codes})
return pages_content
except Exception as e:
logging.error(f"Error in extract_pdf_contents: {e}")
raise
async def insert_content_into_db(content, metadata, doc_id):
pool = await get_db_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
try:
metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)"
content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)"
metadata_values = [
(doc_id, key[:100], value)
for key, value in metadata.items()
if value
]
content_values = [
(doc_id, page_content["page"], page_content["text"])
for page_content in content
]
if metadata_values:
await cursor.executemany(metadata_query, metadata_values)
if content_values:
await cursor.executemany(content_query, content_values)
await conn.commit()
return {"message": "Success"}
except Exception as e:
await conn.rollback()
return {"error": str(e)}
async def initialize_or_load_vector_store(hospital_id, user_id="default"):
"""Initialize or load vector store with Redis caching and thread safety"""
store_key = f"{hospital_id}:{user_id}"
try:
# Check if we already have it loaded - with lock for thread safety
with vector_store_lock:
if store_key in hospital_vector_stores:
return hospital_vector_stores[store_key]
# Initialize vector store
redis_client = get_redis_client(binary=True)
cache_key = f"vector_store_data:{hospital_id}:{user_id}"
hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}")
if os.path.exists(hospital_dir):
logging.info(
f"Loading vector store for hospital {hospital_id} and user {user_id}"
)
vector_store = await asyncio.to_thread(
lambda: Chroma(
collection_name=f"hospital_{hospital_id}",
persist_directory=hospital_dir,
embedding_function=embeddings,
)
)
else:
logging.info(f"Creating vector store for hospital {hospital_id}")
os.makedirs(hospital_dir, exist_ok=True)
vector_store = await asyncio.to_thread(
lambda: Chroma(
collection_name=f"hospital_{hospital_id}",
persist_directory=hospital_dir,
embedding_function=embeddings,
)
)
# Store with lock for thread safety
with vector_store_lock:
hospital_vector_stores[store_key] = vector_store
return vector_store
except Exception as e:
logging.error(f"Error initializing vector store: {e}", exc_info=True)
raise
async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool:
"""Delete all vectors associated with a specific document from ChromaDB"""
try:
# Initialize vector store for the hospital
vector_store = await initialize_or_load_vector_store(hospital_id)
# Delete vectors with matching doc_id
await asyncio.to_thread(
lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)})
)
# Persist changes
await asyncio.to_thread(vector_store.persist)
# Clear Redis cache for this document
redis_client = get_redis_client()
pattern = f"vector_store_data:{hospital_id}:*"
for key in redis_client.scan_iter(pattern):
redis_client.delete(key)
logging.info(
f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}"
)
return True
except Exception as e:
logging.error(f"Error deleting document vectors: {e}", exc_info=True)
return False
async def add_document_to_index(doc_id, hospital_id):
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
vector_store = await initialize_or_load_vector_store(hospital_id)
await cursor.execute(
"SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number",
(doc_id,),
)
rows = await cursor.fetchall()
total_pages = len(rows)
logging.info(f"Processing {total_pages} pages for document {doc_id}")
page_bar = tqdm_async(total=total_pages, desc="Processing pages")
async def process_page(page_data):
page_num, content = page_data
try:
icd_data = extract_and_process_icd_data(
content, hospital_id, save_to_json=False
)
chunks = text_splitter.split_text(content)
await asyncio.sleep(0) # Yield control
return page_num, chunks, icd_data
except Exception as e:
logging.error(f"Error processing page {page_num}: {e}")
return page_num, [], []
tasks = [asyncio.create_task(process_page(row)) for row in rows]
results = []
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
page_bar.update(1)
page_bar.close()
# Vector addition progress bar
all_icd_data = []
all_chunks = []
all_metadatas = []
chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0)
for result in results:
page_num, chunks, icd_data = result
all_icd_data.extend(icd_data)
for i, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append(
{
"doc_id": str(doc_id),
"hospital_id": str(hospital_id),
"page_number": str(page_num),
"chunk_index": str(i),
}
)
if len(all_chunks) >= BATCH_SIZE:
chunk_add_bar.total += len(all_chunks)
chunk_add_bar.refresh()
await asyncio.to_thread(
vector_store.add_texts,
texts=all_chunks,
metadatas=all_metadatas,
)
all_chunks = []
all_metadatas = []
chunk_add_bar.update(BATCH_SIZE)
# Final batch
if all_chunks:
chunk_add_bar.total += len(all_chunks)
chunk_add_bar.refresh()
await asyncio.to_thread(
vector_store.add_texts,
texts=all_chunks,
metadatas=all_metadatas,
)
chunk_add_bar.update(len(all_chunks))
chunk_add_bar.close()
if all_icd_data:
logging.info(f"Saving {len(all_icd_data)} ICD codes")
extract_and_process_icd_data("", hospital_id, save_to_json=True)
await asyncio.to_thread(vector_store.persist)
logging.info(f"Successfully indexed document {doc_id}")
return True
except Exception as e:
logging.error(f"Error adding document: {e}")
return False
def is_general_knowledge_question(
query: str, context: str, conversation_context=None
) -> bool:
"""
Determine if a question is likely a general knowledge question not covered in the documents.
Takes conversation history into account to reduce repeated confirmations.
"""
query_lower = query.lower()
context_lower = context.lower()
if conversation_context:
for interaction in conversation_context:
prev_question = interaction.get("question", "").lower()
if (
prev_question
and query_lower in prev_question
or prev_question in query_lower
):
logging.info(
f"Question is similar to previous conversation, skipping confirmation"
)
return False
stop_words = {
"search",
"query:",
"can",
"you",
"some",
"at",
"the",
"a",
"an",
"in",
"on",
"at",
"to",
"for",
"with",
"by",
"about",
"give",
"full",
"is",
"are",
"was",
"were",
"define",
"what",
"how",
"why",
"when",
"where",
"year",
"list",
"form",
"table",
"who",
"which",
"me",
"tell",
"explain",
"describe",
"of",
"and",
"or",
"there",
"their",
"please",
"could",
"would",
"various",
"different",
"type",
"types",
"kind",
"kinds",
"has",
"have",
"had",
"many",
"say",
"know",
}
key_words = [
word for word in query_lower.split() if word not in stop_words and len(word) > 2
]
logging.info(f"Key words: {key_words}")
if not key_words:
logging.info("No significant keywords found, directing to general knowledge")
return True
matches = sum(1 for word in key_words if word in context_lower)
logging.info(f"Matches: {matches} out of {len(key_words)} keywords")
match_ratio = matches / len(key_words)
logging.info(f"Match ratio: {match_ratio}")
return match_ratio < 0.4
def is_table_request(query: str) -> bool:
"""
Determine if the user is requesting a response in tabular format.
"""
table_keywords = [
"table",
"tabular",
"in a table",
"in table format",
"in tabular format",
"chart",
"data",
"comparison",
"as a table",
"table format",
"in rows and columns",
"in a grid",
"breakdown",
"spreadsheet",
"comparison table",
"data table",
"structured table",
"tabular form",
"table form",
]
query_lower = query.lower()
return any(keyword in query_lower for keyword in table_keywords)
import re
def ensure_html_response(text: str) -> str:
"""
Ensure the response is properly formatted in HTML.
This function handles plain text conversion to HTML.
"""
if "<html" in text.lower() or "<body" in text.lower():
return text
has_html_tags = bool(re.search(r"<[a-z]+.*?>", text))
if not has_html_tags:
paragraphs = text.split("\n\n")
html_parts = []
in_ordered_list = False
in_unordered_list = False
for para in paragraphs:
if para.strip():
if re.match(r"^\s*[\*\-\•]\s", para):
if not in_unordered_list:
html_parts.append("<ul>")
in_unordered_list = True
lines = para.split("\n")
for line in lines:
if line.strip():
item = re.sub(r"^\s*[\*\-\•]\s*", "", line)
html_parts.append(f"<li>{item}</li>")
elif re.match(r"^\s*\d+\.\s", para):
if not in_ordered_list:
html_parts.append("<ol>")
in_ordered_list = True
lines = para.split("\n")
for line in lines:
match = re.match(r"^\s*\d+\.\s*(.*)", line)
if match:
html_parts.append(f"<li>{match.group(1)}</li>")
else: # Close any open lists before adding a new paragraph
if in_ordered_list:
html_parts.append("</ol>")
in_ordered_list = False
if in_unordered_list:
html_parts.append("</ul>")
in_unordered_list = False
html_parts.append(f"<p>{para}</p>")
if in_ordered_list:
html_parts.append("</ol>")
if in_unordered_list:
html_parts.append("</ul>")
return "".join(html_parts)
else:
if not any(tag in text for tag in ("<p>", "<div>", "<ul>", "<ol>")):
paragraphs = text.split("\n\n")
html_parts = [f"<p>{para}</p>" for para in paragraphs if para.strip()]
return "".join(html_parts)
return text
class HybridConversationManager:
"""
Hybrid conversation manager that uses Redis for RAG-based conversations
and in-memory storage for general knowledge conversations.
"""
def __init__(self, redis_client, ttl=3600, max_history_items=5):
self.redis_client = redis_client
self.ttl = ttl
self.max_history_items = max_history_items
# For general knowledge questions (in-memory)
self.general_knowledge_histories = {}
self.lock = Lock()
def _get_redis_key(self, user_id, hospital_id, session_id=None):
"""Create Redis key for document-based conversations."""
if session_id:
return f"conv_history:{user_id}:{hospital_id}:{session_id}"
return f"conv_history:{user_id}:{hospital_id}"
def _get_memory_key(self, user_id, hospital_id, session_id=None):
"""Create memory key for general knowledge conversations."""
if session_id:
return f"{user_id}:{hospital_id}:{session_id}"
return f"{user_id}:{hospital_id}"
async def add_rag_interaction(
self, user_id, hospital_id, question, answer, session_id=None
):
"""Add document-based (RAG) interaction to Redis."""
key = self._get_redis_key(user_id, hospital_id, session_id)
history = self.get_rag_history(user_id, hospital_id, session_id)
# Add new interaction
history.append(
{
"question": question,
"answer": answer,
"timestamp": time.time(),
"type": "rag", # Mark as RAG-based interaction
}
)
# Keep only last N interactions
history = history[-self.max_history_items :]
# Store updated history
try:
self.redis_client.setex(key, self.ttl, json.dumps(history))
logging.info(
f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}"
)
except Exception as e:
logging.error(f"Failed to store RAG interaction in Redis: {e}")
def add_general_knowledge_interaction(
self, user_id, hospital_id, question, answer, session_id=None
):
"""Add general knowledge interaction to in-memory store."""
key = self._get_memory_key(user_id, hospital_id, session_id)
with self.lock:
if key not in self.general_knowledge_histories:
self.general_knowledge_histories[key] = []
self.general_knowledge_histories[key].append(
{
"question": question,
"answer": answer,
"timestamp": time.time(),
"type": "general", # Mark as general knowledge interaction
}
)
# Keep only the most recent interactions
if len(self.general_knowledge_histories[key]) > self.max_history_items:
self.general_knowledge_histories[key] = (
self.general_knowledge_histories[key][-self.max_history_items :]
)
logging.info(
f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}"
)
def get_rag_history(self, user_id, hospital_id, session_id=None):
"""Get document-based (RAG) conversation history from Redis."""
key = self._get_redis_key(user_id, hospital_id, session_id)
try:
history_data = self.redis_client.get(key)
return json.loads(history_data) if history_data else []
except Exception as e:
logging.error(f"Failed to retrieve RAG history from Redis: {e}")
return []
def get_general_knowledge_history(self, user_id, hospital_id, session_id=None):
"""Get general knowledge conversation history from memory."""
key = self._get_memory_key(user_id, hospital_id, session_id)
with self.lock:
return self.general_knowledge_histories.get(key, []).copy()
def get_combined_history(self, user_id, hospital_id, session_id=None):
"""Get combined conversation history from both sources, sorted by timestamp."""
rag_history = self.get_rag_history(user_id, hospital_id, session_id)
general_history = self.get_general_knowledge_history(
user_id, hospital_id, session_id
)
# Combine histories
combined_history = rag_history + general_history
# Sort by timestamp (newest first)
combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
# Return most recent N items
return combined_history[: self.max_history_items]
def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2):
"""Get the most recent interactions for context from combined history."""
combined_history = self.get_combined_history(user_id, hospital_id, session_id)
# Sort by timestamp (oldest first) for context window
sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0))
return sorted_history[-window_size:] if sorted_history else []
def clear_history(self, user_id, hospital_id):
"""Clear conversation history from both stores."""
# Clear Redis history
redis_key = self._get_redis_key(user_id, hospital_id)
try:
self.redis_client.delete(redis_key)
except Exception as e:
logging.error(f"Failed to clear Redis history: {e}")
# Clear memory history
memory_key = self._get_memory_key(user_id, hospital_id)
with self.lock:
if memory_key in self.general_knowledge_histories:
del self.general_knowledge_histories[memory_key]
class ContextMapper:
"""Enhanced context mapping using shared model manager"""
def __init__(self):
self.model_manager = ModelManager()
self.context_cache = {}
self.similarity_threshold = 0.6
def get_semantic_similarity(self, text1, text2):
"""Get semantic similarity using global model manager"""
return self.model_manager.get_semantic_similarity(text1, text2)
def extract_key_concepts(self, text):
"""Extract key concepts using NLP techniques"""
doc = nlp(text)
concepts = []
entities = [(ent.text, ent.label_) for ent in doc.ents]
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
important_words = [
token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"]
]
concepts.extend([e[0] for e in entities])
concepts.extend(noun_phrases)
concepts.extend(important_words)
return list(set(concepts))
def map_conversation_context(
self, current_query, conversation_history, context_window=3
):
"""Map conversation context using enhanced NLP techniques"""
if not conversation_history:
return current_query
recent_context = conversation_history[-context_window:]
context_concepts = []
# Extract concepts from recent conversations
for interaction in recent_context:
q_concepts = self.extract_key_concepts(interaction["question"])
a_concepts = self.extract_key_concepts(interaction["answer"])
context_concepts.extend(q_concepts)
context_concepts.extend(a_concepts)
# Extract concepts from current query
query_concepts = self.extract_key_concepts(current_query)
# Find related concepts
related_concepts = []
for q_concept in query_concepts:
for c_concept in context_concepts:
similarity = self.get_semantic_similarity(q_concept, c_concept)
if similarity > self.similarity_threshold:
related_concepts.append(c_concept)
# Build enhanced query
if related_concepts:
enhanced_query = (
f"{current_query} in context of {', '.join(related_concepts)}"
)
else:
enhanced_query = current_query
return enhanced_query
# Initialize the context mapper
context_mapper = ContextMapper()
async def generate_contextual_query(
question: str, user_id: str, hospital_id: int, conversation_manager
) -> str:
"""Generate enhanced contextual query"""
context_window = conversation_manager.get_context_window(user_id, hospital_id)
if not context_window:
return question
# Enhanced context mapping
last_interaction = context_window[-1]
enhanced_context = f"""
Previous question: {last_interaction['question']}
Previous answer: {last_interaction['answer']}
Current question: {question}
Please generate a detailed search query that combines the context from the previous answer
with the current question, especially if the current question uses words like 'it', 'this',
'that', or asks for more details about the previous topic.
"""
try:
response = await asyncio.to_thread(
lambda: client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a context-aware query generator.",
},
{"role": "user", "content": enhanced_context},
],
temperature=0.3,
max_tokens=150,
)
)
contextual_query = response.choices[0].message.content.strip()
logging.info(f"Enhanced contextual query: {contextual_query}")
return contextual_query
except Exception as e:
logging.error(f"Error generating contextual query: {e}")
return question
def is_follow_up(current_question: str, conversation_history: list) -> bool:
"""Enhanced follow-up detection using NLP techniques"""
if not conversation_history:
return False
last_interaction = conversation_history[-1]
# Get semantic similarity with higher threshold
similarity = context_mapper.get_semantic_similarity(
current_question, f"{last_interaction['question']} {last_interaction['answer']}"
)
# Enhanced referential check
doc = nlp(current_question.lower())
has_referential = any(
token.lemma_
in [
"it",
"this",
"that",
"these",
"those",
"they",
"he",
"she",
"about",
"more",
]
for token in doc
)
# Extract concepts with improved entity detection
current_concepts = set(context_mapper.extract_key_concepts(current_question))
last_concepts = set(
context_mapper.extract_key_concepts(
f"{last_interaction['question']} {last_interaction['answer']}"
)
)
# Calculate enhanced concept overlap
concept_overlap = (
len(current_concepts & last_concepts) / len(current_concepts | last_concepts)
if current_concepts
else 0
)
# More aggressive follow-up detection
return (
similarity > 0.3 # Lowered threshold
or has_referential
or concept_overlap > 0.2 # Lowered threshold
or any(
word in current_question.lower()
for word in ["more", "about", "elaborate", "explain"]
)
)
async def get_relevant_context(question, hospital_id, doc_id=None):
try:
cache_key = f"context:hospital_{hospital_id}"
if doc_id:
cache_key += f":doc_{doc_id}"
cache_key += f":{question.lower().strip()}"
redis_client = get_redis_client()
cached_context = redis_client.get(cache_key)
if cached_context:
logging.info(f"Cache hit for key: {cache_key}")
return (
cached_context.decode("utf-8")
if isinstance(cached_context, bytes)
else cached_context
)
vector_store = await initialize_or_load_vector_store(hospital_id)
if not vector_store:
return ""
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={
"k": 10,
"fetch_k": 20,
"lambda_mult": 0.6,
# "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
},
)
docs = await asyncio.to_thread(retriever.get_relevant_documents, question)
if not docs:
return ""
sorted_docs = sorted(
docs,
key=lambda x: (
int(x.metadata.get("page_number", 0)),
int(x.metadata.get("chunk_index", 0)),
),
)
context_parts = [doc.page_content for doc in sorted_docs]
context = "\n\n".join(context_parts)
try:
redis_client.setex(
cache_key,
int(CACHE_TTL["vector_store"].total_seconds()),
context.encode("utf-8") if isinstance(context, str) else context,
)
logging.info(f"Cached context for key: {cache_key}")
except Exception as cache_error:
logging.error(f"Failed to cache context: {cache_error}")
return context
except Exception as e:
logging.error(f"Error getting relevant context: {e}")
return ""
def format_conversation_context(conv_history):
"""Format conversation history into a string"""
if not conv_history:
return "No previous conversation."
return "\n".join(
[
f"Q: {interaction['question']}\nA: {interaction['answer']}"
for interaction in conv_history
]
)
def get_icd_context_from_question(question, hospital_id):
"""Extract any valid ICD codes from the question and return context"""
icd_data = load_icd_entries(hospital_id)
matches = []
code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper())
seen = set()
for code in code_pattern:
for entry in icd_data:
if entry["code"] == code and code not in seen:
matches.append(f"{entry['code']}: {entry['description']}")
seen.add(code)
return "\n".join(matches)
def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70):
"""Get fuzzy matches for ICD codes from the question"""
icd_data = load_icd_entries(hospital_id)
descriptions = [entry["description"] for entry in icd_data]
matches = process.extract(
question, descriptions, limit=top_n, score_cutoff=threshold
)
matched_context = []
for desc, score, _ in matches:
for entry in icd_data:
if entry["description"] == desc:
matched_context.append(f"{entry['code']}: {entry['description']}")
break
return "\n".join(matched_context)
async def generate_answer_with_rag(question, hospital_id, client, doc_id=None):
"""Generate answer using strict RAG approach - only using document content"""
try:
html_instruction = """
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
- Use <p> tags for paragraphs
- Use <h2>, <h3> tags for headings and subheadings
- Use <ul>, <li> tags for bullet points
- Use <ol>, <li> tags for numbered lists
- Use <blockquote> for quoted text
- Use <strong> for bold text and <em> for emphasis
"""
table_instruction = """
- For tables, use proper HTML table structure:
<table border="1">
<thead>
<tr>
<th colspan="{total_columns}">{table_title}</th>
</tr>
<tr>
{table_headers}
</tr>
</thead>
<tbody>
{table_rows}
</tbody>
</table>
"""
# First, check for ICD codes in the question
icd_exact_context = get_icd_context_from_question(question, hospital_id)
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
# Get document context
vector_store = initialize_or_load_vector_store(hospital_id)
if not vector_store:
return {"answer": "<p>No document content available.</p>"}, 404
# Create template focusing only on document content
prompt_template = PromptTemplate(
template="""Based ONLY on the provided document context and ICD codes, generate an answer to the question.
If the information is not found in the context, explicitly state that.
Do not use any external knowledge or assumptions.
{html_instruction}
{table_instruction}
ICD Code Matches:
{icd_exact_context}
Related ICD Codes:
{icd_fuzzy_context}
Context from documents: {context}
Question: {question}
Instructions:
1. Only use information explicitly present in the context and ICD codes
2. Do not make assumptions or use external knowledge
3. If an ICD code is found, include it in your response
4. If the answer is not in the context, say "This information is not found in the available documents"
5. Format response in clear HTML
Answer:""",
input_variables=["context", "question"],
partial_variables={
"html_instruction": html_instruction,
"table_instruction": table_instruction if is_table_request(question) else "",
"icd_exact_context": icd_exact_context if icd_exact_context else "No exact ICD code matches found.",
"icd_fuzzy_context": icd_fuzzy_context if icd_fuzzy_context else "No related ICD codes found."
}
)
retriever = vector_store.as_retriever(
search_type="similarity",
search_kwargs={
"k": 6,
"filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
}
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template},
return_source_documents=True
)
result = qa_chain({"query": question})
formatted_answer = ensure_html_response(result["result"])
if "not found in" in formatted_answer.lower():
return {"answer": formatted_answer}, 404
return {"answer": formatted_answer}, 200
except Exception as e:
return {"answer": f"<p>Error: {str(e)}</p>"}, 500
async def update_document_status(doc_id, status, failed_page=None):
"""Update document status with enum validation"""
if isinstance(status, str):
status = DocumentStatus[status.upper()].value
pool = await get_db_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
try:
if failed_page:
await cursor.execute(
"UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s",
(status, failed_page, doc_id),
)
else:
await cursor.execute(
"UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s",
(status, doc_id),
)
await conn.commit()
return True
except Exception as e:
logging.error(f"Database update error: {e}")
return False
async def load_existing_vector_stores():
"""Load existing Chroma vector stores for each hospital"""
pool = await get_db_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
try:
await cursor.execute("SELECT DISTINCT id FROM hospitals")
hospital_ids = [row[0] for row in await cursor.fetchall()]
for hospital_id in hospital_ids:
try:
await initialize_or_load_vector_store(hospital_id)
except Exception as e:
logging.error(
f"Failed to load vector store for hospital {hospital_id}: {e}"
)
continue
except Exception as e:
logging.error(f"Error loading vector stores: {e}")
@app.route('/flask-api/generate-answer', methods=['POST'])
def generate_answer():
"""Generate answer using RAG approach."""
try:
data = request.json
question = data.get('question')
hospital_code = data.get('hospital_code')
hospital_id = async_to_sync(get_hospital_id(hospital_code))
doc_id = data.get('doc_id')
html_instruction = """
IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content:
- Use <p> tags for paragraphs
- Use <h2>, <h3> tags for headings and subheadings
- Use <ul>, <li> tags for bullet points
- Use <ol>, <li> tags for numbered lists
- Use <blockquote> for quoted text
- Use <strong> for bold text and <em> for emphasis
"""
table_instruction = """
- For tables, use proper HTML table structure:
<table border="1">
<thead>
<tr>
<th colspan="{total_columns}">{table_title}</th>
</tr>
<tr>
{table_headers}
</tr>
</thead>
<tbody>
{table_rows}
</tbody>
</table>
"""
# Validate required parameters
if not question:
return jsonify({"error": "Missing 'question' in request"}), 400
if not hospital_id:
return jsonify({"error": "Missing 'hospital_code' in request"}), 400
# Get ICD context
icd_exact_context = get_icd_context_from_question(question, hospital_id)
icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id)
vector_store = async_to_sync(initialize_or_load_vector_store(hospital_id))
if not vector_store:
return jsonify({"answer": "<p>No document content available.</p>"}), 404
# Create template focusing only on document content
prompt_template = PromptTemplate(
template="""Based ONLY on the provided document context and ICD codes, generate an answer.
Do not use any external knowledge or assumptions.
{html_instruction}
{table_instruction}
ICD Code Matches:
{icd_exact_context}
Related ICD Codes:
{icd_fuzzy_context}
Context: {context}
Question: {question}
Instructions:
1. Only use information explicitly present in the context and ICD codes
2. Do not make assumptions or use external knowledge
3. If an ICD code is found, include it in your response
4. If the answer is not in the context, say "This information is not found in the available documents"
5. Format response in clear HTML
Answer:""",
input_variables=["context", "question"],
partial_variables={
"html_instruction": html_instruction,
"table_instruction": table_instruction if is_table_request(question) else "",
"icd_exact_context": icd_exact_context if icd_exact_context else "No exact ICD code matches found.",
"icd_fuzzy_context": icd_fuzzy_context if icd_fuzzy_context else "No related ICD codes found."
}
)
retriever = vector_store.as_retriever(
search_type="similarity",
search_kwargs={
"k": 6,
"filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)}
}
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template},
return_source_documents=True
)
result = qa_chain({"query": question})
answer = ensure_html_response(result["result"])
if "not found in" in answer.lower():
return jsonify({
"answer": "<p>This information is not found in the available documents.</p>"
}), 404
return jsonify({"answer": answer}), 200
except Exception as e:
logging.error(f"Error generating answer: {e}")
return jsonify({"error": str(e)}), 500
thread_pool = ThreadPoolExecutor(max_workers=10)
def async_to_sync(coroutine):
"""Helper function to run async code in sync context"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coroutine)
finally:
loop.close()
@app.route("/flask-api", methods=["GET"])
def health_check():
"""Health check endpoint"""
access_logger.info(f"Health check request received from {request.remote_addr}")
return jsonify({"status": "ok"}), 200
@app.route("/flask-api/process-pdf", methods=["POST"])
def process_pdf():
access_logger.info(f"PDF processing request received from {request.remote_addr}")
file_path = None
try:
file = request.files.get("pdf")
hospital_id = request.form.get("hospital_id")
doc_id = request.form.get("doc_id")
logging.info(
f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}"
)
if not all([file, hospital_id, doc_id]):
return jsonify({"error": "Missing required parameters"}), 400
def process_in_background():
nonlocal file_path
try:
async_to_sync(update_document_status(doc_id, "processing"))
# Add progress logging
logging.info(f"Starting processing of document {doc_id}")
filename = f"doc_{doc_id}_{file.filename}"
file_path = os.path.join(uploads_dir, filename)
with open(file_path, "wb") as f:
file.save(f)
logging.info("Extracting PDF contents...")
content = extract_pdf_contents(file_path, int(hospital_id))
logging.info("Inserting content into database...")
metadata = {"filename": filename}
result = async_to_sync(
insert_content_into_db(content, metadata, doc_id)
)
if "error" in result:
async_to_sync(update_document_status(doc_id, "failed", 1))
return False
logging.info("Creating embeddings and indexing...")
success = async_to_sync(add_document_to_index(doc_id, hospital_id))
if success:
logging.info("Document processing completed successfully")
async_to_sync(update_document_status(doc_id, "processed"))
return True
else:
logging.error("Document processing failed during indexing")
async_to_sync(update_document_status(doc_id, "failed"))
return False
except Exception as e:
logging.error(f"Processing error: {e}")
async_to_sync(update_document_status(doc_id, "failed"))
return False
finally:
if file_path and os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
logging.error(f"Error removing temporary file: {e}")
# Execute processing and wait for result
future = thread_pool.submit(process_in_background)
success = future.result()
if success:
return jsonify({"message": "Document processed successfully"}), 200
else:
return jsonify({"error": "Document processing failed"}), 500
except Exception as e:
logging.error(f"API error: {e}")
if file_path and os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as file_e:
logging.error(f"Error removing temporary file: {file_e}")
return jsonify({"error": str(e)}), 500
# Initialize the hybrid conversation manager
redis_client = get_redis_client()
conversation_manager = HybridConversationManager(redis_client)
@app.route("/flask-api/delete-document-vectors", methods=["DELETE"])
def delete_document_vectors_endpoint():
"""Endpoint to delete document vectors from ChromaDB"""
try:
data = request.json
hospital_id = data.get("hospital_id")
doc_id = data.get("doc_id")
if not all([hospital_id, doc_id]):
return jsonify({"error": "Missing required parameters"}), 400
logging.info(
f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}"
)
def process_deletion():
try:
success = async_to_sync(delete_document_vectors(hospital_id, doc_id))
if success:
return {"message": "Document vectors deleted successfully"}, 200
else:
return {"error": "Failed to delete document vectors"}, 500
except Exception as e:
logging.error(f"Error in vector deletion process: {e}")
return {"error": str(e)}, 500
future = thread_pool.submit(process_deletion)
result, status_code = future.result()
return jsonify(result), status_code
except Exception as e:
logging.error(f"API error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route("/flask-api/get-chroma-content", methods=["GET"])
def get_chroma_content_endpoint():
"""API endpoint to get ChromaDB content by hospital_id"""
try:
hospital_id = request.args.get("hospital_id")
limit = int(request.args.get("limit", 30000))
if not hospital_id:
return jsonify({"error": "Missing required parameter: hospital_id"}), 400
def process_fetch():
try:
result, status_code = async_to_sync(
get_chroma_content_by_hospital(
hospital_id=int(hospital_id), limit=limit
)
)
return result, status_code
except Exception as e:
logging.error(f"Error in ChromaDB fetch process: {e}")
return {"error": str(e)}, 500
future = thread_pool.submit(process_fetch)
result, status_code = future.result()
return jsonify(result), status_code
except Exception as e:
logging.error(f"API error: {str(e)}")
return jsonify({"error": str(e)}), 500
async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100):
"""Fetch content from ChromaDB for a specific hospital"""
try:
# Initialize vector store
vector_store = await initialize_or_load_vector_store(hospital_id)
if not vector_store:
return {"error": "Vector store not found"}, 404
# Get collection
collection = vector_store._collection
# Query the collection with hospital_id filter
results = await asyncio.to_thread(
lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit)
)
if not results or not results["ids"]:
return {"data": [], "count": 0}, 200
# Format the response
formatted_results = []
for i in range(len(results["ids"])):
formatted_results.append(
{
"id": results["ids"][i],
"content": results["documents"][i],
"metadata": results["metadatas"][i],
}
)
return {"data": formatted_results, "count": len(formatted_results)}, 200
except Exception as e:
logging.error(f"Error fetching ChromaDB content: {e}")
return {"error": str(e)}, 500
@app.before_request
def before_request():
request._start_time = time.time()
@app.after_request
def after_request(response):
if hasattr(request, "_start_time"):
duration = time.time() - request._start_time
access_logger.info(
f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - '
f"IP: {request.remote_addr}"
)
return response
if __name__ == "__main__":
logger.info("Starting SpurrinAI application")
logger.info(f"Python version: {sys.version}")
logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}")
try:
model_manager = ModelManager()
logger.info("Model manager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize model manager: {e}")
sys.exit(1)
# Initialize directories
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHROMA_DIR, exist_ok=True)
logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}")
# Clear Redis cache
redis_client = get_redis_client()
cleared_keys = 0
for key in redis_client.scan_iter("vector_store_data:*"):
redis_client.delete(key)
cleared_keys += 1
logger.info(f"Cleared {cleared_keys} Redis cache keys")
# Load vector stores
logger.info("Loading existing vector stores...")
async_to_sync(load_existing_vector_stores())
logger.info("Vector stores loaded successfully")
# Start application
logger.info("Starting Flask application on port 5000")
app.run(port=5000, debug=False)