1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
 
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
 
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
"""
from typing import Optional, Union
 
import numpy as np
import pandas as pd
from sklearn.base import OutlierMixin
 
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
 
from mysql.connector.abstracts import MySQLConnectionAbstract
 
EPS = 1e-5
 
 
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
    """
    Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
 
    Args:
        prob: Scalar or array of probability values in (0,1).
 
    Returns:
        logit-transformed probabilities.
    """
    result = np.clip(prob, EPS, 1 - EPS)
    return np.log(result / (1 - result))
 
 
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
    """
    MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
 
    Flags samples as outliers when the probability of being an anomaly
    exceeds a user-tunable threshold.
    Includes helpers to obtain decision scores and anomaly probabilities
    for ranking.
 
    Args:
        db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
        model_name (str, optional): Custom model name in the database.
        fit_extra_options (dict, optional): Extra options for fitting.
        score_extra_options (dict, optional): Extra options for scoring/prediction.
 
    Attributes:
        boundary: Decision threshold boundary in logit space. Derived from
            trained model's catalog info
 
    Methods:
        predict(X): Predict outlier/inlier labels.
        score_samples(X): Compute anomaly (normal class) logit scores.
        decision_function(X): Compute signed score above/below threshold for ranking.
    """
 
    def __init__(
        self,
        db_connection: MySQLConnectionAbstract,
        model_name: Optional[str] = None,
        fit_extra_options: Optional[dict] = None,
        score_extra_options: Optional[dict] = None,
    ):
        """
        Initialize an anomaly detector instance with threshold and extra options.
 
        Args:
            db_connection: Active MySQL DB connection.
            model_name: Optional model name in DB.
            fit_extra_options: Optional extra fit options.
            score_extra_options: Optional extra scoring options.
 
        Raises:
            ValueError: If outlier_threshold is not in (0,1).
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.
        """
        MyBaseMLModel.__init__(
            self,
            db_connection,
            ML_TASK.ANOMALY_DETECTION,
            model_name=model_name,
            fit_extra_options=fit_extra_options,
        )
        self.score_extra_options = copy_dict(score_extra_options)
        self.boundary: Optional[float] = None
 
    def predict(
        self,
        X: Union[pd.DataFrame, np.ndarray],  # pylint: disable=invalid-name
    ) -> np.ndarray:
        """
        Predict outlier/inlier binary labels for input samples.
 
        Args:
            X: Samples to predict on.
 
        Returns:
            ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
 
        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
        """
        return np.where(self.decision_function(X) < 0.0, -1, 1)
 
    def decision_function(
        self,
        X: Union[pd.DataFrame, np.ndarray],  # pylint: disable=invalid-name
    ) -> np.ndarray:
        """
        Compute signed distance to the outlier threshold.
 
        Args:
            X: Samples to predict on.
 
        Returns:
            ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
 
        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError:
                If the provided model info does not provide threshold
        """
        sample_scores = self.score_samples(X)
 
        if self.boundary is None:
            model_info = self.get_model_info()
            if model_info is None:
                raise ValueError("Model does not exist in catalog.")
 
            threshold = model_info["model_metadata"]["training_params"].get(
                "anomaly_detection_threshold", None
            )
            if threshold is None:
                raise ValueError(
                    "Trained model is outdated and does not support threshold. "
                    "Try retraining or using an existing, trained model with MyModel."
                )
 
            # scikit-learn uses large positive values as inlier
            # and negative as outlier, so we need to flip our threshold
            self.boundary = _get_logits(1.0 - threshold)
 
        return sample_scores - self.boundary
 
    def score_samples(
        self,
        X: Union[pd.DataFrame, np.ndarray],  # pylint: disable=invalid-name
    ) -> np.ndarray:
        """
        Compute normal probability logit score for each sample.
        Used for ranking, thresholding.
 
        Args:
            X: Samples to score.
 
        Returns:
            ndarray: Logit scores based on "normal" class probability.
 
        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
        """
        result = self._model.predict(X, options=self.score_extra_options)
 
        return _get_logits(
            result["ml_results"]
            .apply(lambda x: x["probabilities"]["normal"])
            .to_numpy()
        )