""" SpurrinAI - Intelligent Document Processing and Question Answering System Copyright (c) 2024 Tech4biz. All rights reserved. This module implements the main Flask application for the SpurrinAI system, providing REST APIs for document processing, vector storage, and question answering using RAG (Retrieval Augmented Generation) architecture. Author: Tech4biz Development Team Version: 1.0.0 Last Updated: 2024-01-19 """ # Standard library imports import os import re import sys import json import time import threading import asyncio from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import timedelta from enum import Enum # Third-party imports import spacy import redis import aiomysql from dotenv import load_dotenv from flask import Flask, request, jsonify, Response from flask_cors import CORS from tqdm import tqdm from tqdm.asyncio import tqdm as tqdm_async from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain_community.chat_models import ChatOpenAI from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from openai import OpenAI from rapidfuzz import process from threading import Lock # Local imports from model_manager import ModelManager # Suppress warnings import warnings warnings.filterwarnings("ignore") # Initialize NLTK import nltk nltk.download("punkt") # Configure logging import logging import logging.handlers app = Flask(__name__) CORS(app) script_dir = os.path.dirname(os.path.abspath(__file__)) log_file_path = os.path.join(script_dir, "error.log") logging.basicConfig(filename=log_file_path, level=logging.INFO) # Configure logging def setup_logging(): log_dir = os.path.join(script_dir, "logs") os.makedirs(log_dir, exist_ok=True) main_log = os.path.join(log_dir, "app.log") error_log = os.path.join(log_dir, "error.log") access_log = os.path.join(log_dir, "access.log") perf_log = os.path.join(log_dir, "performance.log") # Create formatters detailed_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" ) access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # Main logger setup main_handler = logging.handlers.RotatingFileHandler( main_log, maxBytes=10485760, backupCount=5 ) main_handler.setFormatter(detailed_formatter) main_handler.setLevel(logging.INFO) # Error logger setup error_handler = logging.handlers.RotatingFileHandler( error_log, maxBytes=10485760, backupCount=5 ) error_handler.setFormatter(detailed_formatter) error_handler.setLevel(logging.ERROR) # Access logger setup access_handler = logging.handlers.TimedRotatingFileHandler( access_log, when="midnight", interval=1, backupCount=30 ) access_handler.setFormatter(access_formatter) access_handler.setLevel(logging.INFO) # Performance logger setup perf_handler = logging.handlers.RotatingFileHandler( perf_log, maxBytes=10485760, backupCount=5 ) perf_handler.setFormatter(detailed_formatter) perf_handler.setLevel(logging.INFO) # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) root_logger.addHandler(main_handler) root_logger.addHandler(error_handler) # Create specific loggers access_logger = logging.getLogger("access") access_logger.addHandler(access_handler) access_logger.setLevel(logging.INFO) perf_logger = logging.getLogger("performance") perf_logger.addHandler(perf_handler) perf_logger.setLevel(logging.INFO) return root_logger, access_logger, perf_logger load_dotenv() # Initialize loggers logger, access_logger, perf_logger = setup_logging() # DB_CONFIG = { # 'host': 'localhost', # 'user': 'flaskuser', # 'password': 'Flask@123', # 'database': 'spurrinai', # } # DB_CONFIG = { # 'host': 'localhost', # 'user': 'spurrindevuser', # 'password': 'Admin@123', # 'database': 'spurrindev', # } # Redis Configuration REDIS_CONFIG = { "host": "localhost", "port": 6379, "db": 0, "decode_responses": True, # For string operations } DB_CONFIG = { "host": os.getenv("DB_HOST"), "user": os.getenv("DB_USER"), "password": os.getenv("DB_PASSWORD"), "database": os.getenv("DB_NAME"), } # Redis connection pool redis_pool = redis.ConnectionPool(**REDIS_CONFIG) redis_binary_pool = redis.ConnectionPool( host="localhost", port=6379, db=1, decode_responses=False ) def get_redis_client(binary=False): """Get Redis client from pool""" logger.debug(f"Getting Redis client with binary={binary}") try: pool = redis_binary_pool if binary else redis_pool client = redis.Redis(connection_pool=pool) logger.debug("Redis client created successfully") return client except Exception as e: logger.error(f"Failed to create Redis client: {e}", exc_info=True) raise def fetch_cached_answer(cache_key): logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") start_time = time.time() try: redis_client = get_redis_client() cached_answer = redis_client.get(cache_key) fetch_time = time.time() - start_time perf_logger.info( f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" ) return cached_answer except Exception as e: logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) return None # Cache TTL configurations CACHE_TTL = { "vector_store": timedelta(hours=24), "chat_completion": timedelta(hours=1), "document_metadata": timedelta(days=7), } DATA_DIR = os.path.join(script_dir, "hospital_data") CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") uploads_dir = os.path.join(script_dir, "llm-uploads") if not os.path.exists(uploads_dir): os.makedirs(uploads_dir) nlp = spacy.load("en_core_web_sm") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=OPENAI_API_KEY) embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) llm = ChatOpenAI( model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY ) # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) hospital_vector_stores = {} vector_store_lock = threading.Lock() @dataclass class Document: doc_id: int page_num: int content: str class DocumentStatus(Enum): PROCESSING = "processing" PROCESSED = "processed" FAILED = "failed" async def get_db_pool(): return await aiomysql.create_pool( host=DB_CONFIG["host"], user=DB_CONFIG["user"], password=DB_CONFIG["password"], db=DB_CONFIG["database"], autocommit=True, ) async def get_hospital_id(hospital_code): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute( "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", (hospital_code,), ) result = await cursor.fetchone() return result["id"] if result else None except Exception as error: logging.error(f"Database error: {error}") return None finally: pool.close() await pool.wait_closed() CHUNK_SIZE = 4000 CHUNK_OVERLAP = 150 BATCH_SIZE = 250 text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, # length_function=len, # separators=["\n\n", "\n", ". ", " ", ""] ) # Update the JSON_PATH to be dynamic based on hospital_id def get_icd_json_path(hospital_id): hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") os.makedirs(hospital_data_dir, exist_ok=True) return os.path.join(hospital_data_dir, "icd_data.json") def extract_and_process_icd_data(content, hospital_id, save_to_json=True): """Extract and process ICD codes with optimized processing and optional JSON saving""" try: # Initialize pattern compilation once pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) # Process in chunks for large content chunk_size = 50000 # Process 50KB at a time icd_data = [] current_code = None current_description = [] # Split content into manageable chunks content_chunks = [ content[i : i + chunk_size] for i in range(0, len(content), chunk_size) ] # Process each chunk for chunk in content_chunks: lines = chunk.splitlines() for line in lines: line = line.strip() if not line: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code = None current_description = [] continue match = pattern.match(line) if match: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code, description = match.groups() current_description = [description.strip()] elif current_code: current_description.append(line) # Add final entry if exists if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) # Save to hospital-specific JSON if requested if save_to_json and icd_data: try: json_path = get_icd_json_path(hospital_id) # Use a lock for thread safety with threading.Lock(): if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: try: existing_data = json.load(f) except json.JSONDecodeError: existing_data = [] else: existing_data = [] # Efficient deduplication using dictionary seen_codes = {item["code"]: item for item in existing_data} for item in icd_data: seen_codes[item["code"]] = item unique_data = list(seen_codes.values()) # Write atomically using temporary file temp_path = f"{json_path}.tmp" with open(temp_path, "w", encoding="utf-8") as f: json.dump(unique_data, f, indent=2, ensure_ascii=False) os.replace(temp_path, json_path) logging.info( f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" ) except Exception as e: logging.error( f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" ) return icd_data except Exception as e: logging.error(f"Error in extract_and_process_icd_data: {e}") return [] def load_icd_entries(hospital_id): """Load ICD entries from hospital-specific JSON file""" json_path = get_icd_json_path(hospital_id) try: if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: return json.load(f) return [] except Exception as e: logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") return [] # Update the process_icd_codes function to include hospital_id async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): """Process and store ICD codes using the optimized extraction function""" try: # Extract and save codes with hospital_id extract_and_process_icd_data(content, hospital_id, save_to_json=True) except Exception as e: logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") async def initialize_icd_vector_store(hospital_id): """This function is deprecated. ICD codes are now handled through JSON search.""" logging.warning( "initialize_icd_vector_store is deprecated - using JSON search instead" ) return None def extract_pdf_contents(pdf_path, hospital_id): """Extract PDF contents with optimized chunking and code extraction""" try: loader = PyPDFLoader(pdf_path) pages = loader.load() pages_content = [] for i, page in enumerate(tqdm(pages, desc="Extracting pages")): text = page.page_content.strip() # Extract ICD codes from the page icd_codes = extract_and_process_icd_data( text, hospital_id ) # We'll set doc_id later pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) return pages_content except Exception as e: logging.error(f"Error in extract_pdf_contents: {e}") raise async def insert_content_into_db(content, metadata, doc_id): pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" metadata_values = [ (doc_id, key[:100], value) for key, value in metadata.items() if value ] content_values = [ (doc_id, page_content["page"], page_content["text"]) for page_content in content ] if metadata_values: await cursor.executemany(metadata_query, metadata_values) if content_values: await cursor.executemany(content_query, content_values) await conn.commit() return {"message": "Success"} except Exception as e: await conn.rollback() return {"error": str(e)} async def initialize_or_load_vector_store(hospital_id, user_id="default"): """Initialize or load vector store with Redis caching and thread safety""" store_key = f"{hospital_id}:{user_id}" try: # Check if we already have it loaded - with lock for thread safety with vector_store_lock: if store_key in hospital_vector_stores: return hospital_vector_stores[store_key] # Initialize vector store redis_client = get_redis_client(binary=True) cache_key = f"vector_store_data:{hospital_id}:{user_id}" hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") if os.path.exists(hospital_dir): logging.info( f"Loading vector store for hospital {hospital_id} and user {user_id}" ) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) else: logging.info(f"Creating vector store for hospital {hospital_id}") os.makedirs(hospital_dir, exist_ok=True) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) # Store with lock for thread safety with vector_store_lock: hospital_vector_stores[store_key] = vector_store return vector_store except Exception as e: logging.error(f"Error initializing vector store: {e}", exc_info=True) raise async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: """Delete all vectors associated with a specific document from ChromaDB""" try: # Initialize vector store for the hospital vector_store = await initialize_or_load_vector_store(hospital_id) # Delete vectors with matching doc_id await asyncio.to_thread( lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) ) # Persist changes await asyncio.to_thread(vector_store.persist) # Clear Redis cache for this document redis_client = get_redis_client() pattern = f"vector_store_data:{hospital_id}:*" for key in redis_client.scan_iter(pattern): redis_client.delete(key) logging.info( f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" ) return True except Exception as e: logging.error(f"Error deleting document vectors: {e}", exc_info=True) return False async def add_document_to_index(doc_id, hospital_id): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: vector_store = await initialize_or_load_vector_store(hospital_id) await cursor.execute( "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", (doc_id,), ) rows = await cursor.fetchall() total_pages = len(rows) logging.info(f"Processing {total_pages} pages for document {doc_id}") page_bar = tqdm_async(total=total_pages, desc="Processing pages") async def process_page(page_data): page_num, content = page_data try: icd_data = extract_and_process_icd_data( content, hospital_id, save_to_json=False ) chunks = text_splitter.split_text(content) await asyncio.sleep(0) # Yield control return page_num, chunks, icd_data except Exception as e: logging.error(f"Error processing page {page_num}: {e}") return page_num, [], [] tasks = [asyncio.create_task(process_page(row)) for row in rows] results = [] for coro in asyncio.as_completed(tasks): result = await coro results.append(result) page_bar.update(1) page_bar.close() # Vector addition progress bar all_icd_data = [] all_chunks = [] all_metadatas = [] chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) for result in results: page_num, chunks, icd_data = result all_icd_data.extend(icd_data) for i, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append( { "doc_id": str(doc_id), "hospital_id": str(hospital_id), "page_number": str(page_num), "chunk_index": str(i), } ) if len(all_chunks) >= BATCH_SIZE: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) all_chunks = [] all_metadatas = [] chunk_add_bar.update(BATCH_SIZE) # Final batch if all_chunks: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) chunk_add_bar.update(len(all_chunks)) chunk_add_bar.close() if all_icd_data: logging.info(f"Saving {len(all_icd_data)} ICD codes") extract_and_process_icd_data("", hospital_id, save_to_json=True) await asyncio.to_thread(vector_store.persist) logging.info(f"Successfully indexed document {doc_id}") return True except Exception as e: logging.error(f"Error adding document: {e}") return False def is_general_knowledge_question( query: str, context: str, conversation_context=None ) -> bool: """ Determine if a question is likely a general knowledge question not covered in the documents. Takes conversation history into account to reduce repeated confirmations. """ query_lower = query.lower() context_lower = context.lower() if conversation_context: for interaction in conversation_context: prev_question = interaction.get("question", "").lower() if ( prev_question and query_lower in prev_question or prev_question in query_lower ): logging.info( f"Question is similar to previous conversation, skipping confirmation" ) return False stop_words = { "search", "query:", "can", "you", "some", "at", "the", "a", "an", "in", "on", "at", "to", "for", "with", "by", "about", "give", "full", "is", "are", "was", "were", "define", "what", "how", "why", "when", "where", "year", "list", "form", "table", "who", "which", "me", "tell", "explain", "describe", "of", "and", "or", "there", "their", "please", "could", "would", "various", "different", "type", "types", "kind", "kinds", "has", "have", "had", "many", "say", } key_words = [ word for word in query_lower.split() if word not in stop_words and len(word) > 2 ] logging.info(f"Key words: {key_words}") if not key_words: logging.info("No significant keywords found, directing to general knowledge") return True matches = sum(1 for word in key_words if word in context_lower) logging.info(f"Matches: {matches} out of {len(key_words)} keywords") match_ratio = matches / len(key_words) logging.info(f"Match ratio: {match_ratio}") return match_ratio < 0.6 def is_table_request(query: str) -> bool: """ Determine if the user is requesting a response in tabular format. """ table_keywords = [ "table", "tabular", "in a table", "in table format", "in tabular format", "chart", "data", "comparison", "as a table", "table format", "in rows and columns", "in a grid", "breakdown", "spreadsheet", "comparison table", "data table", "structured table", "tabular form", "table form", ] query_lower = query.lower() return any(keyword in query_lower for keyword in table_keywords) import re def ensure_html_response(text: str) -> str: """ Ensure the response is properly formatted in HTML. This function handles plain text conversion to HTML. """ if "", text)) if not has_html_tags: paragraphs = text.split("\n\n") html_parts = [] in_ordered_list = False in_unordered_list = False for para in paragraphs: if para.strip(): if re.match(r"^\s*[\*\-\•]\s", para): if not in_unordered_list: html_parts.append("
{para}
") if in_ordered_list: html_parts.append("") if in_unordered_list: html_parts.append("") return "".join(html_parts) else: if not any(tag in text for tag in ("", "
{para}
" for para in paragraphs if para.strip()] return "".join(html_parts) return text class HybridConversationManager: """ Hybrid conversation manager that uses Redis for RAG-based conversations and in-memory storage for general knowledge conversations. """ def __init__(self, redis_client, ttl=3600, max_history_items=5): self.redis_client = redis_client self.ttl = ttl self.max_history_items = max_history_items # For general knowledge questions (in-memory) self.general_knowledge_histories = {} self.lock = Lock() def _get_redis_key(self, user_id, hospital_id, session_id=None): """Create Redis key for document-based conversations.""" if session_id: return f"conv_history:{user_id}:{hospital_id}:{session_id}" return f"conv_history:{user_id}:{hospital_id}" def _get_memory_key(self, user_id, hospital_id, session_id=None): """Create memory key for general knowledge conversations.""" if session_id: return f"{user_id}:{hospital_id}:{session_id}" return f"{user_id}:{hospital_id}" async def add_rag_interaction( self, user_id, hospital_id, question, answer, session_id=None ): """Add document-based (RAG) interaction to Redis.""" key = self._get_redis_key(user_id, hospital_id, session_id) history = self.get_rag_history(user_id, hospital_id, session_id) # Add new interaction history.append( { "question": question, "answer": answer, "timestamp": time.time(), "type": "rag", # Mark as RAG-based interaction } ) # Keep only last N interactions history = history[-self.max_history_items :] # Store updated history try: self.redis_client.setex(key, self.ttl, json.dumps(history)) logging.info( f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}" ) except Exception as e: logging.error(f"Failed to store RAG interaction in Redis: {e}") def add_general_knowledge_interaction( self, user_id, hospital_id, question, answer, session_id=None ): """Add general knowledge interaction to in-memory store.""" key = self._get_memory_key(user_id, hospital_id, session_id) with self.lock: if key not in self.general_knowledge_histories: self.general_knowledge_histories[key] = [] self.general_knowledge_histories[key].append( { "question": question, "answer": answer, "timestamp": time.time(), "type": "general", # Mark as general knowledge interaction } ) # Keep only the most recent interactions if len(self.general_knowledge_histories[key]) > self.max_history_items: self.general_knowledge_histories[key] = ( self.general_knowledge_histories[key][-self.max_history_items :] ) logging.info( f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}" ) def get_rag_history(self, user_id, hospital_id, session_id=None): """Get document-based (RAG) conversation history from Redis.""" key = self._get_redis_key(user_id, hospital_id, session_id) try: history_data = self.redis_client.get(key) return json.loads(history_data) if history_data else [] except Exception as e: logging.error(f"Failed to retrieve RAG history from Redis: {e}") return [] def get_general_knowledge_history(self, user_id, hospital_id, session_id=None): """Get general knowledge conversation history from memory.""" key = self._get_memory_key(user_id, hospital_id, session_id) with self.lock: return self.general_knowledge_histories.get(key, []).copy() def get_combined_history(self, user_id, hospital_id, session_id=None): """Get combined conversation history from both sources, sorted by timestamp.""" rag_history = self.get_rag_history(user_id, hospital_id, session_id) general_history = self.get_general_knowledge_history( user_id, hospital_id, session_id ) # Combine histories combined_history = rag_history + general_history # Sort by timestamp (newest first) combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True) # Return most recent N items return combined_history[: self.max_history_items] def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2): """Get the most recent interactions for context from combined history.""" combined_history = self.get_combined_history(user_id, hospital_id, session_id) # Sort by timestamp (oldest first) for context window sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0)) return sorted_history[-window_size:] if sorted_history else [] def clear_history(self, user_id, hospital_id): """Clear conversation history from both stores.""" # Clear Redis history redis_key = self._get_redis_key(user_id, hospital_id) try: self.redis_client.delete(redis_key) except Exception as e: logging.error(f"Failed to clear Redis history: {e}") # Clear memory history memory_key = self._get_memory_key(user_id, hospital_id) with self.lock: if memory_key in self.general_knowledge_histories: del self.general_knowledge_histories[memory_key] class ContextMapper: """Enhanced context mapping using shared model manager""" def __init__(self): self.model_manager = ModelManager() self.context_cache = {} self.similarity_threshold = 0.6 def get_semantic_similarity(self, text1, text2): """Get semantic similarity using global model manager""" return self.model_manager.get_semantic_similarity(text1, text2) def extract_key_concepts(self, text): """Extract key concepts using NLP techniques""" doc = nlp(text) concepts = [] entities = [(ent.text, ent.label_) for ent in doc.ents] noun_phrases = [chunk.text for chunk in doc.noun_chunks] important_words = [ token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"] ] concepts.extend([e[0] for e in entities]) concepts.extend(noun_phrases) concepts.extend(important_words) return list(set(concepts)) def map_conversation_context( self, current_query, conversation_history, context_window=3 ): """Map conversation context using enhanced NLP techniques""" if not conversation_history: return current_query recent_context = conversation_history[-context_window:] context_concepts = [] # Extract concepts from recent conversations for interaction in recent_context: q_concepts = self.extract_key_concepts(interaction["question"]) a_concepts = self.extract_key_concepts(interaction["answer"]) context_concepts.extend(q_concepts) context_concepts.extend(a_concepts) # Extract concepts from current query query_concepts = self.extract_key_concepts(current_query) # Find related concepts related_concepts = [] for q_concept in query_concepts: for c_concept in context_concepts: similarity = self.get_semantic_similarity(q_concept, c_concept) if similarity > self.similarity_threshold: related_concepts.append(c_concept) # Build enhanced query if related_concepts: enhanced_query = ( f"{current_query} in context of {', '.join(related_concepts)}" ) else: enhanced_query = current_query return enhanced_query # Initialize the context mapper context_mapper = ContextMapper() async def generate_contextual_query( question: str, user_id: str, hospital_id: int, conversation_manager ) -> str: """Generate enhanced contextual query""" context_window = conversation_manager.get_context_window(user_id, hospital_id) if not context_window: return question # Enhanced context mapping last_interaction = context_window[-1] enhanced_context = f""" Previous question: {last_interaction['question']} Previous answer: {last_interaction['answer']} Current question: {question} Please generate a detailed search query that combines the context from the previous answer with the current question, especially if the current question uses words like 'it', 'this', 'that', or asks for more details about the previous topic. """ try: response = await asyncio.to_thread( lambda: client.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": "You are a context-aware query generator.", }, {"role": "user", "content": enhanced_context}, ], temperature=0.3, max_tokens=150, ) ) contextual_query = response.choices[0].message.content.strip() logging.info(f"Enhanced contextual query: {contextual_query}") return contextual_query except Exception as e: logging.error(f"Error generating contextual query: {e}") return question def is_follow_up(current_question: str, conversation_history: list) -> bool: """Enhanced follow-up detection using NLP techniques""" if not conversation_history: return False last_interaction = conversation_history[-1] # Get semantic similarity with higher threshold similarity = context_mapper.get_semantic_similarity( current_question, f"{last_interaction['question']} {last_interaction['answer']}" ) # Enhanced referential check doc = nlp(current_question.lower()) has_referential = any( token.lemma_ in [ "it", "this", "that", "these", "those", "they", "he", "she", "about", "more", ] for token in doc ) # Extract concepts with improved entity detection current_concepts = set(context_mapper.extract_key_concepts(current_question)) last_concepts = set( context_mapper.extract_key_concepts( f"{last_interaction['question']} {last_interaction['answer']}" ) ) # Calculate enhanced concept overlap concept_overlap = ( len(current_concepts & last_concepts) / len(current_concepts | last_concepts) if current_concepts else 0 ) # More aggressive follow-up detection return ( similarity > 0.3 # Lowered threshold or has_referential or concept_overlap > 0.2 # Lowered threshold or any( word in current_question.lower() for word in ["more", "about", "elaborate", "explain"] ) ) async def get_relevant_context(question, hospital_id, doc_id=None): try: cache_key = f"context:hospital_{hospital_id}" if doc_id: cache_key += f":doc_{doc_id}" cache_key += f":{question.lower().strip()}" redis_client = get_redis_client() cached_context = redis_client.get(cache_key) if cached_context: logging.info(f"Cache hit for key: {cache_key}") return cached_context.decode("utf-8") if isinstance(cached_context, bytes) else cached_context vector_store = await initialize_or_load_vector_store(hospital_id) if not vector_store: return "" retriever = vector_store.as_retriever( search_type="mmr", search_kwargs={ "k": 5, # Reduced number of documents for precision "fetch_k": 10, # Reduced fetch size "lambda_mult": 0.8, # Increase diversity in MMR "score_threshold": 0.7, # Add minimum similarity score "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)} }, ) docs = await asyncio.to_thread(retriever.get_relevant_documents, question) if not docs: logging.info(f"No relevant documents found for question: {question}") return "" # # Filter documents by relevance using spaCy similarity # question_doc = nlp(question) # relevant_docs = [] # for doc in docs: # doc_content = nlp(doc.page_content) # similarity = question_doc.similarity(doc_content) # if similarity >= 0.7: # Strict similarity threshold # relevant_docs.append(doc) # if not relevant_docs: # logging.info({relevant_docs}) # logging.info(f"No sufficiently relevant documents after similarity filtering for: {question}") # return "" sorted_docs = sorted( docs, key=lambda x: ( int(x.metadata.get("page_number", 0)), int(x.metadata.get("chunk_index", 0)), ), ) context_parts = [doc.page_content for doc in sorted_docs] context = "\n\n".join(context_parts) try: redis_client.setex( cache_key, int(CACHE_TTL["vector_store"].total_seconds()), context.encode("utf-8") if isinstance(context, str) else context, ) logging.info(f"Cached context for key: {cache_key}") except Exception as cache_error: logging.error(f"Failed to cache context: {cache_error}") return context except Exception as e: logging.error(f"Error getting relevant context: {e}") return "" def format_conversation_context(conv_history): """Format conversation history into a string""" if not conv_history: return "No previous conversation." return "\n".join( [ f"Q: {interaction['question']}\nA: {interaction['answer']}" for interaction in conv_history ] ) def get_icd_context_from_question(question, hospital_id): """Extract any valid ICD codes from the question and return context""" icd_data = load_icd_entries(hospital_id) matches = [] code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper()) seen = set() for code in code_pattern: for entry in icd_data: if entry["code"] == code and code not in seen: matches.append(f"{entry['code']}: {entry['description']}") seen.add(code) return "\n".join(matches) def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70): """Get fuzzy matches for ICD codes from the question""" icd_data = load_icd_entries(hospital_id) descriptions = [entry["description"] for entry in icd_data] matches = process.extract( question, descriptions, limit=top_n, score_cutoff=threshold ) matched_context = [] for desc, score, _ in matches: for entry in icd_data: if entry["description"] == desc: matched_context.append(f"{entry['code']}: {entry['description']}") break return "\n".join(matched_context) # async def generate_answer_with_rag( # question, # hospital_id, # client, # doc_id=None, # user_id="default", # conversation_manager=None, # session_id=None, # ): # """Generate an answer using RAG with improved conversation flow""" # try: # # Continue with regular RAG processing if not an ICD code or if no ICD match found # html_instruction = """ # IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: # - Usetags for paragraphs # - Use
for quoted text # - Use for bold text and for emphasis # """ # table_instruction = """ # - For tables, use proper HTML table structure: ## #
# """ # # Get conversation history first # conv_history = ( # conversation_manager.get_context_window(user_id, hospital_id, session_id) # if conversation_manager # else [] # ) # # Get contextual query and relevant context first # contextual_query = await generate_contextual_query( # question, user_id, hospital_id, conversation_manager # ) # # Track ICD context across conversation # icd_context = {} # if conv_history: # # Extract ICD code from previous interaction # last_answer = conv_history[-1].get("answer", "") # icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer) # if icd_codes: # icd_context["last_code"] = icd_codes[0] # # Check if current question is about a previously discussed ICD code # is_icd_followup = False # if icd_context.get("last_code"): # followup_indicators = [ # "what causes", # "what is causing", # "why", # "how", # "symptoms", # "treatment", # "diagnosis", # "causes", # "effects", # "complications", # "risk factors", # "prevention", # "prognosis", # "this", # "disease", # "that", # "it", # ] # is_icd_followup = any( # indicator in question.lower() for indicator in followup_indicators # ) # if is_icd_followup: # # Add the previous ICD code context to the current question # icd_exact_context = get_icd_context_from_question( # icd_context["last_code"], hospital_id # ) # icd_fuzzy_context = get_fuzzy_icd_context( # f"{icd_context['last_code']} {question}", hospital_id # ) # else: # icd_exact_context = get_icd_context_from_question(question, hospital_id) # icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) # else: # icd_exact_context = get_icd_context_from_question(question, hospital_id) # icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) # # Get contextual query and relevant context # contextual_query = await generate_contextual_query( # question, user_id, hospital_id, conversation_manager # ) # doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id) # # Combine context with priority for ICD information # context_parts = [] # if is_icd_followup: # context_parts.append( # f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}" # ) # if icd_exact_context: # context_parts.append("## ICD Code Match\n" + icd_exact_context) # if icd_fuzzy_context: # context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context) # if doc_context: # context_parts.append("## Document Context\n" + doc_context) # context = "\n\n".join(context_parts) # # Initialize follow-up detection # is_follow_up = False # # Check if this is a follow-up question # if conv_history: # last_interaction = conv_history[-1] # last_question = last_interaction["question"].lower() # last_answer = last_interaction.get("answer", "").lower() # current_question = question.lower() # # Define meaningful keywords that indicate entity-related follow-ups # entity_related_keywords = { # "achievements", # "awards", # "accomplishments", # "work", # "contributions", # "career", # "company", # "products", # "life", # "background", # "education", # "role", # "experience", # "history", # "details", # "places", # "place", # "information", # "facts", # "about", # "birth", # "death", # "family", # "books", # "projects", # "population", # } # # Check if question is asking about attributes/achievements of previously discussed entity # has_entity_attribute = any( # word in current_question.split() for word in entity_related_keywords # ) # # Extract entities from last answer to maintain context # def extract_entities(text): # # Split into words and get potential entities (capitalized words) # words = text.split() # entities = set() # current_entity = [] # for word in words: # if word[0].isupper(): # current_entity.append(word) # elif current_entity: # if len(current_entity) > 0: # entities.add(" ".join(current_entity)) # current_entity = [] # if current_entity: # entities.add(" ".join(current_entity)) # return entities # last_entities = extract_entities(last_answer) # # Check for referential words # referential_words = { # "it", # "this", # "that", # "these", # "those", # "they", # "their", # "he", # "she", # "him", # "her", # "his", # "hers", # "them", # "there", # "such", # "its", # } # has_referential = any( # word in referential_words for word in current_question.split() # ) # # Calculate term overlap with both question and answer context # def get_significant_terms(text): # stop_words = { # "what", # "when", # "where", # "who", # "why", # "how", # "is", # "are", # "was", # "were", # "be", # "been", # "the", # "a", # "an", # "in", # "on", # "at", # "to", # "for", # "of", # "with", # "by", # "about", # "as", # "tell", # "me", # "please", # } # return set( # word # for word in text.split() # if len(word) > 2 and word.lower() not in stop_words # ) # current_terms = get_significant_terms(current_question) # last_terms = get_significant_terms(last_question) # answer_terms = get_significant_terms(last_answer) # # Include terms from both question and answer in context # all_prev_terms = last_terms | answer_terms # term_overlap = len(current_terms & all_prev_terms) # total_terms = len(current_terms | all_prev_terms) # term_similarity = term_overlap / total_terms if total_terms > 0 else 0 # # Enhanced follow-up detection combining multiple signals # is_follow_up = ( # has_referential # or term_similarity # >= 0.2 # Lower threshold when including answer context # or ( # has_entity_attribute and bool(last_entities) # ) # Check if asking about attributes of known entity # or ( # last_interaction.get("type") == "general" # and term_similarity >= 0.15 # ) # ) # logging.info(f"Follow-up analysis enhanced:") # logging.info(f"- Referential words: {has_referential}") # logging.info(f"- Term similarity: {term_similarity:.2f}") # logging.info(f"- Entity attribute question: {has_entity_attribute}") # logging.info(f"- Last entities found: {last_entities}") # logging.info(f"- Is follow-up: {is_follow_up}") # # For entirely new topics (not follow-ups), use is_general_knowledge_question # if not is_follow_up: # if not context or is_general_knowledge_question(question, context, conv_history): # logging.info("No relevant context or general knowledge question detected") # answer = "# #{table_title} ## {table_headers} # # # # {table_rows} # #No relevant information found in the hospital documents for this query.
" # if conversation_manager: # await conversation_manager.add_rag_interaction( # user_id, hospital_id, question, answer, session_id # ) # return {"answer": answer}, 200 # # Generate RAG answer with enhanced context # prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question. # Previous conversation: # {format_conversation_context(conv_history)} # Context from documents: # {context} # Current question: {question} # Instructions: # 1. When providing medical codes (ICD, CPT, etc.): # - Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context. # - Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context. # - If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format. # - Remove all decimal points (e.g., use 'A150' instead of 'A15.0') # - Format the response as: 'The medical code for [condition] is [code]
' # 2. Address the current question while maintaining conversation continuity # 3. Resolve any ambiguous references using conversation history # 4. Format the response in clear HTML # {html_instruction} # {table_instruction if is_table_request(question) else ""} # """ # response = await asyncio.to_thread( # lambda: client.chat.completions.create( # model="gpt-3.5-turbo-16k", # messages=[ # {"role": "system", "content": prompt_template}, # {"role": "user", "content": question}, # ], # temperature=0.3, # max_tokens=1000, # ) # ) # answer = ensure_html_response(response.choices[0].message.content) # logging.info(f"Generated RAG answer for question: {question}") # # Store interaction in history # if conversation_manager: # await conversation_manager.add_rag_interaction( # user_id, hospital_id, question, answer, session_id # ) # return {"answer": answer}, 200 # except Exception as e: # logging.error(f"Error in generate_answer_with_rag: {e}") # return {"answer": f"Error: {str(e)}
"}, 500 async def generate_answer_with_rag( question, hospital_id, client, doc_id=None, user_id="default", conversation_manager=None, session_id=None, ): """Generate an answer using RAG, strictly using provided document context and ICD data.""" try: html_instruction = """ IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: - Usetags for paragraphs - Use
,
tags for headings and subheadings - Use
,
- tags for bullet points - Use
,
- tags for numbered lists - Use
for quoted text - Use for bold text and for emphasis - If no relevant information is found in the provided context, respond ONLY with:No relevant information found in the hospital documents for this query.
""" table_instruction = """ - For tables, use proper HTML table structure:""" # Get conversation history conv_history = ( conversation_manager.get_context_window(user_id, hospital_id, session_id) if conversation_manager else [] ) # Generate contextual query contextual_query = await generate_contextual_query( question, user_id, hospital_id, conversation_manager ) # Check for table requests in both original question and contextual query is_original_table_request = is_table_request(question) is_contextual_table_request = is_table_request(contextual_query) if contextual_query != question else False is_any_table_request = is_original_table_request or is_contextual_table_request # Retrieve document context with strict relevance doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id) # Handle ICD context icd_context = {} if conv_history: last_answer = conv_history[-1].get("answer", "") icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer) if icd_codes: icd_context["last_code"] = icd_codes[0] is_icd_followup = False if icd_context.get("last_code"): followup_indicators = [ "what causes", "what is causing", "why", "how", "symptoms", "treatment", "diagnosis", "causes", "effects", "complications", "risk factors", "prevention", "prognosis", "this", "disease", "that", "it", ] is_icd_followup = any(indicator in question.lower() for indicator in followup_indicators) if is_icd_followup: icd_exact_context = get_icd_context_from_question(icd_context["last_code"], hospital_id) icd_fuzzy_context = get_fuzzy_icd_context(f"{icd_context['last_code']} {question}", hospital_id) else: icd_exact_context = get_icd_context_from_question(question, hospital_id) icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) else: icd_exact_context = get_icd_context_from_question(question, hospital_id) icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) # Combine context with priority for ICD information context_parts = [] if is_icd_followup: context_parts.append(f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}") if icd_exact_context: context_parts.append("## ICD Code Match\n" + icd_exact_context) if icd_fuzzy_context: context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context) if doc_context: context_parts.append("## Document Context\n" + doc_context) context = "\n\n".join(context_parts) logging.info(f"Total context length: {len(context.split())} words") logging.info({context}) # Check context length if len(doc_context.split()) == 0: logging.info("A") logging.info(f"Context too short ({len(context.split())} words), returning no information found") answer = "
{table_title} {table_headers} {table_rows}No relevant information found in the hospital documents for this query.
" if conversation_manager: await conversation_manager.add_rag_interaction( user_id, hospital_id, question, answer, session_id ) return {"answer": answer}, 200 # Check follow-up status with stricter criteria is_follow_up = False if conv_history: last_interaction = conv_history[-1] last_question = last_interaction["question"].lower() last_answer = last_interaction.get("answer", "").lower() current_question = question.lower() # Define entity-related keywords entity_related_keywords = { "achievements", "awards", "accomplishments", "work", "contributions", "career", "company", "products", "life", "background", "education", "role", "experience", "history", "details", "places", "place", "information", "facts", "about", "birth", "death", "family", "books", "projects", "population", } has_entity_attribute = any(word in current_question.split() for word in entity_related_keywords) # Extract entities using spaCy for better precision doc_last = nlp(f"{last_question} {last_answer}") doc_current = nlp(current_question) last_entities = {ent.text.lower() for ent in doc_last.ents} current_entities = {ent.text.lower() for ent in doc_current.ents} # Check for referential words referential_words = { "it", "this", "that", "these", "those", "they", "their", "he", "she", "him", "her", "his", "hers", "them", "there", "such", "its", } has_referential = any(word in referential_words for word in current_question.split()) # Calculate term overlap with stricter criteria def get_significant_terms(text): stop_words = { "what", "when", "where", "who", "why", "how", "is", "are", "was", "were", "be", "been", "the", "a", "an", "in", "on", "at", "to", "for", "of", "with", "by", "about", "as", "tell", "me", "please", } return set(word for word in text.split() if len(word) > 2 and word.lower() not in stop_words) current_terms = get_significant_terms(current_question) last_terms = get_significant_terms(last_question) answer_terms = get_significant_terms(last_answer) all_prev_terms = last_terms | answer_terms term_overlap = len(current_terms & all_prev_terms) total_terms = len(current_terms | all_prev_terms) term_similarity = term_overlap / total_terms if total_terms > 0 else 0 # Use spaCy similarity for follow-up detection similarity = doc_current.similarity(doc_last) is_follow_up = ( has_referential or term_similarity >= 0.4 # Stricter threshold or (has_entity_attribute and bool(last_entities & current_entities)) or (last_interaction.get("type") == "general" and term_similarity >= 0.3) ) logging.info(f"Follow-up analysis:") logging.info(f"- Referential words: {has_referential}") logging.info(f"- Term similarity: {term_similarity:.2f}") logging.info(f"- Entity overlap: {bool(last_entities & current_entities)}") logging.info(f"- SpaCy similarity: {similarity:.2f}") logging.info(f"- Is follow-up: {is_follow_up}") # Check if question lacks relevant context if not is_follow_up: if not context or is_general_knowledge_question(question, context, conv_history): logging.info("B") logging.info("No relevant context or general knowledge question detected") answer = "No relevant information found in the hospital documents for this query.
" if conversation_manager: await conversation_manager.add_rag_interaction( user_id, hospital_id, question, answer, session_id ) return {"answer": answer}, 200 # Generate answer with strict document-based instruction prompt_template = f"""You are a document-based question-answering system. You must ONLY use the provided context and conversation history to answer the question. Do NOT use any external knowledge, assumptions, or definitions beyond the given context, even if the query seems familiar. If the context does not contain sufficient information to directly answer the question, respond ONLY with:No relevant information found in the hospital documents for this query.
Previous conversation: {format_conversation_context(conv_history)} Context from documents: {context} Current question: {question} Instructions: 1. When providing medical codes (ICD, CPT, etc.): - ONLY use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context. - Do not use or invent ICD codes from your own knowledge. - If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format. - Remove all decimal points (e.g., use 'A150' instead of 'A15.0'). - Format the response as: 'The medical code for [condition] is [code]
'. 2. Address the current question while maintaining conversation continuity. 3. Resolve any ambiguous references using conversation history. 4. Format the response in clear HTML. 5. Strictly adhere and provide a detailed answer only from the {context}.No extra knowledge or assumptions. 6. Every answer must be detailed and only from the provided context. 7. Answer should be 400-500 words long. {html_instruction} {table_instruction if is_table_request(question) else ""} """ response = await asyncio.to_thread( lambda: client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": prompt_template}, {"role": "user", "content": question}, ], temperature=0.1, # Lower temperature for strict adherence max_tokens=1000, ) ) answer = ensure_html_response(response.choices[0].message.content) logging.info(f"Generated RAG answer for question: {question}") # Store interaction in history if conversation_manager: await conversation_manager.add_rag_interaction( user_id, hospital_id, question, answer, session_id ) return {"answer": answer}, 200 except Exception as e: logging.error(f"Error in generate_answer_with_rag: {e}") return {"answer": f"Error: {str(e)}
"}, 500 async def load_existing_vector_stores(): """Load existing Chroma vector stores for each hospital""" pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: await cursor.execute("SELECT DISTINCT id FROM hospitals") hospital_ids = [row[0] for row in await cursor.fetchall()] for hospital_id in hospital_ids: try: await initialize_or_load_vector_store(hospital_id) except Exception as e: logging.error( f"Failed to load vector store for hospital {hospital_id}: {e}" ) continue except Exception as e: logging.error(f"Error loading vector stores: {e}") async def get_failed_page(doc_id): pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: await cursor.execute( "SELECT failed_page FROM documents WHERE id = %s", (doc_id,) ) result = await cursor.fetchone() return result[0] if result and result[0] else None except Exception as e: logging.error(f"Database error checking failed_page: {e}") return None async def update_document_status(doc_id, status, failed_page=None): """Update document status with enum validation""" if isinstance(status, str): status = DocumentStatus[status.upper()].value pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: if failed_page: await cursor.execute( "UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s", (status, failed_page, doc_id), ) else: await cursor.execute( "UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s", (status, doc_id), ) await conn.commit() return True except Exception as e: logging.error(f"Database update error: {e}") return False thread_pool = ThreadPoolExecutor(max_workers=10) def async_to_sync(coroutine): """Helper function to run async code in sync context""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(coroutine) finally: loop.close() @app.route("/flask-api/", methods=["GET"]) def health_check(): """Health check endpoint""" access_logger.info(f"Health check request received from {request.remote_addr}") return jsonify({"status": "ok"}), 200 @app.route("/flask-api/process-pdf", methods=["POST"]) def process_pdf(): access_logger.info(f"PDF processing request received from {request.remote_addr}") file_path = None try: file = request.files.get("pdf") hospital_id = request.form.get("hospital_id") doc_id = request.form.get("doc_id") logging.info( f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}" ) if not all([file, hospital_id, doc_id]): return jsonify({"error": "Missing required parameters"}), 400 def process_in_background(): nonlocal file_path try: async_to_sync(update_document_status(doc_id, "processing")) # Add progress logging logging.info(f"Starting processing of document {doc_id}") filename = f"doc_{doc_id}_{file.filename}" file_path = os.path.join(uploads_dir, filename) with open(file_path, "wb") as f: file.save(f) logging.info("Extracting PDF contents...") content = extract_pdf_contents(file_path, int(hospital_id)) logging.info("Inserting content into database...") metadata = {"filename": filename} result = async_to_sync( insert_content_into_db(content, metadata, doc_id) ) if "error" in result: async_to_sync(update_document_status(doc_id, "failed", 1)) return False logging.info("Creating embeddings and indexing...") success = async_to_sync(add_document_to_index(doc_id, hospital_id)) if success: logging.info("Document processing completed successfully") async_to_sync(update_document_status(doc_id, "processed")) return True else: logging.error("Document processing failed during indexing") async_to_sync(update_document_status(doc_id, "failed")) return False except Exception as e: logging.error(f"Processing error: {e}") async_to_sync(update_document_status(doc_id, "failed")) return False finally: if file_path and os.path.exists(file_path): try: os.remove(file_path) except Exception as e: logging.error(f"Error removing temporary file: {e}") # Execute processing and wait for result future = thread_pool.submit(process_in_background) success = future.result() if success: return jsonify({"message": "Document processed successfully"}), 200 else: return jsonify({"error": "Document processing failed"}), 500 except Exception as e: logging.error(f"API error: {e}") if file_path and os.path.exists(file_path): try: os.remove(file_path) except Exception as file_e: logging.error(f"Error removing temporary file: {file_e}") return jsonify({"error": str(e)}), 500 # Initialize the hybrid conversation manager redis_client = get_redis_client() conversation_manager = HybridConversationManager(redis_client) @app.route("/flask-api/generate-answer", methods=["POST"]) def rag_answer_api(): """Sync API endpoint for RAG-based question answering with conversation history.""" access_logger.info(f"Generate answer request received from {request.remote_addr}") try: data = request.json question = data.get("question", "").strip().lower() hospital_code = data.get("hospital_code") doc_id = data.get("doc_id") user_id = data.get("user_id", "default") session_id = data.get("session_id", None) logging.info(f"Received question from user {user_id}: {question}") logging.info(f"Received hospital code: {hospital_code}") logging.info(f"Received session_id: {session_id}") # is_confirmation_response = data.get("is_confirmation_response", False) original_query = data.get("original_query", "") def process_rag_answer(): try: hospital_id = async_to_sync(get_hospital_id(hospital_code)) logging.info(f"Resolved hospital ID: {hospital_id}") if not hospital_id: return { "error": "Invalid or missing 'hospital_code' in request" }, 400 if original_query: response_message = """I can only answer questions based on information found in the hospital documents.
The question you asked doesn't seem to be covered in the available documents.
You can try rephrasing your question or asking about a different topic.
""" return {"answer": response_message}, 200 else: # Regular RAG answer return async_to_sync( generate_answer_with_rag( question=question, hospital_id=hospital_id, client=client, doc_id=doc_id, user_id=user_id, conversation_manager=conversation_manager, # Pass the hybrid manager session_id=session_id, ) ) except Exception as e: logging.error(f"Thread processing error: {str(e)}") return {"error": str(e)}, 500 if not question: return jsonify({"error": "Missing 'question' in request"}), 400 future = thread_pool.submit(process_rag_answer) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 @app.route("/flask-api/delete-document-vectors", methods=["DELETE"]) def delete_document_vectors_endpoint(): """Endpoint to delete document vectors from ChromaDB""" try: data = request.json hospital_id = data.get("hospital_id") doc_id = data.get("doc_id") if not all([hospital_id, doc_id]): return jsonify({"error": "Missing required parameters"}), 400 logging.info( f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}" ) def process_deletion(): try: success = async_to_sync(delete_document_vectors(hospital_id, doc_id)) if success: return {"message": "Document vectors deleted successfully"}, 200 else: return {"error": "Failed to delete document vectors"}, 500 except Exception as e: logging.error(f"Error in vector deletion process: {e}") return {"error": str(e)}, 500 future = thread_pool.submit(process_deletion) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 @app.route("/flask-api/get-chroma-content", methods=["GET"]) def get_chroma_content_endpoint(): """API endpoint to get ChromaDB content by hospital_id""" try: hospital_id = request.args.get("hospital_id") limit = int(request.args.get("limit", 30000)) if not hospital_id: return jsonify({"error": "Missing required parameter: hospital_id"}), 400 def process_fetch(): try: result, status_code = async_to_sync( get_chroma_content_by_hospital( hospital_id=int(hospital_id), limit=limit ) ) return result, status_code except Exception as e: logging.error(f"Error in ChromaDB fetch process: {e}") return {"error": str(e)}, 500 future = thread_pool.submit(process_fetch) result, status_code = future.result() return jsonify(result), status_code except Exception as e: logging.error(f"API error: {str(e)}") return jsonify({"error": str(e)}), 500 async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100): """Fetch content from ChromaDB for a specific hospital""" try: # Initialize vector store vector_store = await initialize_or_load_vector_store(hospital_id) if not vector_store: return {"error": "Vector store not found"}, 404 # Get collection collection = vector_store._collection # Query the collection with hospital_id filter results = await asyncio.to_thread( lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit) ) if not results or not results["ids"]: return {"data": [], "count": 0}, 200 # Format the response formatted_results = [] for i in range(len(results["ids"])): formatted_results.append( { "id": results["ids"][i], "content": results["documents"][i], "metadata": results["metadatas"][i], } ) return {"data": formatted_results, "count": len(formatted_results)}, 200 except Exception as e: logging.error(f"Error fetching ChromaDB content: {e}") return {"error": str(e)}, 500 @app.before_request def before_request(): request._start_time = time.time() @app.after_request def after_request(response): if hasattr(request, "_start_time"): duration = time.time() - request._start_time access_logger.info( f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - ' f"IP: {request.remote_addr}" ) return response if __name__ == "__main__": logger.info("Starting SpurrinAI application") logger.info(f"Python version: {sys.version}") logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}") try: model_manager = ModelManager() logger.info("Model manager initialized successfully") except Exception as e: logger.error(f"Failed to initialize model manager: {e}") sys.exit(1) # Initialize directories os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(CHROMA_DIR, exist_ok=True) logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}") # Clear Redis cache redis_client = get_redis_client() cleared_keys = 0 for key in redis_client.scan_iter("vector_store_data:*"): redis_client.delete(key) cleared_keys += 1 logger.info(f"Cleared {cleared_keys} Redis cache keys") # Load vector stores logger.info("Loading existing vector stores...") async_to_sync(load_existing_vector_stores()) logger.info("Vector stores loaded successfully") # Start application logger.info("Starting Flask application on port 5000") app.run(port=5000, debug=False)