spurrin-cleaned-backend-dev/model_manager.py
rohitgir-879 dc39677783 v1.0.0-rc
2025-06-12 00:19:44 +05:30

140 lines
5.7 KiB
Python

from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
from sklearn.metrics.pairwise import cosine_similarity
import logging
import atexit
class ModelManager:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not ModelManager._initialized:
logging.info("Initializing ModelManager - Loading models...")
self.load_models()
ModelManager._initialized = True
atexit.register(self.cleanup)
def load_models(self):
try:
# Load models with specific device placement
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device: {self.device}")
# Enable model caching
torch.hub.set_dir('./model_cache')
# Load models with batch preparation
self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
self.sentence_model.to(self.device)
self.sentence_model.eval() # Set to evaluation mode
self.bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.bert_model = AutoModel.from_pretrained('bert-base-uncased')
self.bert_model.to(self.device)
self.bert_model.eval() # Set to evaluation mode
# Initialize embedding cache with batch support
self.embedding_cache = {}
self.max_cache_size = 10000
self.batch_size = 32 # Optimize batch size
logging.info("Models loaded successfully with batch optimization")
except Exception as e:
logging.error(f"Error loading models: {e}")
raise
def get_bert_embeddings(self, texts):
if isinstance(texts, str):
texts = [texts]
# Process in batches
all_embeddings = []
for i in range(0, len(texts), self.batch_size):
batch_texts = texts[i:i + self.batch_size]
# Check cache for each text in batch
batch_embeddings = []
uncached_texts = []
uncached_indices = []
for idx, text in enumerate(batch_texts):
cache_key = f"bert_{hash(text)}"
if cache_key in self.embedding_cache:
batch_embeddings.append(self.embedding_cache[cache_key])
else:
uncached_texts.append(text)
uncached_indices.append(idx)
if uncached_texts:
inputs = self.bert_tokenizer(uncached_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
new_embeddings = outputs.last_hidden_state.mean(dim=1)
# Cache new embeddings
for idx, text in enumerate(uncached_texts):
cache_key = f"bert_{hash(text)}"
if len(self.embedding_cache) < self.max_cache_size:
self.embedding_cache[cache_key] = new_embeddings[idx]
batch_embeddings.insert(uncached_indices[idx], new_embeddings[idx])
all_embeddings.extend(batch_embeddings)
return torch.stack(all_embeddings) if len(all_embeddings) > 1 else all_embeddings[0].unsqueeze(0)
def get_semantic_similarity(self, text1, text2):
# Check cache
cache_key = f"sim_{hash(text1)}_{hash(text2)}"
if cache_key in self.embedding_cache:
return self.embedding_cache[cache_key]
# Preprocess texts for better matching
text1 = text1.lower().strip()
text2 = text2.lower().strip()
# Enhanced batch process embeddings with context awareness
with torch.no_grad():
# Sentence transformer similarity with increased weight
emb1 = self.sentence_model.encode([text1], batch_size=1, convert_to_numpy=True)
emb2 = self.sentence_model.encode([text2], batch_size=1, convert_to_numpy=True)
sent_sim = cosine_similarity(emb1, emb2)[0][0]
# BERT similarity for deeper semantic understanding
bert_emb1 = self.get_bert_embeddings(text1).cpu().numpy()
bert_emb2 = self.get_bert_embeddings(text2).cpu().numpy()
bert_sim = cosine_similarity(bert_emb1, bert_emb2)[0][0]
# Adjusted weights for better follow-up detection
similarity = 0.8 * sent_sim + 0.2 * bert_sim
# Boost similarity for related context
if any(word in text2.split() for word in text1.split()):
similarity = min(1.0, similarity * 1.2)
# Cache the result
if len(self.embedding_cache) < self.max_cache_size:
self.embedding_cache[cache_key] = similarity
return similarity
def cleanup(self):
"""Cleanup models and free memory"""
logging.info("Cleaning up models...")
try:
del self.sentence_model
del self.bert_model
del self.bert_tokenizer
torch.cuda.empty_cache() if torch.cuda.is_available() else None
self.embedding_cache.clear()
logging.info("Models cleaned up successfully")
except Exception as e:
logging.error(f"Error during cleanup: {e}")