From 03aab50781fc476d3dbdc30cc7027e94663a7b06 Mon Sep 17 00:00:00 2001 From: vriti Date: Mon, 9 Jun 2025 19:24:02 +0530 Subject: [PATCH] deleted certificates folders, automation test --- .gitignore | 3 +- chat copy.py | 2161 -------------------------------------------------- 2 files changed, 2 insertions(+), 2162 deletions(-) delete mode 100644 chat copy.py diff --git a/.gitignore b/.gitignore index 9455c0b..76e9b59 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ /logs /error.log /uploads -/llm-uploads \ No newline at end of file +/llm-uploads +/certificates \ No newline at end of file diff --git a/chat copy.py b/chat copy.py deleted file mode 100644 index 9af82fc..0000000 --- a/chat copy.py +++ /dev/null @@ -1,2161 +0,0 @@ -""" -SpurrinAI - Intelligent Document Processing and Question Answering System -Copyright (c) 2024 Tech4biz. All rights reserved. - -This module implements the main Flask application for the SpurrinAI system, -providing REST APIs for document processing, vector storage, and question answering -using RAG (Retrieval Augmented Generation) architecture. - -Author: Tech4biz Development Team -Version: 1.0.0 -Last Updated: 2024-01-19 -""" - -# Standard library imports -import os -import re -import sys -import json -import time -import threading -import asyncio -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from datetime import timedelta -from enum import Enum - -# Third-party imports -import spacy -import redis -import aiomysql -from dotenv import load_dotenv -from flask import Flask, request, jsonify, Response -from flask_cors import CORS -from tqdm import tqdm -from tqdm.asyncio import tqdm as tqdm_async -from langchain_community.document_loaders import PyPDFLoader -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.vectorstores import Chroma -from langchain_community.chat_models import ChatOpenAI -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from openai import OpenAI -from rapidfuzz import process -from threading import Lock - -# Local imports -from model_manager import ModelManager - -# Suppress warnings -import warnings - -warnings.filterwarnings("ignore") - -# Initialize NLTK -import nltk - -nltk.download("punkt") - -# Configure logging -import logging -import logging.handlers - -app = Flask(__name__) -CORS(app) - -script_dir = os.path.dirname(os.path.abspath(__file__)) -log_file_path = os.path.join(script_dir, "error.log") -logging.basicConfig(filename=log_file_path, level=logging.INFO) - - -# Configure logging -def setup_logging(): - log_dir = os.path.join(script_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - - main_log = os.path.join(log_dir, "app.log") - error_log = os.path.join(log_dir, "error.log") - access_log = os.path.join(log_dir, "access.log") - perf_log = os.path.join(log_dir, "performance.log") - - # Create formatters - detailed_formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" - ) - access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - - # Main logger setup - main_handler = logging.handlers.RotatingFileHandler( - main_log, maxBytes=10485760, backupCount=5 - ) - main_handler.setFormatter(detailed_formatter) - main_handler.setLevel(logging.INFO) - - # Error logger setup - error_handler = logging.handlers.RotatingFileHandler( - error_log, maxBytes=10485760, backupCount=5 - ) - error_handler.setFormatter(detailed_formatter) - error_handler.setLevel(logging.ERROR) - - # Access logger setup - access_handler = logging.handlers.TimedRotatingFileHandler( - access_log, when="midnight", interval=1, backupCount=30 - ) - access_handler.setFormatter(access_formatter) - access_handler.setLevel(logging.INFO) - - # Performance logger setup - perf_handler = logging.handlers.RotatingFileHandler( - perf_log, maxBytes=10485760, backupCount=5 - ) - perf_handler.setFormatter(detailed_formatter) - perf_handler.setLevel(logging.INFO) - - # Configure root logger - root_logger = logging.getLogger() - root_logger.setLevel(logging.INFO) - root_logger.addHandler(main_handler) - root_logger.addHandler(error_handler) - - # Create specific loggers - access_logger = logging.getLogger("access") - access_logger.addHandler(access_handler) - access_logger.setLevel(logging.INFO) - - perf_logger = logging.getLogger("performance") - perf_logger.addHandler(perf_handler) - perf_logger.setLevel(logging.INFO) - - return root_logger, access_logger, perf_logger - - -# Initialize loggers -logger, access_logger, perf_logger = setup_logging() - -# DB_CONFIG = { -# 'host': 'localhost', -# 'user': 'flaskuser', -# 'password': 'Flask@123', -# 'database': 'spurrinai', -# } - -# DB_CONFIG = { -# 'host': 'localhost', -# 'user': 'spurrindevuser', -# 'password': 'Admin@123', -# 'database': 'spurrindev', -# } - -# Redis Configuration -REDIS_CONFIG = { - "host": "localhost", - "port": 6379, - "db": 0, - "decode_responses": True, # For string operations -} - -DB_CONFIG = { - "host": os.getenv("DB_HOST", "localhost"), - "user": os.getenv("DB_USER", "testuser"), - "password": os.getenv("DB_PASSWORD", "Admin@123"), - "database": os.getenv("DB_NAME", "spurrintest"), -} - -# Redis connection pool -redis_pool = redis.ConnectionPool(**REDIS_CONFIG) -redis_binary_pool = redis.ConnectionPool( - host="localhost", port=6379, db=1, decode_responses=False -) - - -def get_redis_client(binary=False): - """Get Redis client from pool""" - logger.debug(f"Getting Redis client with binary={binary}") - try: - pool = redis_binary_pool if binary else redis_pool - client = redis.Redis(connection_pool=pool) - logger.debug("Redis client created successfully") - return client - except Exception as e: - logger.error(f"Failed to create Redis client: {e}", exc_info=True) - raise - - -def fetch_cached_answer(cache_key): - logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") - start_time = time.time() - try: - redis_client = get_redis_client() - cached_answer = redis_client.get(cache_key) - fetch_time = time.time() - start_time - perf_logger.info( - f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" - ) - return cached_answer - except Exception as e: - logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) - return None - - -# Cache TTL configurations -CACHE_TTL = { - "vector_store": timedelta(hours=24), - "chat_completion": timedelta(hours=1), - "document_metadata": timedelta(days=7), -} - -DATA_DIR = os.path.join(script_dir, "hospital_data") -CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") -uploads_dir = os.path.join(script_dir, "llm-uploads") - -if not os.path.exists(uploads_dir): - os.makedirs(uploads_dir) - -nlp = spacy.load("en_core_web_sm") - -load_dotenv() -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") - -client = OpenAI(api_key=OPENAI_API_KEY) -embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) -llm = ChatOpenAI( - model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY -) -# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) -hospital_vector_stores = {} -vector_store_lock = threading.Lock() - - -@dataclass -class Document: - doc_id: int - page_num: int - content: str - - -class DocumentStatus(Enum): - PROCESSING = "processing" - PROCESSED = "processed" - FAILED = "failed" - - -async def get_db_pool(): - return await aiomysql.create_pool( - host=DB_CONFIG["host"], - user=DB_CONFIG["user"], - password=DB_CONFIG["password"], - db=DB_CONFIG["database"], - autocommit=True, - ) - - -async def get_hospital_id(hospital_code): - try: - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor(aiomysql.DictCursor) as cursor: - await cursor.execute( - "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", - (hospital_code,), - ) - result = await cursor.fetchone() - return result["id"] if result else None - except Exception as error: - logging.error(f"Database error: {error}") - return None - finally: - pool.close() - await pool.wait_closed() - - -CHUNK_SIZE = 1000 -CHUNK_OVERLAP = 50 -BATCH_SIZE = 1000 - -text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=CHUNK_OVERLAP, - # length_function=len, - # separators=["\n\n", "\n", ". ", " ", ""] -) - - -# Update the JSON_PATH to be dynamic based on hospital_id -def get_icd_json_path(hospital_id): - hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") - os.makedirs(hospital_data_dir, exist_ok=True) - return os.path.join(hospital_data_dir, "icd_data.json") - - -def extract_and_process_icd_data(content, hospital_id, save_to_json=True): - """Extract and process ICD codes with optimized processing and optional JSON saving""" - try: - # Initialize pattern compilation once - pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) - - # Process in chunks for large content - chunk_size = 50000 # Process 50KB at a time - icd_data = [] - - current_code = None - current_description = [] - - # Split content into manageable chunks - content_chunks = [ - content[i : i + chunk_size] for i in range(0, len(content), chunk_size) - ] - - # Process each chunk - for chunk in content_chunks: - lines = chunk.splitlines() - - for line in lines: - line = line.strip() - if not line: - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - current_code = None - current_description = [] - continue - - match = pattern.match(line) - if match: - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - current_code, description = match.groups() - current_description = [description.strip()] - elif current_code: - current_description.append(line) - - # Add final entry if exists - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - - # Save to hospital-specific JSON if requested - if save_to_json and icd_data: - try: - json_path = get_icd_json_path(hospital_id) - - # Use a lock for thread safety - with threading.Lock(): - if os.path.exists(json_path): - with open(json_path, "r", encoding="utf-8") as f: - try: - existing_data = json.load(f) - except json.JSONDecodeError: - existing_data = [] - else: - existing_data = [] - - # Efficient deduplication using dictionary - seen_codes = {item["code"]: item for item in existing_data} - for item in icd_data: - seen_codes[item["code"]] = item - - unique_data = list(seen_codes.values()) - - # Write atomically using temporary file - temp_path = f"{json_path}.tmp" - with open(temp_path, "w", encoding="utf-8") as f: - json.dump(unique_data, f, indent=2, ensure_ascii=False) - os.replace(temp_path, json_path) - - logging.info( - f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" - ) - - except Exception as e: - logging.error( - f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" - ) - - return icd_data - - except Exception as e: - logging.error(f"Error in extract_and_process_icd_data: {e}") - return [] - - -def load_icd_entries(hospital_id): - """Load ICD entries from hospital-specific JSON file""" - json_path = get_icd_json_path(hospital_id) - try: - if os.path.exists(json_path): - with open(json_path, "r", encoding="utf-8") as f: - return json.load(f) - return [] - except Exception as e: - logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") - return [] - - -# Update the process_icd_codes function to include hospital_id -async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): - """Process and store ICD codes using the optimized extraction function""" - try: - # Extract and save codes with hospital_id - extract_and_process_icd_data(content, hospital_id, save_to_json=True) - except Exception as e: - logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") - - -async def initialize_icd_vector_store(hospital_id): - """This function is deprecated. ICD codes are now handled through JSON search.""" - logging.warning( - "initialize_icd_vector_store is deprecated - using JSON search instead" - ) - return None - - -def extract_pdf_contents(pdf_path, hospital_id): - """Extract PDF contents with optimized chunking and code extraction""" - try: - loader = PyPDFLoader(pdf_path) - pages = loader.load() - pages_content = [] - - for i, page in enumerate(tqdm(pages, desc="Extracting pages")): - text = page.page_content.strip() - - # Extract ICD codes from the page - icd_codes = extract_and_process_icd_data( - text, hospital_id - ) # We'll set doc_id later - - pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) - - return pages_content - - except Exception as e: - logging.error(f"Error in extract_pdf_contents: {e}") - raise - - -async def insert_content_into_db(content, metadata, doc_id): - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - try: - metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" - content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" - - metadata_values = [ - (doc_id, key[:100], value) - for key, value in metadata.items() - if value - ] - content_values = [ - (doc_id, page_content["page"], page_content["text"]) - for page_content in content - ] - - if metadata_values: - await cursor.executemany(metadata_query, metadata_values) - if content_values: - await cursor.executemany(content_query, content_values) - - await conn.commit() - return {"message": "Success"} - except Exception as e: - await conn.rollback() - return {"error": str(e)} - - -async def initialize_or_load_vector_store(hospital_id, user_id="default"): - """Initialize or load vector store with Redis caching and thread safety""" - store_key = f"{hospital_id}:{user_id}" - - try: - # Check if we already have it loaded - with lock for thread safety - with vector_store_lock: - if store_key in hospital_vector_stores: - return hospital_vector_stores[store_key] - - # Initialize vector store - redis_client = get_redis_client(binary=True) - cache_key = f"vector_store_data:{hospital_id}:{user_id}" - hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") - - if os.path.exists(hospital_dir): - logging.info( - f"Loading vector store for hospital {hospital_id} and user {user_id}" - ) - vector_store = await asyncio.to_thread( - lambda: Chroma( - collection_name=f"hospital_{hospital_id}", - persist_directory=hospital_dir, - embedding_function=embeddings, - ) - ) - else: - logging.info(f"Creating vector store for hospital {hospital_id}") - os.makedirs(hospital_dir, exist_ok=True) - vector_store = await asyncio.to_thread( - lambda: Chroma( - collection_name=f"hospital_{hospital_id}", - persist_directory=hospital_dir, - embedding_function=embeddings, - ) - ) - - # Store with lock for thread safety - with vector_store_lock: - hospital_vector_stores[store_key] = vector_store - - return vector_store - except Exception as e: - logging.error(f"Error initializing vector store: {e}", exc_info=True) - raise - - -async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: - """Delete all vectors associated with a specific document from ChromaDB""" - try: - # Initialize vector store for the hospital - vector_store = await initialize_or_load_vector_store(hospital_id) - - # Delete vectors with matching doc_id - await asyncio.to_thread( - lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) - ) - - # Persist changes - await asyncio.to_thread(vector_store.persist) - - # Clear Redis cache for this document - redis_client = get_redis_client() - pattern = f"vector_store_data:{hospital_id}:*" - for key in redis_client.scan_iter(pattern): - redis_client.delete(key) - - logging.info( - f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" - ) - return True - - except Exception as e: - logging.error(f"Error deleting document vectors: {e}", exc_info=True) - return False - - -async def add_document_to_index(doc_id, hospital_id): - try: - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - vector_store = await initialize_or_load_vector_store(hospital_id) - - await cursor.execute( - "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", - (doc_id,), - ) - rows = await cursor.fetchall() - - total_pages = len(rows) - logging.info(f"Processing {total_pages} pages for document {doc_id}") - page_bar = tqdm_async(total=total_pages, desc="Processing pages") - - async def process_page(page_data): - page_num, content = page_data - try: - icd_data = extract_and_process_icd_data( - content, hospital_id, save_to_json=False - ) - chunks = text_splitter.split_text(content) - await asyncio.sleep(0) # Yield control - return page_num, chunks, icd_data - except Exception as e: - logging.error(f"Error processing page {page_num}: {e}") - return page_num, [], [] - - tasks = [asyncio.create_task(process_page(row)) for row in rows] - results = [] - - for coro in asyncio.as_completed(tasks): - result = await coro - results.append(result) - page_bar.update(1) - - page_bar.close() - - # Vector addition progress bar - all_icd_data = [] - all_chunks = [] - all_metadatas = [] - - chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) - - for result in results: - page_num, chunks, icd_data = result - all_icd_data.extend(icd_data) - - for i, chunk in enumerate(chunks): - all_chunks.append(chunk) - all_metadatas.append( - { - "doc_id": str(doc_id), - "hospital_id": str(hospital_id), - "page_number": str(page_num), - "chunk_index": str(i), - } - ) - - if len(all_chunks) >= BATCH_SIZE: - chunk_add_bar.total += len(all_chunks) - chunk_add_bar.refresh() - await asyncio.to_thread( - vector_store.add_texts, - texts=all_chunks, - metadatas=all_metadatas, - ) - all_chunks = [] - all_metadatas = [] - chunk_add_bar.update(BATCH_SIZE) - - # Final batch - if all_chunks: - chunk_add_bar.total += len(all_chunks) - chunk_add_bar.refresh() - await asyncio.to_thread( - vector_store.add_texts, - texts=all_chunks, - metadatas=all_metadatas, - ) - chunk_add_bar.update(len(all_chunks)) - - chunk_add_bar.close() - - if all_icd_data: - logging.info(f"Saving {len(all_icd_data)} ICD codes") - extract_and_process_icd_data("", hospital_id, save_to_json=True) - - await asyncio.to_thread(vector_store.persist) - logging.info(f"Successfully indexed document {doc_id}") - return True - - except Exception as e: - logging.error(f"Error adding document: {e}") - return False - - -def is_general_knowledge_question( - query: str, context: str, conversation_context=None -) -> bool: - """ - Determine if a question is likely a general knowledge question not covered in the documents. - Takes conversation history into account to reduce repeated confirmations. - """ - query_lower = query.lower() - context_lower = context.lower() - - if conversation_context: - for interaction in conversation_context: - prev_question = interaction.get("question", "").lower() - if ( - prev_question - and query_lower in prev_question - or prev_question in query_lower - ): - logging.info( - f"Question is similar to previous conversation, skipping confirmation" - ) - return False - - stop_words = { - "search", - "query:", - "can", - "you", - "some", - "at", - "the", - "a", - "an", - "in", - "on", - "at", - "to", - "for", - "with", - "by", - "about", - "give", - "full", - "is", - "are", - "was", - "were", - "define", - "what", - "how", - "why", - "when", - "where", - "year", - "list", - "form", - "table", - "who", - "which", - "me", - "tell", - "explain", - "describe", - "of", - "and", - "or", - "there", - "their", - "please", - "could", - "would", - "various", - "different", - "type", - "types", - "kind", - "kinds", - "has", - "have", - "had", - "many", - "say", - } - - key_words = [ - word for word in query_lower.split() if word not in stop_words and len(word) > 2 - ] - logging.info(f"Key words: {key_words}") - - if not key_words: - logging.info("No significant keywords found, directing to general knowledge") - return True - - matches = sum(1 for word in key_words if word in context_lower) - logging.info(f"Matches: {matches} out of {len(key_words)} keywords") - - match_ratio = matches / len(key_words) - logging.info(f"Match ratio: {match_ratio}") - - return match_ratio < 0.4 - - -def is_table_request(query: str) -> bool: - """ - Determine if the user is requesting a response in tabular format. - """ - table_keywords = [ - "table", - "tabular", - "in a table", - "in table format", - "in tabular format", - "chart", - "data", - "comparison", - "as a table", - "table format", - "in rows and columns", - "in a grid", - "breakdown", - "spreadsheet", - "comparison table", - "data table", - "structured table", - "tabular form", - "table form", - ] - - query_lower = query.lower() - return any(keyword in query_lower for keyword in table_keywords) - - -import re - - -def ensure_html_response(text: str) -> str: - """ - Ensure the response is properly formatted in HTML. - This function handles plain text conversion to HTML. - """ - if "", text)) - - if not has_html_tags: - paragraphs = text.split("\n\n") - html_parts = [] - in_ordered_list = False - in_unordered_list = False - - for para in paragraphs: - if para.strip(): - if re.match(r"^\s*[\*\-\•]\s", para): - if not in_unordered_list: - html_parts.append("") - in_unordered_list = False - - html_parts.append(f"

{para}

") - - if in_ordered_list: - html_parts.append("") - if in_unordered_list: - html_parts.append("") - - return "".join(html_parts) - - else: - if not any(tag in text for tag in ("

", "

", "
    ", "
      ")): - paragraphs = text.split("\n\n") - html_parts = [f"

      {para}

      " for para in paragraphs if para.strip()] - return "".join(html_parts) - - return text - - -class HybridConversationManager: - """ - Hybrid conversation manager that uses Redis for RAG-based conversations - and in-memory storage for general knowledge conversations. - """ - - def __init__(self, redis_client, ttl=3600, max_history_items=5): - self.redis_client = redis_client - self.ttl = ttl - self.max_history_items = max_history_items - - # For general knowledge questions (in-memory) - self.general_knowledge_histories = {} - self.lock = Lock() - - def _get_redis_key(self, user_id, hospital_id, session_id=None): - """Create Redis key for document-based conversations.""" - if session_id: - return f"conv_history:{user_id}:{hospital_id}:{session_id}" - return f"conv_history:{user_id}:{hospital_id}" - - def _get_memory_key(self, user_id, hospital_id, session_id=None): - """Create memory key for general knowledge conversations.""" - if session_id: - return f"{user_id}:{hospital_id}:{session_id}" - return f"{user_id}:{hospital_id}" - - async def add_rag_interaction( - self, user_id, hospital_id, question, answer, session_id=None - ): - """Add document-based (RAG) interaction to Redis.""" - key = self._get_redis_key(user_id, hospital_id, session_id) - history = self.get_rag_history(user_id, hospital_id, session_id) - - # Add new interaction - history.append( - { - "question": question, - "answer": answer, - "timestamp": time.time(), - "type": "rag", # Mark as RAG-based interaction - } - ) - - # Keep only last N interactions - history = history[-self.max_history_items :] - - # Store updated history - try: - self.redis_client.setex(key, self.ttl, json.dumps(history)) - logging.info( - f"Stored RAG interaction in Redis for {user_id}:{hospital_id}:{session_id}" - ) - except Exception as e: - logging.error(f"Failed to store RAG interaction in Redis: {e}") - - def add_general_knowledge_interaction( - self, user_id, hospital_id, question, answer, session_id=None - ): - """Add general knowledge interaction to in-memory store.""" - key = self._get_memory_key(user_id, hospital_id, session_id) - - with self.lock: - if key not in self.general_knowledge_histories: - self.general_knowledge_histories[key] = [] - - self.general_knowledge_histories[key].append( - { - "question": question, - "answer": answer, - "timestamp": time.time(), - "type": "general", # Mark as general knowledge interaction - } - ) - - # Keep only the most recent interactions - if len(self.general_knowledge_histories[key]) > self.max_history_items: - self.general_knowledge_histories[key] = ( - self.general_knowledge_histories[key][-self.max_history_items :] - ) - - logging.info( - f"Stored general knowledge interaction in memory for {user_id}:{hospital_id}:{session_id}" - ) - - def get_rag_history(self, user_id, hospital_id, session_id=None): - """Get document-based (RAG) conversation history from Redis.""" - key = self._get_redis_key(user_id, hospital_id, session_id) - try: - history_data = self.redis_client.get(key) - return json.loads(history_data) if history_data else [] - except Exception as e: - logging.error(f"Failed to retrieve RAG history from Redis: {e}") - return [] - - def get_general_knowledge_history(self, user_id, hospital_id, session_id=None): - """Get general knowledge conversation history from memory.""" - key = self._get_memory_key(user_id, hospital_id, session_id) - - with self.lock: - return self.general_knowledge_histories.get(key, []).copy() - - def get_combined_history(self, user_id, hospital_id, session_id=None): - """Get combined conversation history from both sources, sorted by timestamp.""" - rag_history = self.get_rag_history(user_id, hospital_id, session_id) - general_history = self.get_general_knowledge_history( - user_id, hospital_id, session_id - ) - - # Combine histories - combined_history = rag_history + general_history - - # Sort by timestamp (newest first) - combined_history.sort(key=lambda x: x.get("timestamp", 0), reverse=True) - - # Return most recent N items - return combined_history[: self.max_history_items] - - def get_context_window(self, user_id, hospital_id, session_id=None, window_size=2): - """Get the most recent interactions for context from combined history.""" - combined_history = self.get_combined_history(user_id, hospital_id, session_id) - # Sort by timestamp (oldest first) for context window - sorted_history = sorted(combined_history, key=lambda x: x.get("timestamp", 0)) - return sorted_history[-window_size:] if sorted_history else [] - - def clear_history(self, user_id, hospital_id): - """Clear conversation history from both stores.""" - # Clear Redis history - redis_key = self._get_redis_key(user_id, hospital_id) - try: - self.redis_client.delete(redis_key) - except Exception as e: - logging.error(f"Failed to clear Redis history: {e}") - - # Clear memory history - memory_key = self._get_memory_key(user_id, hospital_id) - with self.lock: - if memory_key in self.general_knowledge_histories: - del self.general_knowledge_histories[memory_key] - - -class ContextMapper: - """Enhanced context mapping using shared model manager""" - - def __init__(self): - self.model_manager = ModelManager() - self.context_cache = {} - self.similarity_threshold = 0.6 - - def get_semantic_similarity(self, text1, text2): - """Get semantic similarity using global model manager""" - return self.model_manager.get_semantic_similarity(text1, text2) - - def extract_key_concepts(self, text): - """Extract key concepts using NLP techniques""" - doc = nlp(text) - concepts = [] - - entities = [(ent.text, ent.label_) for ent in doc.ents] - noun_phrases = [chunk.text for chunk in doc.noun_chunks] - important_words = [ - token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB"] - ] - - concepts.extend([e[0] for e in entities]) - concepts.extend(noun_phrases) - concepts.extend(important_words) - - return list(set(concepts)) - - def map_conversation_context( - self, current_query, conversation_history, context_window=3 - ): - """Map conversation context using enhanced NLP techniques""" - if not conversation_history: - return current_query - - recent_context = conversation_history[-context_window:] - context_concepts = [] - - # Extract concepts from recent conversations - for interaction in recent_context: - q_concepts = self.extract_key_concepts(interaction["question"]) - a_concepts = self.extract_key_concepts(interaction["answer"]) - context_concepts.extend(q_concepts) - context_concepts.extend(a_concepts) - - # Extract concepts from current query - query_concepts = self.extract_key_concepts(current_query) - - # Find related concepts - related_concepts = [] - for q_concept in query_concepts: - for c_concept in context_concepts: - similarity = self.get_semantic_similarity(q_concept, c_concept) - if similarity > self.similarity_threshold: - related_concepts.append(c_concept) - - # Build enhanced query - if related_concepts: - enhanced_query = ( - f"{current_query} in context of {', '.join(related_concepts)}" - ) - else: - enhanced_query = current_query - - return enhanced_query - - -# Initialize the context mapper -context_mapper = ContextMapper() - - -async def generate_contextual_query( - question: str, user_id: str, hospital_id: int, conversation_manager -) -> str: - """Generate enhanced contextual query""" - context_window = conversation_manager.get_context_window(user_id, hospital_id) - - if not context_window: - return question - - # Enhanced context mapping - last_interaction = context_window[-1] - enhanced_context = f""" - Previous question: {last_interaction['question']} - Previous answer: {last_interaction['answer']} - Current question: {question} - - Please generate a detailed search query that combines the context from the previous answer - with the current question, especially if the current question uses words like 'it', 'this', - 'that', or asks for more details about the previous topic. - """ - - try: - response = await asyncio.to_thread( - lambda: client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[ - { - "role": "system", - "content": "You are a context-aware query generator.", - }, - {"role": "user", "content": enhanced_context}, - ], - temperature=0.3, - max_tokens=150, - ) - ) - contextual_query = response.choices[0].message.content.strip() - logging.info(f"Enhanced contextual query: {contextual_query}") - return contextual_query - except Exception as e: - logging.error(f"Error generating contextual query: {e}") - return question - - -def is_follow_up(current_question: str, conversation_history: list) -> bool: - """Enhanced follow-up detection using NLP techniques""" - if not conversation_history: - return False - - last_interaction = conversation_history[-1] - - # Get semantic similarity with higher threshold - similarity = context_mapper.get_semantic_similarity( - current_question, f"{last_interaction['question']} {last_interaction['answer']}" - ) - - # Enhanced referential check - doc = nlp(current_question.lower()) - has_referential = any( - token.lemma_ - in [ - "it", - "this", - "that", - "these", - "those", - "they", - "he", - "she", - "about", - "more", - ] - for token in doc - ) - - # Extract concepts with improved entity detection - current_concepts = set(context_mapper.extract_key_concepts(current_question)) - last_concepts = set( - context_mapper.extract_key_concepts( - f"{last_interaction['question']} {last_interaction['answer']}" - ) - ) - - # Calculate enhanced concept overlap - concept_overlap = ( - len(current_concepts & last_concepts) / len(current_concepts | last_concepts) - if current_concepts - else 0 - ) - - # More aggressive follow-up detection - return ( - similarity > 0.3 # Lowered threshold - or has_referential - or concept_overlap > 0.2 # Lowered threshold - or any( - word in current_question.lower() - for word in ["more", "about", "elaborate", "explain"] - ) - ) - - -async def get_relevant_context(question, hospital_id, doc_id=None): - try: - cache_key = f"context:hospital_{hospital_id}" - if doc_id: - cache_key += f":doc_{doc_id}" - cache_key += f":{question.lower().strip()}" - - redis_client = get_redis_client() - - cached_context = redis_client.get(cache_key) - if cached_context: - logging.info(f"Cache hit for key: {cache_key}") - return ( - cached_context.decode("utf-8") - if isinstance(cached_context, bytes) - else cached_context - ) - - vector_store = await initialize_or_load_vector_store(hospital_id) - if not vector_store: - return "" - - retriever = vector_store.as_retriever( - search_type="mmr", - search_kwargs={ - "k": 10, - "fetch_k": 20, - "lambda_mult": 0.6, - # "filter": {"doc_id": str(doc_id)} if doc_id else {"hospital_id": str(hospital_id)} - }, - ) - - docs = await asyncio.to_thread(retriever.get_relevant_documents, question) - if not docs: - return "" - - sorted_docs = sorted( - docs, - key=lambda x: ( - int(x.metadata.get("page_number", 0)), - int(x.metadata.get("chunk_index", 0)), - ), - ) - - context_parts = [doc.page_content for doc in sorted_docs] - context = "\n\n".join(context_parts) - - try: - redis_client.setex( - cache_key, - int(CACHE_TTL["vector_store"].total_seconds()), - context.encode("utf-8") if isinstance(context, str) else context, - ) - logging.info(f"Cached context for key: {cache_key}") - except Exception as cache_error: - logging.error(f"Failed to cache context: {cache_error}") - - return context - except Exception as e: - logging.error(f"Error getting relevant context: {e}") - return "" - - -def format_conversation_context(conv_history): - """Format conversation history into a string""" - if not conv_history: - return "No previous conversation." - return "\n".join( - [ - f"Q: {interaction['question']}\nA: {interaction['answer']}" - for interaction in conv_history - ] - ) - - -def get_icd_context_from_question(question, hospital_id): - """Extract any valid ICD codes from the question and return context""" - icd_data = load_icd_entries(hospital_id) - matches = [] - code_pattern = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", question.upper()) - - seen = set() - for code in code_pattern: - for entry in icd_data: - if entry["code"] == code and code not in seen: - matches.append(f"{entry['code']}: {entry['description']}") - seen.add(code) - return "\n".join(matches) - - -def get_fuzzy_icd_context(question, hospital_id, top_n=5, threshold=70): - """Get fuzzy matches for ICD codes from the question""" - icd_data = load_icd_entries(hospital_id) - descriptions = [entry["description"] for entry in icd_data] - matches = process.extract( - question, descriptions, limit=top_n, score_cutoff=threshold - ) - - matched_context = [] - for desc, score, _ in matches: - for entry in icd_data: - if entry["description"] == desc: - matched_context.append(f"{entry['code']}: {entry['description']}") - break - - return "\n".join(matched_context) - - -async def generate_answer_with_rag( - question, - hospital_id, - client, - doc_id=None, - user_id="default", - conversation_manager=None, - session_id=None, -): - """Generate an answer using RAG with improved conversation flow""" - try: - # Continue with regular RAG processing if not an ICD code or if no ICD match found - html_instruction = """ - IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: - - Use

      tags for paragraphs - - Use

      ,

      tags for headings and subheadings - - Use
        ,
      • tags for bullet points - - Use
          ,
        1. tags for numbered lists - - Use
          for quoted text - - Use for bold text and for emphasis - """ - - table_instruction = """ - - For tables, use proper HTML table structure: - - - - - - - {table_headers} - - - - {table_rows} - -
          {table_title}
          - """ - # Get conversation history first - conv_history = ( - conversation_manager.get_context_window(user_id, hospital_id, session_id) - if conversation_manager - else [] - ) - - # Get contextual query and relevant context first - contextual_query = await generate_contextual_query( - question, user_id, hospital_id, conversation_manager - ) - # Track ICD context across conversation - icd_context = {} - if conv_history: - # Extract ICD code from previous interaction - last_answer = conv_history[-1].get("answer", "") - icd_codes = re.findall(r"\b([A-Z][0-9A-Z]{2,6}[A-Z]?)\b", last_answer) - if icd_codes: - icd_context["last_code"] = icd_codes[0] - - # Check if current question is about a previously discussed ICD code - is_icd_followup = False - if icd_context.get("last_code"): - followup_indicators = [ - "what causes", - "what is causing", - "why", - "how", - "symptoms", - "treatment", - "diagnosis", - "causes", - "effects", - "complications", - "risk factors", - "prevention", - "prognosis", - "this", - "disease", - "that", - "it", - ] - is_icd_followup = any( - indicator in question.lower() for indicator in followup_indicators - ) - - if is_icd_followup: - # Add the previous ICD code context to the current question - icd_exact_context = get_icd_context_from_question( - icd_context["last_code"], hospital_id - ) - icd_fuzzy_context = get_fuzzy_icd_context( - f"{icd_context['last_code']} {question}", hospital_id - ) - else: - icd_exact_context = get_icd_context_from_question(question, hospital_id) - icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) - else: - icd_exact_context = get_icd_context_from_question(question, hospital_id) - icd_fuzzy_context = get_fuzzy_icd_context(question, hospital_id) - - # Get contextual query and relevant context - contextual_query = await generate_contextual_query( - question, user_id, hospital_id, conversation_manager - ) - doc_context = await get_relevant_context(contextual_query, hospital_id, doc_id) - - # Combine context with priority for ICD information - context_parts = [] - if is_icd_followup: - context_parts.append( - f"## Previous ICD Code Context\nContinuing discussion about: {icd_context['last_code']}" - ) - if icd_exact_context: - context_parts.append("## ICD Code Match\n" + icd_exact_context) - if icd_fuzzy_context: - context_parts.append("## Related ICD Suggestions\n" + icd_fuzzy_context) - if doc_context: - context_parts.append("## Document Context\n" + doc_context) - - context = "\n\n".join(context_parts) - - # Initialize follow-up detection - is_follow_up = False - - # Check if this is a follow-up question - if conv_history: - last_interaction = conv_history[-1] - last_question = last_interaction["question"].lower() - last_answer = last_interaction.get("answer", "").lower() - current_question = question.lower() - - # Define meaningful keywords that indicate entity-related follow-ups - entity_related_keywords = { - "achievements", - "awards", - "accomplishments", - "work", - "contributions", - "career", - "company", - "products", - "life", - "background", - "education", - "role", - "experience", - "history", - "details", - "places", - "place", - "information", - "facts", - "about", - "birth", - "death", - "family", - "books", - "projects", - "population", - } - - # Check if question is asking about attributes/achievements of previously discussed entity - has_entity_attribute = any( - word in current_question.split() for word in entity_related_keywords - ) - - # Extract entities from last answer to maintain context - def extract_entities(text): - # Split into words and get potential entities (capitalized words) - words = text.split() - entities = set() - current_entity = [] - - for word in words: - if word[0].isupper(): - current_entity.append(word) - elif current_entity: - if len(current_entity) > 0: - entities.add(" ".join(current_entity)) - current_entity = [] - - if current_entity: - entities.add(" ".join(current_entity)) - return entities - - last_entities = extract_entities(last_answer) - - # Check for referential words - referential_words = { - "it", - "this", - "that", - "these", - "those", - "they", - "their", - "he", - "she", - "him", - "her", - "his", - "hers", - "them", - "there", - "such", - "its", - } - has_referential = any( - word in referential_words for word in current_question.split() - ) - - # Calculate term overlap with both question and answer context - def get_significant_terms(text): - stop_words = { - "what", - "when", - "where", - "who", - "why", - "how", - "is", - "are", - "was", - "were", - "be", - "been", - "the", - "a", - "an", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "about", - "as", - "tell", - "me", - "please", - } - return set( - word - for word in text.split() - if len(word) > 2 and word.lower() not in stop_words - ) - - current_terms = get_significant_terms(current_question) - last_terms = get_significant_terms(last_question) - answer_terms = get_significant_terms(last_answer) - - # Include terms from both question and answer in context - all_prev_terms = last_terms | answer_terms - term_overlap = len(current_terms & all_prev_terms) - total_terms = len(current_terms | all_prev_terms) - term_similarity = term_overlap / total_terms if total_terms > 0 else 0 - - # Enhanced follow-up detection combining multiple signals - is_follow_up = ( - has_referential - or term_similarity - >= 0.2 # Lower threshold when including answer context - or ( - has_entity_attribute and bool(last_entities) - ) # Check if asking about attributes of known entity - or ( - last_interaction.get("type") == "general" - and term_similarity >= 0.15 - ) - ) - - logging.info(f"Follow-up analysis enhanced:") - logging.info(f"- Referential words: {has_referential}") - logging.info(f"- Term similarity: {term_similarity:.2f}") - logging.info(f"- Entity attribute question: {has_entity_attribute}") - logging.info(f"- Last entities found: {last_entities}") - logging.info(f"- Is follow-up: {is_follow_up}") - - # For entirely new topics (not follow-ups), use is_general_knowledge_question - if not is_follow_up: - is_general = is_general_knowledge_question(question, context, conv_history) - if is_general: - confirmation_prompt = f""" -

          Reviewed the hospital’s documentation, but this particular question does not seem to be covered.

          - """ - logging.info("General knowledge question detected") - return { - "answer": confirmation_prompt, - "requires_confirmation": True, - }, 200 - - # # For follow-up questions to general knowledge, handle directly - # if is_follow_up and conv_history and conv_history[-1].get("type") == "general": - # logging.info("Handling follow-up to general knowledge question") - # answer = await generate_general_knowledge_answer( - # question, - # client, - # user_id, - # hospital_id, - # conversation_manager, - # is_table_request(question), - # session_id=session_id, - # ) - # return {"answer": answer}, 200 - - # Generate RAG answer with enhanced context - prompt_template = f"""Based on the following context and conversation history, provide a detailed answer to the question. - Previous conversation: - {format_conversation_context(conv_history)} - - Context from documents: - {context} - - Current question: {question} - - Instructions: - 1. When providing medical codes (ICD, CPT, etc.): - - Always use the ICD codes listed in the sections titled "ICD Code Match" and "Related ICD Suggestions" from the context. - - Do not use or invent ICD codes from your own training knowledge unless they appear in the provided context. - - If multiple codes are relevant, return the one that best matches the user’s question. If unsure, return multiple options in HTML list format. - - Remove all decimal points (e.g., use 'A150' instead of 'A15.0') - - Format the response as: '

          The medical code for [condition] is [code]

          ' - 2. Address the current question while maintaining conversation continuity - 3. Resolve any ambiguous references using conversation history - 4. Format the response in clear HTML - - {html_instruction} - {table_instruction if is_table_request(question) else ""} - """ - - response = await asyncio.to_thread( - lambda: client.chat.completions.create( - model="gpt-3.5-turbo-16k", - messages=[ - {"role": "system", "content": prompt_template}, - {"role": "user", "content": question}, - ], - temperature=0.3, - max_tokens=1000, - ) - ) - - answer = ensure_html_response(response.choices[0].message.content) - logging.info(f"Generated RAG answer for question: {question}") - - # Store interaction in history - if conversation_manager: - await conversation_manager.add_rag_interaction( - user_id, hospital_id, question, answer, session_id - ) - - return {"answer": answer}, 200 - - except Exception as e: - logging.error(f"Error in generate_answer_with_rag: {e}") - return {"answer": f"

          Error: {str(e)}

          "}, 500 - - -# async def generate_general_knowledge_answer( -# query: str, -# client, -# user_id: str, -# hospital_id: int, -# conversation_manager, -# wants_table: bool = False, -# session_id=None, -# ) -> str: -# """Generate an answer from general knowledge with enhanced context awareness""" -# html_instruction = """ -# IMPORTANT: Format your ENTIRE response as HTML. Use appropriate HTML tags for all content: -# - Use

          tags for paragraphs -# - Use

          ,

          tags for headings and subheadings -# - Use
            ,
          • tags for bullet points -# - Use
              ,
            1. tags for numbered lists -# - Use
              for quoted text -# - Use for bold text and for emphasis -# """ - -# table_instruction = """ -# - For tables, use proper HTML table structure: -# -# -# -# -# -# -# {table_headers} -# -# -# -# {table_rows} -# -#
              {table_title}
              -# """ -# # Get conversation history for context -# conv_history = conversation_manager.get_context_window( -# user_id, hospital_id, session_id -# ) - -# # Extract entities from previous answer if available -# entities = [] -# if conv_history: -# last_answer = conv_history[-1].get("answer", "") -# words = last_answer.split() -# current_entity = [] -# for word in words: -# if word[0].isupper() and len(word) > 2: -# current_entity.append(word) -# elif current_entity: -# entities.append(" ".join(current_entity)) -# current_entity = [] -# if current_entity: -# entities.append(" ".join(current_entity)) - -# conversation_context = "\n".join( -# [ -# f"Q: {interaction['question']}\nA: {interaction['answer']}" -# for interaction in conv_history -# ] -# ) - -# # Enhanced system prompt with entity awareness -# system_prompt = f"""You are a helpful assistant providing information from general knowledge. -# Previous conversation: -# {conversation_context} - -# {'Previous answer mentioned these entities: ' + ', '.join(entities) if entities else ''} - -# 1. If the question seems incomplete or asks about attributes (like population, details, etc.), -# assume it's referring to the most recently mentioned entity ({entities[0] if entities else 'none'}). -# 2. Present information in a clear, organized manner. -# 3. If you're uncertain about any information, acknowledge that limitation. -# 4. If the question requires specialized expertise, note that the user may want to consult an expert. -# 5. Format the response in clear HTML. - -# {html_instruction} -# {table_instruction if wants_table else ""} -# """ - -# try: -# logging.info(f"Generating general knowledge answer for: {query}") -# response = client.chat.completions.create( -# model="gpt-3.5-turbo", -# messages=[ -# {"role": "system", "content": system_prompt}, -# {"role": "user", "content": query}, -# ], -# temperature=0.3, -# stream=True, -# ) - -# response_text = "" -# for chunk in response: -# if chunk.choices[0].delta.content: -# response_text += chunk.choices[0].delta.content - -# formatted_response = ensure_html_response(response_text.strip()) - -# # Store the interaction in conversation history -# conversation_manager.add_general_knowledge_interaction( -# user_id, hospital_id, query, formatted_response, session_id -# ) - -# return formatted_response - -# except Exception as e: -# logging.error(f"Error generating general knowledge answer: {e}") -# return f"

              Error generating response: {str(e)}

              " - - -async def load_existing_vector_stores(): - """Load existing Chroma vector stores for each hospital""" - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - try: - await cursor.execute("SELECT DISTINCT id FROM hospitals") - hospital_ids = [row[0] for row in await cursor.fetchall()] - - for hospital_id in hospital_ids: - try: - await initialize_or_load_vector_store(hospital_id) - except Exception as e: - logging.error( - f"Failed to load vector store for hospital {hospital_id}: {e}" - ) - continue - - except Exception as e: - logging.error(f"Error loading vector stores: {e}") - - -async def get_failed_page(doc_id): - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - try: - await cursor.execute( - "SELECT failed_page FROM documents WHERE id = %s", (doc_id,) - ) - result = await cursor.fetchone() - return result[0] if result and result[0] else None - except Exception as e: - logging.error(f"Database error checking failed_page: {e}") - return None - - -async def update_document_status(doc_id, status, failed_page=None): - """Update document status with enum validation""" - if isinstance(status, str): - status = DocumentStatus[status.upper()].value - - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - try: - if failed_page: - await cursor.execute( - "UPDATE documents SET processed_status = %s, failed_page = %s WHERE id = %s", - (status, failed_page, doc_id), - ) - else: - await cursor.execute( - "UPDATE documents SET processed_status = %s, failed_page = NULL WHERE id = %s", - (status, doc_id), - ) - await conn.commit() - return True - except Exception as e: - logging.error(f"Database update error: {e}") - return False - - -thread_pool = ThreadPoolExecutor(max_workers=10) - - -def async_to_sync(coroutine): - """Helper function to run async code in sync context""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(coroutine) - finally: - loop.close() - - -@app.route("/flask-api", methods=["GET"]) -def health_check(): - """Health check endpoint""" - access_logger.info(f"Health check request received from {request.remote_addr}") - return jsonify({"status": "ok"}), 200 - - -@app.route("/flask-api/process-pdf", methods=["POST"]) -def process_pdf(): - access_logger.info(f"PDF processing request received from {request.remote_addr}") - file_path = None - try: - file = request.files.get("pdf") - hospital_id = request.form.get("hospital_id") - doc_id = request.form.get("doc_id") - - logging.info( - f"Received PDF processing request for hospital {hospital_id}, doc_id {doc_id}" - ) - - if not all([file, hospital_id, doc_id]): - return jsonify({"error": "Missing required parameters"}), 400 - - def process_in_background(): - nonlocal file_path - try: - async_to_sync(update_document_status(doc_id, "processing")) - - # Add progress logging - logging.info(f"Starting processing of document {doc_id}") - - filename = f"doc_{doc_id}_{file.filename}" - file_path = os.path.join(uploads_dir, filename) - - with open(file_path, "wb") as f: - file.save(f) - - logging.info("Extracting PDF contents...") - content = extract_pdf_contents(file_path, int(hospital_id)) - - logging.info("Inserting content into database...") - metadata = {"filename": filename} - result = async_to_sync( - insert_content_into_db(content, metadata, doc_id) - ) - - if "error" in result: - async_to_sync(update_document_status(doc_id, "failed", 1)) - return False - - logging.info("Creating embeddings and indexing...") - success = async_to_sync(add_document_to_index(doc_id, hospital_id)) - - if success: - logging.info("Document processing completed successfully") - async_to_sync(update_document_status(doc_id, "processed")) - return True - else: - logging.error("Document processing failed during indexing") - async_to_sync(update_document_status(doc_id, "failed")) - return False - - except Exception as e: - logging.error(f"Processing error: {e}") - async_to_sync(update_document_status(doc_id, "failed")) - return False - finally: - if file_path and os.path.exists(file_path): - try: - os.remove(file_path) - except Exception as e: - logging.error(f"Error removing temporary file: {e}") - - # Execute processing and wait for result - future = thread_pool.submit(process_in_background) - success = future.result() - - if success: - return jsonify({"message": "Document processed successfully"}), 200 - else: - return jsonify({"error": "Document processing failed"}), 500 - - except Exception as e: - logging.error(f"API error: {e}") - if file_path and os.path.exists(file_path): - try: - os.remove(file_path) - except Exception as file_e: - logging.error(f"Error removing temporary file: {file_e}") - return jsonify({"error": str(e)}), 500 - - -# Initialize the hybrid conversation manager -redis_client = get_redis_client() -conversation_manager = HybridConversationManager(redis_client) - - -@app.route("/flask-api/generate-answer", methods=["POST"]) -def rag_answer_api(): - """Sync API endpoint for RAG-based question answering with conversation history.""" - access_logger.info(f"Generate answer request received from {request.remote_addr}") - try: - data = request.json - question = data.get("question", "").strip().lower() - hospital_code = data.get("hospital_code") - doc_id = data.get("doc_id") - user_id = data.get("user_id", "default") - session_id = data.get("session_id", None) - - logging.info(f"Received question from user {user_id}: {question}") - logging.info(f"Received hospital code: {hospital_code}") - logging.info(f"Received session_id: {session_id}") - - # is_confirmation_response = data.get("is_confirmation_response", False) - original_query = data.get("original_query", "") - - def process_rag_answer(): - try: - hospital_id = async_to_sync(get_hospital_id(hospital_code)) - logging.info(f"Resolved hospital ID: {hospital_id}") - - if not hospital_id: - return { - "error": "Invalid or missing 'hospital_code' in request" - }, 400 - - # if question == "yes" and original_query: - # # User confirmed they want a general knowledge answer - # answer = async_to_sync( - # generate_general_knowledge_answer( - # original_query, - # client, - # user_id, - # hospital_id, - # conversation_manager, # Pass the hybrid manager - # is_table_request(original_query), - # session_id=session_id, - # ) - # ) - # return {"answer": answer}, 200 - - if original_query: - response_message = """ -

              I can only answer questions based on information found in the hospital documents.

              -

              The question you asked doesn't seem to be covered in the available documents.

              -

              You can try rephrasing your question or asking about a different topic.

              - """ - return {"answer": response_message}, 200 - - else: - # Regular RAG answer - return async_to_sync( - generate_answer_with_rag( - question=question, - hospital_id=hospital_id, - client=client, - doc_id=doc_id, - user_id=user_id, - conversation_manager=conversation_manager, # Pass the hybrid manager - session_id=session_id, - ) - ) - except Exception as e: - logging.error(f"Thread processing error: {str(e)}") - return {"error": str(e)}, 500 - - if not question: - return jsonify({"error": "Missing 'question' in request"}), 400 - - future = thread_pool.submit(process_rag_answer) - result, status_code = future.result() - - return jsonify(result), status_code - - except Exception as e: - logging.error(f"API error: {str(e)}") - return jsonify({"error": str(e)}), 500 - - -@app.route("/flask-api/delete-document-vectors", methods=["DELETE"]) -def delete_document_vectors_endpoint(): - """Endpoint to delete document vectors from ChromaDB""" - try: - data = request.json - hospital_id = data.get("hospital_id") - doc_id = data.get("doc_id") - - if not all([hospital_id, doc_id]): - return jsonify({"error": "Missing required parameters"}), 400 - - logging.info( - f"Received request to delete vectors for document {doc_id} from hospital {hospital_id}" - ) - - def process_deletion(): - try: - success = async_to_sync(delete_document_vectors(hospital_id, doc_id)) - if success: - return {"message": "Document vectors deleted successfully"}, 200 - else: - return {"error": "Failed to delete document vectors"}, 500 - except Exception as e: - logging.error(f"Error in vector deletion process: {e}") - return {"error": str(e)}, 500 - - future = thread_pool.submit(process_deletion) - result, status_code = future.result() - - return jsonify(result), status_code - - except Exception as e: - logging.error(f"API error: {str(e)}") - return jsonify({"error": str(e)}), 500 - - -@app.route("/flask-api/get-chroma-content", methods=["GET"]) -def get_chroma_content_endpoint(): - """API endpoint to get ChromaDB content by hospital_id""" - try: - hospital_id = request.args.get("hospital_id") - limit = int(request.args.get("limit", 30000)) - - if not hospital_id: - return jsonify({"error": "Missing required parameter: hospital_id"}), 400 - - def process_fetch(): - try: - result, status_code = async_to_sync( - get_chroma_content_by_hospital( - hospital_id=int(hospital_id), limit=limit - ) - ) - return result, status_code - except Exception as e: - logging.error(f"Error in ChromaDB fetch process: {e}") - return {"error": str(e)}, 500 - - future = thread_pool.submit(process_fetch) - result, status_code = future.result() - - return jsonify(result), status_code - - except Exception as e: - logging.error(f"API error: {str(e)}") - return jsonify({"error": str(e)}), 500 - - -async def get_chroma_content_by_hospital(hospital_id: int, limit: int = 100): - """Fetch content from ChromaDB for a specific hospital""" - try: - # Initialize vector store - vector_store = await initialize_or_load_vector_store(hospital_id) - if not vector_store: - return {"error": "Vector store not found"}, 404 - - # Get collection - collection = vector_store._collection - - # Query the collection with hospital_id filter - results = await asyncio.to_thread( - lambda: collection.get(where={"hospital_id": str(hospital_id)}, limit=limit) - ) - - if not results or not results["ids"]: - return {"data": [], "count": 0}, 200 - - # Format the response - formatted_results = [] - for i in range(len(results["ids"])): - formatted_results.append( - { - "id": results["ids"][i], - "content": results["documents"][i], - "metadata": results["metadatas"][i], - } - ) - - return {"data": formatted_results, "count": len(formatted_results)}, 200 - - except Exception as e: - logging.error(f"Error fetching ChromaDB content: {e}") - return {"error": str(e)}, 500 - - -@app.before_request -def before_request(): - request._start_time = time.time() - - -@app.after_request -def after_request(response): - if hasattr(request, "_start_time"): - duration = time.time() - request._start_time - access_logger.info( - f'"{request.method} {request.path}" {response.status_code} - Duration: {duration:.3f}s - ' - f"IP: {request.remote_addr}" - ) - return response - - -if __name__ == "__main__": - logger.info("Starting SpurrinAI application") - logger.info(f"Python version: {sys.version}") - logger.info(f"Environment: {os.getenv('FLASK_ENV', 'production')}") - - try: - model_manager = ModelManager() - logger.info("Model manager initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize model manager: {e}") - sys.exit(1) - - # Initialize directories - os.makedirs(DATA_DIR, exist_ok=True) - os.makedirs(CHROMA_DIR, exist_ok=True) - logger.info(f"Initialized directories: {DATA_DIR}, {CHROMA_DIR}") - - # Clear Redis cache - redis_client = get_redis_client() - cleared_keys = 0 - for key in redis_client.scan_iter("vector_store_data:*"): - redis_client.delete(key) - cleared_keys += 1 - logger.info(f"Cleared {cleared_keys} Redis cache keys") - - # Load vector stores - logger.info("Loading existing vector stores...") - async_to_sync(load_existing_vector_stores()) - logger.info("Vector stores loaded successfully") - - # Start application - logger.info("Starting Flask application on port 5000") - app.run(port=5000, debug=False) \ No newline at end of file