# """ # SpurrinAI - Intelligent Document Processing and Question Answering System # Copyright (c) 2024 Tech4biz. All rights reserved. # This module implements the main Flask application for the SpurrinAI system, # providing REST APIs for document processing, vector storage, and question answering # using RAG (Retrieval Augmented Generation) architecture. # Author: Tech4biz Development Team # Version: 1.0.0 # Last Updated: 2024-01-19 # """ # # Standard library imports # import os # import re # import sys # import json # import time # import threading # import asyncio # from concurrent.futures import ThreadPoolExecutor # from dataclasses import dataclass # from datetime import timedelta # from enum import Enum # # Third-party imports # import spacy # import redis # import aiomysql # from dotenv import load_dotenv # from flask import Flask, request, jsonify, Response # from flask_cors import CORS # from tqdm import tqdm # from tqdm.asyncio import tqdm as tqdm_async # from langchain_community.document_loaders import PyPDFLoader # from langchain.text_splitter import RecursiveCharacterTextSplitter # from langchain_community.embeddings import OpenAIEmbeddings # from langchain_community.vectorstores import Chroma # from langchain_community.chat_models import ChatOpenAI # from langchain.chains import RetrievalQA # from langchain.prompts import PromptTemplate # from openai import OpenAI # from rapidfuzz import process # from threading import Lock # # Local imports # from model_manager import ModelManager # # Suppress warnings # import warnings # warnings.filterwarnings("ignore") # # Initialize NLTK # import nltk # nltk.download("punkt") # # Configure logging # import logging # import logging.handlers # app = Flask(__name__) # CORS(app) # script_dir = os.path.dirname(os.path.abspath(__file__)) # log_file_path = os.path.join(script_dir, "error.log") # logging.basicConfig(filename=log_file_path, level=logging.INFO) # # Configure logging # def setup_logging(): # log_dir = os.path.join(script_dir, "logs") # os.makedirs(log_dir, exist_ok=True) # main_log = os.path.join(log_dir, "app.log") # error_log = os.path.join(log_dir, "error.log") # access_log = os.path.join(log_dir, "access.log") # perf_log = os.path.join(log_dir, "performance.log") # # Create formatters # detailed_formatter = logging.Formatter( # "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" # ) # access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # # Main logger setup # main_handler = logging.handlers.RotatingFileHandler( # main_log, maxBytes=10485760, backupCount=5 # ) # main_handler.setFormatter(detailed_formatter) # main_handler.setLevel(logging.INFO) # # Error logger setup # error_handler = logging.handlers.RotatingFileHandler( # error_log, maxBytes=10485760, backupCount=5 # ) # error_handler.setFormatter(detailed_formatter) # error_handler.setLevel(logging.ERROR) # # Access logger setup # access_handler = logging.handlers.TimedRotatingFileHandler( # access_log, when="midnight", interval=1, backupCount=30 # ) # access_handler.setFormatter(access_formatter) # access_handler.setLevel(logging.INFO) # # Performance logger setup # perf_handler = logging.handlers.RotatingFileHandler( # perf_log, maxBytes=10485760, backupCount=5 # ) # perf_handler.setFormatter(detailed_formatter) # perf_handler.setLevel(logging.INFO) # # Configure root logger # root_logger = logging.getLogger() # root_logger.setLevel(logging.INFO) # root_logger.addHandler(main_handler) # root_logger.addHandler(error_handler) # # Create specific loggers # access_logger = logging.getLogger("access") # access_logger.addHandler(access_handler) # access_logger.setLevel(logging.INFO) # perf_logger = logging.getLogger("performance") # perf_logger.addHandler(perf_handler) # perf_logger.setLevel(logging.INFO) # return root_logger, access_logger, perf_logger # # Initialize loggers # logger, access_logger, perf_logger = setup_logging() # # DB_CONFIG = { # # 'host': 'localhost', # # 'user': 'flaskuser', # # 'password': 'Flask@123', # # 'database': 'spurrinai', # # } # # DB_CONFIG = { # # 'host': 'localhost', # # 'user': 'spurrindevuser', # # 'password': 'Admin@123', # # 'database': 'spurrindev', # # } # # DB_CONFIG = { # # 'host': 'localhost', # # 'user': 'root', # # 'password': 'root', # # 'database': 'medqueryai', # # 'port': 3307 # # } # # Redis Configuration # REDIS_CONFIG = { # "host": "localhost", # "port": 6379, # "db": 0, # "decode_responses": True, # For string operations # } # DB_CONFIG = { # "host": os.getenv("DB_HOST", "localhost"), # "user": os.getenv("DB_USER", "testuser"), # "password": os.getenv("DB_PASSWORD", "Admin@123"), # "database": os.getenv("DB_NAME", "spurrintest"), # } # # Redis connection pool # redis_pool = redis.ConnectionPool(**REDIS_CONFIG) # redis_binary_pool = redis.ConnectionPool( # host="localhost", port=6379, db=1, decode_responses=False # ) # def get_redis_client(binary=False): # """Get Redis client from pool""" # logger.debug(f"Getting Redis client with binary={binary}") # try: # pool = redis_binary_pool if binary else redis_pool # client = redis.Redis(connection_pool=pool) # logger.debug("Redis client created successfully") # return client # except Exception as e: # logger.error(f"Failed to create Redis client: {e}", exc_info=True) # raise # def fetch_cached_answer(cache_key): # logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") # start_time = time.time() # try: # redis_client = get_redis_client() # cached_answer = redis_client.get(cache_key) # fetch_time = time.time() - start_time # perf_logger.info( # f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" # ) # return cached_answer # except Exception as e: # logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) # return None # # Cache TTL configurations # CACHE_TTL = { # "vector_store": timedelta(hours=24), # "chat_completion": timedelta(hours=1), # "document_metadata": timedelta(days=7), # } # DATA_DIR = os.path.join(script_dir, "hospital_data") # CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") # uploads_dir = os.path.join(script_dir, "llm-uploads") # if not os.path.exists(uploads_dir): # os.makedirs(uploads_dir) # nlp = spacy.load("en_core_web_sm") # load_dotenv() # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # client = OpenAI(api_key=OPENAI_API_KEY) # embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) # llm = ChatOpenAI( # model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY # ) # # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) # hospital_vector_stores = {} # vector_store_lock = threading.Lock() # @dataclass # class Document: # doc_id: int # page_num: int # content: str # class DocumentStatus(Enum): # PROCESSING = "processing" # PROCESSED = "processed" # FAILED = "failed" # async def get_db_pool(): # return await aiomysql.create_pool( # host=DB_CONFIG["host"], # user=DB_CONFIG["user"], # password=DB_CONFIG["password"], # db=DB_CONFIG["database"], # autocommit=True, # ) # async def get_hospital_id(hospital_code): # try: # pool = await get_db_pool() # async with pool.acquire() as conn: # async with conn.cursor(aiomysql.DictCursor) as cursor: # await cursor.execute( # "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", # (hospital_code,), # ) # result = await cursor.fetchone() # return result["id"] if result else None # except Exception as error: # logging.error(f"Database error: {error}") # return None # finally: # pool.close() # await pool.wait_closed() # CHUNK_SIZE = 1000 # CHUNK_OVERLAP = 50 # BATCH_SIZE = 1000 # text_splitter = RecursiveCharacterTextSplitter( # chunk_size=CHUNK_SIZE, # chunk_overlap=CHUNK_OVERLAP, # # length_function=len, # # separators=["\n\n", "\n", ". ", " ", ""] # ) # # Update the JSON_PATH to be dynamic based on hospital_id # def get_icd_json_path(hospital_id): # hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") # os.makedirs(hospital_data_dir, exist_ok=True) # return os.path.join(hospital_data_dir, "icd_data.json") # def extract_and_process_icd_data(content, hospital_id, save_to_json=True): # """Extract and process ICD codes with optimized processing and optional JSON saving""" # try: # # Initialize pattern compilation once # pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) # # Process in chunks for large content # chunk_size = 50000 # Process 50KB at a time # icd_data = [] # current_code = None # current_description = [] # # Split content into manageable chunks # content_chunks = [ # content[i : i + chunk_size] for i in range(0, len(content), chunk_size) # ] # # Process each chunk # for chunk in content_chunks: # lines = chunk.splitlines() # for line in lines: # line = line.strip() # if not line: # if current_code and current_description: # icd_data.append( # { # "code": current_code, # "description": " ".join(current_description).strip(), # } # ) # current_code = None # current_description = [] # continue # match = pattern.match(line) # if match: # if current_code and current_description: # icd_data.append( # { # "code": current_code, # "description": " ".join(current_description).strip(), # } # ) # current_code, description = match.groups() # current_description = [description.strip()] # elif current_code: # current_description.append(line) # # Add final entry if exists # if current_code and current_description: # icd_data.append( # { # "code": current_code, # "description": " ".join(current_description).strip(), # } # ) # # Save to hospital-specific JSON if requested # if save_to_json and icd_data: # try: # json_path = get_icd_json_path(hospital_id) # # Use a lock for thread safety # with threading.Lock(): # if os.path.exists(json_path): # with open(json_path, "r", encoding="utf-8") as f: # try: # existing_data = json.load(f) # except json.JSONDecodeError: # existing_data = [] # else: # existing_data = [] # # Efficient deduplication using dictionary # seen_codes = {item["code"]: item for item in existing_data} # for item in icd_data: # seen_codes[item["code"]] = item # unique_data = list(seen_codes.values()) # # Write atomically using temporary file # temp_path = f"{json_path}.tmp" # with open(temp_path, "w", encoding="utf-8") as f: # json.dump(unique_data, f, indent=2, ensure_ascii=False) # os.replace(temp_path, json_path) # logging.info( # f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" # ) # except Exception as e: # logging.error( # f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" # ) # return icd_data # except Exception as e: # logging.error(f"Error in extract_and_process_icd_data: {e}") # return [] # def load_icd_entries(hospital_id): # """Load ICD entries from hospital-specific JSON file""" # json_path = get_icd_json_path(hospital_id) # try: # if os.path.exists(json_path): # with open(json_path, "r", encoding="utf-8") as f: # return json.load(f) # return [] # except Exception as e: # logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") # return [] # # Update the process_icd_codes function to include hospital_id # async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): # """Process and store ICD codes using the optimized extraction function""" # try: # # Extract and save codes with hospital_id # extract_and_process_icd_data(content, hospital_id, save_to_json=True) # except Exception as e: # logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") # async def initialize_icd_vector_store(hospital_id): # """This function is deprecated. ICD codes are now handled through JSON search.""" # logging.warning( # "initialize_icd_vector_store is deprecated - using JSON search instead" # ) # return None # def extract_pdf_contents(pdf_path, hospital_id): # """Extract PDF contents with optimized chunking and code extraction""" # try: # loader = PyPDFLoader(pdf_path) # pages = loader.load() # pages_content = [] # for i, page in enumerate(tqdm(pages, desc="Extracting pages")): # text = page.page_content.strip() # # Extract ICD codes from the page # icd_codes = extract_and_process_icd_data( # text, hospital_id # ) # We'll set doc_id later # pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) # return pages_content # except Exception as e: # logging.error(f"Error in extract_pdf_contents: {e}") # raise # async def insert_content_into_db(content, metadata, doc_id): # pool = await get_db_pool() # async with pool.acquire() as conn: # async with conn.cursor() as cursor: # try: # metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" # content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" # metadata_values = [ # (doc_id, key[:100], value) # for key, value in metadata.items() # if value # ] # content_values = [ # (doc_id, page_content["page"], page_content["text"]) # for page_content in content # ] # if metadata_values: # await cursor.executemany(metadata_query, metadata_values) # if content_values: # await cursor.executemany(content_query, content_values) # await conn.commit() # return {"message": "Success"} # except Exception as e: # await conn.rollback() # return {"error": str(e)} # async def initialize_or_load_vector_store(hospital_id, user_id="default"): # """Initialize or load vector store with Redis caching and thread safety""" # store_key = f"{hospital_id}:{user_id}" # try: # # Check if we already have it loaded - with lock for thread safety # with vector_store_lock: # if store_key in hospital_vector_stores: # return hospital_vector_stores[store_key] # # Initialize vector store # redis_client = get_redis_client(binary=True) # cache_key = f"vector_store_data:{hospital_id}:{user_id}" # hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") # if os.path.exists(hospital_dir): # logging.info( # f"Loading vector store for hospital {hospital_id} and user {user_id}" # ) # vector_store = await asyncio.to_thread( # lambda: Chroma( # collection_name=f"hospital_{hospital_id}", # persist_directory=hospital_dir, # embedding_function=embeddings, # ) # ) # else: # logging.info(f"Creating vector store for hospital {hospital_id}") # os.makedirs(hospital_dir, exist_ok=True) # vector_store = await asyncio.to_thread( # lambda: Chroma( # collection_name=f"hospital_{hospital_id}", # persist_directory=hospital_dir, # embedding_function=embeddings, # ) # ) # # Store with lock for thread safety # with vector_store_lock: # hospital_vector_stores[store_key] = vector_store # return vector_store # except Exception as e: # logging.error(f"Error initializing vector store: {e}", exc_info=True) # raise # async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: # """Delete all vectors associated with a specific document from ChromaDB""" # try: # # Initialize vector store for the hospital # vector_store = await initialize_or_load_vector_store(hospital_id) # # Delete vectors with matching doc_id # await asyncio.to_thread( # lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) # ) # # Persist changes # await asyncio.to_thread(vector_store.persist) # # Clear Redis cache for this document # redis_client = get_redis_client() # pattern = f"vector_store_data:{hospital_id}:*" # for key in redis_client.scan_iter(pattern): # redis_client.delete(key) # logging.info( # f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" # ) # return True # except Exception as e: # logging.error(f"Error deleting document vectors: {e}", exc_info=True) # return False # async def add_document_to_index(doc_id, hospital_id): # try: # pool = await get_db_pool() # async with pool.acquire() as conn: # async with conn.cursor() as cursor: # vector_store = await initialize_or_load_vector_store(hospital_id) # await cursor.execute( # "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", # (doc_id,), # ) # rows = await cursor.fetchall() # total_pages = len(rows) # logging.info(f"Processing {total_pages} pages for document {doc_id}") # page_bar = tqdm_async(total=total_pages, desc="Processing pages") # async def process_page(page_data): # page_num, content = page_data # try: # icd_data = extract_and_process_icd_data( # content, hospital_id, save_to_json=False # ) # chunks = text_splitter.split_text(content) # await asyncio.sleep(0) # Yield control # return page_num, chunks, icd_data # except Exception as e: # logging.error(f"Error processing page {page_num}: {e}") # return page_num, [], [] # tasks = [asyncio.create_task(process_page(row)) for row in rows] # results = [] # for coro in asyncio.as_completed(tasks): # result = await coro # results.append(result) # page_bar.update(1) # page_bar.close() # # Vector addition progress bar # all_icd_data = [] # all_chunks = [] # all_metadatas = [] # chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) # for result in results: # page_num, chunks, icd_data = result # all_icd_data.extend(icd_data) # for i, chunk in enumerate(chunks): # all_chunks.append(chunk) # all_metadatas.append( # { # "doc_id": str(doc_id), # "hospital_id": str(hospital_id), # "page_number": str(page_num), # "chunk_index": str(i), # } # ) # if len(all_chunks) >= BATCH_SIZE: # chunk_add_bar.total += len(all_chunks) # chunk_add_bar.refresh() # await asyncio.to_thread( # vector_store.add_texts, # texts=all_chunks, # metadatas=all_metadatas, # ) # all_chunks = [] # all_metadatas = [] # chunk_add_bar.update(BATCH_SIZE) # # Final batch # if all_chunks: # chunk_add_bar.total += len(all_chunks) # chunk_add_bar.refresh() # await asyncio.to_thread( # vector_store.add_texts, # texts=all_chunks, # metadatas=all_metadatas, # ) # chunk_add_bar.update(len(all_chunks)) # chunk_add_bar.close() # if all_icd_data: # logging.info(f"Saving {len(all_icd_data)} ICD codes") # extract_and_process_icd_data("", hospital_id, save_to_json=True) # await asyncio.to_thread(vector_store.persist) # logging.info(f"Successfully indexed document {doc_id}") # return True # except Exception as e: # logging.error(f"Error adding document: {e}") # return False # def is_general_knowledge_question( # query: str, context: str, conversation_context=None # ) -> bool: # """ # Determine if a question is likely a general knowledge question not covered in the documents. # Takes conversation history into account to reduce repeated confirmations. # """ # query_lower = query.lower() # context_lower = context.lower() # if conversation_context: # for interaction in conversation_context: # prev_question = interaction.get("question", "").lower() # if ( # prev_question # and query_lower in prev_question # or prev_question in query_lower # ): # logging.info( # f"Question is similar to previous conversation, skipping confirmation" # ) # return False # stop_words = { # "search", # "query:", # "can", # "you", # "some", # "at", # "the", # "a", # "an", # "in", # "on", # "at", # "to", # "for", # "with", # "by", # "about", # "give", # "full", # "is", # "are", # "was", # "were", # "define", # "what", # "how", # "why", # "when", # "where", # "year", # "list", # "form", # "table", # "who", # "which", # "me", # "tell", # "explain", # "describe", # "of", # "and", # "or", # "there", # "their", # "please", # "could", # "would", # "various", # "different", # "type", # "types", # "kind", # "kinds", # "has", # "have", # "had", # "many", # "say", # "know", # } # key_words = [ # word for word in query_lower.split() if word not in stop_words and len(word) > 2 # ] # logging.info(f"Key words: {key_words}") # if not key_words: # logging.info("No significant keywords found, directing to general knowledge") # return True # matches = sum(1 for word in key_words if word in context_lower) # logging.info(f"Matches: {matches} out of {len(key_words)} keywords") # match_ratio = matches / len(key_words) # logging.info(f"Match ratio: {match_ratio}") # return match_ratio < 0.4 # def is_table_request(query: str) -> bool: # """ # Determine if the user is requesting a response in tabular format. # """ # table_keywords = [ # "table", # "tabular", # "in a table", # "in table format", # "in tabular format", # "chart", # "data", # "comparison", # "as a table", # "table format", # "in rows and columns", # "in a grid", # "breakdown", # "spreadsheet", # "comparison table", # "data table", # "structured table", # "tabular form", # "table form", # ] # query_lower = query.lower() # return any(keyword in query_lower for keyword in table_keywords) # import re # def ensure_html_response(text: str) -> str: # """ # Ensure the response is properly formatted in HTML. # This function handles plain text conversion to HTML. # """ # if "", text)) # if not has_html_tags: # paragraphs = text.split("\n\n") # html_parts = [] # in_ordered_list = False # in_unordered_list = False # for para in paragraphs: # if para.strip(): # if re.match(r"^\s*[\*\-\•]\s", para): # if not in_unordered_list: # html_parts.append("") # in_unordered_list = False # html_parts.append(f"

{para}

") # if in_ordered_list: # html_parts.append("") # if in_unordered_list: # html_parts.append("") # return "".join(html_parts) # else: # if not any(tag in text for tag in ("

", "

", "