saas-market-analysis-dubai/apps/core/middleware.py
2025-09-17 03:04:22 +05:30

205 lines
6.7 KiB
Python

"""
Custom middleware for API rate limiting and logging.
"""
import time
import json
from django.utils.deprecation import MiddlewareMixin
from django.core.cache import cache
from django.http import JsonResponse
from django.conf import settings
from .models import APIUsage, AuditLog
from django.contrib.auth import get_user_model
User = get_user_model()
class APIRateLimitMiddleware(MiddlewareMixin):
"""
Rate limiting middleware for API endpoints.
"""
def process_request(self, request):
# Only apply to API endpoints
if not request.path.startswith('/api/'):
return None
# Skip rate limiting for admin users
if request.user.is_authenticated and request.user.is_staff:
return None
# Get user (from JWT or API key)
user = getattr(request, 'user', None)
if not user or not user.is_authenticated:
return None
# Get rate limits for user's subscription
subscription_type = getattr(user, 'subscription_type', 'free')
rate_limits = self.get_rate_limits(subscription_type)
if not rate_limits:
return None
# Check rate limits
if not self.check_rate_limit(user, request, rate_limits):
return JsonResponse({
'error': 'Rate limit exceeded',
'message': 'Too many requests. Please try again later.',
'retry_after': 60
}, status=429)
return None
def get_rate_limits(self, subscription_type):
"""Get rate limits for subscription type."""
try:
from .models import APIRateLimit
rate_limit = APIRateLimit.objects.get(subscription_type=subscription_type)
return {
'per_minute': rate_limit.requests_per_minute,
'per_hour': rate_limit.requests_per_hour,
'per_day': rate_limit.requests_per_day,
}
except:
# Default limits if not configured
return {
'per_minute': 60,
'per_hour': 1000,
'per_day': 10000,
}
def check_rate_limit(self, user, request, limits):
"""Check if request is within rate limits."""
now = int(time.time())
user_id = str(user.id)
# Check per-minute limit
minute_key = f'rate_limit:{user_id}:{now // 60}'
minute_count = cache.get(minute_key, 0)
if minute_count >= limits['per_minute']:
return False
cache.set(minute_key, minute_count + 1, 60)
# Check per-hour limit
hour_key = f'rate_limit:{user_id}:{now // 3600}'
hour_count = cache.get(hour_key, 0)
if hour_count >= limits['per_hour']:
return False
cache.set(hour_key, hour_count + 1, 3600)
# Check per-day limit
day_key = f'rate_limit:{user_id}:{now // 86400}'
day_count = cache.get(day_key, 0)
if day_count >= limits['per_day']:
return False
cache.set(day_key, day_count + 1, 86400)
return True
class APILoggingMiddleware(MiddlewareMixin):
"""
Logging middleware for API requests.
"""
def process_request(self, request):
request.start_time = time.time()
return None
def process_response(self, request, response):
# Only log API endpoints
if not request.path.startswith('/api/'):
return response
# Skip logging for certain endpoints
skip_paths = ['/api/schema/', '/api/docs/', '/api/redoc/']
if any(request.path.startswith(path) for path in skip_paths):
return response
# Calculate response time
if hasattr(request, 'start_time'):
response_time = int((time.time() - request.start_time) * 1000)
else:
response_time = 0
# Get user
user = getattr(request, 'user', None)
if not user or not user.is_authenticated:
return response
# Log API usage
try:
APIUsage.objects.create(
user=user,
endpoint=request.path,
method=request.method,
status_code=response.status_code,
response_time_ms=response_time,
ip_address=self.get_client_ip(request),
user_agent=request.META.get('HTTP_USER_AGENT', '')[:500]
)
except Exception as e:
# Log error but don't break the request
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to log API usage: {e}")
return response
def get_client_ip(self, request):
"""Get client IP address."""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
class AuditLoggingMiddleware(MiddlewareMixin):
"""
Audit logging middleware for security and compliance.
"""
def process_request(self, request):
# Only log authenticated requests
if not hasattr(request, 'user') or not request.user.is_authenticated:
return None
# Track sensitive actions
sensitive_actions = [
'POST', 'PUT', 'PATCH', 'DELETE'
]
if request.method in sensitive_actions:
try:
AuditLog.objects.create(
user=request.user,
action=f'{request.method} {request.path}',
resource_type='api_request',
details={
'method': request.method,
'path': request.path,
'query_params': dict(request.GET),
'content_type': request.content_type,
},
ip_address=self.get_client_ip(request),
user_agent=request.META.get('HTTP_USER_AGENT', '')[:500]
)
except Exception as e:
# Log error but don't break the request
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to create audit log: {e}")
return None
def get_client_ip(self, request):
"""Get client IP address."""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip