84 lines
2.3 KiB
Python
84 lines
2.3 KiB
Python
from functools import cached_property
|
|
import logging
|
|
|
|
import httpx
|
|
from jose import jws, jwt
|
|
import strawberry
|
|
from strawberry.fastapi import BaseContext
|
|
from strawberry.types import Info as _Info
|
|
from strawberry.types.info import RootValueType
|
|
|
|
import consts
|
|
import config
|
|
|
|
# valid signing algorithms. This mostly serves to avoid the 'none' exploit, add more algorithms as necessary
|
|
VALID_ALGS = ['HS256', 'RS256']
|
|
|
|
logger = logging.getLogger(consts.LOG_ROOT)
|
|
PERMISSIONS_KEY = config.str_val('AUTH_PERMISSIONS_KEY')
|
|
OLS_PREFIX = config.str_val('AUTH_OLS_PREFIX')
|
|
|
|
logger.debug('loading JWKS data')
|
|
|
|
# fetch IdP configuration, extract JWKS url from there, and load its contents
|
|
idp_url = config.str_val('AUTH_IDP_URL')
|
|
config_url = idp_url + '/.well-known/openid-configuration'
|
|
|
|
logger.debug('loading openid configuration from %s', config_url)
|
|
config = httpx.get(config_url).json()
|
|
|
|
jwks_url = config['jwks_uri']
|
|
logger.debug('loading JWKS data from %s', jwks_url)
|
|
jwks = httpx.get(jwks_url).json()
|
|
|
|
|
|
class Context(BaseContext):
|
|
@cached_property
|
|
def token(self) -> str | None:
|
|
if not self.request:
|
|
return None
|
|
|
|
token = self.request.headers.get("Authorization", None)
|
|
if token:
|
|
token = token.split("Bearer ")[1]
|
|
return token
|
|
|
|
|
|
Info = _Info[Context, RootValueType]
|
|
|
|
|
|
async def get_context() -> Context:
|
|
return Context()
|
|
|
|
|
|
class SecurityException(Exception):
|
|
pass
|
|
|
|
|
|
def validate_permissions(token: str | None, method_name: str) -> bool:
|
|
if token is None:
|
|
logger.info("no token")
|
|
return False
|
|
|
|
headers = jwt.get_unverified_headers(token)
|
|
logger.debug('headers: %s', headers)
|
|
|
|
claims = jwt.get_unverified_claims(token)
|
|
logger.debug('claims: %s', claims)
|
|
|
|
permissions = claims.get("basebox/permissions", [])
|
|
algorithm = headers.get('alg')
|
|
if not algorithm in VALID_ALGS:
|
|
logger.info('Invalid signing algorithm: %s', algorithm)
|
|
return False
|
|
|
|
token_valid = jws.verify(token=token, key=jwks, algorithms=algorithm)
|
|
if not token_valid:
|
|
logger.info("invalid token")
|
|
return False
|
|
|
|
namespaced_method = f'{OLS_PREFIX}{method_name}'
|
|
logger.debug("verify %s <> %s", namespaced_method, permissions)
|
|
has_permission = namespaced_method in permissions
|
|
return has_permission
|