1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from abc import ABC, abstractmethod
from datetime import datetime, timezone
 
from redis.auth.err import InvalidTokenSchemaErr
 
 
class TokenInterface(ABC):
    @abstractmethod
    def is_expired(self) -> bool:
        pass
 
    @abstractmethod
    def ttl(self) -> float:
        pass
 
    @abstractmethod
    def try_get(self, key: str) -> str:
        pass
 
    @abstractmethod
    def get_value(self) -> str:
        pass
 
    @abstractmethod
    def get_expires_at_ms(self) -> float:
        pass
 
    @abstractmethod
    def get_received_at_ms(self) -> float:
        pass
 
 
class TokenResponse:
    def __init__(self, token: TokenInterface):
        self._token = token
 
    def get_token(self) -> TokenInterface:
        return self._token
 
    def get_ttl_ms(self) -> float:
        return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
 
 
class SimpleToken(TokenInterface):
    def __init__(
        self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
    ) -> None:
        self.value = value
        self.expires_at = expires_at_ms
        self.received_at = received_at_ms
        self.claims = claims
 
    def ttl(self) -> float:
        if self.expires_at == -1:
            return -1
 
        return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
 
    def is_expired(self) -> bool:
        if self.expires_at == -1:
            return False
 
        return self.ttl() <= 0
 
    def try_get(self, key: str) -> str:
        return self.claims.get(key)
 
    def get_value(self) -> str:
        return self.value
 
    def get_expires_at_ms(self) -> float:
        return self.expires_at
 
    def get_received_at_ms(self) -> float:
        return self.received_at
 
 
class JWToken(TokenInterface):
    REQUIRED_FIELDS = {"exp"}
 
    def __init__(self, token: str):
        try:
            import jwt
        except ImportError as ie:
            raise ImportError(
                f"The PyJWT library is required for {self.__class__.__name__}.",
            ) from ie
        self._value = token
        self._decoded = jwt.decode(
            self._value,
            options={"verify_signature": False},
            algorithms=[jwt.get_unverified_header(self._value).get("alg")],
        )
        self._validate_token()
 
    def is_expired(self) -> bool:
        exp = self._decoded["exp"]
        if exp == -1:
            return False
 
        return (
            self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
        )
 
    def ttl(self) -> float:
        exp = self._decoded["exp"]
        if exp == -1:
            return -1
 
        return (
            self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
        )
 
    def try_get(self, key: str) -> str:
        return self._decoded.get(key)
 
    def get_value(self) -> str:
        return self._value
 
    def get_expires_at_ms(self) -> float:
        return float(self._decoded["exp"] * 1000)
 
    def get_received_at_ms(self) -> float:
        return datetime.now(timezone.utc).timestamp() * 1000
 
    def _validate_token(self):
        actual_fields = {x for x in self._decoded.keys()}
 
        if len(self.REQUIRED_FIELDS - actual_fields) != 0:
            raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)