1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import time
from abc import ABC, abstractmethod
 
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.database import Databases, SyncDatabase
from redis.multidb.exception import (
    NoValidDatabaseException,
    TemporaryUnavailableException,
)
 
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
 
 
class FailoverStrategy(ABC):
    @abstractmethod
    def database(self) -> SyncDatabase:
        """Select the database according to the strategy."""
        pass
 
    @abstractmethod
    def set_databases(self, databases: Databases) -> None:
        """Set the database strategy operates on."""
        pass
 
 
class FailoverStrategyExecutor(ABC):
    @property
    @abstractmethod
    def failover_attempts(self) -> int:
        """The number of failover attempts."""
        pass
 
    @property
    @abstractmethod
    def failover_delay(self) -> float:
        """The delay between failover attempts."""
        pass
 
    @property
    @abstractmethod
    def strategy(self) -> FailoverStrategy:
        """The strategy to execute."""
        pass
 
    @abstractmethod
    def execute(self) -> SyncDatabase:
        """Execute the failover strategy."""
        pass
 
 
class WeightBasedFailoverStrategy(FailoverStrategy):
    """
    Failover strategy based on database weights.
    """
 
    def __init__(self) -> None:
        self._databases = WeightedList()
 
    def database(self) -> SyncDatabase:
        for database, _ in self._databases:
            if database.circuit.state == CBState.CLOSED:
                return database
 
        raise NoValidDatabaseException("No valid database available for communication")
 
    def set_databases(self, databases: Databases) -> None:
        self._databases = databases
 
 
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
    """
    Executes given failover strategy.
    """
 
    def __init__(
        self,
        strategy: FailoverStrategy,
        failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
        failover_delay: float = DEFAULT_FAILOVER_DELAY,
    ):
        self._strategy = strategy
        self._failover_attempts = failover_attempts
        self._failover_delay = failover_delay
        self._next_attempt_ts: int = 0
        self._failover_counter: int = 0
 
    @property
    def failover_attempts(self) -> int:
        return self._failover_attempts
 
    @property
    def failover_delay(self) -> float:
        return self._failover_delay
 
    @property
    def strategy(self) -> FailoverStrategy:
        return self._strategy
 
    def execute(self) -> SyncDatabase:
        try:
            database = self._strategy.database()
            self._reset()
            return database
        except NoValidDatabaseException as e:
            if self._next_attempt_ts == 0:
                self._next_attempt_ts = time.time() + self._failover_delay
                self._failover_counter += 1
            elif time.time() >= self._next_attempt_ts:
                self._next_attempt_ts += self._failover_delay
                self._failover_counter += 1
 
            if self._failover_counter > self._failover_attempts:
                self._reset()
                raise e
            else:
                raise TemporaryUnavailableException(
                    "No database connections currently available. "
                    "This is a temporary condition - please retry the operation."
                )
 
    def _reset(self) -> None:
        self._next_attempt_ts = 0
        self._failover_counter = 0