205 lines
6.7 KiB
Python
205 lines
6.7 KiB
Python
"""
|
|
Custom middleware for API rate limiting and logging.
|
|
"""
|
|
import time
|
|
import json
|
|
from django.utils.deprecation import MiddlewareMixin
|
|
from django.core.cache import cache
|
|
from django.http import JsonResponse
|
|
from django.conf import settings
|
|
from .models import APIUsage, AuditLog
|
|
from django.contrib.auth import get_user_model
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
class APIRateLimitMiddleware(MiddlewareMixin):
|
|
"""
|
|
Rate limiting middleware for API endpoints.
|
|
"""
|
|
|
|
def process_request(self, request):
|
|
# Only apply to API endpoints
|
|
if not request.path.startswith('/api/'):
|
|
return None
|
|
|
|
# Skip rate limiting for admin users
|
|
if request.user.is_authenticated and request.user.is_staff:
|
|
return None
|
|
|
|
# Get user (from JWT or API key)
|
|
user = getattr(request, 'user', None)
|
|
if not user or not user.is_authenticated:
|
|
return None
|
|
|
|
# Get rate limits for user's subscription
|
|
subscription_type = getattr(user, 'subscription_type', 'free')
|
|
rate_limits = self.get_rate_limits(subscription_type)
|
|
|
|
if not rate_limits:
|
|
return None
|
|
|
|
# Check rate limits
|
|
if not self.check_rate_limit(user, request, rate_limits):
|
|
return JsonResponse({
|
|
'error': 'Rate limit exceeded',
|
|
'message': 'Too many requests. Please try again later.',
|
|
'retry_after': 60
|
|
}, status=429)
|
|
|
|
return None
|
|
|
|
def get_rate_limits(self, subscription_type):
|
|
"""Get rate limits for subscription type."""
|
|
try:
|
|
from .models import APIRateLimit
|
|
rate_limit = APIRateLimit.objects.get(subscription_type=subscription_type)
|
|
return {
|
|
'per_minute': rate_limit.requests_per_minute,
|
|
'per_hour': rate_limit.requests_per_hour,
|
|
'per_day': rate_limit.requests_per_day,
|
|
}
|
|
except:
|
|
# Default limits if not configured
|
|
return {
|
|
'per_minute': 60,
|
|
'per_hour': 1000,
|
|
'per_day': 10000,
|
|
}
|
|
|
|
def check_rate_limit(self, user, request, limits):
|
|
"""Check if request is within rate limits."""
|
|
now = int(time.time())
|
|
user_id = str(user.id)
|
|
|
|
# Check per-minute limit
|
|
minute_key = f'rate_limit:{user_id}:{now // 60}'
|
|
minute_count = cache.get(minute_key, 0)
|
|
if minute_count >= limits['per_minute']:
|
|
return False
|
|
cache.set(minute_key, minute_count + 1, 60)
|
|
|
|
# Check per-hour limit
|
|
hour_key = f'rate_limit:{user_id}:{now // 3600}'
|
|
hour_count = cache.get(hour_key, 0)
|
|
if hour_count >= limits['per_hour']:
|
|
return False
|
|
cache.set(hour_key, hour_count + 1, 3600)
|
|
|
|
# Check per-day limit
|
|
day_key = f'rate_limit:{user_id}:{now // 86400}'
|
|
day_count = cache.get(day_key, 0)
|
|
if day_count >= limits['per_day']:
|
|
return False
|
|
cache.set(day_key, day_count + 1, 86400)
|
|
|
|
return True
|
|
|
|
|
|
class APILoggingMiddleware(MiddlewareMixin):
|
|
"""
|
|
Logging middleware for API requests.
|
|
"""
|
|
|
|
def process_request(self, request):
|
|
request.start_time = time.time()
|
|
return None
|
|
|
|
def process_response(self, request, response):
|
|
# Only log API endpoints
|
|
if not request.path.startswith('/api/'):
|
|
return response
|
|
|
|
# Skip logging for certain endpoints
|
|
skip_paths = ['/api/schema/', '/api/docs/', '/api/redoc/']
|
|
if any(request.path.startswith(path) for path in skip_paths):
|
|
return response
|
|
|
|
# Calculate response time
|
|
if hasattr(request, 'start_time'):
|
|
response_time = int((time.time() - request.start_time) * 1000)
|
|
else:
|
|
response_time = 0
|
|
|
|
# Get user
|
|
user = getattr(request, 'user', None)
|
|
if not user or not user.is_authenticated:
|
|
return response
|
|
|
|
# Log API usage
|
|
try:
|
|
APIUsage.objects.create(
|
|
user=user,
|
|
endpoint=request.path,
|
|
method=request.method,
|
|
status_code=response.status_code,
|
|
response_time_ms=response_time,
|
|
ip_address=self.get_client_ip(request),
|
|
user_agent=request.META.get('HTTP_USER_AGENT', '')[:500]
|
|
)
|
|
except Exception as e:
|
|
# Log error but don't break the request
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Failed to log API usage: {e}")
|
|
|
|
return response
|
|
|
|
def get_client_ip(self, request):
|
|
"""Get client IP address."""
|
|
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
|
if x_forwarded_for:
|
|
ip = x_forwarded_for.split(',')[0]
|
|
else:
|
|
ip = request.META.get('REMOTE_ADDR')
|
|
return ip
|
|
|
|
|
|
class AuditLoggingMiddleware(MiddlewareMixin):
|
|
"""
|
|
Audit logging middleware for security and compliance.
|
|
"""
|
|
|
|
def process_request(self, request):
|
|
# Only log authenticated requests
|
|
if not hasattr(request, 'user') or not request.user.is_authenticated:
|
|
return None
|
|
|
|
# Track sensitive actions
|
|
sensitive_actions = [
|
|
'POST', 'PUT', 'PATCH', 'DELETE'
|
|
]
|
|
|
|
if request.method in sensitive_actions:
|
|
try:
|
|
AuditLog.objects.create(
|
|
user=request.user,
|
|
action=f'{request.method} {request.path}',
|
|
resource_type='api_request',
|
|
details={
|
|
'method': request.method,
|
|
'path': request.path,
|
|
'query_params': dict(request.GET),
|
|
'content_type': request.content_type,
|
|
},
|
|
ip_address=self.get_client_ip(request),
|
|
user_agent=request.META.get('HTTP_USER_AGENT', '')[:500]
|
|
)
|
|
except Exception as e:
|
|
# Log error but don't break the request
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Failed to create audit log: {e}")
|
|
|
|
return None
|
|
|
|
def get_client_ip(self, request):
|
|
"""Get client IP address."""
|
|
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
|
if x_forwarded_for:
|
|
ip = x_forwarded_for.split(',')[0]
|
|
else:
|
|
ip = request.META.get('REMOTE_ADDR')
|
|
return ip
|
|
|