from transformers import AutoTokenizer, AutoModel from sentence_transformers import SentenceTransformer import torch from sklearn.metrics.pairwise import cosine_similarity import logging import atexit class ModelManager: _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) return cls._instance def __init__(self): if not ModelManager._initialized: logging.info("Initializing ModelManager - Loading models...") self.load_models() ModelManager._initialized = True atexit.register(self.cleanup) def load_models(self): try: # Load models with specific device placement self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f"Using device: {self.device}") # Enable model caching torch.hub.set_dir('./model_cache') # Load models with batch preparation self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2') self.sentence_model.to(self.device) self.sentence_model.eval() # Set to evaluation mode self.bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') self.bert_model = AutoModel.from_pretrained('bert-base-uncased') self.bert_model.to(self.device) self.bert_model.eval() # Set to evaluation mode # Initialize embedding cache with batch support self.embedding_cache = {} self.max_cache_size = 10000 self.batch_size = 32 # Optimize batch size logging.info("Models loaded successfully with batch optimization") except Exception as e: logging.error(f"Error loading models: {e}") raise def get_bert_embeddings(self, texts): if isinstance(texts, str): texts = [texts] # Process in batches all_embeddings = [] for i in range(0, len(texts), self.batch_size): batch_texts = texts[i:i + self.batch_size] # Check cache for each text in batch batch_embeddings = [] uncached_texts = [] uncached_indices = [] for idx, text in enumerate(batch_texts): cache_key = f"bert_{hash(text)}" if cache_key in self.embedding_cache: batch_embeddings.append(self.embedding_cache[cache_key]) else: uncached_texts.append(text) uncached_indices.append(idx) if uncached_texts: inputs = self.bert_tokenizer(uncached_texts, return_tensors="pt", padding=True, truncation=True).to(self.device) with torch.no_grad(): outputs = self.bert_model(**inputs) new_embeddings = outputs.last_hidden_state.mean(dim=1) # Cache new embeddings for idx, text in enumerate(uncached_texts): cache_key = f"bert_{hash(text)}" if len(self.embedding_cache) < self.max_cache_size: self.embedding_cache[cache_key] = new_embeddings[idx] batch_embeddings.insert(uncached_indices[idx], new_embeddings[idx]) all_embeddings.extend(batch_embeddings) return torch.stack(all_embeddings) if len(all_embeddings) > 1 else all_embeddings[0].unsqueeze(0) def get_semantic_similarity(self, text1, text2): # Check cache cache_key = f"sim_{hash(text1)}_{hash(text2)}" if cache_key in self.embedding_cache: return self.embedding_cache[cache_key] # Preprocess texts for better matching text1 = text1.lower().strip() text2 = text2.lower().strip() # Enhanced batch process embeddings with context awareness with torch.no_grad(): # Sentence transformer similarity with increased weight emb1 = self.sentence_model.encode([text1], batch_size=1, convert_to_numpy=True) emb2 = self.sentence_model.encode([text2], batch_size=1, convert_to_numpy=True) sent_sim = cosine_similarity(emb1, emb2)[0][0] # BERT similarity for deeper semantic understanding bert_emb1 = self.get_bert_embeddings(text1).cpu().numpy() bert_emb2 = self.get_bert_embeddings(text2).cpu().numpy() bert_sim = cosine_similarity(bert_emb1, bert_emb2)[0][0] # Adjusted weights for better follow-up detection similarity = 0.8 * sent_sim + 0.2 * bert_sim # Boost similarity for related context if any(word in text2.split() for word in text1.split()): similarity = min(1.0, similarity * 1.2) # Cache the result if len(self.embedding_cache) < self.max_cache_size: self.embedding_cache[cache_key] = similarity return similarity def cleanup(self): """Cleanup models and free memory""" logging.info("Cleaning up models...") try: del self.sentence_model del self.bert_model del self.bert_tokenizer torch.cuda.empty_cache() if torch.cuda.is_available() else None self.embedding_cache.clear() logging.info("Models cleaned up successfully") except Exception as e: logging.error(f"Error during cleanup: {e}")