""" SpurrinAI - Intelligent Document Processing and Question Answering System Copyright (c) 2024 Tech4biz. All rights reserved. This module implements the main Flask application for the SpurrinAI system, providing REST APIs for document processing, vector storage, and question answering using RAG (Retrieval Augmented Generation) architecture. Author: Tech4biz Development Team Version: 1.0.0 Last Updated: 2024-01-19 """ # Standard library imports import os import re import sys import json import time import threading import asyncio from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import timedelta from enum import Enum # Third-party imports import spacy import redis import aiomysql from dotenv import load_dotenv from flask import Flask, request, jsonify, Response from flask_cors import CORS from tqdm import tqdm from tqdm.asyncio import tqdm as tqdm_async from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain_community.chat_models import ChatOpenAI from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from openai import OpenAI from rapidfuzz import process from threading import Lock # Local imports from model_manager import ModelManager # Suppress warnings import warnings warnings.filterwarnings("ignore") # Initialize NLTK import nltk nltk.download("punkt") # Configure logging import logging import logging.handlers app = Flask(__name__) CORS(app) script_dir = os.path.dirname(os.path.abspath(__file__)) log_file_path = os.path.join(script_dir, "error.log") logging.basicConfig(filename=log_file_path, level=logging.INFO) # Configure logging def setup_logging(): log_dir = os.path.join(script_dir, "logs") os.makedirs(log_dir, exist_ok=True) main_log = os.path.join(log_dir, "app.log") error_log = os.path.join(log_dir, "error.log") access_log = os.path.join(log_dir, "access.log") perf_log = os.path.join(log_dir, "performance.log") # Create formatters detailed_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" ) access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # Main logger setup main_handler = logging.handlers.RotatingFileHandler( main_log, maxBytes=10485760, backupCount=5 ) main_handler.setFormatter(detailed_formatter) main_handler.setLevel(logging.INFO) # Error logger setup error_handler = logging.handlers.RotatingFileHandler( error_log, maxBytes=10485760, backupCount=5 ) error_handler.setFormatter(detailed_formatter) error_handler.setLevel(logging.ERROR) # Access logger setup access_handler = logging.handlers.TimedRotatingFileHandler( access_log, when="midnight", interval=1, backupCount=30 ) access_handler.setFormatter(access_formatter) access_handler.setLevel(logging.INFO) # Performance logger setup perf_handler = logging.handlers.RotatingFileHandler( perf_log, maxBytes=10485760, backupCount=5 ) perf_handler.setFormatter(detailed_formatter) perf_handler.setLevel(logging.INFO) # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) root_logger.addHandler(main_handler) root_logger.addHandler(error_handler) # Create specific loggers access_logger = logging.getLogger("access") access_logger.addHandler(access_handler) access_logger.setLevel(logging.INFO) perf_logger = logging.getLogger("performance") perf_logger.addHandler(perf_handler) perf_logger.setLevel(logging.INFO) return root_logger, access_logger, perf_logger load_dotenv() # Initialize loggers logger, access_logger, perf_logger = setup_logging() # DB_CONFIG = { # 'host': 'localhost', # 'user': 'flaskuser', # 'password': 'Flask@123', # 'database': 'spurrinai', # } # DB_CONFIG = { # 'host': 'localhost', # 'user': 'spurrindevuser', # 'password': 'Admin@123', # 'database': 'spurrindev', # } # Redis Configuration REDIS_CONFIG = { "host": "localhost", "port": 6379, "db": 0, "decode_responses": True, # For string operations } DB_CONFIG = { "host": os.getenv("DB_HOST"), "user": os.getenv("DB_USER"), "password": os.getenv("DB_PASSWORD"), "database": os.getenv("DB_NAME"), } # Redis connection pool redis_pool = redis.ConnectionPool(**REDIS_CONFIG) redis_binary_pool = redis.ConnectionPool( host="localhost", port=6379, db=1, decode_responses=False ) def get_redis_client(binary=False): """Get Redis client from pool""" logger.debug(f"Getting Redis client with binary={binary}") try: pool = redis_binary_pool if binary else redis_pool client = redis.Redis(connection_pool=pool) logger.debug("Redis client created successfully") return client except Exception as e: logger.error(f"Failed to create Redis client: {e}", exc_info=True) raise def fetch_cached_answer(cache_key): logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") start_time = time.time() try: redis_client = get_redis_client() cached_answer = redis_client.get(cache_key) fetch_time = time.time() - start_time perf_logger.info( f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" ) return cached_answer except Exception as e: logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) return None # Cache TTL configurations CACHE_TTL = { "vector_store": timedelta(hours=24), "chat_completion": timedelta(hours=1), "document_metadata": timedelta(days=7), } DATA_DIR = os.path.join(script_dir, "hospital_data") CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") uploads_dir = os.path.join(script_dir, "llm-uploads") if not os.path.exists(uploads_dir): os.makedirs(uploads_dir) nlp = spacy.load("en_core_web_sm") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=OPENAI_API_KEY) embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) llm = ChatOpenAI( model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY ) # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) hospital_vector_stores = {} vector_store_lock = threading.Lock() @dataclass class Document: doc_id: int page_num: int content: str class DocumentStatus(Enum): PROCESSING = "processing" PROCESSED = "processed" FAILED = "failed" async def get_db_pool(): return await aiomysql.create_pool( host=DB_CONFIG["host"], user=DB_CONFIG["user"], password=DB_CONFIG["password"], db=DB_CONFIG["database"], autocommit=True, ) async def get_hospital_id(hospital_code): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute( "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", (hospital_code,), ) result = await cursor.fetchone() return result["id"] if result else None except Exception as error: logging.error(f"Database error: {error}") return None finally: pool.close() await pool.wait_closed() CHUNK_SIZE = 4000 CHUNK_OVERLAP = 150 BATCH_SIZE = 250 text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, # length_function=len, # separators=["\n\n", "\n", ". ", " ", ""] ) # Update the JSON_PATH to be dynamic based on hospital_id def get_icd_json_path(hospital_id): hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") os.makedirs(hospital_data_dir, exist_ok=True) return os.path.join(hospital_data_dir, "icd_data.json") def extract_and_process_icd_data(content, hospital_id, save_to_json=True): """Extract and process ICD codes with optimized processing and optional JSON saving""" try: # Initialize pattern compilation once pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) # Process in chunks for large content chunk_size = 50000 # Process 50KB at a time icd_data = [] current_code = None current_description = [] # Split content into manageable chunks content_chunks = [ content[i : i + chunk_size] for i in range(0, len(content), chunk_size) ] # Process each chunk for chunk in content_chunks: lines = chunk.splitlines() for line in lines: line = line.strip() if not line: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code = None current_description = [] continue match = pattern.match(line) if match: if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) current_code, description = match.groups() current_description = [description.strip()] elif current_code: current_description.append(line) # Add final entry if exists if current_code and current_description: icd_data.append( { "code": current_code, "description": " ".join(current_description).strip(), } ) # Save to hospital-specific JSON if requested if save_to_json and icd_data: try: json_path = get_icd_json_path(hospital_id) # Use a lock for thread safety with threading.Lock(): if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: try: existing_data = json.load(f) except json.JSONDecodeError: existing_data = [] else: existing_data = [] # Efficient deduplication using dictionary seen_codes = {item["code"]: item for item in existing_data} for item in icd_data: seen_codes[item["code"]] = item unique_data = list(seen_codes.values()) # Write atomically using temporary file temp_path = f"{json_path}.tmp" with open(temp_path, "w", encoding="utf-8") as f: json.dump(unique_data, f, indent=2, ensure_ascii=False) os.replace(temp_path, json_path) logging.info( f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" ) except Exception as e: logging.error( f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" ) return icd_data except Exception as e: logging.error(f"Error in extract_and_process_icd_data: {e}") return [] def load_icd_entries(hospital_id): """Load ICD entries from hospital-specific JSON file""" json_path = get_icd_json_path(hospital_id) try: if os.path.exists(json_path): with open(json_path, "r", encoding="utf-8") as f: return json.load(f) return [] except Exception as e: logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") return [] # Update the process_icd_codes function to include hospital_id async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): """Process and store ICD codes using the optimized extraction function""" try: # Extract and save codes with hospital_id extract_and_process_icd_data(content, hospital_id, save_to_json=True) except Exception as e: logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") async def initialize_icd_vector_store(hospital_id): """This function is deprecated. ICD codes are now handled through JSON search.""" logging.warning( "initialize_icd_vector_store is deprecated - using JSON search instead" ) return None def extract_pdf_contents(pdf_path, hospital_id): """Extract PDF contents with optimized chunking and code extraction""" try: loader = PyPDFLoader(pdf_path) pages = loader.load() pages_content = [] for i, page in enumerate(tqdm(pages, desc="Extracting pages")): text = page.page_content.strip() # Extract ICD codes from the page icd_codes = extract_and_process_icd_data( text, hospital_id ) # We'll set doc_id later pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) return pages_content except Exception as e: logging.error(f"Error in extract_pdf_contents: {e}") raise async def insert_content_into_db(content, metadata, doc_id): pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" metadata_values = [ (doc_id, key[:100], value) for key, value in metadata.items() if value ] content_values = [ (doc_id, page_content["page"], page_content["text"]) for page_content in content ] if metadata_values: await cursor.executemany(metadata_query, metadata_values) if content_values: await cursor.executemany(content_query, content_values) await conn.commit() return {"message": "Success"} except Exception as e: await conn.rollback() return {"error": str(e)} async def initialize_or_load_vector_store(hospital_id, user_id="default"): """Initialize or load vector store with Redis caching and thread safety""" store_key = f"{hospital_id}:{user_id}" try: # Check if we already have it loaded - with lock for thread safety with vector_store_lock: if store_key in hospital_vector_stores: return hospital_vector_stores[store_key] # Initialize vector store redis_client = get_redis_client(binary=True) cache_key = f"vector_store_data:{hospital_id}:{user_id}" hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") if os.path.exists(hospital_dir): logging.info( f"Loading vector store for hospital {hospital_id} and user {user_id}" ) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) else: logging.info(f"Creating vector store for hospital {hospital_id}") os.makedirs(hospital_dir, exist_ok=True) vector_store = await asyncio.to_thread( lambda: Chroma( collection_name=f"hospital_{hospital_id}", persist_directory=hospital_dir, embedding_function=embeddings, ) ) # Store with lock for thread safety with vector_store_lock: hospital_vector_stores[store_key] = vector_store return vector_store except Exception as e: logging.error(f"Error initializing vector store: {e}", exc_info=True) raise async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: """Delete all vectors associated with a specific document from ChromaDB""" try: # Initialize vector store for the hospital vector_store = await initialize_or_load_vector_store(hospital_id) # Delete vectors with matching doc_id await asyncio.to_thread( lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) ) # Persist changes await asyncio.to_thread(vector_store.persist) # Clear Redis cache for this document redis_client = get_redis_client() pattern = f"vector_store_data:{hospital_id}:*" for key in redis_client.scan_iter(pattern): redis_client.delete(key) logging.info( f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" ) return True except Exception as e: logging.error(f"Error deleting document vectors: {e}", exc_info=True) return False async def add_document_to_index(doc_id, hospital_id): try: pool = await get_db_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: vector_store = await initialize_or_load_vector_store(hospital_id) await cursor.execute( "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", (doc_id,), ) rows = await cursor.fetchall() total_pages = len(rows) logging.info(f"Processing {total_pages} pages for document {doc_id}") page_bar = tqdm_async(total=total_pages, desc="Processing pages") async def process_page(page_data): page_num, content = page_data try: icd_data = extract_and_process_icd_data( content, hospital_id, save_to_json=False ) chunks = text_splitter.split_text(content) await asyncio.sleep(0) # Yield control return page_num, chunks, icd_data except Exception as e: logging.error(f"Error processing page {page_num}: {e}") return page_num, [], [] tasks = [asyncio.create_task(process_page(row)) for row in rows] results = [] for coro in asyncio.as_completed(tasks): result = await coro results.append(result) page_bar.update(1) page_bar.close() # Vector addition progress bar all_icd_data = [] all_chunks = [] all_metadatas = [] chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) for result in results: page_num, chunks, icd_data = result all_icd_data.extend(icd_data) for i, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append( { "doc_id": str(doc_id), "hospital_id": str(hospital_id), "page_number": str(page_num), "chunk_index": str(i), } ) if len(all_chunks) >= BATCH_SIZE: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) all_chunks = [] all_metadatas = [] chunk_add_bar.update(BATCH_SIZE) # Final batch if all_chunks: chunk_add_bar.total += len(all_chunks) chunk_add_bar.refresh() await asyncio.to_thread( vector_store.add_texts, texts=all_chunks, metadatas=all_metadatas, ) chunk_add_bar.update(len(all_chunks)) chunk_add_bar.close() if all_icd_data: logging.info(f"Saving {len(all_icd_data)} ICD codes") extract_and_process_icd_data("", hospital_id, save_to_json=True) await asyncio.to_thread(vector_store.persist) logging.info(f"Successfully indexed document {doc_id}") return True except Exception as e: logging.error(f"Error adding document: {e}") return False def is_general_knowledge_question( query: str, context: str, conversation_context=None ) -> bool: """ Determine if a question is likely a general knowledge question not covered in the documents. Takes conversation history into account to reduce repeated confirmations. """ query_lower = query.lower() context_lower = context.lower() if conversation_context: for interaction in conversation_context: prev_question = interaction.get("question", "").lower() if ( prev_question and query_lower in prev_question or prev_question in query_lower ): logging.info( f"Question is similar to previous conversation, skipping confirmation" ) return False stop_words = { "search", "query:", "can", "you", "some", "at", "the", "a", "an", "in", "on", "at", "to", "for", "with", "by", "about", "give", "full", "is", "are", "was", "were", "define", "what", "how", "why", "when", "where", "year", "list", "form", "table", "who", "which", "me", "tell", "explain", "describe", "of", "and", "or", "there", "their", "please", "could", "would", "various", "different", "type", "types", "kind", "kinds", "has", "have", "had", "many", "say", } key_words = [ word for word in query_lower.split() if word not in stop_words and len(word) > 2 ] logging.info(f"Key words: {key_words}") if not key_words: logging.info("No significant keywords found, directing to general knowledge") return True matches = sum(1 for word in key_words if word in context_lower) logging.info(f"Matches: {matches} out of {len(key_words)} keywords") match_ratio = matches / len(key_words) logging.info(f"Match ratio: {match_ratio}") return match_ratio < 0.6 def is_table_request(query: str) -> bool: """ Determine if the user is requesting a response in tabular format. """ table_keywords = [ "table", "tabular", "in a table", "in table format", "in tabular format", "chart", "data", "comparison", "as a table", "table format", "in rows and columns", "in a grid", "breakdown", "spreadsheet", "comparison table", "data table", "structured table", "tabular form", "table form", ] query_lower = query.lower() return any(keyword in query_lower for keyword in table_keywords) import re def ensure_html_response(text: str) -> str: """ Ensure the response is properly formatted in HTML. This function handles plain text conversion to HTML. """ if "", text)) if not has_html_tags: paragraphs = text.split("\n\n") html_parts = [] in_ordered_list = False in_unordered_list = False for para in paragraphs: if para.strip(): if re.match(r"^\s*[\*\-\•]\s", para): if not in_unordered_list: html_parts.append("") in_unordered_list = False html_parts.append(f"

{para}

") if in_ordered_list: html_parts.append("") if in_unordered_list: html_parts.append("") return "".join(html_parts) else: if not any(tag in text for tag in ("

", "

", "