diff --git a/chat.py b/chat.py index 9482ec6..8a84749 100644 --- a/chat.py +++ b/chat.py @@ -130,7 +130,7 @@ def setup_logging(): return root_logger, access_logger, perf_logger - +load_dotenv() # Initialize loggers logger, access_logger, perf_logger = setup_logging() @@ -149,20 +149,6 @@ logger, access_logger, perf_logger = setup_logging() # } # Redis Configuration -# REDIS_CONFIG = { -# "host": "localhost", -# "port": 6379, -# "db": 0, -# "decode_responses": True, # For string operations -# } - -# DB_CONFIG = { -# "host": os.getenv("DB_HOST", "localhost"), -# "user": os.getenv("DB_USER", "spurrinuser"), -# "password": os.getenv("DB_PASSWORD", "Admin@123"), -# "database": os.getenv("DB_NAME", "spurrin-live"), -# } - REDIS_CONFIG = { "host": "localhost", "port": 6379, @@ -171,10 +157,10 @@ REDIS_CONFIG = { } DB_CONFIG = { - "host": os.getenv("DB_HOST", "localhost"), - "user": os.getenv("DB_USER", "testuser"), - "password": os.getenv("DB_PASSWORD", "Admin@123"), - "database": os.getenv("DB_NAME", "spurrintest"), + "host": os.getenv("DB_HOST"), + "user": os.getenv("DB_USER"), + "password": os.getenv("DB_PASSWORD"), + "database": os.getenv("DB_NAME"), } # Redis connection pool @@ -229,7 +215,6 @@ if not os.path.exists(uploads_dir): nlp = spacy.load("en_core_web_sm") -load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=OPENAI_API_KEY) @@ -284,8 +269,8 @@ async def get_hospital_id(hospital_code): await pool.wait_closed() -CHUNK_SIZE = 1000 -CHUNK_OVERLAP = 50 +CHUNK_SIZE = 4000 +CHUNK_OVERLAP = 150 BATCH_SIZE = 1000 text_splitter = RecursiveCharacterTextSplitter( @@ -669,6 +654,108 @@ async def add_document_to_index(doc_id, hospital_id): return False +def is_general_knowledge_question( + query: str, context: str, conversation_context=None +) -> bool: + """ + Determine if a question is likely a general knowledge question not covered in the documents. + Takes conversation history into account to reduce repeated confirmations. + """ + query_lower = query.lower() + context_lower = context.lower() + + if conversation_context: + for interaction in conversation_context: + prev_question = interaction.get("question", "").lower() + if ( + prev_question + and query_lower in prev_question + or prev_question in query_lower + ): + logging.info( + f"Question is similar to previous conversation, skipping confirmation" + ) + return False + + stop_words = { + "search", + "query:", + "can", + "you", + "some", + "at", + "the", + "a", + "an", + "in", + "on", + "at", + "to", + "for", + "with", + "by", + "about", + "give", + "full", + "is", + "are", + "was", + "were", + "define", + "what", + "how", + "why", + "when", + "where", + "year", + "list", + "form", + "table", + "who", + "which", + "me", + "tell", + "explain", + "describe", + "of", + "and", + "or", + "there", + "their", + "please", + "could", + "would", + "various", + "different", + "type", + "types", + "kind", + "kinds", + "has", + "have", + "had", + "many", + "say", + } + + key_words = [ + word for word in query_lower.split() if word not in stop_words and len(word) > 2 + ] + logging.info(f"Key words: {key_words}") + + if not key_words: + logging.info("No significant keywords found, directing to general knowledge") + return True + + matches = sum(1 for word in key_words if word in context_lower) + logging.info(f"Matches: {matches} out of {len(key_words)} keywords") + + match_ratio = matches / len(key_words) + logging.info(f"Match ratio: {match_ratio}") + + return match_ratio < 0.6 + + def is_table_request(query: str) -> bool: """ Determine if the user is requesting a response in tabular format. @@ -768,15 +855,19 @@ def ensure_html_response(text: str) -> str: return text -class RAGConversationManager: +class HybridConversationManager: """ - Conversation manager that uses Redis for RAG-based conversations only. + Hybrid conversation manager that uses Redis for RAG-based conversations + and in-memory storage for general knowledge conversations. """ def __init__(self, redis_client, ttl=3600, max_history_items=5): self.redis_client = redis_client self.ttl = ttl self.max_history_items = max_history_items + + # For general knowledge questions (in-memory) + self.general_knowledge_histories = {} self.lock = Lock() def _get_redis_key(self, user_id, hospital_id, session_id=None): @@ -785,6 +876,12 @@ class RAGConversationManager: return f"conv_history:{user_id}:{hospital_id}:{session_id}" return f"conv_history:{user_id}:{hospital_id}" + def _get_memory_key(self, user_id, hospital_id, session_id=None): + """Create memory key for general knowledge conversations.""" + if session_id: + return f"{user_id}:{hospital_id}:{session_id}" + return f"{user_id}:{hospital_id}" + async def add_rag_interaction( self, user_id, hospital_id, question, answer, session_id=None ): @@ -814,6 +911,35 @@ class RAGConversationManager: except Exception as e: logging.error(f"Failed to store RAG interaction in Redis: {e}") + def add_general_knowledge_interaction( + self, user_id, hospital_id, question, answer, session_id=None + ): + """Add general knowledge interaction to in-memory store.""" + key = self._get_memory_key(user_id, hospital_id, session_id) + + with self.lock: + if key not in self.general_knowledge_histories: + self.general_knowledge_histories[key] = [] + + self.general_knowledge_histories[key].append( + { + "question": question, + "answer": answer, + "timestamp": time.time(), + "type": "general", # Mark as general knowledge interaction + } + ) + + # Keep only the most recent interactions + if len(self.general_knowledge_histories[key]) > self.max_history_items: + self.general_knowledge_histories[key] = ( + self.general_knowledge_histories[key][-self.max_history_items :] + ) + + logging.info( + f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}" + ) + def get_rag_history(self, user_id, hospital_id, session_id=None): """Get document-based (RAG) conversation history from Redis.""" key = self._get_redis_key(user_id, hospital_id, session_id) @@ -824,21 +950,51 @@ class RAGConversationManager: logging.error(f"Failed to retrieve RAG history from Redis: {e}") return [] + def get_general_knowledge_history(self, user_id, hospital_id, session_id=None): + """Get general knowledge conversation history from memory.""" + key = self._get_memory_key(user_id, hospital_id, session_id) + + with self.lock: + return self.general_knowledge_histories.get(key, []).copy() + + def get_combined_history(self, user_id, hospital_id, session_id=None): + """Get combined conversation history from both sources, sorted by timestamp.""" + rag_history = self.get_rag_history(user_id, hospital_id, session_id) + general_history = self.get_general_knowledge_history( + user_id, hospital_id, session_id + ) + + # Combine histories + combined_history = rag_history + general_history + + # Sort by timestamp (newest first) + combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True) + + # Return most recent N items + return combined_history[: self.max_history_items] + def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2): - """Get the most recent interactions for context.""" - history = self.get_rag_history(user_id, hospital_id, session_id) + """Get the most recent interactions for context from combined history.""" + combined_history = self.get_combined_history(user_id, hospital_id, session_id) # Sort by timestamp (oldest first) for context window - sorted_history = sorted(history, key=lambda x: x.get("timestamp", 0)) + sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0)) return sorted_history[-window_size:] if sorted_history else [] def clear_history(self, user_id, hospital_id): - """Clear conversation history.""" + """Clear conversation history from both stores.""" + # Clear Redis history redis_key = self._get_redis_key(user_id, hospital_id) try: self.redis_client.delete(redis_key) except Exception as e: logging.error(f"Failed to clear Redis history: {e}") + # Clear memory history + memory_key = self._get_memory_key(user_id, hospital_id) + with self.lock: + if memory_key in self.general_knowledge_histories: + del self.general_knowledge_histories[memory_key] + class ContextMapper: """Enhanced context mapping using shared model manager""" @@ -1013,7 +1169,6 @@ def is_follow_up(current_question: str, conversation_history: list) -> bool: ) ) - async def get_relevant_context(question, hospital_id, doc_id=None): try: cache_key = f"context:hospital_{hospital_id}" @@ -1022,15 +1177,10 @@ async def get_relevant_context(question, hospital_id, doc_id=None): cache_key += f":{question.lower().strip()}" redis_client = get_redis_client() - cached_context = redis_client.get(cache_key) if cached_context: logging.info(f"Cache hit for key: {cache_key}") - return ( - cached_context.decode("utf-8") - if isinstance(cached_context, bytes) - else cached_context - ) + return cached_context.decode("utf-8") if isinstance(cached_context, bytes) else cached_context vector_store = await initialize_or_load_vector_store(hospital_id) if not vector_store: @@ -1039,17 +1189,33 @@ async def get_relevant_context(question, hospital_id, doc_id=None): retriever = vector_store.as_retriever( search_type="mmr", search_kwargs={ - "k": 10, - "fetch_k": 20, - "lambda_mult": 0.6, - # "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)} + "k": 5, # Reduced number of documents for precision + "fetch_k": 10, # Reduced fetch size + "lambda_mult": 0.8, # Increase diversity in MMR + "score_threshold": 0.7, # Add minimum similarity score + "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)} }, ) docs = await asyncio.to_thread(retriever.get_relevant_documents, question) if not docs: + logging.info(f"No relevant documents found for question: {question}") return "" + # # Filter documents by relevance using spaCy similarity + # question_doc = nlp(question) + # relevant_docs = [] + # for doc in docs: + # doc_content = nlp(doc.page_content) + # similarity = question_doc.similarity(doc_content) + # if similarity >= 0.7: # Strict similarity threshold + # relevant_docs.append(doc) + + # if not relevant_docs: + # logging.info({relevant_docs}) + # logging.info(f"No sufficiently relevant documents after similarity filtering for: {question}") + # return "" + sorted_docs = sorted( docs, key=lambda x: ( @@ -1076,7 +1242,6 @@ async def get_relevant_context(question, hospital_id, doc_id=None): logging.error(f"Error getting relevant context: {e}") return "" - def format_conversation_context(conv_history): """Format conversation history into a string""" if not conv_history: @@ -1122,6 +1287,348 @@ def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70): return "\n".join(matched_context) +# async def generate_answer_with_rag( +# question, +# hospital_id, +# client, +# doc_id=None, +# user_id="default", +# conversation_manager=None, +# session_id=None, +# ): +# """Generate an answer using RAG with improved conversation flow""" +# try: +# # Continue with regular RAG processing if not an ICD code or if no ICD match found +# html_instruction = """ +# IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: +# - Use

tags for paragraphs +# - Use

,

tags for headings and subheadings +# - Use