aegislm / security /auth_middleware.py
ACA050's picture
Upload 57 files
f2c6053 verified
"""
Authentication Middleware for AegisLM
Provides FastAPI middleware for authentication and authorization.
"""
import hashlib
import jwt
import uuid
from datetime import datetime, timedelta
from typing import Optional
from dataclasses import dataclass
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from backend.db.models import User, Tenant, APIKey
from backend.db.session import get_db_session
from security.rbac import RBACContext, Role
from security.tenant_scope import TenantScope, set_tenant_scope
# Security scheme
security = HTTPBearer()
# JWT configuration - loaded from secret manager
def _get_jwt_secret() -> str:
"""Get JWT secret from secret manager."""
from security.secret_manager import get_jwt_secret
return get_jwt_secret()
def _get_jwt_algorithm() -> str:
"""Get JWT algorithm from secret manager."""
from security.secret_manager import get_secret_manager
return get_secret_manager().get_jwt_algorithm()
def _get_jwt_expiration_hours() -> int:
"""Get JWT expiration hours from secret manager."""
from security.secret_manager import get_secret_manager
return get_secret_manager().get_jwt_expiration_hours()
@dataclass
class AuthenticatedUser:
"""Represents an authenticated user."""
user_id: uuid.UUID
tenant_id: uuid.UUID
email: str
role: Role
is_api_client: bool = False
def hash_api_key(api_key: str) -> str:
"""Hash an API key for storage/comparison."""
return hashlib.sha256(api_key.encode()).hexdigest()
def create_jwt_token(
user_id: uuid.UUID,
tenant_id: uuid.UUID,
email: str,
role: str,
expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a JWT token for a user.
Args:
user_id: User ID
tenant_id: Tenant ID
email: User email
role: User role
expires_delta: Token expiration time delta
Returns:
JWT token string
"""
if expires_delta is None:
expires_delta = timedelta(hours=_get_jwt_expiration_hours())
expire = datetime.utcnow() + expires_delta
payload = {
"sub": str(user_id),
"tenant_id": str(tenant_id),
"email": email,
"role": role,
"exp": expire,
"iat": datetime.utcnow(),
}
return jwt.encode(payload, _get_jwt_secret(), algorithm=_get_jwt_algorithm())
def decode_jwt_token(token: str) -> dict:
"""
Decode and validate a JWT token.
Args:
token: JWT token string
Returns:
Decoded token payload
Raises:
HTTPException: If token is invalid or expired
"""
try:
payload = jwt.decode(
token,
_get_jwt_secret(),
algorithms=[_get_jwt_algorithm()]
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_db() -> AsyncSession:
"""Get database session."""
async for session in get_db_session():
yield session
async def get_current_user_from_token(
token: str,
) -> tuple[uuid.UUID, uuid.UUID, str, Role]:
"""
Validate JWT token and extract user info WITHOUT database access.
This function validates the token signature and expiration ONLY.
It does NOT query the database.
Returns:
Tuple of (user_id, tenant_id, email, role)
Raises:
HTTPException: If token is invalid or expired
"""
# Step 1: Validate JWT signature and expiration (NO DB ACCESS)
payload = decode_jwt_token(token)
# Step 2: Extract claims from validated token
user_id = uuid.UUID(payload["sub"])
tenant_id = uuid.UUID(payload["tenant_id"])
email = payload["email"]
role = Role(payload["role"])
return user_id, tenant_id, email, role
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> AuthenticatedUser:
"""
FastAPI dependency to get the current authenticated user.
CRITICAL: This validates the JWT signature FIRST (no DB),
then only queries DB if token is valid.
This ensures 401 is returned BEFORE any database access for
unauthenticated requests.
"""
# Step 1: Validate token WITHOUT database (fails fast for invalid tokens)
token = credentials.credentials
user_id, tenant_id, email, role = await get_current_user_from_token(token)
# Step 2: Only query DB AFTER token validation succeeds
query = select(User).where(
User.id == user_id,
User.tenant_id == tenant_id,
User.active == True,
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
headers={"WWW-Authenticate": "Bearer"},
)
return AuthenticatedUser(
user_id=user.id,
tenant_id=user.tenant_id,
email=user.email,
role=Role(user.role),
is_api_client=False,
)
async def get_current_user_optional(
request: Request,
db: AsyncSession = Depends(get_db),
) -> Optional[AuthenticatedUser]:
"""
FastAPI dependency to get the current authenticated user, optionally.
Returns None if no valid authentication is provided.
"""
# Check for Bearer token
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload = decode_jwt_token(token)
user_id = uuid.UUID(payload["sub"])
tenant_id = uuid.UUID(payload["tenant_id"])
query = select(User).where(
User.id == user_id,
User.tenant_id == tenant_id,
User.active == True,
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if user:
return AuthenticatedUser(
user_id=user.id,
tenant_id=user.tenant_id,
email=user.email,
role=Role(user.role),
is_api_client=False,
)
except Exception:
pass
# Check for API key
api_key = request.headers.get("X-API-Key")
if api_key:
return await verify_api_key(db, api_key)
return None
async def verify_api_key(
db: AsyncSession,
api_key: str,
) -> Optional[AuthenticatedUser]:
"""
Verify an API key and return the associated user.
Args:
db: Database session
api_key: API key to verify
Returns:
AuthenticatedUser if valid, None otherwise
"""
key_hash = hash_api_key(api_key)
query = select(APIKey).where(
APIKey.key_hash == key_hash,
APIKey.active == True,
)
result = await db.execute(query)
api_key_obj = result.scalar_one_or_none()
if api_key_obj is None:
return None
# Update last used
api_key_obj.last_used = datetime.utcnow()
await db.commit()
# Get the tenant
query = select(Tenant).where(Tenant.id == api_key_obj.tenant_id)
result = await db.execute(query)
tenant = result.scalar_one_or_none()
if tenant is None or not tenant.active:
return None
return AuthenticatedUser(
user_id=api_key_obj.id, # Use API key ID as user_id for API clients
tenant_id=api_key_obj.tenant_id,
email=f"api:{api_key_obj.owner}",
role=Role.API_CLIENT,
is_api_client=True,
)
async def get_current_tenant(
user: AuthenticatedUser = Depends(get_current_user),
) -> uuid.UUID:
"""Get the current tenant ID from the authenticated user."""
return user.tenant_id
class AuthMiddleware:
"""
Authentication middleware for FastAPI.
Provides request authentication and sets up tenant context.
"""
@staticmethod
def hash_api_key(api_key: str) -> str:
"""Hash an API key for storage/comparison."""
return hashlib.sha256(api_key.encode()).hexdigest()
@staticmethod
async def verify_api_key(db: AsyncSession, api_key: str) -> Optional[AuthenticatedUser]:
"""
Verify an API key and return the associated user.
"""
return await verify_api_key(db, api_key)
@staticmethod
async def authenticate_request(
request: Request,
db: AsyncSession,
) -> Optional[AuthenticatedUser]:
"""
Authenticate a request using either JWT or API key.
Checks Authorization header for Bearer token or X-API-Key.
"""
# Check for Bearer token
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload = decode_jwt_token(token)
user_id = uuid.UUID(payload["sub"])
tenant_id = uuid.UUID(payload["tenant_id"])
query = select(User).where(
User.id == user_id,
User.tenant_id == tenant_id,
User.active == True,
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if user:
return AuthenticatedUser(
user_id=user.id,
tenant_id=user.tenant_id,
email=user.email,
role=Role(user.role),
is_api_client=False,
)
except Exception:
pass
# Check for API key
api_key = request.headers.get("X-API-Key")
if api_key:
return await verify_api_key(db, api_key)
return None
class TenantContextMiddleware:
"""
Middleware to set up tenant context for each request.
This ensures all database queries are properly scoped to the
current tenant.
"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# TODO: Extract tenant from request and set context
# This would typically be done after authentication
await self.app(scope, receive, send)
async def require_role(required_role: Role):
"""
FastAPI dependency to require a specific role.
Usage:
@router.get("/admin/users")
async def list_users(user: AuthenticatedUser = Depends(require_role(Role.ADMIN))):
...
"""
async def role_checker(user: AuthenticatedUser = Depends(get_current_user)):
if user.role != required_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires {required_role.value} role"
)
return user
return role_checker
async def require_permission(permission: str):
"""
FastAPI dependency to require a specific permission.
Usage:
@router.post("/jobs")
async def create_job(
user: AuthenticatedUser = Depends(require_permission("create_job"))
):
...
"""
async def permission_checker(user: AuthenticatedUser = Depends(get_current_user)):
# Import here to avoid circular imports
from security.rbac import RBAC, Permission
# Convert string permission to enum
try:
perm = Permission(permission)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid permission: {permission}"
)
# Check if user has permission
if not RBAC.has_permission(user.role, perm):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires {permission} permission"
)
return user
return permission_checker
def create_rbac_context(user: AuthenticatedUser) -> RBACContext:
"""
Create an RBAC context from an authenticated user.
This can be stored in request state for easy access.
"""
return RBACContext(
user_id=user.user_id,
tenant_id=user.tenant_id,
role=user.role,
)
async def setup_tenant_context(
user: AuthenticatedUser = Depends(get_current_user),
) -> TenantScope:
"""
Set up tenant context for the current request.
This ensures all subsequent database queries are scoped to the tenant.
"""
scope = TenantScope(tenant_id=user.tenant_id)
set_tenant_scope(scope)
return scope