| """ |
| Query Result Caching System for AegisLM SaaS Backend. |
| |
| Production-ready intelligent query caching with Redis backend, |
| cache invalidation, and performance optimization. |
| """ |
|
|
| import asyncio |
| import json |
| import hashlib |
| import pickle |
| from datetime import datetime, timedelta |
| from typing import Any, Optional, Dict, List, Union, Callable |
| from functools import wraps |
| from sqlalchemy import text |
| from sqlalchemy.ext.asyncio import AsyncSession |
| import logging |
|
|
| from .database import get_redis |
| from .config import settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CacheKey: |
| """Cache key generator for queries.""" |
| |
| @staticmethod |
| def generate(query: str, params: Optional[Dict[str, Any]] = None) -> str: |
| """Generate cache key for query and parameters.""" |
| |
| normalized_query = ' '.join(query.lower().split()) |
| |
| |
| content = normalized_query |
| if params: |
| content += json.dumps(params, sort_keys=True) |
| |
| hash_key = hashlib.sha256(content.encode()).hexdigest() |
| return f"query_cache:{hash_key}" |
| |
| @staticmethod |
| def generate_table_dependency_key(table_name: str) -> str: |
| """Generate table dependency key for cache invalidation.""" |
| return f"table_deps:{table_name}" |
|
|
|
|
| class CacheConfig: |
| """Cache configuration.""" |
| |
| def __init__(self, ttl_seconds: int = 300, max_size: int = 1000, |
| enabled: bool = True, smart_invalidation: bool = True): |
| self.ttl_seconds = ttl_seconds |
| self.max_size = max_size |
| self.enabled = enabled |
| self.smart_invalidation = smart_invalidation |
|
|
|
|
| class QueryCache: |
| """Intelligent query result caching system.""" |
| |
| def __init__(self): |
| self.redis_client = None |
| self.default_config = CacheConfig( |
| ttl_seconds=getattr(settings, 'QUERY_CACHE_TTL', 300), |
| max_size=getattr(settings, 'QUERY_CACHE_MAX_SIZE', 1000), |
| enabled=getattr(settings, 'QUERY_CACHE_ENABLED', True), |
| smart_invalidation=getattr(settings, 'QUERY_CACHE_SMART_INVALIDATION', True) |
| ) |
| |
| |
| self.table_configs = { |
| 'users': CacheConfig(ttl_seconds=600), |
| 'evaluations': CacheConfig(ttl_seconds=180), |
| 'api_keys': CacheConfig(ttl_seconds=900), |
| } |
| |
| |
| self.excluded_patterns = [ |
| 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER', |
| 'TRUNCATE', 'COMMIT', 'ROLLBACK' |
| ] |
| |
| async def get_redis(self): |
| """Get Redis client.""" |
| if not self.redis_client: |
| self.redis_client = await get_redis() |
| return self.redis_client |
| |
| def should_cache_query(self, query: str) -> bool: |
| """Check if query should be cached.""" |
| query_upper = query.upper() |
| |
| |
| if not query_upper.strip().startswith('SELECT'): |
| return False |
| |
| |
| for pattern in self.excluded_patterns: |
| if pattern in query_upper: |
| return False |
| |
| |
| time_functions = ['NOW()', 'CURRENT_TIMESTAMP', 'CURRENT_DATE', 'CURRENT_TIME'] |
| for func in time_functions: |
| if func in query_upper: |
| return False |
| |
| return True |
| |
| def get_cache_config(self, query: str) -> CacheConfig: |
| """Get cache configuration for query.""" |
| |
| tables = self._extract_tables(query) |
| |
| |
| for table in tables: |
| if table in self.table_configs: |
| return self.table_configs[table] |
| |
| return self.default_config |
| |
| def _extract_tables(self, query: str) -> List[str]: |
| """Extract table names from query.""" |
| import re |
| tables = [] |
| |
| |
| from_matches = re.findall(r'FROM\s+(\w+)', query, re.IGNORECASE) |
| tables.extend(from_matches) |
| |
| |
| join_matches = re.findall(r'JOIN\s+(\w+)', query, re.IGNORECASE) |
| tables.extend(join_matches) |
| |
| return list(set(tables)) |
| |
| async def get(self, query: str, params: Optional[Dict[str, Any]] = None) -> Optional[Any]: |
| """Get cached query result.""" |
| try: |
| if not self.should_cache_query(query): |
| return None |
| |
| config = self.get_cache_config(query) |
| if not config.enabled: |
| return None |
| |
| redis_client = await self.get_redis() |
| cache_key = CacheKey.generate(query, params) |
| |
| |
| cached_data = await redis_client.get(cache_key) |
| if cached_data: |
| |
| result = pickle.loads(cached_data) |
| logger.debug(f"Cache hit for query: {cache_key[:16]}...") |
| return result |
| |
| logger.debug(f"Cache miss for query: {cache_key[:16]}...") |
| return None |
| |
| except Exception as e: |
| logger.error(f"Failed to get from cache: {e}") |
| return None |
| |
| async def set(self, query: str, result: Any, params: Optional[Dict[str, Any]] = None) -> bool: |
| """Cache query result.""" |
| try: |
| if not self.should_cache_query(query): |
| return False |
| |
| config = self.get_cache_config(query) |
| if not config.enabled: |
| return False |
| |
| redis_client = await self.get_redis() |
| cache_key = CacheKey.generate(query, params) |
| |
| |
| serialized_result = pickle.dumps(result) |
| |
| |
| await self._enforce_size_limit(redis_client) |
| |
| |
| await redis_client.setex(cache_key, config.ttl_seconds, serialized_result) |
| |
| |
| if config.smart_invalidation: |
| await self._store_table_dependencies(query, cache_key) |
| |
| logger.debug(f"Cached query result: {cache_key[:16]}...") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Failed to cache result: {e}") |
| return False |
| |
| async def invalidate_table(self, table_name: str) -> int: |
| """Invalidate all cache entries dependent on a table.""" |
| try: |
| redis_client = await self.get_redis() |
| dep_key = CacheKey.generate_table_dependency_key(table_name) |
| |
| |
| cache_keys = await redis_client.smembers(dep_key) |
| |
| if cache_keys: |
| |
| await redis_client.delete(*cache_keys) |
| |
| |
| await redis_client.delete(dep_key) |
| |
| logger.info(f"Invalidated {len(cache_keys)} cache entries for table: {table_name}") |
| return len(cache_keys) |
| |
| return 0 |
| |
| except Exception as e: |
| logger.error(f"Failed to invalidate table cache: {e}") |
| return 0 |
| |
| async def invalidate_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> bool: |
| """Invalidate specific query cache.""" |
| try: |
| redis_client = await self.get_redis() |
| cache_key = CacheKey.generate(query, params) |
| |
| result = await redis_client.delete(cache_key) |
| return result > 0 |
| |
| except Exception as e: |
| logger.error(f"Failed to invalidate query cache: {e}") |
| return False |
| |
| async def clear_all(self) -> bool: |
| """Clear all query cache entries.""" |
| try: |
| redis_client = await self.get_redis() |
| |
| |
| cursor = 0 |
| cache_keys = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="query_cache:*", count=100) |
| cache_keys.extend(keys) |
| if cursor == 0: |
| break |
| |
| |
| if cache_keys: |
| await redis_client.delete(*cache_keys) |
| |
| |
| cursor = 0 |
| dep_keys = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="table_deps:*", count=100) |
| dep_keys.extend(keys) |
| if cursor == 0: |
| break |
| |
| if dep_keys: |
| await redis_client.delete(*dep_keys) |
| |
| logger.info(f"Cleared {len(cache_keys)} cache entries") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Failed to clear cache: {e}") |
| return False |
| |
| async def _store_table_dependencies(self, query: str, cache_key: str): |
| """Store table dependencies for smart invalidation.""" |
| try: |
| redis_client = await self.get_redis() |
| tables = self._extract_tables(query) |
| |
| for table in tables: |
| dep_key = CacheKey.generate_table_dependency_key(table) |
| await redis_client.sadd(dep_key, cache_key) |
| |
| await redis_client.expire(dep_key, 3600) |
| |
| except Exception as e: |
| logger.error(f"Failed to store table dependencies: {e}") |
| |
| async def _enforce_size_limit(self, redis_client): |
| """Enforce cache size limit by removing oldest entries.""" |
| try: |
| |
| cursor = 0 |
| cache_keys = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="query_cache:*", count=100) |
| cache_keys.extend(keys) |
| if cursor == 0: |
| break |
| |
| current_size = len(cache_keys) |
| max_size = self.default_config.max_size |
| |
| if current_size >= max_size: |
| |
| |
| key_ttl_pairs = [] |
| for key in cache_keys[:100]: |
| ttl = await redis_client.ttl(key) |
| key_ttl_pairs.append((key, ttl)) |
| |
| |
| key_ttl_pairs.sort(key=lambda x: x[1]) |
| keys_to_remove = [k[0] for k in key_ttl_pairs[:10]] |
| |
| if keys_to_remove: |
| await redis_client.delete(*keys_to_remove) |
| logger.debug(f"Removed {len(keys_to_remove)} old cache entries") |
| |
| except Exception as e: |
| logger.error(f"Failed to enforce size limit: {e}") |
| |
| async def get_cache_stats(self) -> Dict[str, Any]: |
| """Get cache statistics.""" |
| try: |
| redis_client = await self.get_redis() |
| |
| |
| cursor = 0 |
| cache_keys = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="query_cache:*", count=100) |
| cache_keys.extend(keys) |
| if cursor == 0: |
| break |
| |
| |
| cursor = 0 |
| dep_keys = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="table_deps:*", count=100) |
| dep_keys.extend(keys) |
| if cursor == 0: |
| break |
| |
| |
| total_memory = 0 |
| sample_size = min(50, len(cache_keys)) |
| |
| for key in cache_keys[:sample_size]: |
| try: |
| size = await redis_client.memory_usage(key) |
| total_memory += size |
| except: |
| pass |
| |
| |
| if sample_size > 0: |
| avg_size = total_memory / sample_size |
| estimated_total = avg_size * len(cache_keys) |
| else: |
| estimated_total = 0 |
| |
| return { |
| "cache_entries": len(cache_keys), |
| "table_dependencies": len(dep_keys), |
| "estimated_memory_mb": round(estimated_total / (1024 * 1024), 2), |
| "max_size": self.default_config.max_size, |
| "utilization_percent": round((len(cache_keys) / self.default_config.max_size) * 100, 2), |
| "default_ttl_seconds": self.default_config.ttl_seconds, |
| "enabled": self.default_config.enabled |
| } |
| |
| except Exception as e: |
| logger.error(f"Failed to get cache stats: {e}") |
| return {"error": str(e)} |
|
|
|
|
| |
| query_cache = QueryCache() |
|
|
|
|
| |
| def cached_query(ttl_seconds: Optional[int] = None, enabled: bool = True): |
| """Decorator for automatic query result caching.""" |
| def decorator(func: Callable): |
| @wraps(func) |
| async def wrapper(*args, **kwargs): |
| if not enabled: |
| return await func(*args, **kwargs) |
| |
| |
| |
| query = args[0] if args else kwargs.get('query') |
| params = kwargs.get('params') |
| |
| if not query: |
| return await func(*args, **kwargs) |
| |
| |
| cached_result = await query_cache.get(query, params) |
| if cached_result is not None: |
| return cached_result |
| |
| |
| result = await func(*args, **kwargs) |
| await query_cache.set(query, result, params) |
| |
| return result |
| |
| return wrapper |
| return decorator |
|
|
|
|
| |
| class CachedQueryContext: |
| """Context manager for query caching with automatic invalidation.""" |
| |
| def __init__(self, ttl_seconds: Optional[int] = None, tables: Optional[List[str]] = None): |
| self.ttl_seconds = ttl_seconds |
| self.tables = tables or [] |
| self.config = CacheConfig(ttl_seconds=ttl_seconds) if ttl_seconds else query_cache.default_config |
| |
| async def __aenter__(self): |
| return self |
| |
| async def __aexit__(self, exc_type, exc_val, exc_tb): |
| |
| if exc_type is not None and self.tables: |
| for table in self.tables: |
| await query_cache.invalidate_table(table) |
| |
| async def get(self, query: str, params: Optional[Dict[str, Any]] = None) -> Optional[Any]: |
| """Get cached result.""" |
| return await query_cache.get(query, params) |
| |
| async def set(self, query: str, result: Any, params: Optional[Dict[str, Any]] = None) -> bool: |
| """Cache result.""" |
| return await query_cache.set(query, result, params) |
|
|
|
|
| |
| async def warm_cache(common_queries: List[Dict[str, Any]]): |
| """Warm cache with common queries.""" |
| try: |
| from .database import async_engine |
| |
| warmed_count = 0 |
| for query_info in common_queries: |
| query = query_info["query"] |
| params = query_info.get("params") |
| ttl = query_info.get("ttl") |
| |
| |
| async with async_engine.begin() as conn: |
| result = await conn.execute(text(query), params or {}) |
| data = result.fetchall() |
| |
| if ttl: |
| config = CacheConfig(ttl_seconds=ttl) |
| |
| original_config = query_cache.default_config |
| query_cache.default_config = config |
| await query_cache.set(query, data, params) |
| query_cache.default_config = original_config |
| else: |
| await query_cache.set(query, data, params) |
| |
| warmed_count += 1 |
| |
| logger.info(f"Warmed cache with {warmed_count} queries") |
| return warmed_count |
| |
| except Exception as e: |
| logger.error(f"Failed to warm cache: {e}") |
| return 0 |
|
|
|
|
| |
| async def cache_maintenance_task(): |
| """Run scheduled cache maintenance.""" |
| try: |
| |
| stats = await query_cache.get_cache_stats() |
| |
| |
| logger.info(f"Cache stats: {stats['cache_entries']} entries, {stats['estimated_memory_mb']} MB") |
| |
| |
| redis_client = await query_cache.get_redis() |
| cursor = 0 |
| expired_deps = [] |
| |
| while True: |
| cursor, keys = await redis_client.scan(cursor, match="table_deps:*", count=100) |
| for key in keys: |
| ttl = await redis_client.ttl(key) |
| if ttl == -1: |
| await redis_client.expire(key, 3600) |
| if cursor == 0: |
| break |
| |
| logger.info("Cache maintenance completed") |
| |
| except Exception as e: |
| logger.error(f"Cache maintenance failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| |
| async def main(): |
| command = sys.argv[1] if len(sys.argv) > 1 else "help" |
| |
| if command == "stats": |
| stats = await query_cache.get_cache_stats() |
| print(json.dumps(stats, indent=2)) |
| |
| elif command == "clear": |
| success = await query_cache.clear_all() |
| if success: |
| print("✅ Cache cleared successfully") |
| else: |
| print("❌ Failed to clear cache") |
| |
| elif command == "invalidate": |
| table_name = sys.argv[2] if len(sys.argv) > 2 else None |
| if not table_name: |
| print("Error: invalidate requires table name") |
| sys.exit(1) |
| |
| count = await query_cache.invalidate_table(table_name) |
| print(f"Invalidated {count} cache entries for table: {table_name}") |
| |
| elif command == "warm": |
| |
| common_queries = [ |
| {"query": "SELECT COUNT(*) FROM users WHERE is_active = true"}, |
| {"query": "SELECT COUNT(*) FROM evaluations WHERE status = 'completed'"}, |
| {"query": "SELECT COUNT(*) FROM api_keys WHERE is_active = true"} |
| ] |
| |
| warmed = await warm_cache(common_queries) |
| print(f"Warmed {warmed} queries in cache") |
| |
| else: |
| print("Usage: python query_cache.py <command> [args]") |
| print("Commands: stats, clear, invalidate <table>, warm") |
| |
| asyncio.run(main()) |
|
|