126 lines
3.1 KiB
Python
126 lines
3.1 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from datetime import timezone
|
|
from hmac import compare_digest
|
|
from json import JSONEncoder
|
|
from typing import Any
|
|
from typing import Iterable
|
|
from typing import List
|
|
from typing import Type
|
|
from typing import Union
|
|
|
|
import jwt
|
|
|
|
from flask_jwt_extended.exceptions import CSRFError
|
|
from flask_jwt_extended.exceptions import JWTDecodeError
|
|
from flask_jwt_extended.typing import ExpiresDelta
|
|
from flask_jwt_extended.typing import Fresh
|
|
|
|
|
|
def _encode_jwt(
|
|
algorithm: str,
|
|
audience: Union[str, Iterable[str]],
|
|
claim_overrides: dict,
|
|
csrf: bool,
|
|
expires_delta: ExpiresDelta,
|
|
fresh: Fresh,
|
|
header_overrides: dict,
|
|
identity: Any,
|
|
identity_claim_key: str,
|
|
issuer: str,
|
|
json_encoder: Type[JSONEncoder],
|
|
secret: str,
|
|
token_type: str,
|
|
nbf: bool,
|
|
) -> str:
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if isinstance(fresh, timedelta):
|
|
fresh = datetime.timestamp(now + fresh)
|
|
|
|
token_data = {
|
|
"fresh": fresh,
|
|
"iat": now,
|
|
"jti": str(uuid.uuid4()),
|
|
"type": token_type,
|
|
identity_claim_key: identity,
|
|
}
|
|
|
|
if nbf:
|
|
token_data["nbf"] = now
|
|
|
|
if csrf:
|
|
token_data["csrf"] = str(uuid.uuid4())
|
|
|
|
if audience:
|
|
token_data["aud"] = audience
|
|
|
|
if issuer:
|
|
token_data["iss"] = issuer
|
|
|
|
if expires_delta:
|
|
token_data["exp"] = now + expires_delta
|
|
|
|
if claim_overrides:
|
|
token_data.update(claim_overrides)
|
|
|
|
return jwt.encode(
|
|
token_data,
|
|
secret,
|
|
algorithm,
|
|
json_encoder=json_encoder, # type: ignore
|
|
headers=header_overrides,
|
|
)
|
|
|
|
|
|
def _decode_jwt(
|
|
algorithms: List,
|
|
allow_expired: bool,
|
|
audience: Union[str, Iterable[str]],
|
|
csrf_value: str,
|
|
encoded_token: str,
|
|
identity_claim_key: str,
|
|
issuer: str,
|
|
leeway: int,
|
|
secret: str,
|
|
verify_aud: bool,
|
|
verify_sub: bool,
|
|
) -> dict:
|
|
options = {"verify_aud": verify_aud, "verify_sub": verify_sub}
|
|
if allow_expired:
|
|
options["verify_exp"] = False
|
|
|
|
# This call verifies the ext, iat, and nbf claims
|
|
# This optionally verifies the exp and aud claims if enabled
|
|
decoded_token = jwt.decode(
|
|
encoded_token,
|
|
secret,
|
|
algorithms=algorithms,
|
|
audience=audience,
|
|
issuer=issuer,
|
|
leeway=leeway,
|
|
options=options,
|
|
)
|
|
|
|
# Make sure that any custom claims we expect in the token are present
|
|
if identity_claim_key not in decoded_token:
|
|
raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
|
|
|
|
if "type" not in decoded_token:
|
|
decoded_token["type"] = "access"
|
|
|
|
if "fresh" not in decoded_token:
|
|
decoded_token["fresh"] = False
|
|
|
|
if "jti" not in decoded_token:
|
|
decoded_token["jti"] = None
|
|
|
|
if csrf_value:
|
|
if "csrf" not in decoded_token:
|
|
raise JWTDecodeError("Missing claim: csrf")
|
|
if not compare_digest(decoded_token["csrf"], csrf_value):
|
|
raise CSRFError("CSRF double submit tokens do not match")
|
|
|
|
return decoded_token
|