""" SpurrinAI - Intelligent Document Processing and Question Answering System Copyright (c) 2024 Tech4biz. All rights reserved. This module implements the main Flask application for the SpurrinAI system, providing REST APIs for document processing, vector storage, and question answering using RAG (Retrieval Augmented Generation) architecture. Author: Tech4biz Development Team Version: 1.0.0 Last Updated: 2024-01-19 """ # Standard library imports import os import re import sys import json import time import threading import asyncio from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import timedelta from enum import Enum # Third-party imports import spacy import redis import aiomysql from dotenv import load_dotenv from flask import Flask, request, jsonify, Response from flask_cors import CORS from tqdm import tqdm from tqdm.asyncio import tqdm as tqdm_async from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain_community.chat_models import ChatOpenAI from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from openai import OpenAI from rapidfuzz import process from threading import Lock # Local imports from model_manager import ModelManager # Suppress warnings import warnings warnings.filterwarnings("ignore") # Initialize NLTK import nltk nltk.download("punkt") # Configure logging import logging import logging.handlers app = Flask(__name__) CORS(app) script_dir = os.path.dirname(os.path.abspath(__file__)) log_file_path = os.path.join(script_dir, "error.log") logging.basicConfig(filename=log_file_path, level=logging.INFO) # Configure logging def setup_logging(): log_dir = os.path.join(script_dir, "logs") os.makedirs(log_dir, exist_ok=True) main_log = os.path.join(log_dir, "app.log") error_log = os.path.join(log_dir, "error.log") access_log = os.path.join(log_dir, "access.log") perf_log = os.path.join(log_dir, "performance.log") # Create formatters detailed_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" ) access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # Main logger setup main_handler = logging.handlers.RotatingFileHandler( main_log, maxBytes=10485760, backupCount=5 ) main_handler.setFormatter(detailed_formatter) main_handler.setLevel(logging.INFO) # Error logger setup error_handler = logging.handlers.RotatingFileHandler( error_log, maxBytes=10485760, backupCount=5 ) error_handler.setFormatter(detailed_formatter) error_handler.setLevel(logging.ERROR) # Access logger setup access_handler = logging.handlers.TimedRotatingFileHandler( access_log, when="midnight", interval=1, backupCount=30 ) access_handler.setFormatter(access_formatter) access_handler.setLevel(logging.INFO) # Performance logger setup perf_handler = logging.handlers.RotatingFileHandler( perf_log, maxBytes=10485760, backupCount=5 ) perf_handler.setFormatter(detailed_formatter) perf_handler.setLevel(logging.INFO) # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) root_logger.addHandler(main_handler) root_logger.addHandler(error_handler) # Create specific loggers access_logger = logging.getLogger("access") access_logger.addHandler(access_handler) access_logger.setLevel(logging.INFO) perf_logger = logging.getLogger("performance") perf_logger.addHandler(perf_handler) perf_logger.setLevel(logging.INFO) return root_logger, access_logger, perf_logger # Initialize loggers logger, access_logger, perf_logger = setup_logging() # DB_CONFIG = { # 'host': 'localhost', # 'user': 'flaskuser', # 'password': 'Flask@123', # 'database': 'spurrinai', # } # DB_CONFIG = { # 'host': 'localhost', # 'user': 'spurrindevuser', # 'password': 'Admin@123', # 'database': 'spurrindev', # } # DB_CONFIG = { # 'host': 'localhost', # 'user': 'root', # 'password': 'root', # 'database': 'medqueryai', # 'port': 3307 # } # Redis Configuration REDIS_CONFIG = { "host": "localhost", "port": 6379, "db": 0, "decode_responses": True, # For string operations } DB_CONFIG = { "host": os.getenv("DB_HOST", "localhost"), "user": os.getenv("DB_USER", "testuser"), "password": os.getenv("DB_PASSWORD", "Admin@123"), "database": os.getenv("DB_NAME", "spurrintest"), } # Redis connection pool redis_pool = redis.ConnectionPool(**REDIS_CONFIG) redis_binary_pool = redis.ConnectionPool( host="localhost", port=6379, db=1, decode_responses=False ) def get_redis_client(binary=False): """Get Redis client from pool""" logger.debug(f"Getting Redis client with binary={binary}") try: pool = redis_binary_pool if binary else redis_pool client = redis.Redis(connection_pool=pool) logger.debug("Redis client created successfully") return client except Exception as e: logger.error(f"Failed to create Redis client: {e}", exc_info=True) raise def fetch_cached_answer(cache_key): logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") start_time = time.time() try: redis_client = get_redis_client() cached_answer = redis_client.get(cache_key) fetch_time = time.time() - start_time perf_logger.info( f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" ) return cached_answer except Exception as e: logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) return None # Cache TTL configurations CACHE_TTL = { "vector_store": timedelta(hours=24), "chat_completion": timedelta(hours=1), "document_metadata": timedelta(days=7), } DATA_DIR = os.path.join(script_dir, "hospital_data") CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") uploads_dir = os.path.join(script_dir, "llm-uploads") if not os.path.exists(uploads_dir): os.makedirs(uploads_dir) nlp = spacy.load("en_core_web_sm") load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=OPENAI_API_KEY) embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) llm = ChatOpenAI( model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY ) # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) hospital_vector_stores = {} vector_store_lock = threading.Lock() @dataclass class Document: doc_id: int page_num: int content: str class DocumentStatus(Enum): PROCESSING = "processing" PROCESSED = "processed" FAILED = "failed" async def get_db_pool(): return await aiomysql.create_pool( host=DB_CONFIG["host"], user=DB_CONFIG["user"], password=DB_CONFIG["password"], db=DB_CONFIG["database"], autocommit=True, ) async def get_hospital_id(hospital_code): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute( "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", (hospital_code,), ) result = await cursor.fetchone() return result["id"] if result else None except Exception as error: logging.error(f"Database error: {error}") return None finally: pool.close() await pool.wait_closed() CHUNK_SIZE = 1000 CHUNK_OVERLAP = 50 BATCH_SIZE = 1000 text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, # length_function=len, # separators=["\n\n", "\n", ". ", " ", ""] ) # Update the JSON_PATH to be dynamic based on hospital_id def get_icd_json_path(hospital_id): hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") os.makedirs(hospital_data_dir, exist_ok=True) return os.path.join(hospital_data_dir, "icd_data.json") def extract_and_process_icd_data(content, hospital_id, save_to_json=True): """Extract and process ICD codes with optimized processing and optional JSON saving""" try: # Initialize pattern compilation once pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) # Process in chunks for large content chunk_size = 50000 # Process 50KB at a time icd_data = [] current_code = None current_description = [] # Split content into manageable chunks content_chunks = [ content[i : i + chunk_size] for i in range(0, len(content), chunk_size) ] # Process each chunk for chunk in content_chunks: lines = chunk.splitlines() for line in lines: line = line.strip() if not line: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code = None current_description = [] continue match = pattern.match(line) if match: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code, description = match.groups() current_description = [description.strip()] elif current_code: current_description.append(line) # Add final entry if exists if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) # Save to hospital-specific JSON if requested if save_to_json and icd_data: try: json_path = get_icd_json_path(hospital_id) # Use a lock for thread safety with threading.Lock(): if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: try: existing_data = json.load(f) except json.JSONDecodeError: existing_data = [] else: existing_data = [] # Efficient deduplication using dictionary seen_codes = {item["code"]: item for item in existing_data} for item in icd_data: seen_codes[item["code"]] = item unique_data = list(seen_codes.values()) # Write atomically using temporary file temp_path = f"{json_path}.tmp" with open(temp_path, "w", encoding="utf-8") as f: json.dump(unique_data, f, indent=2, ensure_ascii=False) os.replace(temp_path, json_path) logging.info( f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" ) except Exception as e: logging.error( f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" ) return icd_data except Exception as e: logging.error(f"Error in extract_and_process_icd_data: {e}") return [] def load_icd_entries(hospital_id): """Load ICD entries from hospital-specific JSON file""" json_path = get_icd_json_path(hospital_id) try: if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: return json.load(f) return [] except Exception as e: logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") return [] # Update the process_icd_codes function to include hospital_id async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): """Process and store ICD codes using the optimized extraction function""" try: # Extract and save codes with hospital_id extract_and_process_icd_data(content, hospital_id, save_to_json=True) except Exception as e: logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") async def initialize_icd_vector_store(hospital_id): """This function is deprecated. ICD codes are now handled through JSON search.""" logging.warning( "initialize_icd_vector_store is deprecated - using JSON search instead" ) return None def extract_pdf_contents(pdf_path, hospital_id): """Extract PDF contents with optimized chunking and code extraction""" try: loader = PyPDFLoader(pdf_path) pages = loader.load() pages_content = [] for i, page in enumerate(tqdm(pages, desc="Extracting pages")): text = page.page_content.strip() # Extract ICD codes from the page icd_codes = extract_and_process_icd_data( text, hospital_id ) # We'll set doc_id later pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) return pages_content except Exception as e: logging.error(f"Error in extract_pdf_contents: {e}") raise async def insert_content_into_db(content, metadata, doc_id): pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" metadata_values = [ (doc_id, key[:100], value) for key, value in metadata.items() if value ] content_values = [ (doc_id, page_content["page"], page_content["text"]) for page_content in content ] if metadata_values: await cursor.executemany(metadata_query, metadata_values) if content_values: await cursor.executemany(content_query, content_values) await conn.commit() return {"message": "Success"} except Exception as e: await conn.rollback() return {"error": str(e)} async def initialize_or_load_vector_store(hospital_id, user_id="default"): """Initialize or load vector store with Redis caching and thread safety""" store_key = f"{hospital_id}:{user_id}" try: # Check if we already have it loaded - with lock for thread safety with vector_store_lock: if store_key in hospital_vector_stores: return hospital_vector_stores[store_key] # Initialize vector store redis_client = get_redis_client(binary=True) cache_key = f"vector_store_data:{hospital_id}:{user_id}" hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") if os.path.exists(hospital_dir): logging.info( f"Loading vector store for hospital {hospital_id} and user {user_id}" ) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) else: logging.info(f"Creating vector store for hospital {hospital_id}") os.makedirs(hospital_dir, exist_ok=True) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) # Store with lock for thread safety with vector_store_lock: hospital_vector_stores[store_key] = vector_store return vector_store except Exception as e: logging.error(f"Error initializing vector store: {e}", exc_info=True) raise async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: """Delete all vectors associated with a specific document from ChromaDB""" try: # Initialize vector store for the hospital vector_store = await initialize_or_load_vector_store(hospital_id) # Delete vectors with matching doc_id await asyncio.to_thread( lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) ) # Persist changes await asyncio.to_thread(vector_store.persist) # Clear Redis cache for this document redis_client = get_redis_client() pattern = f"vector_store_data:{hospital_id}:*" for key in redis_client.scan_iter(pattern): redis_client.delete(key) logging.info( f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" ) return True except Exception as e: logging.error(f"Error deleting document vectors: {e}", exc_info=True) return False async def add_document_to_index(doc_id, hospital_id): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: vector_store = await initialize_or_load_vector_store(hospital_id) await cursor.execute( "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", (doc_id,), ) rows = await cursor.fetchall() total_pages = len(rows) logging.info(f"Processing {total_pages} pages for document {doc_id}") page_bar = tqdm_async(total=total_pages, desc="Processing pages") async def process_page(page_data): page_num, content = page_data try: icd_data = extract_and_process_icd_data( content, hospital_id, save_to_json=False ) chunks = text_splitter.split_text(content) await asyncio.sleep(0) # Yield control return page_num, chunks, icd_data except Exception as e: logging.error(f"Error processing page {page_num}: {e}") return page_num, [], [] tasks = [asyncio.create_task(process_page(row)) for row in rows] results = [] for coro in asyncio.as_completed(tasks): result = await coro results.append(result) page_bar.update(1) page_bar.close() # Vector addition progress bar all_icd_data = [] all_chunks = [] all_metadatas = [] chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) for result in results: page_num, chunks, icd_data = result all_icd_data.extend(icd_data) for i, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append( { "doc_id": str(doc_id), "hospital_id": str(hospital_id), "page_number": str(page_num), "chunk_index": str(i), } ) if len(all_chunks) >= BATCH_SIZE: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) all_chunks = [] all_metadatas = [] chunk_add_bar.update(BATCH_SIZE) # Final batch if all_chunks: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) chunk_add_bar.update(len(all_chunks)) chunk_add_bar.close() if all_icd_data: logging.info(f"Saving {len(all_icd_data)} ICD codes") extract_and_process_icd_data("", hospital_id, save_to_json=True) await asyncio.to_thread(vector_store.persist) logging.info(f"Successfully indexed document {doc_id}") return True except Exception as e: logging.error(f"Error adding document: {e}") return False def is_general_knowledge_question( query: str, context: str, conversation_context=None ) -> bool: """ Determine if a question is likely a general knowledge question not covered in the documents. Takes conversation history into account to reduce repeated confirmations. """ query_lower = query.lower() context_lower = context.lower() if conversation_context: for interaction in conversation_context: prev_question = interaction.get("question", "").lower() if ( prev_question and query_lower in prev_question or prev_question in query_lower ): logging.info( f"Question is similar to previous conversation, skipping confirmation" ) return False stop_words = { "search", "query:", "can", "you", "some", "at", "the", "a", "an", "in", "on", "at", "to", "for", "with", "by", "about", "give", "full", "is", "are", "was", "were", "define", "what", "how", "why", "when", "where", "year", "list", "form", "table", "who", "which", "me", "tell", "explain", "describe", "of", "and", "or", "there", "their", "please", "could", "would", "various", "different", "type", "types", "kind", "kinds", "has", "have", "had", "many", "say", "know", } key_words = [ word for word in query_lower.split() if word not in stop_words and len(word) > 2 ] logging.info(f"Key words: {key_words}") if not key_words: logging.info("No significant keywords found, directing to general knowledge") return True matches = sum(1 for word in key_words if word in context_lower) logging.info(f"Matches: {matches} out of {len(key_words)} keywords") match_ratio = matches / len(key_words) logging.info(f"Match ratio: {match_ratio}") return match_ratio < 0.4 def is_table_request(query: str) -> bool: """ Determine if the user is requesting a response in tabular format. """ table_keywords = [ "table", "tabular", "in a table", "in table format", "in tabular format", "chart", "data", "comparison", "as a table", "table format", "in rows and columns", "in a grid", "breakdown", "spreadsheet", "comparison table", "data table", "structured table", "tabular form", "table form", ] query_lower = query.lower() return any(keyword in query_lower for keyword in table_keywords) import re def ensure_html_response(text: str) -> str: """ Ensure the response is properly formatted in HTML. This function handles plain text conversion to HTML. """ if "", text)) if not has_html_tags: paragraphs = text.split("\n\n") html_parts = [] in_ordered_list = False in_unordered_list = False for para in paragraphs: if para.strip(): if re.match(r"^\s*[\*\-\•]\s", para): if not in_unordered_list: html_parts.append("
{para}
") if in_ordered_list: html_parts.append("") if in_unordered_list: html_parts.append("") return "".join(html_parts) else: if not any(tag in text for tag in ("", "
{para}
" for para in paragraphs if para.strip()] return "".join(html_parts) return text class HybridConversationManager: """ Hybrid conversation manager that uses Redis for RAG-based conversations and in-memory storage for general knowledge conversations. """ def __init__(self, redis_client, ttl=3600, max_history_items=5): self.redis_client = redis_client self.ttl = ttl self.max_history_items = max_history_items # For general knowledge questions (in-memory) self.general_knowledge_histories = {} self.lock = Lock() def _get_redis_key(self, user_id, hospital_id, session_id=None): """Create Redis key for document-based conversations.""" if session_id: return f"conv_history:{user_id}:{hospital_id}:{session_id}" return f"conv_history:{user_id}:{hospital_id}" def _get_memory_key(self, user_id, hospital_id, session_id=None): """Create memory key for general knowledge conversations.""" if session_id: return f"{user_id}:{hospital_id}:{session_id}" return f"{user_id}:{hospital_id}" async def add_rag_interaction( self, user_id, hospital_id, question, answer, session_id=None ): """Add document-based (RAG) interaction to Redis.""" key = self._get_redis_key(user_id, hospital_id, session_id) history = self.get_rag_history(user_id, hospital_id, session_id) # Add new interaction history.append( { "question": question, "answer": answer, "timestamp": time.time(), "type": "rag", # Mark as RAG-based interaction } ) # Keep only last N interactions history = history[-self.max_history_items :] # Store updated history try: self.redis_client.setex(key, self.ttl, json.dumps(history)) logging.info( f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}" ) except Exception as e: logging.error(f"Failed to store RAG interaction in Redis: {e}") def add_general_knowledge_interaction( self, user_id, hospital_id, question, answer, session_id=None ): """Add general knowledge interaction to in-memory store.""" key = self._get_memory_key(user_id, hospital_id, session_id) with self.lock: if key not in self.general_knowledge_histories: self.general_knowledge_histories[key] = [] self.general_knowledge_histories[key].append( { "question": question, "answer": answer, "timestamp": time.time(), "type": "general", # Mark as general knowledge interaction } ) # Keep only the most recent interactions if len(self.general_knowledge_histories[key]) > self.max_history_items: self.general_knowledge_histories[key] = ( self.general_knowledge_histories[key][-self.max_history_items :] ) logging.info( f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}" ) def get_rag_history(self, user_id, hospital_id, session_id=None): """Get document-based (RAG) conversation history from Redis.""" key = self._get_redis_key(user_id, hospital_id, session_id) try: history_data = self.redis_client.get(key) return json.loads(history_data) if history_data else [] except Exception as e: logging.error(f"Failed to retrieve RAG history from Redis: {e}") return [] def get_general_knowledge_history(self, user_id, hospital_id, session_id=None): """Get general knowledge conversation history from memory.""" key = self._get_memory_key(user_id, hospital_id, session_id) with self.lock: return self.general_knowledge_histories.get(key, []).copy() def get_combined_history(self, user_id, hospital_id, session_id=None): """Get combined conversation history from both sources, sorted by timestamp.""" rag_history = self.get_rag_history(user_id, hospital_id, session_id) general_history = self.get_general_knowledge_history( user_id, hospital_id, session_id ) # Combine histories combined_history = rag_history + general_history # Sort by timestamp (newest first) combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True) # Return most recent N items return combined_history[: self.max_history_items] def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2): """Get the most recent interactions for context from combined history.""" combined_history = self.get_combined_history(user_id, hospital_id, session_id) # Sort by timestamp (oldest first) for context window sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0)) return sorted_history[-window_size:] if sorted_history else [] def clear_history(self, user_id, hospital_id): """Clear conversation history from both stores.""" # Clear Redis history redis_key = self._get_redis_key(user_id, hospital_id) try: self.redis_client.delete(redis_key) except Exception as e: logging.error(f"Failed to clear Redis history: {e}") # Clear memory history memory_key = self._get_memory_key(user_id, hospital_id) with self.lock: if memory_key in self.general_knowledge_histories: del self.general_knowledge_histories[memory_key] class ContextMapper: """Enhanced context mapping using shared model manager""" def __init__(self): self.model_manager = ModelManager() self.context_cache = {} self.similarity_threshold = 0.6 def get_semantic_similarity(self, text1, text2): """Get semantic similarity using global model manager""" return self.model_manager.get_semantic_similarity(text1, text2) def extract_key_concepts(self, text): """Extract key concepts using NLP techniques""" doc = nlp(text) concepts = [] entities = [(ent.text, ent.label_) for ent in doc.ents] noun_phrases = [chunk.text for chunk in doc.noun_chunks] important_words = [ token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"] ] concepts.extend([e[0] for e in entities]) concepts.extend(noun_phrases) concepts.extend(important_words) return list(set(concepts)) def map_conversation_context( self, current_query, conversation_history, context_window=3 ): """Map conversation context using enhanced NLP techniques""" if not conversation_history: return current_query recent_context = conversation_history[-context_window:] context_concepts = [] # Extract concepts from recent conversations for interaction in recent_context: q_concepts = self.extract_key_concepts(interaction["question"]) a_concepts = self.extract_key_concepts(interaction["answer"]) context_concepts.extend(q_concepts) context_concepts.extend(a_concepts) # Extract concepts from current query query_concepts = self.extract_key_concepts(current_query) # Find related concepts related_concepts = [] for q_concept in query_concepts: for c_concept in context_concepts: similarity = self.get_semantic_similarity(q_concept, c_concept) if similarity > self.similarity_threshold: related_concepts.append(c_concept) # Build enhanced query if related_concepts: enhanced_query = ( f"{current_query} in context of {', '.join(related_concepts)}" ) else: enhanced_query = current_query return enhanced_query # Initialize the context mapper context_mapper = ContextMapper() async def generate_contextual_query( question: str, user_id: str, hospital_id: int, conversation_manager ) -> str: """Generate enhanced contextual query""" context_window = conversation_manager.get_context_window(user_id, hospital_id) if not context_window: return question # Enhanced context mapping last_interaction = context_window[-1] enhanced_context = f""" Previous question: {last_interaction['question']} Previous answer: {last_interaction['answer']} Current question: {question} Please generate a detailed search query that combines the context from the previous answer with the current question, especially if the current question uses words like 'it', 'this', 'that', or asks for more details about the previous topic. """ try: response = await asyncio.to_thread( lambda: client.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": "You are a context-aware query generator.", }, {"role": "user", "content": enhanced_context}, ], temperature=0.3, max_tokens=150, ) ) contextual_query = response.choices[0].message.content.strip() logging.info(f"Enhanced contextual query: {contextual_query}") return contextual_query except Exception as e: logging.error(f"Error generating contextual query: {e}") return question def is_follow_up(current_question: str, conversation_history: list) -> bool: """Enhanced follow-up detection using NLP techniques""" if not conversation_history: return False last_interaction = conversation_history[-1] # Get semantic similarity with higher threshold similarity = context_mapper.get_semantic_similarity( current_question, f"{last_interaction['question']} {last_interaction['answer']}" ) # Enhanced referential check doc = nlp(current_question.lower()) has_referential = any( token.lemma_ in [ "it", "this", "that", "these", "those", "they", "he", "she", "about", "more", ] for token in doc ) # Extract concepts with improved entity detection current_concepts = set(context_mapper.extract_key_concepts(current_question)) last_concepts = set( context_mapper.extract_key_concepts( f"{last_interaction['question']} {last_interaction['answer']}" ) ) # Calculate enhanced concept overlap concept_overlap = ( len(current_concepts & last_concepts) / len(current_concepts | last_concepts) if current_concepts else 0 ) # More aggressive follow-up detection return ( similarity > 0.3 # Lowered threshold or has_referential or concept_overlap > 0.2 # Lowered threshold or any( word in current_question.lower() for word in ["more", "about", "elaborate", "explain"] ) ) async def get_relevant_context(question, hospital_id, doc_id=None): try: cache_key = f"context:hospital_{hospital_id}" if doc_id: cache_key += f":doc_{doc_id}" cache_key += f":{question.lower().strip()}" redis_client = get_redis_client() cached_context = redis_client.get(cache_key) if cached_context: logging.info(f"Cache hit for key: {cache_key}") return ( cached_context.decode("utf-8") if isinstance(cached_context, bytes) else cached_context ) vector_store = await initialize_or_load_vector_store(hospital_id) if not vector_store: return "" retriever = vector_store.as_retriever( search_type="mmr", search_kwargs={ "k": 10, "fetch_k": 20, "lambda_mult": 0.6, # "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)} }, ) docs = await asyncio.to_thread(retriever.get_relevant_documents, question) if not docs: return "" sorted_docs = sorted( docs, key=lambda x: ( int(x.metadata.get("page_number", 0)), int(x.metadata.get("chunk_index", 0)), ), ) context_parts = [doc.page_content for doc in sorted_docs] context = "\n\n".join(context_parts) try: redis_client.setex( cache_key, int(CACHE_TTL["vector_store"].total_seconds()), context.encode("utf-8") if isinstance(context, str) else context, ) logging.info(f"Cached context for key: {cache_key}") except Exception as cache_error: logging.error(f"Failed to cache context: {cache_error}") return context except Exception as e: logging.error(f"Error getting relevant context: {e}") return "" def format_conversation_context(conv_history): """Format conversation history into a string""" if not conv_history: return "No previous conversation." return "\n".join( [ f"Q: {interaction['question']}\nA: {interaction['answer']}" for interaction in conv_history ] ) def get_icd_context_from_question(question, hospital_id): """Extract any valid ICD codes from the question and return context""" icd_data = load_icd_entries(hospital_id) matches = [] code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper()) seen = set() for code in code_pattern: for entry in icd_data: if entry["code"] == code and code not in seen: matches.append(f"{entry['code']}: {entry['description']}") seen.add(code) return "\n".join(matches) def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70): """Get fuzzy matches for ICD codes from the question""" icd_data = load_icd_entries(hospital_id) descriptions = [entry["description"] for entry in icd_data] matches = process.extract( question, descriptions, limit=top_n, score_cutoff=threshold ) matched_context = [] for desc, score, _ in matches: for entry in icd_data: if entry["description"] == desc: matched_context.append(f"{entry['code']}: {entry['description']}") break return "\n".join(matched_context) async def generate_answer_with_rag( question, hospital_id, client, doc_id=None, user_id="default", conversation_manager=None, session_id=None, ): """Generate an answer using RAG with improved conversation flow""" try: # Continue with regular RAG processing if not an ICD code or if no ICD match found html_instruction = """ IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: - Usetags for paragraphs - Use
for quoted text - Use for bold text and for emphasis """ table_instruction = """ - For tables, use proper HTML table structure:""" # Get conversation history first conv_history = ( conversation_manager.get_context_window(user_id, hospital_id, session_id) if conversation_manager else [] ) # Get contextual query and relevant context first contextual_query = await generate_contextual_query( question, user_id, hospital_id, conversation_manager ) # Track ICD context across conversation icd_context = {} if conv_history: # Extract ICD code from previous interaction last_answer = conv_history[-1].get("answer", "") icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer) if icd_codes: icd_context["last_code"] = icd_codes[0] # Check if current question is about a previously discussed ICD code is_icd_followup = False if icd_context.get("last_code"): followup_indicators = [ "what causes", "what is causing", "why", "how", "symptoms", "treatment", "diagnosis", "causes", "effects", "complications", "risk factors", "prevention", "prognosis", "this", "disease", "that", "it", ] is_icd_followup = any( indicator in question.lower() for indicator in followup_indicators ) if is_icd_followup: # Add the previous ICD code context to the current question icd_exact_context = get_icd_context_from_question( icd_context["last_code"], hospital_id ) icd_fuzzy_context = get_fuzzy_icd_context( f"{icd_context['last_code']} {question}", hospital_id ) else: icd_exact_context = get_icd_context_from_question(question, hospital_id) icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) else: icd_exact_context = get_icd_context_from_question(question, hospital_id) icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) # Get contextual query and relevant context contextual_query = await generate_contextual_query( question, user_id, hospital_id, conversation_manager ) doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id) # Combine context with priority for ICD information context_parts = [] if is_icd_followup: context_parts.append( f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}" ) if icd_exact_context: context_parts.append("## ICD Code Match\n" + icd_exact_context) if icd_fuzzy_context: context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context) if doc_context: context_parts.append("## Document Context\n" + doc_context) context = "\n\n".join(context_parts) # Initialize follow-up detection is_follow_up = False # Check if this is a follow-up question if conv_history: last_interaction = conv_history[-1] last_question = last_interaction["question"].lower() last_answer = last_interaction.get("answer", "").lower() current_question = question.lower() # Define meaningful keywords that indicate entity-related follow-ups entity_related_keywords = { "achievements", "awards", "accomplishments", "work", "contributions", "career", "company", "products", "life", "background", "education", "role", "experience", "history", "details", "places", "place", "information", "facts", "about", "birth", "death", "family", "books", "projects", "population", } # Check if question is asking about attributes/achievements of previously discussed entity has_entity_attribute = any( word in current_question.split() for word in entity_related_keywords ) # Extract entities from last answer to maintain context def extract_entities(text): # Split into words and get potential entities (capitalized words) words = text.split() entities = set() current_entity = [] for word in words: if word[0].isupper(): current_entity.append(word) elif current_entity: if len(current_entity) > 0: entities.add(" ".join(current_entity)) current_entity = [] if current_entity: entities.add(" ".join(current_entity)) return entities last_entities = extract_entities(last_answer) # Check for referential words referential_words = { "it", "this", "that", "these", "those", "they", "their", "he", "she", "him", "her", "his", "hers", "them", "there", "such", "its", } has_referential = any( word in referential_words for word in current_question.split() ) # Calculate term overlap with both question and answer context def get_significant_terms(text): stop_words = { "what", "when", "where", "who", "why", "how", "is", "are", "was", "were", "be", "been", "the", "a", "an", "in", "on", "at", "to", "for", "of", "with", "by", "about", "as", "tell", "me", "please", } return set( word for word in text.split() if len(word) > 2 and word.lower() not in stop_words ) current_terms = get_significant_terms(current_question) last_terms = get_significant_terms(last_question) answer_terms = get_significant_terms(last_answer) # Include terms from both question and answer in context all_prev_terms = last_terms | answer_terms term_overlap = len(current_terms & all_prev_terms) total_terms = len(current_terms | all_prev_terms) term_similarity = term_overlap / total_terms if total_terms > 0 else 0 # Enhanced follow-up detection combining multiple signals is_follow_up = ( has_referential or term_similarity >= 0.2 # Lower threshold when including answer context or ( has_entity_attribute and bool(last_entities) ) # Check if asking about attributes of known entity or ( last_interaction.get("type") == "general" and term_similarity >= 0.15 ) ) logging.info(f"Follow-up analysis enhanced:") logging.info(f"- Referential words: {has_referential}") logging.info(f"- Term similarity: {term_similarity:.2f}") logging.info(f"- Entity attribute question: {has_entity_attribute}") logging.info(f"- Last entities found: {last_entities}") logging.info(f"- Is follow-up: {is_follow_up}") # For entirely new topics (not follow-ups), use is_general_knowledge_question if not is_follow_up: is_general = is_general_knowledge_question(question, context, conv_history) if is_general: confirmation_prompt = f"""
{table_title} {table_headers} {table_rows}Reviewed the hospital’s documentation, but this particular question does not seem to be covered.
""" logging.info("General knowledge question detected") return { "answer": confirmation_prompt, "requires_confirmation": True, }, 200 prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question. Previous conversation: {format_conversation_context(conv_history)} Context from documents: {context} Current question: {question} Instructions: 1. When providing medical codes (ICD, CPT, etc.): - Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context. - Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context. - If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format. - Remove all decimal points (e.g., use 'A150' instead of 'A15.0') - Format the response as: 'The medical code for [condition] is [code]
' 2. Address the current question while maintaining conversation continuity 3. Resolve any ambiguous references using conversation history 4. Format the response in clear HTML. 5. Strictly it should answer only from the {context} provided, do not invent or assume and give the information and it should not be general knowledge also. Purely 100% RAG-based response and only from documents. {html_instruction} {table_instruction if is_table_request(question) else ""} """ response = await asyncio.to_thread( lambda: client.chat.completions.create( model="gpt-3.5-turbo-16k", messages=[ {"role": "system", "content": prompt_template}, {"role": "user", "content": question}, ], temperature=0.2, max_tokens=1000, ) ) answer = ensure_html_response(response.choices[0].message.content) logging.info(f"Generated RAG answer for question: {question}") # Store interaction in history if conversation_manager: await conversation_manager.add_rag_interaction( user_id, hospital_id, question, answer, session_id ) return {"answer": answer}, 200 except Exception as e: logging.error(f"Error in generate_answer_with_rag: {e}") return {"answer": f"Error: {str(e)}
"}, 500 async def load_existing_vector_stores(): """Load existing Chroma vector stores for each hospital""" pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: await cursor.execute("SELECT DISTINCT id FROM hospitals") hospital_ids = [row[0] for row in await cursor.fetchall()] for hospital_id in hospital_ids: try: await initialize_or_load_vector_store(hospital_id) except Exception as e: logging.error( f"Failed to load vector store for hospital {hospital_id}: {e}" ) continue except Exception as e: logging.error(f"Error loading vector stores: {e}") async def get_failed_page(doc_id): pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: await cursor.execute( "SELECT failed_page FROM documents WHERE id = %s", (doc_id,) ) result = await cursor.fetchone() return result[0] if result and result[0] else None except Exception as e: logging.error(f"Database error checking failed_page: {e}") return None async def update_document_status(doc_id, status, failed_page=None): """Update document status with enum validation""" if isinstance(status, str): status = DocumentStatus[status.upper()].value pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: if failed_page: await cursor.execute( "UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s", (status, failed_page, doc_id), ) else: await cursor.execute( "UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s", (status, doc_id), ) await conn.commit() return True except Exception as e: logging.error(f"Database update error: {e}") return False thread_pool = ThreadPoolExecutor(max_workers=10) def async_to_sync(coroutine): """Helper function to run async code in sync context""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(coroutine) finally: loop.close() @app.route("/flask-api", methods=["GET"]) def health_check(): """Health check endpoint""" access_logger.info(f"Health check request received from {request.remote_addr}") return jsonify({"status": "ok"}), 200 @app.route("/flask-api/process-pdf", methods=["POST"]) def process_pdf(): access_logger.info(f"PDF processing request received from {request.remote_addr}") file_path = None try: file = request.files.get("pdf") hospital_id = request.form.get("hospital_id") doc_id = request.form.get("doc_id") logging.info( f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}" ) if not all([file, hospital_id, doc_id]): return jsonify({"error": "Missing required parameters"}), 400 def process_in_background(): nonlocal file_path try: async_to_sync(update_document_status(doc_id, "processing")) # Add progress logging logging.info(f"Starting processing of document {doc_id}") filename = f"doc_{doc_id}_{file.filename}" file_path = os.path.join(uploads_dir, filename) with open(file_path, "wb") as f: file.save(f) logging.info("Extracting PDF contents...") content = extract_pdf_contents(file_path, int(hospital_id)) logging.info("Inserting content into database...") metadata = {"filename": filename} result = async_to_sync( insert_content_into_db(content, metadata, doc_id) ) if "error" in result: async_to_sync(update_document_status(doc_id, "failed", 1)) return False logging.info("Creating embeddings and indexing...") success = async_to_sync(add_document_to_index(doc_id, hospital_id)) if success: logging.info("Document processing completed successfully") async_to_sync(update_document_status(doc_id, "processed")) return True else: logging.error("Document processing failed during indexing") async_to_sync(update_document_status(doc_id, "failed")) return False except Exception as e: logging.error(f"Processing error: {e}") async_to_sync(update_document_status(doc_id, "failed")) return False finally: if file_path and os.path.exists(file_path): try: os.remove(file_path) except Exception as e: logging.error(f"Error removing temporary file: {e}") # Execute processing and wait for result future = thread_pool.submit(process_in_background) success = future.result() if success: return jsonify({"message": "Document processed successfully"}), 200 else: return jsonify({"error": "Document processing failed"}), 500 except Exception as e: logging.error(f"API error: {e}") if file_path and os.path.exists(file_path): try: os.remove(file_path) except Exception as file_e: logging.error(f"Error removing temporary file: {file_e}") return jsonify({"error": str(e)}), 500 # Initialize the hybrid conversation manager redis_client = get_redis_client() conversation_manager = HybridConversationManager(redis_client) @app.route("/flask-api/generate-answer", methods=["POST"]) def rag_answer_api(): """Sync API endpoint for RAG-based question answering with conversation history.""" access_logger.info(f"Generate answer request received from {request.remote_addr}") try: data = request.json question = data.get("question", "").strip().lower() hospital_code = data.get("hospital_code") doc_id = data.get("doc_id") user_id = data.get("user_id", "default") session_id = data.get("session_id", None) logging.info(f"Received question from user {user_id}: {question}") logging.info(f"Received hospital code: {hospital_code}") logging.info(f"Received session_id: {session_id}") # is_confirmation_response = data.get("is_confirmation_response", False) original_query = data.get("original_query", "") def process_rag_answer(): try: hospital_id = async_to_sync(get_hospital_id(hospital_code)) logging.info(f"Resolved hospital ID: {hospital_id}") if not hospital_id: return { "error": "Invalid or missing 'hospital_code' in request" }, 400 # if question == "yes" and original_query: # # User confirmed they want a general knowledge answer # answer = async_to_sync( # generate_general_knowledge_answer( # original_query, # client, # user_id, # hospital_id, # conversation_manager, # Pass the hybrid manager # is_table_request(original_query), # session_id=session_id, # ) # ) # return {"answer": answer}, 200 if original_query: response_message = """I can only answer questions based on information found in the hospital documents.
The question you asked doesn't seem to be covered in the available documents.
You can try rephrasing your question or asking about a different topic.
""" return {"answer": response_message}, 200 else: # Regular RAG answer return async_to_sync( generate_answer_with_rag( question=question, hospital_id=hospital_id, client=client, doc_id=doc_id, user_id=user_id, conversation_manager=conversation_manager, # Pass the hybrid manager session_id=session_id, ) ) except Exception as e: logging.error(f"Thread processing error: {str(e)}") return {"error": str(e)}, 500 if not question: return jsonify({"error": "Missing 'question' in request"}), 400 future = thread_pool.submit(process_rag_answer) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 @app.route("/flask-api/delete-document-vectors", methods=["DELETE"]) def delete_document_vectors_endpoint(): """Endpoint to delete document vectors from ChromaDB""" try: data = request.json hospital_id = data.get("hospital_id") doc_id = data.get("doc_id") if not all([hospital_id, doc_id]): return jsonify({"error": "Missing required parameters"}), 400 logging.info( f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}" ) def process_deletion(): try: success = async_to_sync(delete_document_vectors(hospital_id, doc_id)) if success: return {"message": "Document vectors deleted successfully"}, 200 else: return {"error": "Failed to delete document vectors"}, 500 except Exception as e: logging.error(f"Error in vector deletion process: {e}") return {"error": str(e)}, 500 future = thread_pool.submit(process_deletion) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 @app.route("/flask-api/get-chroma-content", methods=["GET"]) def get_chroma_content_endpoint(): """API endpoint to get ChromaDB content by hospital_id""" try: hospital_id = request.args.get("hospital_id") limit = int(request.args.get("limit", 30000)) if not hospital_id: return jsonify({"error": "Missing required parameter: hospital_id"}), 400 def process_fetch(): try: result, status_code = async_to_sync( get_chroma_content_by_hospital( hospital_id=int(hospital_id), limit=limit ) ) return result, status_code except Exception as e: logging.error(f"Error in ChromaDB fetch process: {e}") return {"error": str(e)}, 500 future = thread_pool.submit(process_fetch) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100): """Fetch content from ChromaDB for a specific hospital""" try: # Initialize vector store vector_store = await initialize_or_load_vector_store(hospital_id) if not vector_store: return {"error": "Vector store not found"}, 404 # Get collection collection = vector_store._collection # Query the collection with hospital_id filter results = await asyncio.to_thread( lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit) ) if not results or not results["ids"]: return {"data": [], "count": 0}, 200 # Format the response formatted_results = [] for i in range(len(results["ids"])): formatted_results.append( { "id": results["ids"][i], "content": results["documents"][i], "metadata": results["metadatas"][i], } ) return {"data": formatted_results, "count": len(formatted_results)}, 200 except Exception as e: logging.error(f"Error fetching ChromaDB content: {e}") return {"error": str(e)}, 500 @app.before_request def before_request(): request._start_time = time.time() @app.after_request def after_request(response): if hasattr(request, "_start_time"): duration = time.time() - request._start_time access_logger.info( f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - ' f"IP: {request.remote_addr}" ) return response if __name__ == "__main__": logger.info("Starting SpurrinAI application") logger.info(f"Python version: {sys.version}") logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}") try: model_manager = ModelManager() logger.info("Model manager initialized successfully") except Exception as e: logger.error(f"Failed to initialize model manager: {e}") sys.exit(1) # Initialize directories os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(CHROMA_DIR, exist_ok=True) logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}") # Clear Redis cache redis_client = get_redis_client() cleared_keys = 0 for key in redis_client.scan_iter("vector_store_data:*"): redis_client.delete(key) cleared_keys += 1 logger.info(f"Cleared {cleared_keys} Redis cache keys") # Load vector stores logger.info("Loading existing vector stores...") async_to_sync(load_existing_vector_stores()) logger.info("Vector stores loaded successfully") # Start application logger.info("Starting Flask application on port 5000") app.run(port=5000, debug=False)