""" Redis caching service for the Virtual Board Member AI System. """ import logging import json import hashlib from typing import Optional, Any, Dict, List, Union from datetime import timedelta import redis.asyncio as redis from functools import wraps import pickle from app.core.config import settings logger = logging.getLogger(__name__) class CacheService: """Redis caching service with tenant-aware caching.""" def __init__(self): self.redis_client = None # Initialize Redis client lazily when needed async def _init_redis(self): """Initialize Redis connection.""" try: self.redis_client = redis.from_url( settings.REDIS_URL, encoding="utf-8", decode_responses=False # Keep as bytes for pickle support ) await self.redis_client.ping() logger.info("Redis connection established for cache service") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") self.redis_client = None def _generate_key(self, prefix: str, tenant_id: str, *args, **kwargs) -> str: """Generate cache key with tenant isolation.""" # Create a hash of the arguments for consistent key generation key_parts = [prefix, tenant_id] if args: key_parts.extend([str(arg) for arg in args]) if kwargs: # Sort kwargs for consistent key generation sorted_kwargs = sorted(kwargs.items()) key_parts.extend([f"{k}:{v}" for k, v in sorted_kwargs]) key_string = ":".join(key_parts) return hashlib.md5(key_string.encode()).hexdigest() async def get(self, key: str, tenant_id: str) -> Optional[Any]: """Get value from cache.""" if not self.redis_client: await self._init_redis() try: full_key = f"cache:{tenant_id}:{key}" data = await self.redis_client.get(full_key) if data: # Try to deserialize as JSON first, then pickle try: return json.loads(data.decode()) except (json.JSONDecodeError, UnicodeDecodeError): try: return pickle.loads(data) except pickle.UnpicklingError: logger.warning(f"Failed to deserialize cache data for key: {full_key}") return None return None except Exception as e: logger.error(f"Cache get error: {e}") return None async def set(self, key: str, value: Any, tenant_id: str, expire: Optional[int] = None) -> bool: """Set value in cache with optional expiration.""" if not self.redis_client: await self._init_redis() try: full_key = f"cache:{tenant_id}:{key}" # Try to serialize as JSON first, fallback to pickle try: data = json.dumps(value).encode() except (TypeError, ValueError): data = pickle.dumps(value) if expire: await self.redis_client.setex(full_key, expire, data) else: await self.redis_client.set(full_key, data) return True except Exception as e: logger.error(f"Cache set error: {e}") return False async def delete(self, key: str, tenant_id: str) -> bool: """Delete value from cache.""" if not self.redis_client: return False try: full_key = f"cache:{tenant_id}:{key}" result = await self.redis_client.delete(full_key) return result > 0 except Exception as e: logger.error(f"Cache delete error: {e}") return False async def delete_pattern(self, pattern: str, tenant_id: str) -> int: """Delete all keys matching pattern for a tenant.""" if not self.redis_client: return 0 try: full_pattern = f"cache:{tenant_id}:{pattern}" keys = await self.redis_client.keys(full_pattern) if keys: result = await self.redis_client.delete(*keys) logger.info(f"Deleted {result} cache keys matching pattern: {full_pattern}") return result return 0 except Exception as e: logger.error(f"Cache delete pattern error: {e}") return 0 async def clear_tenant_cache(self, tenant_id: str) -> int: """Clear all cache entries for a specific tenant.""" return await self.delete_pattern("*", tenant_id) async def get_many(self, keys: List[str], tenant_id: str) -> Dict[str, Any]: """Get multiple values from cache.""" if not self.redis_client: return {} try: full_keys = [f"cache:{tenant_id}:{key}" for key in keys] values = await self.redis_client.mget(full_keys) result = {} for key, value in zip(keys, values): if value is not None: try: result[key] = json.loads(value.decode()) except (json.JSONDecodeError, UnicodeDecodeError): try: result[key] = pickle.loads(value) except pickle.UnpicklingError: logger.warning(f"Failed to deserialize cache data for key: {key}") return result except Exception as e: logger.error(f"Cache get_many error: {e}") return {} async def set_many(self, data: Dict[str, Any], tenant_id: str, expire: Optional[int] = None) -> bool: """Set multiple values in cache.""" if not self.redis_client: return False try: pipeline = self.redis_client.pipeline() for key, value in data.items(): full_key = f"cache:{tenant_id}:{key}" try: serialized_value = json.dumps(value).encode() except (TypeError, ValueError): serialized_value = pickle.dumps(value) if expire: pipeline.setex(full_key, expire, serialized_value) else: pipeline.set(full_key, serialized_value) await pipeline.execute() return True except Exception as e: logger.error(f"Cache set_many error: {e}") return False async def increment(self, key: str, tenant_id: str, amount: int = 1) -> Optional[int]: """Increment a counter in cache.""" if not self.redis_client: return None try: full_key = f"cache:{tenant_id}:{key}" result = await self.redis_client.incrby(full_key, amount) return result except Exception as e: logger.error(f"Cache increment error: {e}") return None async def expire(self, key: str, tenant_id: str, seconds: int) -> bool: """Set expiration for a cache key.""" if not self.redis_client: return False try: full_key = f"cache:{tenant_id}:{key}" result = await self.redis_client.expire(full_key, seconds) return result except Exception as e: logger.error(f"Cache expire error: {e}") return False # Global cache service instance cache_service = CacheService() def cache_result(prefix: str, expire: Optional[int] = 3600): """Decorator to cache function results with tenant isolation.""" def decorator(func): @wraps(func) async def wrapper(*args, tenant_id: str = None, **kwargs): if not tenant_id: # Try to extract tenant_id from args or kwargs if args and hasattr(args[0], 'tenant_id'): tenant_id = args[0].tenant_id elif 'tenant_id' in kwargs: tenant_id = kwargs['tenant_id'] else: # If no tenant_id, skip caching return await func(*args, **kwargs) # Generate cache key cache_key = cache_service._generate_key(prefix, tenant_id, *args, **kwargs) # Try to get from cache cached_result = await cache_service.get(cache_key, tenant_id) if cached_result is not None: logger.debug(f"Cache hit for key: {cache_key}") return cached_result # Execute function and cache result result = await func(*args, **kwargs) await cache_service.set(cache_key, result, tenant_id, expire) logger.debug(f"Cache miss, stored result for key: {cache_key}") return result return wrapper return decorator def invalidate_cache(prefix: str, pattern: str = "*"): """Decorator to invalidate cache entries after function execution.""" def decorator(func): @wraps(func) async def wrapper(*args, tenant_id: str = None, **kwargs): result = await func(*args, **kwargs) if tenant_id: await cache_service.delete_pattern(pattern, tenant_id) logger.debug(f"Invalidated cache for tenant {tenant_id}, pattern: {pattern}") return result return wrapper return decorator