| """
|
| Authentication Middleware for AegisLM
|
|
|
| Provides FastAPI middleware for authentication and authorization.
|
| """
|
|
|
| import hashlib
|
| import jwt
|
| import uuid
|
| from datetime import datetime, timedelta
|
| from typing import Optional
|
| from dataclasses import dataclass
|
|
|
| from fastapi import Depends, HTTPException, Request, status
|
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| from sqlalchemy import select
|
| from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
| from backend.db.models import User, Tenant, APIKey
|
| from backend.db.session import get_db_session
|
| from security.rbac import RBACContext, Role
|
| from security.tenant_scope import TenantScope, set_tenant_scope
|
|
|
|
|
|
|
| security = HTTPBearer()
|
|
|
|
|
| def _get_jwt_secret() -> str:
|
| """Get JWT secret from secret manager."""
|
| from security.secret_manager import get_jwt_secret
|
| return get_jwt_secret()
|
|
|
| def _get_jwt_algorithm() -> str:
|
| """Get JWT algorithm from secret manager."""
|
| from security.secret_manager import get_secret_manager
|
| return get_secret_manager().get_jwt_algorithm()
|
|
|
| def _get_jwt_expiration_hours() -> int:
|
| """Get JWT expiration hours from secret manager."""
|
| from security.secret_manager import get_secret_manager
|
| return get_secret_manager().get_jwt_expiration_hours()
|
|
|
|
|
| @dataclass
|
| class AuthenticatedUser:
|
| """Represents an authenticated user."""
|
| user_id: uuid.UUID
|
| tenant_id: uuid.UUID
|
| email: str
|
| role: Role
|
| is_api_client: bool = False
|
|
|
|
|
| def hash_api_key(api_key: str) -> str:
|
| """Hash an API key for storage/comparison."""
|
| return hashlib.sha256(api_key.encode()).hexdigest()
|
|
|
|
|
| def create_jwt_token(
|
| user_id: uuid.UUID,
|
| tenant_id: uuid.UUID,
|
| email: str,
|
| role: str,
|
| expires_delta: Optional[timedelta] = None,
|
| ) -> str:
|
| """
|
| Create a JWT token for a user.
|
|
|
| Args:
|
| user_id: User ID
|
| tenant_id: Tenant ID
|
| email: User email
|
| role: User role
|
| expires_delta: Token expiration time delta
|
|
|
| Returns:
|
| JWT token string
|
| """
|
| if expires_delta is None:
|
| expires_delta = timedelta(hours=_get_jwt_expiration_hours())
|
|
|
| expire = datetime.utcnow() + expires_delta
|
|
|
| payload = {
|
| "sub": str(user_id),
|
| "tenant_id": str(tenant_id),
|
| "email": email,
|
| "role": role,
|
| "exp": expire,
|
| "iat": datetime.utcnow(),
|
| }
|
|
|
| return jwt.encode(payload, _get_jwt_secret(), algorithm=_get_jwt_algorithm())
|
|
|
|
|
| def decode_jwt_token(token: str) -> dict:
|
| """
|
| Decode and validate a JWT token.
|
|
|
| Args:
|
| token: JWT token string
|
|
|
| Returns:
|
| Decoded token payload
|
|
|
| Raises:
|
| HTTPException: If token is invalid or expired
|
| """
|
| try:
|
| payload = jwt.decode(
|
| token,
|
| _get_jwt_secret(),
|
| algorithms=[_get_jwt_algorithm()]
|
| )
|
| return payload
|
| except jwt.ExpiredSignatureError:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="Token has expired",
|
| headers={"WWW-Authenticate": "Bearer"},
|
| )
|
| except jwt.InvalidTokenError:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="Invalid token",
|
| headers={"WWW-Authenticate": "Bearer"},
|
| )
|
|
|
|
|
| async def get_db() -> AsyncSession:
|
| """Get database session."""
|
| async for session in get_db_session():
|
| yield session
|
|
|
|
|
| async def get_current_user_from_token(
|
| token: str,
|
| ) -> tuple[uuid.UUID, uuid.UUID, str, Role]:
|
| """
|
| Validate JWT token and extract user info WITHOUT database access.
|
|
|
| This function validates the token signature and expiration ONLY.
|
| It does NOT query the database.
|
|
|
| Returns:
|
| Tuple of (user_id, tenant_id, email, role)
|
|
|
| Raises:
|
| HTTPException: If token is invalid or expired
|
| """
|
|
|
| payload = decode_jwt_token(token)
|
|
|
|
|
| user_id = uuid.UUID(payload["sub"])
|
| tenant_id = uuid.UUID(payload["tenant_id"])
|
| email = payload["email"]
|
| role = Role(payload["role"])
|
|
|
| return user_id, tenant_id, email, role
|
|
|
|
|
| async def get_current_user(
|
| credentials: HTTPAuthorizationCredentials = Depends(security),
|
| db: AsyncSession = Depends(get_db),
|
| ) -> AuthenticatedUser:
|
| """
|
| FastAPI dependency to get the current authenticated user.
|
|
|
| CRITICAL: This validates the JWT signature FIRST (no DB),
|
| then only queries DB if token is valid.
|
|
|
| This ensures 401 is returned BEFORE any database access for
|
| unauthenticated requests.
|
| """
|
|
|
| token = credentials.credentials
|
| user_id, tenant_id, email, role = await get_current_user_from_token(token)
|
|
|
|
|
| query = select(User).where(
|
| User.id == user_id,
|
| User.tenant_id == tenant_id,
|
| User.active == True,
|
| )
|
| result = await db.execute(query)
|
| user = result.scalar_one_or_none()
|
|
|
| if user is None:
|
| raise HTTPException(
|
| status_code=status.HTTP_401_UNAUTHORIZED,
|
| detail="User not found or inactive",
|
| headers={"WWW-Authenticate": "Bearer"},
|
| )
|
|
|
| return AuthenticatedUser(
|
| user_id=user.id,
|
| tenant_id=user.tenant_id,
|
| email=user.email,
|
| role=Role(user.role),
|
| is_api_client=False,
|
| )
|
|
|
|
|
| async def get_current_user_optional(
|
| request: Request,
|
| db: AsyncSession = Depends(get_db),
|
| ) -> Optional[AuthenticatedUser]:
|
| """
|
| FastAPI dependency to get the current authenticated user, optionally.
|
|
|
| Returns None if no valid authentication is provided.
|
| """
|
|
|
| auth_header = request.headers.get("Authorization")
|
| if auth_header and auth_header.startswith("Bearer "):
|
| token = auth_header[7:]
|
| try:
|
| payload = decode_jwt_token(token)
|
| user_id = uuid.UUID(payload["sub"])
|
| tenant_id = uuid.UUID(payload["tenant_id"])
|
|
|
| query = select(User).where(
|
| User.id == user_id,
|
| User.tenant_id == tenant_id,
|
| User.active == True,
|
| )
|
| result = await db.execute(query)
|
| user = result.scalar_one_or_none()
|
|
|
| if user:
|
| return AuthenticatedUser(
|
| user_id=user.id,
|
| tenant_id=user.tenant_id,
|
| email=user.email,
|
| role=Role(user.role),
|
| is_api_client=False,
|
| )
|
| except Exception:
|
| pass
|
|
|
|
|
| api_key = request.headers.get("X-API-Key")
|
| if api_key:
|
| return await verify_api_key(db, api_key)
|
|
|
| return None
|
|
|
|
|
| async def verify_api_key(
|
| db: AsyncSession,
|
| api_key: str,
|
| ) -> Optional[AuthenticatedUser]:
|
| """
|
| Verify an API key and return the associated user.
|
|
|
| Args:
|
| db: Database session
|
| api_key: API key to verify
|
|
|
| Returns:
|
| AuthenticatedUser if valid, None otherwise
|
| """
|
| key_hash = hash_api_key(api_key)
|
|
|
| query = select(APIKey).where(
|
| APIKey.key_hash == key_hash,
|
| APIKey.active == True,
|
| )
|
| result = await db.execute(query)
|
| api_key_obj = result.scalar_one_or_none()
|
|
|
| if api_key_obj is None:
|
| return None
|
|
|
|
|
| api_key_obj.last_used = datetime.utcnow()
|
| await db.commit()
|
|
|
|
|
| query = select(Tenant).where(Tenant.id == api_key_obj.tenant_id)
|
| result = await db.execute(query)
|
| tenant = result.scalar_one_or_none()
|
|
|
| if tenant is None or not tenant.active:
|
| return None
|
|
|
| return AuthenticatedUser(
|
| user_id=api_key_obj.id,
|
| tenant_id=api_key_obj.tenant_id,
|
| email=f"api:{api_key_obj.owner}",
|
| role=Role.API_CLIENT,
|
| is_api_client=True,
|
| )
|
|
|
|
|
| async def get_current_tenant(
|
| user: AuthenticatedUser = Depends(get_current_user),
|
| ) -> uuid.UUID:
|
| """Get the current tenant ID from the authenticated user."""
|
| return user.tenant_id
|
|
|
|
|
| class AuthMiddleware:
|
| """
|
| Authentication middleware for FastAPI.
|
|
|
| Provides request authentication and sets up tenant context.
|
| """
|
|
|
| @staticmethod
|
| def hash_api_key(api_key: str) -> str:
|
| """Hash an API key for storage/comparison."""
|
| return hashlib.sha256(api_key.encode()).hexdigest()
|
|
|
| @staticmethod
|
| async def verify_api_key(db: AsyncSession, api_key: str) -> Optional[AuthenticatedUser]:
|
| """
|
| Verify an API key and return the associated user.
|
| """
|
| return await verify_api_key(db, api_key)
|
|
|
| @staticmethod
|
| async def authenticate_request(
|
| request: Request,
|
| db: AsyncSession,
|
| ) -> Optional[AuthenticatedUser]:
|
| """
|
| Authenticate a request using either JWT or API key.
|
|
|
| Checks Authorization header for Bearer token or X-API-Key.
|
| """
|
|
|
| auth_header = request.headers.get("Authorization")
|
| if auth_header and auth_header.startswith("Bearer "):
|
| token = auth_header[7:]
|
| try:
|
| payload = decode_jwt_token(token)
|
| user_id = uuid.UUID(payload["sub"])
|
| tenant_id = uuid.UUID(payload["tenant_id"])
|
|
|
| query = select(User).where(
|
| User.id == user_id,
|
| User.tenant_id == tenant_id,
|
| User.active == True,
|
| )
|
| result = await db.execute(query)
|
| user = result.scalar_one_or_none()
|
|
|
| if user:
|
| return AuthenticatedUser(
|
| user_id=user.id,
|
| tenant_id=user.tenant_id,
|
| email=user.email,
|
| role=Role(user.role),
|
| is_api_client=False,
|
| )
|
| except Exception:
|
| pass
|
|
|
|
|
| api_key = request.headers.get("X-API-Key")
|
| if api_key:
|
| return await verify_api_key(db, api_key)
|
|
|
| return None
|
|
|
|
|
| class TenantContextMiddleware:
|
| """
|
| Middleware to set up tenant context for each request.
|
|
|
| This ensures all database queries are properly scoped to the
|
| current tenant.
|
| """
|
|
|
| def __init__(self, app):
|
| self.app = app
|
|
|
| async def __call__(self, scope, receive, send):
|
| if scope["type"] != "http":
|
| await self.app(scope, receive, send)
|
| return
|
|
|
|
|
|
|
|
|
| await self.app(scope, receive, send)
|
|
|
|
|
| async def require_role(required_role: Role):
|
| """
|
| FastAPI dependency to require a specific role.
|
|
|
| Usage:
|
| @router.get("/admin/users")
|
| async def list_users(user: AuthenticatedUser = Depends(require_role(Role.ADMIN))):
|
| ...
|
| """
|
| async def role_checker(user: AuthenticatedUser = Depends(get_current_user)):
|
| if user.role != required_role:
|
| raise HTTPException(
|
| status_code=status.HTTP_403_FORBIDDEN,
|
| detail=f"Requires {required_role.value} role"
|
| )
|
| return user
|
| return role_checker
|
|
|
|
|
| async def require_permission(permission: str):
|
| """
|
| FastAPI dependency to require a specific permission.
|
|
|
| Usage:
|
| @router.post("/jobs")
|
| async def create_job(
|
| user: AuthenticatedUser = Depends(require_permission("create_job"))
|
| ):
|
| ...
|
| """
|
| async def permission_checker(user: AuthenticatedUser = Depends(get_current_user)):
|
|
|
| from security.rbac import RBAC, Permission
|
|
|
|
|
| try:
|
| perm = Permission(permission)
|
| except ValueError:
|
| raise HTTPException(
|
| status_code=status.HTTP_400_BAD_REQUEST,
|
| detail=f"Invalid permission: {permission}"
|
| )
|
|
|
|
|
| if not RBAC.has_permission(user.role, perm):
|
| raise HTTPException(
|
| status_code=status.HTTP_403_FORBIDDEN,
|
| detail=f"Requires {permission} permission"
|
| )
|
|
|
| return user
|
| return permission_checker
|
|
|
|
|
| def create_rbac_context(user: AuthenticatedUser) -> RBACContext:
|
| """
|
| Create an RBAC context from an authenticated user.
|
|
|
| This can be stored in request state for easy access.
|
| """
|
| return RBACContext(
|
| user_id=user.user_id,
|
| tenant_id=user.tenant_id,
|
| role=user.role,
|
| )
|
|
|
|
|
| async def setup_tenant_context(
|
| user: AuthenticatedUser = Depends(get_current_user),
|
| ) -> TenantScope:
|
| """
|
| Set up tenant context for the current request.
|
|
|
| This ensures all subsequent database queries are scoped to the tenant.
|
| """
|
| scope = TenantScope(tenant_id=user.tenant_id)
|
| set_tenant_scope(scope)
|
| return scope
|
|
|