elyte5star

MSAL/Google access token verification

Jul 22nd, 2025 (edited)
726
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.01 KB | Software | 0 0
  1. from typing import cast, Any
  2. from starlette.requests import Request
  3. from modules.settings.configuration import ApiConfig
  4. from fastapi.security.utils import get_authorization_scheme_param
  5. from fastapi.openapi.models import (
  6.     OAuthFlows as OAuthFlowsModel,
  7. )
  8. from fastapi.security.base import SecurityBase
  9. from fastapi.openapi.models import OAuthFlowAuthorizationCode
  10. from starlette.status import (
  11.     HTTP_401_UNAUTHORIZED,
  12.     HTTP_403_FORBIDDEN,
  13.     HTTP_404_NOT_FOUND,
  14.     HTTP_400_BAD_REQUEST,
  15. )
  16. from fastapi.exceptions import HTTPException
  17. from jose import JWTError, jwt
  18. from httpx import AsyncClient, HTTPError, Response
  19. from fastapi.openapi.models import OAuth2 as OAuth2Model
  20. from modules.utils.misc import date_time_now_utc, time_delta
  21. from fastapi.security import SecurityScopes
  22. from datetime import datetime
  23.  
  24.  
  25.  
  26.  
  27.  
  28. SCHEME_NAME = "OAuthorization2CodePKCEBearer"
  29. DESC = "Authorization code with PKCE "
  30.  
  31.  
  32. class OAuth2CodeBearer(SecurityBase):
  33.  
  34.     def __init__(
  35.         self,
  36.         authorization_url: str,
  37.         token_url: str,
  38.         auth_method: str,
  39.         scopes: dict[str, str],
  40.         flows: OAuthFlowsModel | dict[str, dict[str, Any]] | None = None,
  41.         scheme_name: str | None = SCHEME_NAME,
  42.         description: str | None = DESC,
  43.         refresh_url: str | None = None,
  44.     ):
  45.         self.auth_method = auth_method
  46.  
  47.         # ADD MORE OAUTHFLOWS AS NEEDED
  48.  
  49.         if not flows:
  50.             flows = OAuthFlowsModel(
  51.                 authorizationCode=OAuthFlowAuthorizationCode(
  52.                     authorizationUrl=authorization_url,
  53.                     tokenUrl=token_url,
  54.                     refreshUrl=refresh_url,
  55.                     scopes=scopes,
  56.                 ),
  57.             )
  58.         self.model = OAuth2Model(
  59.             flows=cast(OAuthFlowsModel, flows), description=description
  60.         )
  61.         self.scheme_name = (
  62.             f"{auth_method.capitalize()}{scheme_name}" or self.__class__.__name__
  63.         )
  64.  
  65.         self.auth_method = auth_method
  66.         # A cache for Microsoft public keys {'LOCAL': [], 'MSAL': []}
  67.         self.public_keys_cache: dict[str, list] = {
  68.             method: [] for method in cfg.auth_methods
  69.         }
  70.         self.next_ext_api_call_time: datetime | None = None
  71.  
  72.     async def __call__(
  73.         self, security_scopes: SecurityScopes, request: Request
  74.     ) -> dict[str, Any] | None:
  75.         authorization = request.headers.get("Authorization", None)
  76.         scheme, token = get_authorization_scheme_param(authorization)
  77.         if not (authorization and scheme and token):
  78.             raise HTTPException(
  79.                 status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  80.             )
  81.         if scheme.lower() != "bearer":
  82.             raise HTTPException(
  83.                 status_code=HTTP_401_UNAUTHORIZED,
  84.                 detail="Invalid authentication credentials",
  85.             )
  86.  
  87.         if self.auth_method == "MSAL":
  88.             verified_claims = await self.verify_msal_jwt(
  89.                 token, security_scopes.scopes, self.auth_method
  90.             )
  91.         else:
  92.             verified_claims = await self.verify_google_jwt(
  93.                 token,
  94.                 security_scopes.scopes,
  95.             )
  96.         return verified_claims
  97.  
  98.     async def verify_google_jwt(
  99.         self,
  100.         access_token: str,
  101.         required_scopes: list[str],
  102.     ) -> dict:
  103.         if not access_token:
  104.             raise HTTPException(
  105.                 status_code=HTTP_401_UNAUTHORIZED,
  106.                 detail="Authorization token missing or invalid",
  107.             )
  108.         try:
  109.             TOKEN_INFO_URL = cfg.google_token_info_url
  110.             PARAMS = {"access_token": access_token}
  111.             async with AsyncClient(timeout=10) as client:
  112.                 cfg.logger.debug(f"Fetching token info from {TOKEN_INFO_URL}")
  113.                 response: Response = await client.get(
  114.                     TOKEN_INFO_URL,
  115.                     params=PARAMS,
  116.                 )
  117.                 response.raise_for_status()
  118.                 token_info: dict[str, Any] = response.json()
  119.                
  120.             token_info["scp"] = token_info.pop("scope")
  121.  
  122.             # check scope
  123.             self.validate_scope(token_info, required_scopes)
  124.  
  125.             # check audience
  126.             if token_info["aud"] not in cfg.google_client_id:
  127.                 raise ValueError("Could not verify audience.")
  128.  
  129.             return token_info
  130.         except HTTPError as e:
  131.             raise HTTPException(
  132.                 status_code=HTTP_400_BAD_REQUEST,
  133.                 detail="Invalid or expired token",
  134.             )
  135.         except ValueError as e:
  136.             cfg.logger.error(f"Could not verify audience: {e}")
  137.             raise HTTPException(
  138.                 status_code=HTTP_401_UNAUTHORIZED,
  139.                 detail="Could not verify audience",
  140.             )
  141.             return None
  142.         except Exception as e:
  143.             cfg.logger.error(f"Internal server error: {str(e)}")
  144.             raise HTTPException(
  145.                 status_code=HTTP_401_UNAUTHORIZED,
  146.                 detail="Token error: Unable to parse authentication",
  147.             )
  148.  
  149.     # Validate Azure Entra ID token using Azure AD Public Keys
  150.     async def verify_msal_jwt(
  151.         self, access_token: str, required_scopes: list[str], auth_method: str
  152.     ) -> dict:
  153.         """
  154.        This verifies:
  155.  
  156.        # Scopes
  157.  
  158.        # Signature using Azure AD’s public key
  159.  
  160.        # Expiration (exp)
  161.  
  162.        # Issuer (iss)
  163.  
  164.        # Audience (aud)
  165.  
  166.        """
  167.         if not access_token:
  168.             raise HTTPException(
  169.                 status_code=HTTP_401_UNAUTHORIZED,
  170.                 detail="Authorization token missing or invalid",
  171.             )
  172.         try:
  173.             unverified_claims: dict[str, Any] = jwt.get_unverified_claims(
  174.                 access_token,
  175.             )
  176.  
  177.             self.validate_scope(unverified_claims, required_scopes)
  178.  
  179.             # Get Microsoft's public keys
  180.             public_keys = await self.get_public_keys(
  181.                 cfg.msal_jwks_url,
  182.                 auth_method,
  183.             )
  184.             # Decode JWT Header to get the key ID (kid)
  185.             token_headers: dict[str, Any] = jwt.get_unverified_header(
  186.                 access_token,
  187.             )
  188.  
  189.             token_kid = token_headers.get("kid")
  190.  
  191.             rsa_key = next(
  192.                 (key for key in public_keys if key.get("kid") == token_kid), None
  193.             )
  194.             if rsa_key is None:
  195.                 raise HTTPException(
  196.                     status_code=HTTP_401_UNAUTHORIZED,
  197.                     detail="Invalid header error: Unable to find appropriate key",
  198.                 )
  199.             cfg.logger.debug(f"Loading public key: {rsa_key}")
  200.             claims = jwt.decode(
  201.                 access_token,
  202.                 key=rsa_key,
  203.                 algorithms=["RS256"],
  204.                 audience=cfg.msal_client_id,
  205.                 issuer=cfg.msal_issuer,
  206.             )
  207.  
  208.             return claims
  209.         except HTTPError as e:
  210.             cfg.logger.error(f"HTTP Exception for {e.request.url} - {e}")
  211.             raise HTTPException(
  212.                 status_code=HTTP_404_NOT_FOUND,
  213.                 detail=f"HTTP Exception for {e.request.url} - {e}",
  214.             )
  215.         except JWTError:
  216.             cfg.logger.error("Invalid token or expired token.")
  217.             raise HTTPException(
  218.                 status_code=HTTP_401_UNAUTHORIZED,
  219.                 detail="Invalid token or expired token.",
  220.             )
  221.         except Exception as e:
  222.             cfg.logger.error(f"Internal server error: {str(e)}")
  223.             raise HTTPException(
  224.                 status_code=HTTP_401_UNAUTHORIZED,
  225.                 detail="Token error: Unable to parse authentication",
  226.             )
  227.  
  228.     # check if guest user is allowed?
  229.     def validate_scope(self, unverified_claims: dict, required_scopes: list[str]):
  230.         if not required_scopes:
  231.             raise HTTPException(
  232.                 status_code=HTTP_403_FORBIDDEN,
  233.                 detail="No required scope specified",
  234.             )
  235.         # To small letters
  236.         required_scopes = [s.lower() for s in required_scopes]
  237.  
  238.         has_valid_scope = False
  239.  
  240.         if (
  241.             unverified_claims.get("scp") is None
  242.             and unverified_claims.get("roles") is None
  243.         ):
  244.             raise HTTPException(
  245.                 status_code=HTTP_403_FORBIDDEN,
  246.                 detail="No scope or app permission (role) claim was found in the bearer token",
  247.             )
  248.  
  249.         is_app_permission = (
  250.             True if unverified_claims.get("roles") is not None else False
  251.         )
  252.  
  253.         if is_app_permission:
  254.             roles = unverified_claims.get("roles", [])
  255.             if not roles:
  256.                 raise HTTPException(
  257.                     status_code=HTTP_403_FORBIDDEN,
  258.                     detail="No scope or app permission (role) claim was found in the bearer token",
  259.                 )
  260.             else:
  261.                 roles = [s.lower() for s in roles]
  262.                 matches = set(required_scopes).intersection(set(roles))
  263.                 if len(matches) > 0:
  264.                     has_valid_scope = True
  265.         else:
  266.             if unverified_claims.get("scp"):
  267.                 # the scp claim is a space delimited string
  268.                 token_scopes = unverified_claims["scp"].lower().split()
  269.                 matches = set(required_scopes).intersection(set(token_scopes))
  270.                 if len(matches) > 0:
  271.                     has_valid_scope = True
  272.             else:
  273.                 raise HTTPException(
  274.                     status_code=HTTP_403_FORBIDDEN,
  275.                     detail="No scope or app permission (role) claim was found in the bearer token",
  276.                 )
  277.         if is_app_permission and not has_valid_scope:
  278.             raise HTTPException(
  279.                 status_code=HTTP_403_FORBIDDEN, detail="Not enough permissions"
  280.             )
  281.         elif not has_valid_scope:
  282.             raise HTTPException(
  283.                 status_code=HTTP_403_FORBIDDEN, detail="Not enough permissions"
  284.             )
  285.  
  286.     async def get_public_keys(
  287.         self, jwks_uri: str, auth_method: str, params: dict | None = None
  288.     ) -> list:
  289.         make_api_call = (
  290.             self.next_ext_api_call_time is None
  291.             or date_time_now_utc() > self.next_ext_api_call_time
  292.         )
  293.         if not self.public_keys_cache[auth_method] or make_api_call:
  294.             async with AsyncClient(timeout=10) as client:
  295.                 cfg.logger.debug(f"Fetching public keys from {jwks_uri}")
  296.                 response: Response = await client.get(jwks_uri, params=params)
  297.                 response.raise_for_status()  # Raises an error for non-200 responses
  298.                 self.public_keys_cache[auth_method] = response.json().get("keys", [])
  299.                 self.next_ext_api_call_time = date_time_now_utc() + time_delta(
  300.                     minutes=60
  301.                 )  # Fetch keys every 1hr
  302.         return self.public_keys_cache[auth_method]
  303.  
Advertisement
Add Comment
Please, Sign In to add comment