hyb
2026-01-07 c7f60dc7e9a36596f0e0d1787bd0cca4e9b57bcb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
 
# mypy: disable-error-code="attr-defined"
# pylint: disable=protected-access
 
"""Utilities."""
 
__all__ = ["to_thread", "open_connection"]
 
import asyncio
import contextvars
import functools
 
try:
    import ssl
except ImportError:
    ssl = None
 
from typing import TYPE_CHECKING, Any, Callable, Tuple
 
if TYPE_CHECKING:
    from mysql.connector.aio.abstracts import MySQLConnectionAbstract
 
    __all__.append("StreamWriter")
 
 
class StreamReaderProtocol(asyncio.StreamReaderProtocol):
    """Extends asyncio.streams.StreamReaderProtocol for adding start_tls().
 
    The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
    in Python 3.11. It provides the same functionality for older Python versions.
    """
 
    def _replace_writer(self, writer: asyncio.StreamWriter) -> None:
        """Replace stream writer.
 
        Args:
            writer: Stream Writer.
        """
        transport = writer.transport
        self._stream_writer = writer
        self._transport = transport
        self._over_ssl = transport.get_extra_info("sslcontext") is not None
 
 
class StreamWriter(asyncio.streams.StreamWriter):
    """Extends asyncio.streams.StreamWriter for adding start_tls().
 
    The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
    in Python 3.11. It provides the same functionality for older Python versions.
    """
 
    async def start_tls(
        self,
        ssl_context: ssl.SSLContext,
        *,
        server_hostname: str = None,
        ssl_handshake_timeout: int = None,
    ) -> None:
        """Upgrade an existing stream-based connection to TLS.
 
        Args:
            ssl_context: Configured SSL context.
            server_hostname: Server host name.
            ssl_handshake_timeout: SSL handshake timeout.
        """
        server_side = self._protocol._client_connected_cb is not None
        protocol = self._protocol
        await self.drain()
        new_transport = await self._loop.start_tls(
            # pylint: disable=access-member-before-definition
            self._transport,  # type: ignore[has-type]
            protocol,
            ssl_context,
            server_side=server_side,
            server_hostname=server_hostname,
            ssl_handshake_timeout=ssl_handshake_timeout,
        )
        self._transport = (  # pylint: disable=attribute-defined-outside-init
            new_transport
        )
        protocol._replace_writer(self)
 
 
async def open_connection(
    host: str = None, port: int = None, *, limit: int = 2**16, **kwds: Any
) -> Tuple[asyncio.StreamReader, StreamWriter]:
    """A wrapper for create_connection() returning a (reader, writer) pair.
 
    This function is based on ``asyncio.streams.open_connection`` and adds a custom
    stream reader.
 
    MySQL expects TLS negotiation to happen in the middle of a TCP connection, not at
    the start.
    This function in conjunction with ``_StreamReaderProtocol`` and ``_StreamWriter``
    allows the TLS negotiation on an existing connection.
 
    Args:
        host: Server host name.
        port: Server port.
        limit: The buffer size limit used by the returned ``StreamReader`` instance.
               By default the limit is set to 64 KiB.
 
    Returns:
        tuple: Returns a pair of reader and writer objects that are instances of
               ``StreamReader`` and ``StreamWriter`` classes.
    """
    loop = asyncio.get_running_loop()
    reader = asyncio.streams.StreamReader(limit=limit, loop=loop)
    protocol = StreamReaderProtocol(reader, loop=loop)
    transport, _ = await loop.create_connection(lambda: protocol, host, port, **kwds)
    writer = StreamWriter(transport, protocol, reader, loop)
    return reader, writer
 
 
async def to_thread(func: Callable, *args: Any, **kwargs: Any) -> asyncio.Future:
    """Asynchronously run function ``func`` in a separate thread.
 
    This function is based on ``asyncio.to_thread()`` introduced in Python 3.9, which
    provides the same functionality for older Python versions.
 
    Returns:
        coroutine: A coroutine that can be awaited to get the eventual result of
                   ``func``.
    """
    loop = asyncio.get_running_loop()
    ctx = contextvars.copy_context()
    func_call = functools.partial(ctx.run, func, *args, **kwargs)
    return await loop.run_in_executor(None, func_call)