hyb
2026-01-30 44480e71b27aa9d4cb8441f50c873f1b110e9691
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, Union
 
from redis.asyncio import Redis
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
from redis.backoff import NoBackoff
from redis.http.http_client import HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
 
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
 
logger = logging.getLogger(__name__)
 
 
class HealthCheck(ABC):
    @abstractmethod
    async def check_health(self, database) -> bool:
        """Function to determine the health status."""
        pass
 
 
class HealthCheckPolicy(ABC):
    """
    Health checks execution policy.
    """
 
    @property
    @abstractmethod
    def health_check_probes(self) -> int:
        """Number of probes to execute health checks."""
        pass
 
    @property
    @abstractmethod
    def health_check_delay(self) -> float:
        """Delay between health check probes."""
        pass
 
    @abstractmethod
    async def execute(self, health_checks: List[HealthCheck], database) -> bool:
        """Execute health checks and return database health status."""
        pass
 
 
class AbstractHealthCheckPolicy(HealthCheckPolicy):
    def __init__(self, health_check_probes: int, health_check_delay: float):
        if health_check_probes < 1:
            raise ValueError("health_check_probes must be greater than 0")
        self._health_check_probes = health_check_probes
        self._health_check_delay = health_check_delay
 
    @property
    def health_check_probes(self) -> int:
        return self._health_check_probes
 
    @property
    def health_check_delay(self) -> float:
        return self._health_check_delay
 
    @abstractmethod
    async def execute(self, health_checks: List[HealthCheck], database) -> bool:
        pass
 
 
class HealthyAllPolicy(AbstractHealthCheckPolicy):
    """
    Policy that returns True if all health check probes are successful.
    """
 
    def __init__(self, health_check_probes: int, health_check_delay: float):
        super().__init__(health_check_probes, health_check_delay)
 
    async def execute(self, health_checks: List[HealthCheck], database) -> bool:
        for health_check in health_checks:
            for attempt in range(self.health_check_probes):
                try:
                    if not await health_check.check_health(database):
                        return False
                except Exception as e:
                    raise UnhealthyDatabaseException("Unhealthy database", database, e)
 
                if attempt < self.health_check_probes - 1:
                    await asyncio.sleep(self._health_check_delay)
        return True
 
 
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
    """
    Policy that returns True if a majority of health check probes are successful.
    """
 
    def __init__(self, health_check_probes: int, health_check_delay: float):
        super().__init__(health_check_probes, health_check_delay)
 
    async def execute(self, health_checks: List[HealthCheck], database) -> bool:
        for health_check in health_checks:
            if self.health_check_probes % 2 == 0:
                allowed_unsuccessful_probes = self.health_check_probes / 2
            else:
                allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
 
            for attempt in range(self.health_check_probes):
                try:
                    if not await health_check.check_health(database):
                        allowed_unsuccessful_probes -= 1
                        if allowed_unsuccessful_probes <= 0:
                            return False
                except Exception as e:
                    allowed_unsuccessful_probes -= 1
                    if allowed_unsuccessful_probes <= 0:
                        raise UnhealthyDatabaseException(
                            "Unhealthy database", database, e
                        )
 
                if attempt < self.health_check_probes - 1:
                    await asyncio.sleep(self._health_check_delay)
        return True
 
 
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
    """
    Policy that returns True if at least one health check probe is successful.
    """
 
    def __init__(self, health_check_probes: int, health_check_delay: float):
        super().__init__(health_check_probes, health_check_delay)
 
    async def execute(self, health_checks: List[HealthCheck], database) -> bool:
        is_healthy = False
 
        for health_check in health_checks:
            exception = None
 
            for attempt in range(self.health_check_probes):
                try:
                    if await health_check.check_health(database):
                        is_healthy = True
                        break
                    else:
                        is_healthy = False
                except Exception as e:
                    exception = UnhealthyDatabaseException(
                        "Unhealthy database", database, e
                    )
 
                if attempt < self.health_check_probes - 1:
                    await asyncio.sleep(self._health_check_delay)
 
            if not is_healthy and not exception:
                return is_healthy
            elif not is_healthy and exception:
                raise exception
 
        return is_healthy
 
 
class HealthCheckPolicies(Enum):
    HEALTHY_ALL = HealthyAllPolicy
    HEALTHY_MAJORITY = HealthyMajorityPolicy
    HEALTHY_ANY = HealthyAnyPolicy
 
 
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
 
 
class PingHealthCheck(HealthCheck):
    """
    Health check based on PING command.
    """
 
    async def check_health(self, database) -> bool:
        if isinstance(database.client, Redis):
            return await database.client.execute_command("PING")
        else:
            # For a cluster checks if all nodes are healthy.
            all_nodes = database.client.get_nodes()
            for node in all_nodes:
                if not await node.redis_connection.execute_command("PING"):
                    return False
 
            return True
 
 
class LagAwareHealthCheck(HealthCheck):
    """
    Health check available for Redis Enterprise deployments.
    Verify via REST API that the database is healthy based on different lags.
    """
 
    def __init__(
        self,
        rest_api_port: int = 9443,
        lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
        timeout: float = DEFAULT_TIMEOUT,
        auth_basic: Optional[Tuple[str, str]] = None,
        verify_tls: bool = True,
        # TLS verification (server) options
        ca_file: Optional[str] = None,
        ca_path: Optional[str] = None,
        ca_data: Optional[Union[str, bytes]] = None,
        # Mutual TLS (client cert) options
        client_cert_file: Optional[str] = None,
        client_key_file: Optional[str] = None,
        client_key_password: Optional[str] = None,
    ):
        """
        Initialize LagAwareHealthCheck with the specified parameters.
 
        Args:
            rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
            lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
            timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
            auth_basic: Tuple of (username, password) for basic authentication
            verify_tls: Whether to verify TLS certificates (default: True)
            ca_file: Path to CA certificate file for TLS verification
            ca_path: Path to CA certificates directory for TLS verification
            ca_data: CA certificate data as string or bytes
            client_cert_file: Path to client certificate file for mutual TLS
            client_key_file: Path to client private key file for mutual TLS
            client_key_password: Password for encrypted client private key
        """
        self._http_client = AsyncHTTPClientWrapper(
            HttpClient(
                timeout=timeout,
                auth_basic=auth_basic,
                retry=Retry(NoBackoff(), retries=0),
                verify_tls=verify_tls,
                ca_file=ca_file,
                ca_path=ca_path,
                ca_data=ca_data,
                client_cert_file=client_cert_file,
                client_key_file=client_key_file,
                client_key_password=client_key_password,
            )
        )
        self._rest_api_port = rest_api_port
        self._lag_aware_tolerance = lag_aware_tolerance
 
    async def check_health(self, database) -> bool:
        if database.health_check_url is None:
            raise ValueError(
                "Database health check url is not set. Please check DatabaseConfig for the current database."
            )
 
        if isinstance(database.client, Redis):
            db_host = database.client.get_connection_kwargs()["host"]
        else:
            db_host = database.client.startup_nodes[0].host
 
        base_url = f"{database.health_check_url}:{self._rest_api_port}"
        self._http_client.client.base_url = base_url
 
        # Find bdb matching to the current database host
        matching_bdb = None
        for bdb in await self._http_client.get("/v1/bdbs"):
            for endpoint in bdb["endpoints"]:
                if endpoint["dns_name"] == db_host:
                    matching_bdb = bdb
                    break
 
                # In case if the host was set as public IP
                for addr in endpoint["addr"]:
                    if addr == db_host:
                        matching_bdb = bdb
                        break
 
        if matching_bdb is None:
            logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
            raise ValueError("Could not find a matching bdb")
 
        url = (
            f"/v1/bdbs/{matching_bdb['uid']}/availability"
            f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
        )
        await self._http_client.get(url, expect_json=False)
 
        # Status checked in an http client, otherwise HttpError will be raised
        return True