diff --git a/.gitignore b/.gitignore index 9455c0b..76e9b59 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ /logs /error.log /uploads -/llm-uploads \ No newline at end of file +/llm-uploads +/certificates \ No newline at end of file diff --git a/chat copy.py b/chat copy.py deleted file mode 100644 index 9af82fc..0000000 --- a/chat copy.py +++ /dev/null @@ -1,2161 +0,0 @@ -""" -SpurrinAI - Intelligent Document Processing and Question Answering System -Copyright (c) 2024 Tech4biz. All rights reserved. - -This module implements the main Flask application for the SpurrinAI system, -providing REST APIs for document processing, vector storage, and question answering -using RAG (Retrieval Augmented Generation) architecture. - -Author: Tech4biz Development Team -Version: 1.0.0 -Last Updated: 2024-01-19 -""" - -# Standard library imports -import os -import re -import sys -import json -import time -import threading -import asyncio -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from datetime import timedelta -from enum import Enum - -# Third-party imports -import spacy -import redis -import aiomysql -from dotenv import load_dotenv -from flask import Flask, request, jsonify, Response -from flask_cors import CORS -from tqdm import tqdm -from tqdm.asyncio import tqdm as tqdm_async -from langchain_community.document_loaders import PyPDFLoader -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.vectorstores import Chroma -from langchain_community.chat_models import ChatOpenAI -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from openai import OpenAI -from rapidfuzz import process -from threading import Lock - -# Local imports -from model_manager import ModelManager - -# Suppress warnings -import warnings - -warnings.filterwarnings("ignore") - -# Initialize NLTK -import nltk - -nltk.download("punkt") - -# Configure logging -import logging -import logging.handlers - -app = Flask(__name__) -CORS(app) - -script_dir = os.path.dirname(os.path.abspath(__file__)) -log_file_path = os.path.join(script_dir, "error.log") -logging.basicConfig(filename=log_file_path, level=logging.INFO) - - -# Configure logging -def setup_logging(): - log_dir = os.path.join(script_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - - main_log = os.path.join(log_dir, "app.log") - error_log = os.path.join(log_dir, "error.log") - access_log = os.path.join(log_dir, "access.log") - perf_log = os.path.join(log_dir, "performance.log") - - # Create formatters - detailed_formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" - ) - access_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - - # Main logger setup - main_handler = logging.handlers.RotatingFileHandler( - main_log, maxBytes=10485760, backupCount=5 - ) - main_handler.setFormatter(detailed_formatter) - main_handler.setLevel(logging.INFO) - - # Error logger setup - error_handler = logging.handlers.RotatingFileHandler( - error_log, maxBytes=10485760, backupCount=5 - ) - error_handler.setFormatter(detailed_formatter) - error_handler.setLevel(logging.ERROR) - - # Access logger setup - access_handler = logging.handlers.TimedRotatingFileHandler( - access_log, when="midnight", interval=1, backupCount=30 - ) - access_handler.setFormatter(access_formatter) - access_handler.setLevel(logging.INFO) - - # Performance logger setup - perf_handler = logging.handlers.RotatingFileHandler( - perf_log, maxBytes=10485760, backupCount=5 - ) - perf_handler.setFormatter(detailed_formatter) - perf_handler.setLevel(logging.INFO) - - # Configure root logger - root_logger = logging.getLogger() - root_logger.setLevel(logging.INFO) - root_logger.addHandler(main_handler) - root_logger.addHandler(error_handler) - - # Create specific loggers - access_logger = logging.getLogger("access") - access_logger.addHandler(access_handler) - access_logger.setLevel(logging.INFO) - - perf_logger = logging.getLogger("performance") - perf_logger.addHandler(perf_handler) - perf_logger.setLevel(logging.INFO) - - return root_logger, access_logger, perf_logger - - -# Initialize loggers -logger, access_logger, perf_logger = setup_logging() - -# DB_CONFIG = { -# 'host': 'localhost', -# 'user': 'flaskuser', -# 'password': 'Flask@123', -# 'database': 'spurrinai', -# } - -# DB_CONFIG = { -# 'host': 'localhost', -# 'user': 'spurrindevuser', -# 'password': 'Admin@123', -# 'database': 'spurrindev', -# } - -# Redis Configuration -REDIS_CONFIG = { - "host": "localhost", - "port": 6379, - "db": 0, - "decode_responses": True, # For string operations -} - -DB_CONFIG = { - "host": os.getenv("DB_HOST", "localhost"), - "user": os.getenv("DB_USER", "testuser"), - "password": os.getenv("DB_PASSWORD", "Admin@123"), - "database": os.getenv("DB_NAME", "spurrintest"), -} - -# Redis connection pool -redis_pool = redis.ConnectionPool(**REDIS_CONFIG) -redis_binary_pool = redis.ConnectionPool( - host="localhost", port=6379, db=1, decode_responses=False -) - - -def get_redis_client(binary=False): - """Get Redis client from pool""" - logger.debug(f"Getting Redis client with binary={binary}") - try: - pool = redis_binary_pool if binary else redis_pool - client = redis.Redis(connection_pool=pool) - logger.debug("Redis client created successfully") - return client - except Exception as e: - logger.error(f"Failed to create Redis client: {e}", exc_info=True) - raise - - -def fetch_cached_answer(cache_key): - logger.debug(f"Attempting to fetch cached answer for key: {cache_key}") - start_time = time.time() - try: - redis_client = get_redis_client() - cached_answer = redis_client.get(cache_key) - fetch_time = time.time() - start_time - perf_logger.info( - f"Redis fetch completed in {fetch_time:.3f} seconds for key: {cache_key}" - ) - return cached_answer - except Exception as e: - logger.error(f"Redis fetch error for key {cache_key}: {e}", exc_info=True) - return None - - -# Cache TTL configurations -CACHE_TTL = { - "vector_store": timedelta(hours=24), - "chat_completion": timedelta(hours=1), - "document_metadata": timedelta(days=7), -} - -DATA_DIR = os.path.join(script_dir, "hospital_data") -CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") -uploads_dir = os.path.join(script_dir, "llm-uploads") - -if not os.path.exists(uploads_dir): - os.makedirs(uploads_dir) - -nlp = spacy.load("en_core_web_sm") - -load_dotenv() -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") - -client = OpenAI(api_key=OPENAI_API_KEY) -embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY) -llm = ChatOpenAI( - model_name="gpt-3.5-turbo", streaming=True, temperature=0.2, api_key=OPENAI_API_KEY -) -# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) -hospital_vector_stores = {} -vector_store_lock = threading.Lock() - - -@dataclass -class Document: - doc_id: int - page_num: int - content: str - - -class DocumentStatus(Enum): - PROCESSING = "processing" - PROCESSED = "processed" - FAILED = "failed" - - -async def get_db_pool(): - return await aiomysql.create_pool( - host=DB_CONFIG["host"], - user=DB_CONFIG["user"], - password=DB_CONFIG["password"], - db=DB_CONFIG["database"], - autocommit=True, - ) - - -async def get_hospital_id(hospital_code): - try: - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor(aiomysql.DictCursor) as cursor: - await cursor.execute( - "SELECT id FROM hospitals WHERE hospital_code = %s LIMIT 1", - (hospital_code,), - ) - result = await cursor.fetchone() - return result["id"] if result else None - except Exception as error: - logging.error(f"Database error: {error}") - return None - finally: - pool.close() - await pool.wait_closed() - - -CHUNK_SIZE = 1000 -CHUNK_OVERLAP = 50 -BATCH_SIZE = 1000 - -text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=CHUNK_OVERLAP, - # length_function=len, - # separators=["\n\n", "\n", ". ", " ", ""] -) - - -# Update the JSON_PATH to be dynamic based on hospital_id -def get_icd_json_path(hospital_id): - hospital_data_dir = os.path.join(DATA_DIR, f"hospital_{hospital_id}") - os.makedirs(hospital_data_dir, exist_ok=True) - return os.path.join(hospital_data_dir, "icd_data.json") - - -def extract_and_process_icd_data(content, hospital_id, save_to_json=True): - """Extract and process ICD codes with optimized processing and optional JSON saving""" - try: - # Initialize pattern compilation once - pattern = re.compile(r"^\s*([A-Z][0-9A-Z]{2,6}[A-Z]?)\s+(.*)$", re.MULTILINE) - - # Process in chunks for large content - chunk_size = 50000 # Process 50KB at a time - icd_data = [] - - current_code = None - current_description = [] - - # Split content into manageable chunks - content_chunks = [ - content[i : i + chunk_size] for i in range(0, len(content), chunk_size) - ] - - # Process each chunk - for chunk in content_chunks: - lines = chunk.splitlines() - - for line in lines: - line = line.strip() - if not line: - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - current_code = None - current_description = [] - continue - - match = pattern.match(line) - if match: - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - current_code, description = match.groups() - current_description = [description.strip()] - elif current_code: - current_description.append(line) - - # Add final entry if exists - if current_code and current_description: - icd_data.append( - { - "code": current_code, - "description": " ".join(current_description).strip(), - } - ) - - # Save to hospital-specific JSON if requested - if save_to_json and icd_data: - try: - json_path = get_icd_json_path(hospital_id) - - # Use a lock for thread safety - with threading.Lock(): - if os.path.exists(json_path): - with open(json_path, "r", encoding="utf-8") as f: - try: - existing_data = json.load(f) - except json.JSONDecodeError: - existing_data = [] - else: - existing_data = [] - - # Efficient deduplication using dictionary - seen_codes = {item["code"]: item for item in existing_data} - for item in icd_data: - seen_codes[item["code"]] = item - - unique_data = list(seen_codes.values()) - - # Write atomically using temporary file - temp_path = f"{json_path}.tmp" - with open(temp_path, "w", encoding="utf-8") as f: - json.dump(unique_data, f, indent=2, ensure_ascii=False) - os.replace(temp_path, json_path) - - logging.info( - f"Successfully saved {len(unique_data)} unique ICD codes to JSON for hospital {hospital_id}" - ) - - except Exception as e: - logging.error( - f"Error saving ICD data to JSON for hospital {hospital_id}: {e}" - ) - - return icd_data - - except Exception as e: - logging.error(f"Error in extract_and_process_icd_data: {e}") - return [] - - -def load_icd_entries(hospital_id): - """Load ICD entries from hospital-specific JSON file""" - json_path = get_icd_json_path(hospital_id) - try: - if os.path.exists(json_path): - with open(json_path, "r", encoding="utf-8") as f: - return json.load(f) - return [] - except Exception as e: - logging.error(f"Error loading ICD entries for hospital {hospital_id}: {e}") - return [] - - -# Update the process_icd_codes function to include hospital_id -async def process_icd_codes(content, doc_id, hospital_id, batch_size=256): - """Process and store ICD codes using the optimized extraction function""" - try: - # Extract and save codes with hospital_id - extract_and_process_icd_data(content, hospital_id, save_to_json=True) - except Exception as e: - logging.error(f"Error processing ICD codes for hospital {hospital_id}: {e}") - - -async def initialize_icd_vector_store(hospital_id): - """This function is deprecated. ICD codes are now handled through JSON search.""" - logging.warning( - "initialize_icd_vector_store is deprecated - using JSON search instead" - ) - return None - - -def extract_pdf_contents(pdf_path, hospital_id): - """Extract PDF contents with optimized chunking and code extraction""" - try: - loader = PyPDFLoader(pdf_path) - pages = loader.load() - pages_content = [] - - for i, page in enumerate(tqdm(pages, desc="Extracting pages")): - text = page.page_content.strip() - - # Extract ICD codes from the page - icd_codes = extract_and_process_icd_data( - text, hospital_id - ) # We'll set doc_id later - - pages_content.append({"page": i + 1, "text": text, "codes": icd_codes}) - - return pages_content - - except Exception as e: - logging.error(f"Error in extract_pdf_contents: {e}") - raise - - -async def insert_content_into_db(content, metadata, doc_id): - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - try: - metadata_query = "INSERT INTO document_metadata (document_id, key_name, value_name) VALUES (%s, %s, %s)" - content_query = "INSERT INTO document_pages (document_id, page_number, content) VALUES (%s, %s, %s)" - - metadata_values = [ - (doc_id, key[:100], value) - for key, value in metadata.items() - if value - ] - content_values = [ - (doc_id, page_content["page"], page_content["text"]) - for page_content in content - ] - - if metadata_values: - await cursor.executemany(metadata_query, metadata_values) - if content_values: - await cursor.executemany(content_query, content_values) - - await conn.commit() - return {"message": "Success"} - except Exception as e: - await conn.rollback() - return {"error": str(e)} - - -async def initialize_or_load_vector_store(hospital_id, user_id="default"): - """Initialize or load vector store with Redis caching and thread safety""" - store_key = f"{hospital_id}:{user_id}" - - try: - # Check if we already have it loaded - with lock for thread safety - with vector_store_lock: - if store_key in hospital_vector_stores: - return hospital_vector_stores[store_key] - - # Initialize vector store - redis_client = get_redis_client(binary=True) - cache_key = f"vector_store_data:{hospital_id}:{user_id}" - hospital_dir = os.path.join(CHROMA_DIR, f"hospital_{hospital_id}") - - if os.path.exists(hospital_dir): - logging.info( - f"Loading vector store for hospital {hospital_id} and user {user_id}" - ) - vector_store = await asyncio.to_thread( - lambda: Chroma( - collection_name=f"hospital_{hospital_id}", - persist_directory=hospital_dir, - embedding_function=embeddings, - ) - ) - else: - logging.info(f"Creating vector store for hospital {hospital_id}") - os.makedirs(hospital_dir, exist_ok=True) - vector_store = await asyncio.to_thread( - lambda: Chroma( - collection_name=f"hospital_{hospital_id}", - persist_directory=hospital_dir, - embedding_function=embeddings, - ) - ) - - # Store with lock for thread safety - with vector_store_lock: - hospital_vector_stores[store_key] = vector_store - - return vector_store - except Exception as e: - logging.error(f"Error initializing vector store: {e}", exc_info=True) - raise - - -async def delete_document_vectors(hospital_id: int, doc_id: str) -> bool: - """Delete all vectors associated with a specific document from ChromaDB""" - try: - # Initialize vector store for the hospital - vector_store = await initialize_or_load_vector_store(hospital_id) - - # Delete vectors with matching doc_id - await asyncio.to_thread( - lambda: vector_store._collection.delete(where={"doc_id": str(doc_id)}) - ) - - # Persist changes - await asyncio.to_thread(vector_store.persist) - - # Clear Redis cache for this document - redis_client = get_redis_client() - pattern = f"vector_store_data:{hospital_id}:*" - for key in redis_client.scan_iter(pattern): - redis_client.delete(key) - - logging.info( - f"Successfully deleted vectors for document {doc_id} from hospital {hospital_id}" - ) - return True - - except Exception as e: - logging.error(f"Error deleting document vectors: {e}", exc_info=True) - return False - - -async def add_document_to_index(doc_id, hospital_id): - try: - pool = await get_db_pool() - async with pool.acquire() as conn: - async with conn.cursor() as cursor: - vector_store = await initialize_or_load_vector_store(hospital_id) - - await cursor.execute( - "SELECT page_number, content FROM document_pages WHERE document_id = %s ORDER BY page_number", - (doc_id,), - ) - rows = await cursor.fetchall() - - total_pages = len(rows) - logging.info(f"Processing {total_pages} pages for document {doc_id}") - page_bar = tqdm_async(total=total_pages, desc="Processing pages") - - async def process_page(page_data): - page_num, content = page_data - try: - icd_data = extract_and_process_icd_data( - content, hospital_id, save_to_json=False - ) - chunks = text_splitter.split_text(content) - await asyncio.sleep(0) # Yield control - return page_num, chunks, icd_data - except Exception as e: - logging.error(f"Error processing page {page_num}: {e}") - return page_num, [], [] - - tasks = [asyncio.create_task(process_page(row)) for row in rows] - results = [] - - for coro in asyncio.as_completed(tasks): - result = await coro - results.append(result) - page_bar.update(1) - - page_bar.close() - - # Vector addition progress bar - all_icd_data = [] - all_chunks = [] - all_metadatas = [] - - chunk_add_bar = tqdm_async(desc="Vectorizing chunks", total=0) - - for result in results: - page_num, chunks, icd_data = result - all_icd_data.extend(icd_data) - - for i, chunk in enumerate(chunks): - all_chunks.append(chunk) - all_metadatas.append( - { - "doc_id": str(doc_id), - "hospital_id": str(hospital_id), - "page_number": str(page_num), - "chunk_index": str(i), - } - ) - - if len(all_chunks) >= BATCH_SIZE: - chunk_add_bar.total += len(all_chunks) - chunk_add_bar.refresh() - await asyncio.to_thread( - vector_store.add_texts, - texts=all_chunks, - metadatas=all_metadatas, - ) - all_chunks = [] - all_metadatas = [] - chunk_add_bar.update(BATCH_SIZE) - - # Final batch - if all_chunks: - chunk_add_bar.total += len(all_chunks) - chunk_add_bar.refresh() - await asyncio.to_thread( - vector_store.add_texts, - texts=all_chunks, - metadatas=all_metadatas, - ) - chunk_add_bar.update(len(all_chunks)) - - chunk_add_bar.close() - - if all_icd_data: - logging.info(f"Saving {len(all_icd_data)} ICD codes") - extract_and_process_icd_data("", hospital_id, save_to_json=True) - - await asyncio.to_thread(vector_store.persist) - logging.info(f"Successfully indexed document {doc_id}") - return True - - except Exception as e: - logging.error(f"Error adding document: {e}") - return False - - -def is_general_knowledge_question( - query: str, context: str, conversation_context=None -) -> bool: - """ - Determine if a question is likely a general knowledge question not covered in the documents. - Takes conversation history into account to reduce repeated confirmations. - """ - query_lower = query.lower() - context_lower = context.lower() - - if conversation_context: - for interaction in conversation_context: - prev_question = interaction.get("question", "").lower() - if ( - prev_question - and query_lower in prev_question - or prev_question in query_lower - ): - logging.info( - f"Question is similar to previous conversation, skipping confirmation" - ) - return False - - stop_words = { - "search", - "query:", - "can", - "you", - "some", - "at", - "the", - "a", - "an", - "in", - "on", - "at", - "to", - "for", - "with", - "by", - "about", - "give", - "full", - "is", - "are", - "was", - "were", - "define", - "what", - "how", - "why", - "when", - "where", - "year", - "list", - "form", - "table", - "who", - "which", - "me", - "tell", - "explain", - "describe", - "of", - "and", - "or", - "there", - "their", - "please", - "could", - "would", - "various", - "different", - "type", - "types", - "kind", - "kinds", - "has", - "have", - "had", - "many", - "say", - } - - key_words = [ - word for word in query_lower.split() if word not in stop_words and len(word) > 2 - ] - logging.info(f"Key words: {key_words}") - - if not key_words: - logging.info("No significant keywords found, directing to general knowledge") - return True - - matches = sum(1 for word in key_words if word in context_lower) - logging.info(f"Matches: {matches} out of {len(key_words)} keywords") - - match_ratio = matches / len(key_words) - logging.info(f"Match ratio: {match_ratio}") - - return match_ratio < 0.4 - - -def is_table_request(query: str) -> bool: - """ - Determine if the user is requesting a response in tabular format. - """ - table_keywords = [ - "table", - "tabular", - "in a table", - "in table format", - "in tabular format", - "chart", - "data", - "comparison", - "as a table", - "table format", - "in rows and columns", - "in a grid", - "breakdown", - "spreadsheet", - "comparison table", - "data table", - "structured table", - "tabular form", - "table form", - ] - - query_lower = query.lower() - return any(keyword in query_lower for keyword in table_keywords) - - -import re - - -def ensure_html_response(text: str) -> str: - """ - Ensure the response is properly formatted in HTML. - This function handles plain text conversion to HTML. - """ - if "", text)) - - if not has_html_tags: - paragraphs = text.split("\n\n") - html_parts = [] - in_ordered_list = False - in_unordered_list = False - - for para in paragraphs: - if para.strip(): - if re.match(r"^\s*[\*\-\•]\s", para): - if not in_unordered_list: - html_parts.append("") - in_unordered_list = False - - html_parts.append(f"

{para}

") - - if in_ordered_list: - html_parts.append("") - if in_unordered_list: - html_parts.append("") - - return "".join(html_parts) - - else: - if not any(tag in text for tag in ("

", "

", "