166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
"""
|
|
Progress Manager for AI Analysis Service
|
|
Handles real-time progress updates via Server-Sent Events (SSE)
|
|
"""
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Dict, Any, List, Optional
|
|
from collections import defaultdict
|
|
import redis.asyncio as redis
|
|
import os
|
|
|
|
|
|
class AnalysisProgressManager:
|
|
"""Manages progress updates for AI analysis operations"""
|
|
|
|
def __init__(self, analysis_id: str):
|
|
self.analysis_id = analysis_id
|
|
self.subscribers: List[asyncio.Queue] = []
|
|
self.redis_client: Optional[redis.Redis] = None
|
|
self.progress_key = f"analysis_progress:{analysis_id}"
|
|
|
|
async def connect_redis(self):
|
|
"""Connect to Redis for progress persistence"""
|
|
try:
|
|
redis_host = os.getenv('REDIS_HOST', 'redis')
|
|
redis_port = int(os.getenv('REDIS_PORT', 6379))
|
|
redis_password = os.getenv('REDIS_PASSWORD', 'redis_secure_2024')
|
|
self.redis_client = await redis.Redis(
|
|
host=redis_host,
|
|
port=redis_port,
|
|
password=redis_password,
|
|
decode_responses=True
|
|
)
|
|
print(f"✅ Redis connected for progress tracking: {self.analysis_id}")
|
|
except Exception as e:
|
|
print(f"⚠️ Redis connection failed: {e}. Progress will not be persisted.")
|
|
self.redis_client = None
|
|
|
|
async def disconnect_redis(self):
|
|
"""Disconnect from Redis"""
|
|
if self.redis_client:
|
|
await self.redis_client.close()
|
|
|
|
def subscribe(self) -> asyncio.Queue:
|
|
"""Subscribe to progress updates"""
|
|
queue = asyncio.Queue()
|
|
self.subscribers.append(queue)
|
|
print(f"📡 New subscriber added. Total subscribers: {len(self.subscribers)}")
|
|
return queue
|
|
|
|
def unsubscribe(self, queue: asyncio.Queue):
|
|
"""Unsubscribe from progress updates"""
|
|
if queue in self.subscribers:
|
|
self.subscribers.remove(queue)
|
|
print(f"📡 Subscriber removed. Remaining subscribers: {len(self.subscribers)}")
|
|
|
|
async def emit_event(self, event_type: str, data: Dict[str, Any]):
|
|
"""
|
|
Emit a progress event to all subscribers
|
|
|
|
Event types:
|
|
- analysis_started: Analysis has begun
|
|
- file_analysis_started: Started analyzing a file
|
|
- file_analysis_completed: Completed analyzing a file
|
|
- repository_analysis_started: Started repository-level analysis
|
|
- repository_analysis_completed: Completed repository-level analysis
|
|
- report_generation_started: Started generating PDF report
|
|
- report_section_completed: Completed a section of the report
|
|
- analysis_completed: Entire analysis is complete
|
|
- analysis_error: An error occurred
|
|
"""
|
|
event = {
|
|
"analysis_id": self.analysis_id,
|
|
"event": event_type,
|
|
"data": data,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
# Store in Redis for recovery
|
|
if self.redis_client:
|
|
try:
|
|
await self.redis_client.rpush(
|
|
self.progress_key,
|
|
json.dumps(event)
|
|
)
|
|
# Set expiry for 1 hour
|
|
await self.redis_client.expire(self.progress_key, 3600)
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to store progress in Redis: {e}")
|
|
|
|
# Broadcast to all subscribers
|
|
dead_queues = []
|
|
for queue in self.subscribers:
|
|
try:
|
|
await queue.put(event)
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to send to subscriber: {e}")
|
|
dead_queues.append(queue)
|
|
|
|
# Clean up dead queues
|
|
for queue in dead_queues:
|
|
self.unsubscribe(queue)
|
|
|
|
print(f"📤 Event emitted: {event_type} - {data.get('message', '')}")
|
|
|
|
async def get_progress_history(self) -> List[Dict[str, Any]]:
|
|
"""Retrieve progress history from Redis"""
|
|
if not self.redis_client:
|
|
return []
|
|
|
|
try:
|
|
events = await self.redis_client.lrange(self.progress_key, 0, -1)
|
|
return [json.loads(event) for event in events]
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to retrieve progress history: {e}")
|
|
return []
|
|
|
|
async def clear_progress(self):
|
|
"""Clear progress data from Redis"""
|
|
if self.redis_client:
|
|
try:
|
|
await self.redis_client.delete(self.progress_key)
|
|
print(f"🗑️ Progress data cleared for {self.analysis_id}")
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to clear progress: {e}")
|
|
|
|
|
|
class GlobalProgressTracker:
|
|
"""Global singleton to track all active analyses"""
|
|
|
|
_instance = None
|
|
_managers: Dict[str, AnalysisProgressManager] = {}
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(GlobalProgressTracker, cls).__new__(cls)
|
|
cls._managers = {}
|
|
return cls._instance
|
|
|
|
def create_manager(self, analysis_id: str) -> AnalysisProgressManager:
|
|
"""Create a new progress manager"""
|
|
if analysis_id not in self._managers:
|
|
self._managers[analysis_id] = AnalysisProgressManager(analysis_id)
|
|
print(f"🆕 Created progress manager: {analysis_id}")
|
|
return self._managers[analysis_id]
|
|
|
|
def get_manager(self, analysis_id: str) -> Optional[AnalysisProgressManager]:
|
|
"""Get an existing progress manager"""
|
|
return self._managers.get(analysis_id)
|
|
|
|
def remove_manager(self, analysis_id: str):
|
|
"""Remove a progress manager"""
|
|
if analysis_id in self._managers:
|
|
del self._managers[analysis_id]
|
|
print(f"🗑️ Removed progress manager: {analysis_id}")
|
|
|
|
def list_active_analyses(self) -> List[str]:
|
|
"""List all active analysis IDs"""
|
|
return list(self._managers.keys())
|
|
|
|
|
|
# Global tracker instance
|
|
progress_tracker = GlobalProgressTracker()
|
|
|