py-microservice/auth.py

127 lines
3.9 KiB
Python
Raw Normal View History

2023-10-04 20:25:21 +00:00
"""
Authentication module.
Its main export is the `BBPermission` class, used in `permission_classes=[..., BBPermission]`
Further information: [strawberry permssions](https://strawberry.rocks/docs/guides/permissions)
"""
# stdlib imports
2023-10-04 19:03:20 +00:00
from functools import cached_property
import logging
2023-10-04 20:25:21 +00:00
import typing
2023-10-04 19:03:20 +00:00
2023-10-04 20:25:21 +00:00
# dependencies imports
2023-10-04 19:03:20 +00:00
import httpx
2023-10-04 20:25:21 +00:00
2023-10-04 19:03:20 +00:00
from jose import jws, jwt
2023-10-04 20:25:21 +00:00
from strawberry.permission import BasePermission
2023-10-04 19:03:20 +00:00
from strawberry.fastapi import BaseContext
from strawberry.types import Info as _Info
from strawberry.types.info import RootValueType
2023-10-04 20:25:21 +00:00
# app imports
2023-10-04 19:03:20 +00:00
import consts
import config
2023-10-04 20:25:21 +00:00
# valid signing algorithms.
# This mostly serves to avoid the 'none' exploit, add more algorithms as necessary
# see README.md for further discussion
2023-10-04 19:03:20 +00:00
VALID_ALGS = ['HS256', 'RS256']
PERMISSIONS_KEY = config.str_val('AUTH_PERMISSIONS_KEY')
OLS_PREFIX = config.str_val('AUTH_OLS_PREFIX')
2023-10-04 20:25:21 +00:00
logger = logging.getLogger(consts.LOG_ROOT)
2023-10-04 19:03:20 +00:00
logger.debug('loading JWKS data')
# fetch IdP configuration, extract JWKS url from there, and load its contents
idp_url = config.str_val('AUTH_IDP_URL')
config_url = idp_url + '/.well-known/openid-configuration'
logger.debug('loading openid configuration from %s', config_url)
config = httpx.get(config_url).json()
jwks_url = config['jwks_uri']
logger.debug('loading JWKS data from %s', jwks_url)
jwks = httpx.get(jwks_url).json()
class Context(BaseContext):
2023-10-04 20:25:21 +00:00
"""
Context class that adds JWT data (in bearer token form)
"""
2023-10-04 19:03:20 +00:00
@cached_property
def token(self) -> str | None:
2023-10-04 20:25:21 +00:00
"""
JWT bearer token property
"""
# no request means we cannot access the Authorization header -> token is None
2023-10-04 19:03:20 +00:00
if not self.request:
return None
2023-10-04 20:25:21 +00:00
# extract bearer token data from header, if available
2023-10-04 19:03:20 +00:00
token = self.request.headers.get("Authorization", None)
if token:
token = token.split("Bearer ")[1]
return token
2023-10-04 20:25:21 +00:00
# define Info class that contains our additional Context data
# this class is automatically populated when part of a corresponding method signature,
# e.g. `has_permission` or `@strawberry.___`
2023-10-04 19:03:20 +00:00
Info = _Info[Context, RootValueType]
async def get_context() -> Context:
2023-10-04 20:25:21 +00:00
"""
helper method to asynchronously access context data
"""
2023-10-04 19:03:20 +00:00
return Context()
2023-10-04 20:25:21 +00:00
class BBPermission(BasePermission):
"""
Generic permission check that checks the operation/field name against a list of
permissions supplied in the JWT's claims under `config.AUTH_PERMISSIONS_KEY`. The
permissions are prefixed with `config.AUTH_OLS_PREFIX`.
See `README.md` for further information.
"""
message = "Permission denied"
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
operation = info.field_name
token = info.context.token
# no token present -> permission denied
if token is None:
logger.info("no token")
return False
headers = jwt.get_unverified_headers(token)
logger.debug('headers: %s', headers)
claims = jwt.get_unverified_claims(token)
logger.debug('claims: %s', claims)
permissions = claims.get("basebox/permissions", [])
algorithm = headers.get('alg')
# invalid algorithm -> permission denied
if not algorithm in VALID_ALGS:
logger.info('Invalid signing algorithm: %s', algorithm)
return False
token_valid = jws.verify(token=token, key=jwks, algorithms=algorithm)
# token verification failed -> permission denied
if not token_valid:
logger.info("invalid token")
return False
# check whether the requested operation is in the list of valid operations
namespaced_method = f'{OLS_PREFIX}{operation}'
logger.debug("verify required permission: %s against claims: %s",
namespaced_method, permissions)
has_permission = namespaced_method in permissions
return has_permission