140 lines
5.7 KiB
Python
140 lines
5.7 KiB
Python
from transformers import AutoTokenizer, AutoModel
|
|
from sentence_transformers import SentenceTransformer
|
|
import torch
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import logging
|
|
import atexit
|
|
|
|
class ModelManager:
|
|
_instance = None
|
|
_initialized = False
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(ModelManager, cls).__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if not ModelManager._initialized:
|
|
logging.info("Initializing ModelManager - Loading models...")
|
|
self.load_models()
|
|
ModelManager._initialized = True
|
|
atexit.register(self.cleanup)
|
|
|
|
def load_models(self):
|
|
try:
|
|
# Load models with specific device placement
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logging.info(f"Using device: {self.device}")
|
|
|
|
# Enable model caching
|
|
torch.hub.set_dir('./model_cache')
|
|
|
|
# Load models with batch preparation
|
|
self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
self.sentence_model.to(self.device)
|
|
self.sentence_model.eval() # Set to evaluation mode
|
|
|
|
self.bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
self.bert_model = AutoModel.from_pretrained('bert-base-uncased')
|
|
self.bert_model.to(self.device)
|
|
self.bert_model.eval() # Set to evaluation mode
|
|
|
|
# Initialize embedding cache with batch support
|
|
self.embedding_cache = {}
|
|
self.max_cache_size = 10000
|
|
self.batch_size = 32 # Optimize batch size
|
|
|
|
logging.info("Models loaded successfully with batch optimization")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error loading models: {e}")
|
|
raise
|
|
|
|
def get_bert_embeddings(self, texts):
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
# Process in batches
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), self.batch_size):
|
|
batch_texts = texts[i:i + self.batch_size]
|
|
|
|
# Check cache for each text in batch
|
|
batch_embeddings = []
|
|
uncached_texts = []
|
|
uncached_indices = []
|
|
|
|
for idx, text in enumerate(batch_texts):
|
|
cache_key = f"bert_{hash(text)}"
|
|
if cache_key in self.embedding_cache:
|
|
batch_embeddings.append(self.embedding_cache[cache_key])
|
|
else:
|
|
uncached_texts.append(text)
|
|
uncached_indices.append(idx)
|
|
|
|
if uncached_texts:
|
|
inputs = self.bert_tokenizer(uncached_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
|
with torch.no_grad():
|
|
outputs = self.bert_model(**inputs)
|
|
new_embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
|
|
# Cache new embeddings
|
|
for idx, text in enumerate(uncached_texts):
|
|
cache_key = f"bert_{hash(text)}"
|
|
if len(self.embedding_cache) < self.max_cache_size:
|
|
self.embedding_cache[cache_key] = new_embeddings[idx]
|
|
batch_embeddings.insert(uncached_indices[idx], new_embeddings[idx])
|
|
|
|
all_embeddings.extend(batch_embeddings)
|
|
|
|
return torch.stack(all_embeddings) if len(all_embeddings) > 1 else all_embeddings[0].unsqueeze(0)
|
|
|
|
def get_semantic_similarity(self, text1, text2):
|
|
# Check cache
|
|
cache_key = f"sim_{hash(text1)}_{hash(text2)}"
|
|
if cache_key in self.embedding_cache:
|
|
return self.embedding_cache[cache_key]
|
|
|
|
# Preprocess texts for better matching
|
|
text1 = text1.lower().strip()
|
|
text2 = text2.lower().strip()
|
|
|
|
# Enhanced batch process embeddings with context awareness
|
|
with torch.no_grad():
|
|
# Sentence transformer similarity with increased weight
|
|
emb1 = self.sentence_model.encode([text1], batch_size=1, convert_to_numpy=True)
|
|
emb2 = self.sentence_model.encode([text2], batch_size=1, convert_to_numpy=True)
|
|
sent_sim = cosine_similarity(emb1, emb2)[0][0]
|
|
|
|
# BERT similarity for deeper semantic understanding
|
|
bert_emb1 = self.get_bert_embeddings(text1).cpu().numpy()
|
|
bert_emb2 = self.get_bert_embeddings(text2).cpu().numpy()
|
|
bert_sim = cosine_similarity(bert_emb1, bert_emb2)[0][0]
|
|
|
|
# Adjusted weights for better follow-up detection
|
|
similarity = 0.8 * sent_sim + 0.2 * bert_sim
|
|
|
|
# Boost similarity for related context
|
|
if any(word in text2.split() for word in text1.split()):
|
|
similarity = min(1.0, similarity * 1.2)
|
|
|
|
# Cache the result
|
|
if len(self.embedding_cache) < self.max_cache_size:
|
|
self.embedding_cache[cache_key] = similarity
|
|
|
|
return similarity
|
|
|
|
def cleanup(self):
|
|
"""Cleanup models and free memory"""
|
|
logging.info("Cleaning up models...")
|
|
try:
|
|
del self.sentence_model
|
|
del self.bert_model
|
|
del self.bert_tokenizer
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
self.embedding_cache.clear()
|
|
logging.info("Models cleaned up successfully")
|
|
except Exception as e:
|
|
logging.error(f"Error during cleanup: {e}")
|