更新华东师范大学二期的并发入驻笼位,新增直接通过数据库获取使用人ID和对应课题组信息
402 files added
1 files modified
| New file |
| | |
| | | A. HISTORY OF THE SOFTWARE |
| | | ========================== |
| | | |
| | | Python was created in the early 1990s by Guido van Rossum at Stichting |
| | | Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands |
| | | as a successor of a language called ABC. Guido remains Python's |
| | | principal author, although it includes many contributions from others. |
| | | |
| | | In 1995, Guido continued his work on Python at the Corporation for |
| | | National Research Initiatives (CNRI, see https://www.cnri.reston.va.us) |
| | | in Reston, Virginia where he released several versions of the |
| | | software. |
| | | |
| | | In May 2000, Guido and the Python core development team moved to |
| | | BeOpen.com to form the BeOpen PythonLabs team. In October of the same |
| | | year, the PythonLabs team moved to Digital Creations, which became |
| | | Zope Corporation. In 2001, the Python Software Foundation (PSF, see |
| | | https://www.python.org/psf/) was formed, a non-profit organization |
| | | created specifically to own Python-related Intellectual Property. |
| | | Zope Corporation was a sponsoring member of the PSF. |
| | | |
| | | All Python releases are Open Source (see https://opensource.org for |
| | | the Open Source Definition). Historically, most, but not all, Python |
| | | releases have also been GPL-compatible; the table below summarizes |
| | | the various releases. |
| | | |
| | | Release Derived Year Owner GPL- |
| | | from compatible? (1) |
| | | |
| | | 0.9.0 thru 1.2 1991-1995 CWI yes |
| | | 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes |
| | | 1.6 1.5.2 2000 CNRI no |
| | | 2.0 1.6 2000 BeOpen.com no |
| | | 1.6.1 1.6 2001 CNRI yes (2) |
| | | 2.1 2.0+1.6.1 2001 PSF no |
| | | 2.0.1 2.0+1.6.1 2001 PSF yes |
| | | 2.1.1 2.1+2.0.1 2001 PSF yes |
| | | 2.1.2 2.1.1 2002 PSF yes |
| | | 2.1.3 2.1.2 2002 PSF yes |
| | | 2.2 and above 2.1.1 2001-now PSF yes |
| | | |
| | | Footnotes: |
| | | |
| | | (1) GPL-compatible doesn't mean that we're distributing Python under |
| | | the GPL. All Python licenses, unlike the GPL, let you distribute |
| | | a modified version without making your changes open source. The |
| | | GPL-compatible licenses make it possible to combine Python with |
| | | other software that is released under the GPL; the others don't. |
| | | |
| | | (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, |
| | | because its license has a choice of law clause. According to |
| | | CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 |
| | | is "not incompatible" with the GPL. |
| | | |
| | | Thanks to the many outside volunteers who have worked under Guido's |
| | | direction to make these releases possible. |
| | | |
| | | |
| | | B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON |
| | | =============================================================== |
| | | |
| | | Python software and documentation are licensed under the |
| | | Python Software Foundation License Version 2. |
| | | |
| | | Starting with Python 3.8.6, examples, recipes, and other code in |
| | | the documentation are dual licensed under the PSF License Version 2 |
| | | and the Zero-Clause BSD license. |
| | | |
| | | Some software incorporated into Python is under different licenses. |
| | | The licenses are listed with code falling under that license. |
| | | |
| | | |
| | | PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 |
| | | -------------------------------------------- |
| | | |
| | | 1. This LICENSE AGREEMENT is between the Python Software Foundation |
| | | ("PSF"), and the Individual or Organization ("Licensee") accessing and |
| | | otherwise using this software ("Python") in source or binary form and |
| | | its associated documentation. |
| | | |
| | | 2. Subject to the terms and conditions of this License Agreement, PSF hereby |
| | | grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, |
| | | analyze, test, perform and/or display publicly, prepare derivative works, |
| | | distribute, and otherwise use Python alone or in any derivative version, |
| | | provided, however, that PSF's License Agreement and PSF's notice of copyright, |
| | | i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, |
| | | 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation; |
| | | All Rights Reserved" are retained in Python alone or in any derivative version |
| | | prepared by Licensee. |
| | | |
| | | 3. In the event Licensee prepares a derivative work that is based on |
| | | or incorporates Python or any part thereof, and wants to make |
| | | the derivative work available to others as provided herein, then |
| | | Licensee hereby agrees to include in any such work a brief summary of |
| | | the changes made to Python. |
| | | |
| | | 4. PSF is making Python available to Licensee on an "AS IS" |
| | | basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR |
| | | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND |
| | | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS |
| | | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT |
| | | INFRINGE ANY THIRD PARTY RIGHTS. |
| | | |
| | | 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON |
| | | FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS |
| | | A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, |
| | | OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. |
| | | |
| | | 6. This License Agreement will automatically terminate upon a material |
| | | breach of its terms and conditions. |
| | | |
| | | 7. Nothing in this License Agreement shall be deemed to create any |
| | | relationship of agency, partnership, or joint venture between PSF and |
| | | Licensee. This License Agreement does not grant permission to use PSF |
| | | trademarks or trade name in a trademark sense to endorse or promote |
| | | products or services of Licensee, or any third party. |
| | | |
| | | 8. By copying, installing or otherwise using Python, Licensee |
| | | agrees to be bound by the terms and conditions of this License |
| | | Agreement. |
| | | |
| | | |
| | | BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 |
| | | ------------------------------------------- |
| | | |
| | | BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 |
| | | |
| | | 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an |
| | | office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the |
| | | Individual or Organization ("Licensee") accessing and otherwise using |
| | | this software in source or binary form and its associated |
| | | documentation ("the Software"). |
| | | |
| | | 2. Subject to the terms and conditions of this BeOpen Python License |
| | | Agreement, BeOpen hereby grants Licensee a non-exclusive, |
| | | royalty-free, world-wide license to reproduce, analyze, test, perform |
| | | and/or display publicly, prepare derivative works, distribute, and |
| | | otherwise use the Software alone or in any derivative version, |
| | | provided, however, that the BeOpen Python License is retained in the |
| | | Software, alone or in any derivative version prepared by Licensee. |
| | | |
| | | 3. BeOpen is making the Software available to Licensee on an "AS IS" |
| | | basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR |
| | | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND |
| | | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS |
| | | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT |
| | | INFRINGE ANY THIRD PARTY RIGHTS. |
| | | |
| | | 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE |
| | | SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS |
| | | AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY |
| | | DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. |
| | | |
| | | 5. This License Agreement will automatically terminate upon a material |
| | | breach of its terms and conditions. |
| | | |
| | | 6. This License Agreement shall be governed by and interpreted in all |
| | | respects by the law of the State of California, excluding conflict of |
| | | law provisions. Nothing in this License Agreement shall be deemed to |
| | | create any relationship of agency, partnership, or joint venture |
| | | between BeOpen and Licensee. This License Agreement does not grant |
| | | permission to use BeOpen trademarks or trade names in a trademark |
| | | sense to endorse or promote products or services of Licensee, or any |
| | | third party. As an exception, the "BeOpen Python" logos available at |
| | | http://www.pythonlabs.com/logos.html may be used according to the |
| | | permissions granted on that web page. |
| | | |
| | | 7. By copying, installing or otherwise using the software, Licensee |
| | | agrees to be bound by the terms and conditions of this License |
| | | Agreement. |
| | | |
| | | |
| | | CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 |
| | | --------------------------------------- |
| | | |
| | | 1. This LICENSE AGREEMENT is between the Corporation for National |
| | | Research Initiatives, having an office at 1895 Preston White Drive, |
| | | Reston, VA 20191 ("CNRI"), and the Individual or Organization |
| | | ("Licensee") accessing and otherwise using Python 1.6.1 software in |
| | | source or binary form and its associated documentation. |
| | | |
| | | 2. Subject to the terms and conditions of this License Agreement, CNRI |
| | | hereby grants Licensee a nonexclusive, royalty-free, world-wide |
| | | license to reproduce, analyze, test, perform and/or display publicly, |
| | | prepare derivative works, distribute, and otherwise use Python 1.6.1 |
| | | alone or in any derivative version, provided, however, that CNRI's |
| | | License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) |
| | | 1995-2001 Corporation for National Research Initiatives; All Rights |
| | | Reserved" are retained in Python 1.6.1 alone or in any derivative |
| | | version prepared by Licensee. Alternately, in lieu of CNRI's License |
| | | Agreement, Licensee may substitute the following text (omitting the |
| | | quotes): "Python 1.6.1 is made available subject to the terms and |
| | | conditions in CNRI's License Agreement. This Agreement together with |
| | | Python 1.6.1 may be located on the internet using the following |
| | | unique, persistent identifier (known as a handle): 1895.22/1013. This |
| | | Agreement may also be obtained from a proxy server on the internet |
| | | using the following URL: http://hdl.handle.net/1895.22/1013". |
| | | |
| | | 3. In the event Licensee prepares a derivative work that is based on |
| | | or incorporates Python 1.6.1 or any part thereof, and wants to make |
| | | the derivative work available to others as provided herein, then |
| | | Licensee hereby agrees to include in any such work a brief summary of |
| | | the changes made to Python 1.6.1. |
| | | |
| | | 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" |
| | | basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR |
| | | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND |
| | | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS |
| | | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT |
| | | INFRINGE ANY THIRD PARTY RIGHTS. |
| | | |
| | | 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON |
| | | 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS |
| | | A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, |
| | | OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. |
| | | |
| | | 6. This License Agreement will automatically terminate upon a material |
| | | breach of its terms and conditions. |
| | | |
| | | 7. This License Agreement shall be governed by the federal |
| | | intellectual property law of the United States, including without |
| | | limitation the federal copyright law, and, to the extent such |
| | | U.S. federal law does not apply, by the law of the Commonwealth of |
| | | Virginia, excluding Virginia's conflict of law provisions. |
| | | Notwithstanding the foregoing, with regard to derivative works based |
| | | on Python 1.6.1 that incorporate non-separable material that was |
| | | previously distributed under the GNU General Public License (GPL), the |
| | | law of the Commonwealth of Virginia shall govern this License |
| | | Agreement only as to issues arising under or with respect to |
| | | Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this |
| | | License Agreement shall be deemed to create any relationship of |
| | | agency, partnership, or joint venture between CNRI and Licensee. This |
| | | License Agreement does not grant permission to use CNRI trademarks or |
| | | trade name in a trademark sense to endorse or promote products or |
| | | services of Licensee, or any third party. |
| | | |
| | | 8. By clicking on the "ACCEPT" button where indicated, or by copying, |
| | | installing or otherwise using Python 1.6.1, Licensee agrees to be |
| | | bound by the terms and conditions of this License Agreement. |
| | | |
| | | ACCEPT |
| | | |
| | | |
| | | CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 |
| | | -------------------------------------------------- |
| | | |
| | | Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, |
| | | The Netherlands. All rights reserved. |
| | | |
| | | Permission to use, copy, modify, and distribute this software and its |
| | | documentation for any purpose and without fee is hereby granted, |
| | | provided that the above copyright notice appear in all copies and that |
| | | both that copyright notice and this permission notice appear in |
| | | supporting documentation, and that the name of Stichting Mathematisch |
| | | Centrum or CWI not be used in advertising or publicity pertaining to |
| | | distribution of the software without specific, written prior |
| | | permission. |
| | | |
| | | STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO |
| | | THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND |
| | | FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE |
| | | FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
| | | WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN |
| | | ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT |
| | | OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| | | |
| | | ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION |
| | | ---------------------------------------------------------------------- |
| | | |
| | | Permission to use, copy, modify, and/or distribute this software for any |
| | | purpose with or without fee is hereby granted. |
| | | |
| | | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH |
| | | REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY |
| | | AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, |
| | | INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM |
| | | LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR |
| | | OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR |
| | | PERFORMANCE OF THIS SOFTWARE. |
| New file |
| | |
| | | Metadata-Version: 2.3 |
| | | Name: aiohappyeyeballs |
| | | Version: 2.6.1 |
| | | Summary: Happy Eyeballs for asyncio |
| | | License: PSF-2.0 |
| | | Author: J. Nick Koston |
| | | Author-email: nick@koston.org |
| | | Requires-Python: >=3.9 |
| | | Classifier: Development Status :: 5 - Production/Stable |
| | | Classifier: Intended Audience :: Developers |
| | | Classifier: Natural Language :: English |
| | | Classifier: Operating System :: OS Independent |
| | | Classifier: Topic :: Software Development :: Libraries |
| | | Classifier: Programming Language :: Python :: 3 |
| | | Classifier: Programming Language :: Python :: 3.9 |
| | | Classifier: Programming Language :: Python :: 3.10 |
| | | Classifier: Programming Language :: Python :: 3.11 |
| | | Classifier: Programming Language :: Python :: 3.12 |
| | | Classifier: Programming Language :: Python :: 3.13 |
| | | Classifier: License :: OSI Approved :: Python Software Foundation License |
| | | Project-URL: Bug Tracker, https://github.com/aio-libs/aiohappyeyeballs/issues |
| | | Project-URL: Changelog, https://github.com/aio-libs/aiohappyeyeballs/blob/main/CHANGELOG.md |
| | | Project-URL: Documentation, https://aiohappyeyeballs.readthedocs.io |
| | | Project-URL: Repository, https://github.com/aio-libs/aiohappyeyeballs |
| | | Description-Content-Type: text/markdown |
| | | |
| | | # aiohappyeyeballs |
| | | |
| | | <p align="center"> |
| | | <a href="https://github.com/aio-libs/aiohappyeyeballs/actions/workflows/ci.yml?query=branch%3Amain"> |
| | | <img src="https://img.shields.io/github/actions/workflow/status/aio-libs/aiohappyeyeballs/ci-cd.yml?branch=main&label=CI&logo=github&style=flat-square" alt="CI Status" > |
| | | </a> |
| | | <a href="https://aiohappyeyeballs.readthedocs.io"> |
| | | <img src="https://img.shields.io/readthedocs/aiohappyeyeballs.svg?logo=read-the-docs&logoColor=fff&style=flat-square" alt="Documentation Status"> |
| | | </a> |
| | | <a href="https://codecov.io/gh/aio-libs/aiohappyeyeballs"> |
| | | <img src="https://img.shields.io/codecov/c/github/aio-libs/aiohappyeyeballs.svg?logo=codecov&logoColor=fff&style=flat-square" alt="Test coverage percentage"> |
| | | </a> |
| | | </p> |
| | | <p align="center"> |
| | | <a href="https://python-poetry.org/"> |
| | | <img src="https://img.shields.io/badge/packaging-poetry-299bd7?style=flat-square&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA4AAAASCAYAAABrXO8xAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAJJSURBVHgBfZLPa1NBEMe/s7tNXoxW1KJQKaUHkXhQvHgW6UHQQ09CBS/6V3hKc/AP8CqCrUcpmop3Cx48eDB4yEECjVQrlZb80CRN8t6OM/teagVxYZi38+Yz853dJbzoMV3MM8cJUcLMSUKIE8AzQ2PieZzFxEJOHMOgMQQ+dUgSAckNXhapU/NMhDSWLs1B24A8sO1xrN4NECkcAC9ASkiIJc6k5TRiUDPhnyMMdhKc+Zx19l6SgyeW76BEONY9exVQMzKExGKwwPsCzza7KGSSWRWEQhyEaDXp6ZHEr416ygbiKYOd7TEWvvcQIeusHYMJGhTwF9y7sGnSwaWyFAiyoxzqW0PM/RjghPxF2pWReAowTEXnDh0xgcLs8l2YQmOrj3N7ByiqEoH0cARs4u78WgAVkoEDIDoOi3AkcLOHU60RIg5wC4ZuTC7FaHKQm8Hq1fQuSOBvX/sodmNJSB5geaF5CPIkUeecdMxieoRO5jz9bheL6/tXjrwCyX/UYBUcjCaWHljx1xiX6z9xEjkYAzbGVnB8pvLmyXm9ep+W8CmsSHQQY77Zx1zboxAV0w7ybMhQmfqdmmw3nEp1I0Z+FGO6M8LZdoyZnuzzBdjISicKRnpxzI9fPb+0oYXsNdyi+d3h9bm9MWYHFtPeIZfLwzmFDKy1ai3p+PDls1Llz4yyFpferxjnyjJDSEy9CaCx5m2cJPerq6Xm34eTrZt3PqxYO1XOwDYZrFlH1fWnpU38Y9HRze3lj0vOujZcXKuuXm3jP+s3KbZVra7y2EAAAAAASUVORK5CYII=" alt="Poetry"> |
| | | </a> |
| | | <a href="https://github.com/astral-sh/ruff"> |
| | | <img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff"> |
| | | </a> |
| | | <a href="https://github.com/pre-commit/pre-commit"> |
| | | <img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=flat-square" alt="pre-commit"> |
| | | </a> |
| | | </p> |
| | | <p align="center"> |
| | | <a href="https://pypi.org/project/aiohappyeyeballs/"> |
| | | <img src="https://img.shields.io/pypi/v/aiohappyeyeballs.svg?logo=python&logoColor=fff&style=flat-square" alt="PyPI Version"> |
| | | </a> |
| | | <img src="https://img.shields.io/pypi/pyversions/aiohappyeyeballs.svg?style=flat-square&logo=python&logoColor=fff" alt="Supported Python versions"> |
| | | <img src="https://img.shields.io/pypi/l/aiohappyeyeballs.svg?style=flat-square" alt="License"> |
| | | </p> |
| | | |
| | | --- |
| | | |
| | | **Documentation**: <a href="https://aiohappyeyeballs.readthedocs.io" target="_blank">https://aiohappyeyeballs.readthedocs.io </a> |
| | | |
| | | **Source Code**: <a href="https://github.com/aio-libs/aiohappyeyeballs" target="_blank">https://github.com/aio-libs/aiohappyeyeballs </a> |
| | | |
| | | --- |
| | | |
| | | [Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs) |
| | | ([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html)) |
| | | |
| | | ## Use case |
| | | |
| | | This library exists to allow connecting with |
| | | [Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs) |
| | | ([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html)) |
| | | when you |
| | | already have a list of addrinfo and not a DNS name. |
| | | |
| | | The stdlib version of `loop.create_connection()` |
| | | will only work when you pass in an unresolved name which |
| | | is not a good fit when using DNS caching or resolving |
| | | names via another method such as `zeroconf`. |
| | | |
| | | ## Installation |
| | | |
| | | Install this via pip (or your favourite package manager): |
| | | |
| | | `pip install aiohappyeyeballs` |
| | | |
| | | ## License |
| | | |
| | | [aiohappyeyeballs is licensed under the same terms as cpython itself.](https://github.com/python/cpython/blob/main/LICENSE) |
| | | |
| | | ## Example usage |
| | | |
| | | ```python |
| | | |
| | | addr_infos = await loop.getaddrinfo("example.org", 80) |
| | | |
| | | socket = await start_connection(addr_infos) |
| | | socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, happy_eyeballs_delay=0.2) |
| | | |
| | | transport, protocol = await loop.create_connection( |
| | | MyProtocol, sock=socket, ...) |
| | | |
| | | # Remove the first address for each family from addr_info |
| | | pop_addr_infos_interleave(addr_info, 1) |
| | | |
| | | # Remove all matching address from addr_info |
| | | remove_addr_infos(addr_info, "dead::beef::") |
| | | |
| | | # Convert a local_addr to local_addr_infos |
| | | local_addr_infos = addr_to_addr_infos(("127.0.0.1",0)) |
| | | ``` |
| | | |
| | | ## Credits |
| | | |
| | | This package contains code from cpython and is licensed under the same terms as cpython itself. |
| | | |
| | | This package was created with |
| | | [Copier](https://copier.readthedocs.io/) and the |
| | | [browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template) |
| | | project template. |
| | | |
| New file |
| | |
| | | aiohappyeyeballs-2.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 |
| | | aiohappyeyeballs-2.6.1.dist-info/LICENSE,sha256=Oy-B_iHRgcSZxZolbI4ZaEVdZonSaaqFNzv7avQdo78,13936 |
| | | aiohappyeyeballs-2.6.1.dist-info/METADATA,sha256=NSXlhJwAfi380eEjAo7BQ4P_TVal9xi0qkyZWibMsVM,5915 |
| | | aiohappyeyeballs-2.6.1.dist-info/RECORD,, |
| | | aiohappyeyeballs-2.6.1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88 |
| | | aiohappyeyeballs/__init__.py,sha256=x7kktHEtaD9quBcWDJPuLeKyjuVAI-Jj14S9B_5hcTs,361 |
| | | aiohappyeyeballs/__pycache__/__init__.cpython-312.pyc,, |
| | | aiohappyeyeballs/__pycache__/_staggered.cpython-312.pyc,, |
| | | aiohappyeyeballs/__pycache__/impl.cpython-312.pyc,, |
| | | aiohappyeyeballs/__pycache__/types.cpython-312.pyc,, |
| | | aiohappyeyeballs/__pycache__/utils.cpython-312.pyc,, |
| | | aiohappyeyeballs/_staggered.py,sha256=edfVowFx-P-ywJjIEF3MdPtEMVODujV6CeMYr65otac,6900 |
| | | aiohappyeyeballs/impl.py,sha256=Dlcm2mTJ28ucrGnxkb_fo9CZzLAkOOBizOt7dreBbXE,9681 |
| | | aiohappyeyeballs/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 |
| | | aiohappyeyeballs/types.py,sha256=YZJIAnyoV4Dz0WFtlaf_OyE4EW7Xus1z7aIfNI6tDDQ,425 |
| | | aiohappyeyeballs/utils.py,sha256=on9GxIR0LhEfZu8P6Twi9hepX9zDanuZM20MWsb3xlQ,3028 |
| New file |
| | |
| | | Wheel-Version: 1.0 |
| | | Generator: poetry-core 2.1.1 |
| | | Root-Is-Purelib: true |
| | | Tag: py3-none-any |
| New file |
| | |
| | | __version__ = "2.6.1" |
| | | |
| | | from .impl import start_connection |
| | | from .types import AddrInfoType, SocketFactoryType |
| | | from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos |
| | | |
| | | __all__ = ( |
| | | "AddrInfoType", |
| | | "SocketFactoryType", |
| | | "addr_to_addr_infos", |
| | | "pop_addr_infos_interleave", |
| | | "remove_addr_infos", |
| | | "start_connection", |
| | | ) |
| New file |
| | |
| | | import asyncio |
| | | import contextlib |
| | | |
| | | # PY3.9: Import Callable from typing until we drop Python 3.9 support |
| | | # https://github.com/python/cpython/issues/87131 |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | Iterable, |
| | | List, |
| | | Optional, |
| | | Set, |
| | | Tuple, |
| | | TypeVar, |
| | | Union, |
| | | ) |
| | | |
| | | _T = TypeVar("_T") |
| | | |
| | | RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt) |
| | | |
| | | |
| | | def _set_result(wait_next: "asyncio.Future[None]") -> None: |
| | | """Set the result of a future if it is not already done.""" |
| | | if not wait_next.done(): |
| | | wait_next.set_result(None) |
| | | |
| | | |
| | | async def _wait_one( |
| | | futures: "Iterable[asyncio.Future[Any]]", |
| | | loop: asyncio.AbstractEventLoop, |
| | | ) -> _T: |
| | | """Wait for the first future to complete.""" |
| | | wait_next = loop.create_future() |
| | | |
| | | def _on_completion(fut: "asyncio.Future[Any]") -> None: |
| | | if not wait_next.done(): |
| | | wait_next.set_result(fut) |
| | | |
| | | for f in futures: |
| | | f.add_done_callback(_on_completion) |
| | | |
| | | try: |
| | | return await wait_next |
| | | finally: |
| | | for f in futures: |
| | | f.remove_done_callback(_on_completion) |
| | | |
| | | |
| | | async def staggered_race( |
| | | coro_fns: Iterable[Callable[[], Awaitable[_T]]], |
| | | delay: Optional[float], |
| | | *, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | ) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]: |
| | | """ |
| | | Run coroutines with staggered start times and take the first to finish. |
| | | |
| | | This method takes an iterable of coroutine functions. The first one is |
| | | started immediately. From then on, whenever the immediately preceding one |
| | | fails (raises an exception), or when *delay* seconds has passed, the next |
| | | coroutine is started. This continues until one of the coroutines complete |
| | | successfully, in which case all others are cancelled, or until all |
| | | coroutines fail. |
| | | |
| | | The coroutines provided should be well-behaved in the following way: |
| | | |
| | | * They should only ``return`` if completed successfully. |
| | | |
| | | * They should always raise an exception if they did not complete |
| | | successfully. In particular, if they handle cancellation, they should |
| | | probably reraise, like this:: |
| | | |
| | | try: |
| | | # do work |
| | | except asyncio.CancelledError: |
| | | # undo partially completed work |
| | | raise |
| | | |
| | | Args: |
| | | ---- |
| | | coro_fns: an iterable of coroutine functions, i.e. callables that |
| | | return a coroutine object when called. Use ``functools.partial`` or |
| | | lambdas to pass arguments. |
| | | |
| | | delay: amount of time, in seconds, between starting coroutines. If |
| | | ``None``, the coroutines will run sequentially. |
| | | |
| | | loop: the event loop to use. If ``None``, the running loop is used. |
| | | |
| | | Returns: |
| | | ------- |
| | | tuple *(winner_result, winner_index, exceptions)* where |
| | | |
| | | - *winner_result*: the result of the winning coroutine, or ``None`` |
| | | if no coroutines won. |
| | | |
| | | - *winner_index*: the index of the winning coroutine in |
| | | ``coro_fns``, or ``None`` if no coroutines won. If the winning |
| | | coroutine may return None on success, *winner_index* can be used |
| | | to definitively determine whether any coroutine won. |
| | | |
| | | - *exceptions*: list of exceptions returned by the coroutines. |
| | | ``len(exceptions)`` is equal to the number of coroutines actually |
| | | started, and the order is the same as in ``coro_fns``. The winning |
| | | coroutine's entry is ``None``. |
| | | |
| | | """ |
| | | loop = loop or asyncio.get_running_loop() |
| | | exceptions: List[Optional[BaseException]] = [] |
| | | tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set() |
| | | |
| | | async def run_one_coro( |
| | | coro_fn: Callable[[], Awaitable[_T]], |
| | | this_index: int, |
| | | start_next: "asyncio.Future[None]", |
| | | ) -> Optional[Tuple[_T, int]]: |
| | | """ |
| | | Run a single coroutine. |
| | | |
| | | If the coroutine fails, set the exception in the exceptions list and |
| | | start the next coroutine by setting the result of the start_next. |
| | | |
| | | If the coroutine succeeds, return the result and the index of the |
| | | coroutine in the coro_fns list. |
| | | |
| | | If SystemExit or KeyboardInterrupt is raised, re-raise it. |
| | | """ |
| | | try: |
| | | result = await coro_fn() |
| | | except RE_RAISE_EXCEPTIONS: |
| | | raise |
| | | except BaseException as e: |
| | | exceptions[this_index] = e |
| | | _set_result(start_next) # Kickstart the next coroutine |
| | | return None |
| | | |
| | | return result, this_index |
| | | |
| | | start_next_timer: Optional[asyncio.TimerHandle] = None |
| | | start_next: Optional[asyncio.Future[None]] |
| | | task: asyncio.Task[Optional[Tuple[_T, int]]] |
| | | done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]] |
| | | coro_iter = iter(coro_fns) |
| | | this_index = -1 |
| | | try: |
| | | while True: |
| | | if coro_fn := next(coro_iter, None): |
| | | this_index += 1 |
| | | exceptions.append(None) |
| | | start_next = loop.create_future() |
| | | task = loop.create_task(run_one_coro(coro_fn, this_index, start_next)) |
| | | tasks.add(task) |
| | | start_next_timer = ( |
| | | loop.call_later(delay, _set_result, start_next) if delay else None |
| | | ) |
| | | elif not tasks: |
| | | # We exhausted the coro_fns list and no tasks are running |
| | | # so we have no winner and all coroutines failed. |
| | | break |
| | | |
| | | while tasks or start_next: |
| | | done = await _wait_one( |
| | | (*tasks, start_next) if start_next else tasks, loop |
| | | ) |
| | | if done is start_next: |
| | | # The current task has failed or the timer has expired |
| | | # so we need to start the next task. |
| | | start_next = None |
| | | if start_next_timer: |
| | | start_next_timer.cancel() |
| | | start_next_timer = None |
| | | |
| | | # Break out of the task waiting loop to start the next |
| | | # task. |
| | | break |
| | | |
| | | if TYPE_CHECKING: |
| | | assert isinstance(done, asyncio.Task) |
| | | |
| | | tasks.remove(done) |
| | | if winner := done.result(): |
| | | return *winner, exceptions |
| | | finally: |
| | | # We either have: |
| | | # - a winner |
| | | # - all tasks failed |
| | | # - a KeyboardInterrupt or SystemExit. |
| | | |
| | | # |
| | | # If the timer is still running, cancel it. |
| | | # |
| | | if start_next_timer: |
| | | start_next_timer.cancel() |
| | | |
| | | # |
| | | # If there are any tasks left, cancel them and than |
| | | # wait them so they fill the exceptions list. |
| | | # |
| | | for task in tasks: |
| | | task.cancel() |
| | | with contextlib.suppress(asyncio.CancelledError): |
| | | await task |
| | | |
| | | return None, None, exceptions |
| New file |
| | |
| | | """Base implementation.""" |
| | | |
| | | import asyncio |
| | | import collections |
| | | import contextlib |
| | | import functools |
| | | import itertools |
| | | import socket |
| | | from typing import List, Optional, Sequence, Set, Union |
| | | |
| | | from . import _staggered |
| | | from .types import AddrInfoType, SocketFactoryType |
| | | |
| | | |
| | | async def start_connection( |
| | | addr_infos: Sequence[AddrInfoType], |
| | | *, |
| | | local_addr_infos: Optional[Sequence[AddrInfoType]] = None, |
| | | happy_eyeballs_delay: Optional[float] = None, |
| | | interleave: Optional[int] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | socket_factory: Optional[SocketFactoryType] = None, |
| | | ) -> socket.socket: |
| | | """ |
| | | Connect to a TCP server. |
| | | |
| | | Create a socket connection to a specified destination. The |
| | | destination is specified as a list of AddrInfoType tuples as |
| | | returned from getaddrinfo(). |
| | | |
| | | The arguments are, in order: |
| | | |
| | | * ``family``: the address family, e.g. ``socket.AF_INET`` or |
| | | ``socket.AF_INET6``. |
| | | * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or |
| | | ``socket.SOCK_DGRAM``. |
| | | * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or |
| | | ``socket.IPPROTO_UDP``. |
| | | * ``canonname``: the canonical name of the address, e.g. |
| | | ``"www.python.org"``. |
| | | * ``sockaddr``: the socket address |
| | | |
| | | This method is a coroutine which will try to establish the connection |
| | | in the background. When successful, the coroutine returns a |
| | | socket. |
| | | |
| | | The expected use case is to use this method in conjunction with |
| | | loop.create_connection() to establish a connection to a server:: |
| | | |
| | | socket = await start_connection(addr_infos) |
| | | transport, protocol = await loop.create_connection( |
| | | MyProtocol, sock=socket, ...) |
| | | """ |
| | | if not (current_loop := loop): |
| | | current_loop = asyncio.get_running_loop() |
| | | |
| | | single_addr_info = len(addr_infos) == 1 |
| | | |
| | | if happy_eyeballs_delay is not None and interleave is None: |
| | | # If using happy eyeballs, default to interleave addresses by family |
| | | interleave = 1 |
| | | |
| | | if interleave and not single_addr_info: |
| | | addr_infos = _interleave_addrinfos(addr_infos, interleave) |
| | | |
| | | sock: Optional[socket.socket] = None |
| | | # uvloop can raise RuntimeError instead of OSError |
| | | exceptions: List[List[Union[OSError, RuntimeError]]] = [] |
| | | if happy_eyeballs_delay is None or single_addr_info: |
| | | # not using happy eyeballs |
| | | for addrinfo in addr_infos: |
| | | try: |
| | | sock = await _connect_sock( |
| | | current_loop, |
| | | exceptions, |
| | | addrinfo, |
| | | local_addr_infos, |
| | | None, |
| | | socket_factory, |
| | | ) |
| | | break |
| | | except (RuntimeError, OSError): |
| | | continue |
| | | else: # using happy eyeballs |
| | | open_sockets: Set[socket.socket] = set() |
| | | try: |
| | | sock, _, _ = await _staggered.staggered_race( |
| | | ( |
| | | functools.partial( |
| | | _connect_sock, |
| | | current_loop, |
| | | exceptions, |
| | | addrinfo, |
| | | local_addr_infos, |
| | | open_sockets, |
| | | socket_factory, |
| | | ) |
| | | for addrinfo in addr_infos |
| | | ), |
| | | happy_eyeballs_delay, |
| | | ) |
| | | finally: |
| | | # If we have a winner, staggered_race will |
| | | # cancel the other tasks, however there is a |
| | | # small race window where any of the other tasks |
| | | # can be done before they are cancelled which |
| | | # will leave the socket open. To avoid this problem |
| | | # we pass a set to _connect_sock to keep track of |
| | | # the open sockets and close them here if there |
| | | # are any "runner up" sockets. |
| | | for s in open_sockets: |
| | | if s is not sock: |
| | | with contextlib.suppress(OSError): |
| | | s.close() |
| | | open_sockets = None # type: ignore[assignment] |
| | | |
| | | if sock is None: |
| | | all_exceptions = [exc for sub in exceptions for exc in sub] |
| | | try: |
| | | first_exception = all_exceptions[0] |
| | | if len(all_exceptions) == 1: |
| | | raise first_exception |
| | | else: |
| | | # If they all have the same str(), raise one. |
| | | model = str(first_exception) |
| | | if all(str(exc) == model for exc in all_exceptions): |
| | | raise first_exception |
| | | # Raise a combined exception so the user can see all |
| | | # the various error messages. |
| | | msg = "Multiple exceptions: {}".format( |
| | | ", ".join(str(exc) for exc in all_exceptions) |
| | | ) |
| | | # If the errno is the same for all exceptions, raise |
| | | # an OSError with that errno. |
| | | if isinstance(first_exception, OSError): |
| | | first_errno = first_exception.errno |
| | | if all( |
| | | isinstance(exc, OSError) and exc.errno == first_errno |
| | | for exc in all_exceptions |
| | | ): |
| | | raise OSError(first_errno, msg) |
| | | elif isinstance(first_exception, RuntimeError) and all( |
| | | isinstance(exc, RuntimeError) for exc in all_exceptions |
| | | ): |
| | | raise RuntimeError(msg) |
| | | # We have a mix of OSError and RuntimeError |
| | | # so we have to pick which one to raise. |
| | | # and we raise OSError for compatibility |
| | | raise OSError(msg) |
| | | finally: |
| | | all_exceptions = None # type: ignore[assignment] |
| | | exceptions = None # type: ignore[assignment] |
| | | |
| | | return sock |
| | | |
| | | |
| | | async def _connect_sock( |
| | | loop: asyncio.AbstractEventLoop, |
| | | exceptions: List[List[Union[OSError, RuntimeError]]], |
| | | addr_info: AddrInfoType, |
| | | local_addr_infos: Optional[Sequence[AddrInfoType]] = None, |
| | | open_sockets: Optional[Set[socket.socket]] = None, |
| | | socket_factory: Optional[SocketFactoryType] = None, |
| | | ) -> socket.socket: |
| | | """ |
| | | Create, bind and connect one socket. |
| | | |
| | | If open_sockets is passed, add the socket to the set of open sockets. |
| | | Any failure caught here will remove the socket from the set and close it. |
| | | |
| | | Callers can use this set to close any sockets that are not the winner |
| | | of all staggered tasks in the result there are runner up sockets aka |
| | | multiple winners. |
| | | """ |
| | | my_exceptions: List[Union[OSError, RuntimeError]] = [] |
| | | exceptions.append(my_exceptions) |
| | | family, type_, proto, _, address = addr_info |
| | | sock = None |
| | | try: |
| | | if socket_factory is not None: |
| | | sock = socket_factory(addr_info) |
| | | else: |
| | | sock = socket.socket(family=family, type=type_, proto=proto) |
| | | if open_sockets is not None: |
| | | open_sockets.add(sock) |
| | | sock.setblocking(False) |
| | | if local_addr_infos is not None: |
| | | for lfamily, _, _, _, laddr in local_addr_infos: |
| | | # skip local addresses of different family |
| | | if lfamily != family: |
| | | continue |
| | | try: |
| | | sock.bind(laddr) |
| | | break |
| | | except OSError as exc: |
| | | msg = ( |
| | | f"error while attempting to bind on " |
| | | f"address {laddr!r}: " |
| | | f"{(exc.strerror or '').lower()}" |
| | | ) |
| | | exc = OSError(exc.errno, msg) |
| | | my_exceptions.append(exc) |
| | | else: # all bind attempts failed |
| | | if my_exceptions: |
| | | raise my_exceptions.pop() |
| | | else: |
| | | raise OSError(f"no matching local address with {family=} found") |
| | | await loop.sock_connect(sock, address) |
| | | return sock |
| | | except (RuntimeError, OSError) as exc: |
| | | my_exceptions.append(exc) |
| | | if sock is not None: |
| | | if open_sockets is not None: |
| | | open_sockets.remove(sock) |
| | | try: |
| | | sock.close() |
| | | except OSError as e: |
| | | my_exceptions.append(e) |
| | | raise |
| | | raise |
| | | except: |
| | | if sock is not None: |
| | | if open_sockets is not None: |
| | | open_sockets.remove(sock) |
| | | try: |
| | | sock.close() |
| | | except OSError as e: |
| | | my_exceptions.append(e) |
| | | raise |
| | | raise |
| | | finally: |
| | | exceptions = my_exceptions = None # type: ignore[assignment] |
| | | |
| | | |
| | | def _interleave_addrinfos( |
| | | addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1 |
| | | ) -> List[AddrInfoType]: |
| | | """Interleave list of addrinfo tuples by family.""" |
| | | # Group addresses by family |
| | | addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = ( |
| | | collections.OrderedDict() |
| | | ) |
| | | for addr in addrinfos: |
| | | family = addr[0] |
| | | if family not in addrinfos_by_family: |
| | | addrinfos_by_family[family] = [] |
| | | addrinfos_by_family[family].append(addr) |
| | | addrinfos_lists = list(addrinfos_by_family.values()) |
| | | |
| | | reordered: List[AddrInfoType] = [] |
| | | if first_address_family_count > 1: |
| | | reordered.extend(addrinfos_lists[0][: first_address_family_count - 1]) |
| | | del addrinfos_lists[0][: first_address_family_count - 1] |
| | | reordered.extend( |
| | | a |
| | | for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) |
| | | if a is not None |
| | | ) |
| | | return reordered |
| New file |
| | |
| | | """Types for aiohappyeyeballs.""" |
| | | |
| | | import socket |
| | | |
| | | # PY3.9: Import Callable from typing until we drop Python 3.9 support |
| | | # https://github.com/python/cpython/issues/87131 |
| | | from typing import Callable, Tuple, Union |
| | | |
| | | AddrInfoType = Tuple[ |
| | | Union[int, socket.AddressFamily], |
| | | Union[int, socket.SocketKind], |
| | | int, |
| | | str, |
| | | Tuple, # type: ignore[type-arg] |
| | | ] |
| | | |
| | | SocketFactoryType = Callable[[AddrInfoType], socket.socket] |
| New file |
| | |
| | | """Utility functions for aiohappyeyeballs.""" |
| | | |
| | | import ipaddress |
| | | import socket |
| | | from typing import Dict, List, Optional, Tuple, Union |
| | | |
| | | from .types import AddrInfoType |
| | | |
| | | |
| | | def addr_to_addr_infos( |
| | | addr: Optional[ |
| | | Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]] |
| | | ], |
| | | ) -> Optional[List[AddrInfoType]]: |
| | | """Convert an address tuple to a list of addr_info tuples.""" |
| | | if addr is None: |
| | | return None |
| | | host = addr[0] |
| | | port = addr[1] |
| | | is_ipv6 = ":" in host |
| | | if is_ipv6: |
| | | flowinfo = 0 |
| | | scopeid = 0 |
| | | addr_len = len(addr) |
| | | if addr_len >= 4: |
| | | scopeid = addr[3] # type: ignore[misc] |
| | | if addr_len >= 3: |
| | | flowinfo = addr[2] # type: ignore[misc] |
| | | addr = (host, port, flowinfo, scopeid) |
| | | family = socket.AF_INET6 |
| | | else: |
| | | addr = (host, port) |
| | | family = socket.AF_INET |
| | | return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)] |
| | | |
| | | |
| | | def pop_addr_infos_interleave( |
| | | addr_infos: List[AddrInfoType], interleave: Optional[int] = None |
| | | ) -> None: |
| | | """ |
| | | Pop addr_info from the list of addr_infos by family up to interleave times. |
| | | |
| | | The interleave parameter is used to know how many addr_infos for |
| | | each family should be popped of the top of the list. |
| | | """ |
| | | seen: Dict[int, int] = {} |
| | | if interleave is None: |
| | | interleave = 1 |
| | | to_remove: List[AddrInfoType] = [] |
| | | for addr_info in addr_infos: |
| | | family = addr_info[0] |
| | | if family not in seen: |
| | | seen[family] = 0 |
| | | if seen[family] < interleave: |
| | | to_remove.append(addr_info) |
| | | seen[family] += 1 |
| | | for addr_info in to_remove: |
| | | addr_infos.remove(addr_info) |
| | | |
| | | |
| | | def _addr_tuple_to_ip_address( |
| | | addr: Union[Tuple[str, int], Tuple[str, int, int, int]], |
| | | ) -> Union[ |
| | | Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int] |
| | | ]: |
| | | """Convert an address tuple to an IPv4Address.""" |
| | | return (ipaddress.ip_address(addr[0]), *addr[1:]) |
| | | |
| | | |
| | | def remove_addr_infos( |
| | | addr_infos: List[AddrInfoType], |
| | | addr: Union[Tuple[str, int], Tuple[str, int, int, int]], |
| | | ) -> None: |
| | | """ |
| | | Remove an address from the list of addr_infos. |
| | | |
| | | The addr value is typically the return value of |
| | | sock.getpeername(). |
| | | """ |
| | | bad_addrs_infos: List[AddrInfoType] = [] |
| | | for addr_info in addr_infos: |
| | | if addr_info[-1] == addr: |
| | | bad_addrs_infos.append(addr_info) |
| | | if bad_addrs_infos: |
| | | for bad_addr_info in bad_addrs_infos: |
| | | addr_infos.remove(bad_addr_info) |
| | | return |
| | | # Slow path in case addr is formatted differently |
| | | match_addr = _addr_tuple_to_ip_address(addr) |
| | | for addr_info in addr_infos: |
| | | if match_addr == _addr_tuple_to_ip_address(addr_info[-1]): |
| | | bad_addrs_infos.append(addr_info) |
| | | if bad_addrs_infos: |
| | | for bad_addr_info in bad_addrs_infos: |
| | | addr_infos.remove(bad_addr_info) |
| | | return |
| | | raise ValueError(f"Address {addr} not found in addr_infos") |
| New file |
| | |
| | | Metadata-Version: 2.4 |
| | | Name: aiohttp |
| | | Version: 3.13.3 |
| | | Summary: Async http client/server framework (asyncio) |
| | | Maintainer-email: aiohttp team <team@aiohttp.org> |
| | | License: Apache-2.0 AND MIT |
| | | Project-URL: Homepage, https://github.com/aio-libs/aiohttp |
| | | Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org |
| | | Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org |
| | | Project-URL: CI: GitHub Actions, https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI |
| | | Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiohttp |
| | | Project-URL: Docs: Changelog, https://docs.aiohttp.org/en/stable/changes.html |
| | | Project-URL: Docs: RTD, https://docs.aiohttp.org |
| | | Project-URL: GitHub: issues, https://github.com/aio-libs/aiohttp/issues |
| | | Project-URL: GitHub: repo, https://github.com/aio-libs/aiohttp |
| | | Classifier: Development Status :: 5 - Production/Stable |
| | | Classifier: Framework :: AsyncIO |
| | | Classifier: Intended Audience :: Developers |
| | | Classifier: Operating System :: POSIX |
| | | Classifier: Operating System :: MacOS :: MacOS X |
| | | Classifier: Operating System :: Microsoft :: Windows |
| | | Classifier: Programming Language :: Python |
| | | Classifier: Programming Language :: Python :: 3 |
| | | Classifier: Programming Language :: Python :: 3.9 |
| | | Classifier: Programming Language :: Python :: 3.10 |
| | | Classifier: Programming Language :: Python :: 3.11 |
| | | Classifier: Programming Language :: Python :: 3.12 |
| | | Classifier: Programming Language :: Python :: 3.13 |
| | | Classifier: Programming Language :: Python :: 3.14 |
| | | Classifier: Topic :: Internet :: WWW/HTTP |
| | | Requires-Python: >=3.9 |
| | | Description-Content-Type: text/x-rst |
| | | License-File: LICENSE.txt |
| | | License-File: vendor/llhttp/LICENSE |
| | | Requires-Dist: aiohappyeyeballs>=2.5.0 |
| | | Requires-Dist: aiosignal>=1.4.0 |
| | | Requires-Dist: async-timeout<6.0,>=4.0; python_version < "3.11" |
| | | Requires-Dist: attrs>=17.3.0 |
| | | Requires-Dist: frozenlist>=1.1.1 |
| | | Requires-Dist: multidict<7.0,>=4.5 |
| | | Requires-Dist: propcache>=0.2.0 |
| | | Requires-Dist: yarl<2.0,>=1.17.0 |
| | | Provides-Extra: speedups |
| | | Requires-Dist: aiodns>=3.3.0; extra == "speedups" |
| | | Requires-Dist: Brotli>=1.2; platform_python_implementation == "CPython" and extra == "speedups" |
| | | Requires-Dist: brotlicffi>=1.2; platform_python_implementation != "CPython" and extra == "speedups" |
| | | Requires-Dist: backports.zstd; (platform_python_implementation == "CPython" and python_version < "3.14") and extra == "speedups" |
| | | Dynamic: license-file |
| | | |
| | | ================================== |
| | | Async http client/server framework |
| | | ================================== |
| | | |
| | | .. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/aiohttp-plain.svg |
| | | :height: 64px |
| | | :width: 64px |
| | | :alt: aiohttp logo |
| | | |
| | | | |
| | | |
| | | .. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg |
| | | :target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI |
| | | :alt: GitHub Actions status for master branch |
| | | |
| | | .. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg |
| | | :target: https://codecov.io/gh/aio-libs/aiohttp |
| | | :alt: codecov.io status for master branch |
| | | |
| | | .. image:: https://badge.fury.io/py/aiohttp.svg |
| | | :target: https://pypi.org/project/aiohttp |
| | | :alt: Latest PyPI package version |
| | | |
| | | .. image:: https://img.shields.io/pypi/dm/aiohttp |
| | | :target: https://pypistats.org/packages/aiohttp |
| | | :alt: Downloads count |
| | | |
| | | .. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest |
| | | :target: https://docs.aiohttp.org/ |
| | | :alt: Latest Read The Docs |
| | | |
| | | .. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json |
| | | :target: https://codspeed.io/aio-libs/aiohttp |
| | | :alt: Codspeed.io status for aiohttp |
| | | |
| | | |
| | | Key Features |
| | | ============ |
| | | |
| | | - Supports both client and server side of HTTP protocol. |
| | | - Supports both client and server Web-Sockets out-of-the-box and avoids |
| | | Callback Hell. |
| | | - Provides Web-server with middleware and pluggable routing. |
| | | |
| | | |
| | | Getting started |
| | | =============== |
| | | |
| | | Client |
| | | ------ |
| | | |
| | | To get something from the web: |
| | | |
| | | .. code-block:: python |
| | | |
| | | import aiohttp |
| | | import asyncio |
| | | |
| | | async def main(): |
| | | |
| | | async with aiohttp.ClientSession() as session: |
| | | async with session.get('http://python.org') as response: |
| | | |
| | | print("Status:", response.status) |
| | | print("Content-type:", response.headers['content-type']) |
| | | |
| | | html = await response.text() |
| | | print("Body:", html[:15], "...") |
| | | |
| | | asyncio.run(main()) |
| | | |
| | | This prints: |
| | | |
| | | .. code-block:: |
| | | |
| | | Status: 200 |
| | | Content-type: text/html; charset=utf-8 |
| | | Body: <!doctype html> ... |
| | | |
| | | Coming from `requests <https://requests.readthedocs.io/>`_ ? Read `why we need so many lines <https://aiohttp.readthedocs.io/en/latest/http_request_lifecycle.html>`_. |
| | | |
| | | Server |
| | | ------ |
| | | |
| | | An example using a simple server: |
| | | |
| | | .. code-block:: python |
| | | |
| | | # examples/server_simple.py |
| | | from aiohttp import web |
| | | |
| | | async def handle(request): |
| | | name = request.match_info.get('name', "Anonymous") |
| | | text = "Hello, " + name |
| | | return web.Response(text=text) |
| | | |
| | | async def wshandle(request): |
| | | ws = web.WebSocketResponse() |
| | | await ws.prepare(request) |
| | | |
| | | async for msg in ws: |
| | | if msg.type == web.WSMsgType.text: |
| | | await ws.send_str("Hello, {}".format(msg.data)) |
| | | elif msg.type == web.WSMsgType.binary: |
| | | await ws.send_bytes(msg.data) |
| | | elif msg.type == web.WSMsgType.close: |
| | | break |
| | | |
| | | return ws |
| | | |
| | | |
| | | app = web.Application() |
| | | app.add_routes([web.get('/', handle), |
| | | web.get('/echo', wshandle), |
| | | web.get('/{name}', handle)]) |
| | | |
| | | if __name__ == '__main__': |
| | | web.run_app(app) |
| | | |
| | | |
| | | Documentation |
| | | ============= |
| | | |
| | | https://aiohttp.readthedocs.io/ |
| | | |
| | | |
| | | Demos |
| | | ===== |
| | | |
| | | https://github.com/aio-libs/aiohttp-demos |
| | | |
| | | |
| | | External links |
| | | ============== |
| | | |
| | | * `Third party libraries |
| | | <http://aiohttp.readthedocs.io/en/latest/third_party.html>`_ |
| | | * `Built with aiohttp |
| | | <http://aiohttp.readthedocs.io/en/latest/built_with.html>`_ |
| | | * `Powered by aiohttp |
| | | <http://aiohttp.readthedocs.io/en/latest/powered_by.html>`_ |
| | | |
| | | Feel free to make a Pull Request for adding your link to these pages! |
| | | |
| | | |
| | | Communication channels |
| | | ====================== |
| | | |
| | | *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions |
| | | |
| | | *Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_ |
| | | |
| | | We support `Stack Overflow |
| | | <https://stackoverflow.com/questions/tagged/aiohttp>`_. |
| | | Please add *aiohttp* tag to your question there. |
| | | |
| | | Requirements |
| | | ============ |
| | | |
| | | - attrs_ |
| | | - multidict_ |
| | | - yarl_ |
| | | - frozenlist_ |
| | | |
| | | Optionally you may install the aiodns_ library (highly recommended for sake of speed). |
| | | |
| | | .. _aiodns: https://pypi.python.org/pypi/aiodns |
| | | .. _attrs: https://github.com/python-attrs/attrs |
| | | .. _multidict: https://pypi.python.org/pypi/multidict |
| | | .. _frozenlist: https://pypi.org/project/frozenlist/ |
| | | .. _yarl: https://pypi.python.org/pypi/yarl |
| | | .. _async-timeout: https://pypi.python.org/pypi/async_timeout |
| | | |
| | | License |
| | | ======= |
| | | |
| | | ``aiohttp`` is offered under the Apache 2 license. |
| | | |
| | | |
| | | Keepsafe |
| | | ======== |
| | | |
| | | The aiohttp community would like to thank Keepsafe |
| | | (https://www.getkeepsafe.com) for its support in the early days of |
| | | the project. |
| | | |
| | | |
| | | Source code |
| | | =========== |
| | | |
| | | The latest developer version is available in a GitHub repository: |
| | | https://github.com/aio-libs/aiohttp |
| | | |
| | | Benchmarks |
| | | ========== |
| | | |
| | | If you are interested in efficiency, the AsyncIO community maintains a |
| | | list of benchmarks on the official wiki: |
| | | https://github.com/python/asyncio/wiki/Benchmarks |
| | | |
| | | -------- |
| | | |
| | | .. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat |
| | | :target: https://matrix.to/#/%23aio-libs:matrix.org |
| | | :alt: Matrix Room — #aio-libs:matrix.org |
| | | |
| | | .. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat |
| | | :target: https://matrix.to/#/%23aio-libs-space:matrix.org |
| | | :alt: Matrix Space — #aio-libs-space:matrix.org |
| | | |
| | | .. image:: https://insights.linuxfoundation.org/api/badge/health-score?project=aiohttp |
| | | :target: https://insights.linuxfoundation.org/project/aiohttp |
| | | :alt: LFX Health Score |
| New file |
| | |
| | | aiohttp-3.13.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 |
| | | aiohttp-3.13.3.dist-info/METADATA,sha256=jkzui8KtHZ32gb8TfFZwIW4-zZ6Sr1eh1R6wYZW79Sg,8407 |
| | | aiohttp-3.13.3.dist-info/RECORD,, |
| | | aiohttp-3.13.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 |
| | | aiohttp-3.13.3.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101 |
| | | aiohttp-3.13.3.dist-info/licenses/LICENSE.txt,sha256=wUk-nxDVnR-6n53ygAjhVX4zz5-6yM4SY6ozk5goA94,601 |
| | | aiohttp-3.13.3.dist-info/licenses/vendor/llhttp/LICENSE,sha256=bd-mKNt20th7iWi6-61g9RxOyIEA3Xu5b5chbYivCAg,1127 |
| | | aiohttp-3.13.3.dist-info/top_level.txt,sha256=iv-JIaacmTl-hSho3QmphcKnbRRYx1st47yjz_178Ro,8 |
| | | aiohttp/.hash/_cparser.pxd.hash,sha256=eJQ2z7M7WoAng7D5ukCXzE3Yx22bLgv1PyOe0YbbQTM,108 |
| | | aiohttp/.hash/_find_header.pxd.hash,sha256=TxG5w4etbVd6sfm5JWbdf5PW6LnuXRQnlMoFBVGKN2E,112 |
| | | aiohttp/.hash/_http_parser.pyx.hash,sha256=NYbk_8ETW0vAtpTcxRVuWVmKJr9CUh2fR8I9emVQck4,112 |
| | | aiohttp/.hash/_http_writer.pyx.hash,sha256=J4W44iDZQwIyZ0rGO5v-_sKIfPtAwqn99EwgaevQmo8,112 |
| | | aiohttp/.hash/hdrs.py.hash,sha256=c2N-IMHz4dvAGL36CUyEw15noHE2AkJTeSBy3IxcCec,103 |
| | | aiohttp/__init__.py,sha256=wTWxnyVGn59VuoFuK1m2_jJ-Cw5Be9ktp7h5Hfvyaas,8580 |
| | | aiohttp/__pycache__/__init__.cpython-312.pyc,, |
| | | aiohttp/__pycache__/_cookie_helpers.cpython-312.pyc,, |
| | | aiohttp/__pycache__/abc.cpython-312.pyc,, |
| | | aiohttp/__pycache__/base_protocol.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_exceptions.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_middleware_digest_auth.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_middlewares.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_proto.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_reqrep.cpython-312.pyc,, |
| | | aiohttp/__pycache__/client_ws.cpython-312.pyc,, |
| | | aiohttp/__pycache__/compression_utils.cpython-312.pyc,, |
| | | aiohttp/__pycache__/connector.cpython-312.pyc,, |
| | | aiohttp/__pycache__/cookiejar.cpython-312.pyc,, |
| | | aiohttp/__pycache__/formdata.cpython-312.pyc,, |
| | | aiohttp/__pycache__/hdrs.cpython-312.pyc,, |
| | | aiohttp/__pycache__/helpers.cpython-312.pyc,, |
| | | aiohttp/__pycache__/http.cpython-312.pyc,, |
| | | aiohttp/__pycache__/http_exceptions.cpython-312.pyc,, |
| | | aiohttp/__pycache__/http_parser.cpython-312.pyc,, |
| | | aiohttp/__pycache__/http_websocket.cpython-312.pyc,, |
| | | aiohttp/__pycache__/http_writer.cpython-312.pyc,, |
| | | aiohttp/__pycache__/log.cpython-312.pyc,, |
| | | aiohttp/__pycache__/multipart.cpython-312.pyc,, |
| | | aiohttp/__pycache__/payload.cpython-312.pyc,, |
| | | aiohttp/__pycache__/payload_streamer.cpython-312.pyc,, |
| | | aiohttp/__pycache__/pytest_plugin.cpython-312.pyc,, |
| | | aiohttp/__pycache__/resolver.cpython-312.pyc,, |
| | | aiohttp/__pycache__/streams.cpython-312.pyc,, |
| | | aiohttp/__pycache__/tcp_helpers.cpython-312.pyc,, |
| | | aiohttp/__pycache__/test_utils.cpython-312.pyc,, |
| | | aiohttp/__pycache__/tracing.cpython-312.pyc,, |
| | | aiohttp/__pycache__/typedefs.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_app.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_exceptions.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_fileresponse.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_log.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_middlewares.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_protocol.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_request.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_response.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_routedef.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_runner.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_server.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_urldispatcher.cpython-312.pyc,, |
| | | aiohttp/__pycache__/web_ws.cpython-312.pyc,, |
| | | aiohttp/__pycache__/worker.cpython-312.pyc,, |
| | | aiohttp/_cookie_helpers.py,sha256=x6tVKd6fgqjIFQzQ_z-t_CRl-Pnar7qJh8HUwroSKIA,13997 |
| | | aiohttp/_cparser.pxd,sha256=GP0Y9NqZYQGkJtS81XDzU70e7rRMb34TR7yGMmx5_zs,4453 |
| | | aiohttp/_find_header.pxd,sha256=BFUSmxhemBtblqxzjzH3x03FfxaWlTyuAIOz8YZ5_nM,70 |
| | | aiohttp/_headers.pxi,sha256=1MhCe6Un_KI1tpO85HnDfzVO94BhcirLanAOys5FIHA,2090 |
| | | aiohttp/_http_parser.cp312-win_amd64.pyd,sha256=kVErC3Q1vBoeaoCynkMwWayfaXk4Ju-VaWbOVdGcwB8,248832 |
| | | aiohttp/_http_parser.pyx,sha256=9-jyYF9-4i7ToMV0mvVgQ_rqNa8KGJfhQVY0GGrZuGg,29096 |
| | | aiohttp/_http_writer.cp312-win_amd64.pyd,sha256=e2t5uBtwmasH8kAxdg6QOvalydEl5-m3n46J4WSffiI,47104 |
| | | aiohttp/_http_writer.pyx,sha256=WWdOf19QPqScBkifDhJynqPPOAmwB9sKJAO0Kkor4tE,4826 |
| | | aiohttp/_websocket/.hash/mask.pxd.hash,sha256=TL0gGYyJWxqG8dWwa08B74WGg6-0M6_Breqrff-AiZg,115 |
| | | aiohttp/_websocket/.hash/mask.pyx.hash,sha256=7xo6f01JaOQmaUNij3dQlOgxkEC1edkAIhwpeOvimLI,115 |
| | | aiohttp/_websocket/.hash/reader_c.pxd.hash,sha256=RzhqjHN1HadWDeMHVQvaf-XLlGxF6nm5u-HJHGsx2aE,119 |
| | | aiohttp/_websocket/__init__.py,sha256=R51KWH5kkdtDLb7T-ilztksbfweKCy3t22SgxGtiY-4,45 |
| | | aiohttp/_websocket/__pycache__/__init__.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/helpers.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/models.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/reader.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/reader_c.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/reader_py.cpython-312.pyc,, |
| | | aiohttp/_websocket/__pycache__/writer.cpython-312.pyc,, |
| | | aiohttp/_websocket/helpers.py,sha256=amqvDhoAKAi8ptB4qUNuQhkaOn-4JxSh_VLAqytmEfw,5185 |
| | | aiohttp/_websocket/mask.cp312-win_amd64.pyd,sha256=Q7mH9VajqPagYj6NGCurPmwJWcMZU07zN4FEkfUAP_c,36864 |
| | | aiohttp/_websocket/mask.pxd,sha256=41TdSZvhcbYSW_Vrw7bF4r_yoor2njtdaZ3bmvK6-jw,115 |
| | | aiohttp/_websocket/mask.pyx,sha256=Ro7dOOv43HAAqNMz3xyCA11ppcn-vARIvjycStTEYww,1445 |
| | | aiohttp/_websocket/models.py,sha256=Pz8qvnU43VUCNZcY4g03VwTsHOsb_jSN8iG69xMAc_A,2205 |
| | | aiohttp/_websocket/reader.py,sha256=1r0cJ-jdFgbSrC6-jI0zjEA1CppzoUn8u_wiebrVVO0,1061 |
| | | aiohttp/_websocket/reader_c.cp312-win_amd64.pyd,sha256=2gSIJBH5w8xkfbErzqeI_MTILdr4gR4Pc4ytNj_jaD0,147968 |
| | | aiohttp/_websocket/reader_c.pxd,sha256=HNOl4gRWtNBNEYNbK9PGOfFEQwUqJGexBbDKB_20sl0,2735 |
| | | aiohttp/_websocket/reader_c.py,sha256=UKfslJuANla_CQMe7yIJzE8vp7bpzz9TLr-lH87XW6U,19346 |
| | | aiohttp/_websocket/reader_py.py,sha256=UKfslJuANla_CQMe7yIJzE8vp7bpzz9TLr-lH87XW6U,19346 |
| | | aiohttp/_websocket/writer.py,sha256=MpuNvG_t34CaDTAzW5FZJaRME8sL19rZotxSbXz2aas,11523 |
| | | aiohttp/abc.py,sha256=01N6Y63o2bBC8Vi0ZjO6Jw0V9kXZfy3egwzKFW-tv9c,7417 |
| | | aiohttp/base_protocol.py,sha256=8vNIv6QV_SDCW-8tfhlyxSwiBD7dAiMTqJI1GI8RG5s,3125 |
| | | aiohttp/client.py,sha256=KlWhIZt935YpOZcXOOZl3eIRkuO-l0z2BH7arfhGg-A,59992 |
| | | aiohttp/client_exceptions.py,sha256=sJcuvYKaB2nwuSdP7k18y3wc74aU0xAzdJikzzesrPE,11788 |
| | | aiohttp/client_middleware_digest_auth.py,sha256=K4TPt4-rPQ0jjSHx3UFguMN7n31LpCC_o6JA-Hrg_Pc,18107 |
| | | aiohttp/client_middlewares.py,sha256=FEVIXFkQ58n5bhK4BGEqqDCWnDh-GNJmWq20I5Yt6SU,1973 |
| | | aiohttp/client_proto.py,sha256=rfbg8nUsfpCMM_zGpQygiFn8nzSdBI-731rmXVGHwLc,12469 |
| | | aiohttp/client_reqrep.py,sha256=BUrqo2BJbrNazrIJr-ZgMLRTvE2fSON3zPQSq1dfgfU,54927 |
| | | aiohttp/client_ws.py,sha256=9DraHuupuJcT7NOgyeGml8SBr7V5D5ID5-piY1fQMdA,15537 |
| | | aiohttp/compression_utils.py,sha256=w0ECGGLVjtCXdYg-U_9DBn-DASzDPaWEVRx1HlwWslk,12086 |
| | | aiohttp/connector.py,sha256=X2sRe6EAeWiaP6eaK9hWvLtSbdiJfNhK3bWl7XbR_V4,70846 |
| | | aiohttp/cookiejar.py,sha256=C2fVzQGFieFP9mFDTOvfEc6fb5kPS2ijL2tFKAUW7Sw,19444 |
| | | aiohttp/formdata.py,sha256=sz3VaTHVk11z_5G1LaDhUwrONJ8zRAGlZGg3hcCApzA,6563 |
| | | aiohttp/hdrs.py,sha256=7htmhgZyE9HqWbPpxHU0r7kAIdT2kpOXQa1AadDh2W8,5232 |
| | | aiohttp/helpers.py,sha256=1tXIvGSRWJD9wsS7GUVHLfJEsDM_XigurpgjxajkH0g,31615 |
| | | aiohttp/http.py,sha256=DGKcwDbgIMpasv7s2jeKCRuixyj7W-RIrihRFjj0xcY,1914 |
| | | aiohttp/http_exceptions.py,sha256=J3v-1S9S22GfAEtx0pEqp6d4G1Lqi2-gOrdLtuGlEhY,3185 |
| | | aiohttp/http_parser.py,sha256=O5ud4wO80WLFe9kpXU0xGhjczUfrb7BAr0XAP7rBn7E,39263 |
| | | aiohttp/http_websocket.py,sha256=b9kBmxPLPFQP_nu_sMhIMIeqDOm0ug8G4prbrhEMHZ0,878 |
| | | aiohttp/http_writer.py,sha256=jA_aJW7JdH1mihrIYdJcLOHVKQ4Agg3g993v50eITBs,12824 |
| | | aiohttp/log.py,sha256=zYUTvXsMQ9Sz1yNN8kXwd5Qxu49a1FzjZ_wQqriEc8M,333 |
| | | aiohttp/multipart.py,sha256=UvcLOX3lO3ad3nfODsdlyvYWMAZHdUZ-wlZ5w1TbD2E,41634 |
| | | aiohttp/payload.py,sha256=Xbs_2l0wDaThFG-ehNlvzQUkHuBPpc5FxpJnJa3ZPcs,41994 |
| | | aiohttp/payload_streamer.py,sha256=K0iV85iW0vEG3rDkcopruidspynzQvrwW8mJvgPHisg,2289 |
| | | aiohttp/py.typed,sha256=3VVwXUAWVEVX7sDwyYDnW5ZdBC9_Z9AJAFfLCleUW0k,8 |
| | | aiohttp/pytest_plugin.py,sha256=ymhjbYHz2Kf0ZU_4Ly0hAp73dhsgrQIzJDo4Aot3_TI,13345 |
| | | aiohttp/resolver.py,sha256=ePJgZAN5EQY4YuFiuZmVZM6p3UuzJ4qMWM1fu8DJ2Fc,10305 |
| | | aiohttp/streams.py,sha256=J0G4ZJPdRScOPtnaB1ixhQYjLunLk8z70mfN9bc5K_o,24424 |
| | | aiohttp/tcp_helpers.py,sha256=K-hhGh3jd6qCEnHJo8LvFyfJwBjh99UKI7A0aSRVhj4,998 |
| | | aiohttp/test_utils.py,sha256=zFWAb-rPz1fWRUHnrjnfUH7ORlfIgZ2UZbEGe4YTa9I,23790 |
| | | aiohttp/tracing.py,sha256=Kb-N32aMmYqC2Yc82NV6l0mIcavSQst1BHSFj94Apl0,15013 |
| | | aiohttp/typedefs.py,sha256=Sx5v2yUyLu8nbabqtJRWj1M1_uW0IZACu78uYD7LBy0,1726 |
| | | aiohttp/web.py,sha256=BQ96NEuTWikKGN5NnnTHjFLt07GUMWvvn42iFuIS3Mg,18444 |
| | | aiohttp/web_app.py,sha256=WwEEzUg34j81kK2dPFnhlqx_z6nGjnHZDweZJF65pKc,20072 |
| | | aiohttp/web_exceptions.py,sha256=itNRhCMDJFhnMWftr5SyTsoqh-i0n9rzTj0sjcAEUjo,10812 |
| | | aiohttp/web_fileresponse.py,sha256=QIIbcIruCgfYrc8ZDvOgNlZzLbAagwXA9FrNI7NKNPY,16780 |
| | | aiohttp/web_log.py,sha256=G5ugloW9noUxPft0SmVWOXw30MviL6rqZc3XrKN_T1U,8081 |
| | | aiohttp/web_middlewares.py,sha256=mM2-R8eaV2r6Mi9Zc2bDG8QnhE9h0IzPvtDX_fkKR5s,4286 |
| | | aiohttp/web_protocol.py,sha256=gJaDFtYPA-1gz35fwchjLhxrkmXXMOzFMCDHLQ1FHiI,27802 |
| | | aiohttp/web_request.py,sha256=9zqyP32ScMUylQ_ta4tBHpWmoprhSB4jTgj2ixmGK74,30763 |
| | | aiohttp/web_response.py,sha256=WJVumt-P0uMaFSbef_owvOXpq90E4VMl3RvSOWh0nJE,30197 |
| | | aiohttp/web_routedef.py,sha256=XC10f57Q36JmYaaQqrecsyfIxHMepCKaKkBEB7hLzJI,6324 |
| | | aiohttp/web_runner.py,sha256=zyVYVzCgnopiGwnIhKlNZHtLV_IYQ9aC-Vm43j_HRoA,12185 |
| | | aiohttp/web_server.py,sha256=RZSWt_Mj-Lu89bFYsr_T3rjxW2VNN7PHNJ2mvv2qELs,2972 |
| | | aiohttp/web_urldispatcher.py,sha256=4FiNFUWU_jITYl_DnObptuF5c0ShXAEiWyLVmE-GtN0,45595 |
| | | aiohttp/web_ws.py,sha256=VXHGDtfy_jrBByLvuhnL-A_PmpcoT_ZLyYdj_EcL3Hw,23370 |
| | | aiohttp/worker.py,sha256=N_9iyS_tR9U0pf3BRaIH2nzA1pjN1Xfi2gGmRrMhnho,8407 |
| New file |
| | |
| | | Wheel-Version: 1.0 |
| | | Generator: setuptools (80.9.0) |
| | | Root-Is-Purelib: false |
| | | Tag: cp312-cp312-win_amd64 |
| | | |
| New file |
| | |
| | | Copyright aio-libs contributors. |
| | | |
| | | Licensed under the Apache License, Version 2.0 (the "License"); |
| | | you may not use this file except in compliance with the License. |
| | | You may obtain a copy of the License at |
| | | |
| | | http://www.apache.org/licenses/LICENSE-2.0 |
| | | |
| | | Unless required by applicable law or agreed to in writing, software |
| | | distributed under the License is distributed on an "AS IS" BASIS, |
| | | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | | See the License for the specific language governing permissions and |
| | | limitations under the License. |
| New file |
| | |
| | | This software is licensed under the MIT License. |
| | | |
| | | Copyright Fedor Indutny, 2018. |
| | | |
| | | Permission is hereby granted, free of charge, to any person obtaining a |
| | | copy of this software and associated documentation files (the |
| | | "Software"), to deal in the Software without restriction, including |
| | | without limitation the rights to use, copy, modify, merge, publish, |
| | | distribute, sublicense, and/or sell copies of the Software, and to permit |
| | | persons to whom the Software is furnished to do so, subject to the |
| | | following conditions: |
| | | |
| | | The above copyright notice and this permission notice shall be included |
| | | in all copies or substantial portions of the Software. |
| | | |
| | | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
| | | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
| | | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN |
| | | NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, |
| | | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR |
| | | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE |
| | | USE OR OTHER DEALINGS IN THE SOFTWARE. |
| New file |
| | |
| | | 18fd18f4da996101a426d4bcd570f353bd1eeeb44c6f7e1347bc86326c79ff3b *D:/a/aiohttp/aiohttp/aiohttp/_cparser.pxd |
| New file |
| | |
| | | 0455129b185e981b5b96ac738f31f7c74dc57f1696953cae0083b3f18679fe73 *D:/a/aiohttp/aiohttp/aiohttp/_find_header.pxd |
| New file |
| | |
| | | f7e8f2605f7ee22ed3a0c5749af56043faea35af0a1897e1415634186ad9b868 *D:/a/aiohttp/aiohttp/aiohttp/_http_parser.pyx |
| New file |
| | |
| | | 59674e7f5f503ea49c06489f0e12729ea3cf3809b007db0a2403b42a4a2be2d1 *D:/a/aiohttp/aiohttp/aiohttp/_http_writer.pyx |
| New file |
| | |
| | | ee1b6686067213d1ea59b3e9c47534afb90021d4f692939741ad4069d0e1d96f *D:/a/aiohttp/aiohttp/aiohttp/hdrs.py |
| New file |
| | |
| | | __version__ = "3.13.3" |
| | | |
| | | from typing import TYPE_CHECKING, Tuple |
| | | |
| | | from . import hdrs as hdrs |
| | | from .client import ( |
| | | BaseConnector, |
| | | ClientConnectionError, |
| | | ClientConnectionResetError, |
| | | ClientConnectorCertificateError, |
| | | ClientConnectorDNSError, |
| | | ClientConnectorError, |
| | | ClientConnectorSSLError, |
| | | ClientError, |
| | | ClientHttpProxyError, |
| | | ClientOSError, |
| | | ClientPayloadError, |
| | | ClientProxyConnectionError, |
| | | ClientRequest, |
| | | ClientResponse, |
| | | ClientResponseError, |
| | | ClientSession, |
| | | ClientSSLError, |
| | | ClientTimeout, |
| | | ClientWebSocketResponse, |
| | | ClientWSTimeout, |
| | | ConnectionTimeoutError, |
| | | ContentTypeError, |
| | | Fingerprint, |
| | | InvalidURL, |
| | | InvalidUrlClientError, |
| | | InvalidUrlRedirectClientError, |
| | | NamedPipeConnector, |
| | | NonHttpUrlClientError, |
| | | NonHttpUrlRedirectClientError, |
| | | RedirectClientError, |
| | | RequestInfo, |
| | | ServerConnectionError, |
| | | ServerDisconnectedError, |
| | | ServerFingerprintMismatch, |
| | | ServerTimeoutError, |
| | | SocketTimeoutError, |
| | | TCPConnector, |
| | | TooManyRedirects, |
| | | UnixConnector, |
| | | WSMessageTypeError, |
| | | WSServerHandshakeError, |
| | | request, |
| | | ) |
| | | from .client_middleware_digest_auth import DigestAuthMiddleware |
| | | from .client_middlewares import ClientHandlerType, ClientMiddlewareType |
| | | from .compression_utils import set_zlib_backend |
| | | from .connector import ( |
| | | AddrInfoType as AddrInfoType, |
| | | SocketFactoryType as SocketFactoryType, |
| | | ) |
| | | from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar |
| | | from .formdata import FormData as FormData |
| | | from .helpers import BasicAuth, ChainMapProxy, ETag |
| | | from .http import ( |
| | | HttpVersion as HttpVersion, |
| | | HttpVersion10 as HttpVersion10, |
| | | HttpVersion11 as HttpVersion11, |
| | | WebSocketError as WebSocketError, |
| | | WSCloseCode as WSCloseCode, |
| | | WSMessage as WSMessage, |
| | | WSMsgType as WSMsgType, |
| | | ) |
| | | from .multipart import ( |
| | | BadContentDispositionHeader as BadContentDispositionHeader, |
| | | BadContentDispositionParam as BadContentDispositionParam, |
| | | BodyPartReader as BodyPartReader, |
| | | MultipartReader as MultipartReader, |
| | | MultipartWriter as MultipartWriter, |
| | | content_disposition_filename as content_disposition_filename, |
| | | parse_content_disposition as parse_content_disposition, |
| | | ) |
| | | from .payload import ( |
| | | PAYLOAD_REGISTRY as PAYLOAD_REGISTRY, |
| | | AsyncIterablePayload as AsyncIterablePayload, |
| | | BufferedReaderPayload as BufferedReaderPayload, |
| | | BytesIOPayload as BytesIOPayload, |
| | | BytesPayload as BytesPayload, |
| | | IOBasePayload as IOBasePayload, |
| | | JsonPayload as JsonPayload, |
| | | Payload as Payload, |
| | | StringIOPayload as StringIOPayload, |
| | | StringPayload as StringPayload, |
| | | TextIOPayload as TextIOPayload, |
| | | get_payload as get_payload, |
| | | payload_type as payload_type, |
| | | ) |
| | | from .payload_streamer import streamer as streamer |
| | | from .resolver import ( |
| | | AsyncResolver as AsyncResolver, |
| | | DefaultResolver as DefaultResolver, |
| | | ThreadedResolver as ThreadedResolver, |
| | | ) |
| | | from .streams import ( |
| | | EMPTY_PAYLOAD as EMPTY_PAYLOAD, |
| | | DataQueue as DataQueue, |
| | | EofStream as EofStream, |
| | | FlowControlDataQueue as FlowControlDataQueue, |
| | | StreamReader as StreamReader, |
| | | ) |
| | | from .tracing import ( |
| | | TraceConfig as TraceConfig, |
| | | TraceConnectionCreateEndParams as TraceConnectionCreateEndParams, |
| | | TraceConnectionCreateStartParams as TraceConnectionCreateStartParams, |
| | | TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams, |
| | | TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams, |
| | | TraceConnectionReuseconnParams as TraceConnectionReuseconnParams, |
| | | TraceDnsCacheHitParams as TraceDnsCacheHitParams, |
| | | TraceDnsCacheMissParams as TraceDnsCacheMissParams, |
| | | TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams, |
| | | TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams, |
| | | TraceRequestChunkSentParams as TraceRequestChunkSentParams, |
| | | TraceRequestEndParams as TraceRequestEndParams, |
| | | TraceRequestExceptionParams as TraceRequestExceptionParams, |
| | | TraceRequestHeadersSentParams as TraceRequestHeadersSentParams, |
| | | TraceRequestRedirectParams as TraceRequestRedirectParams, |
| | | TraceRequestStartParams as TraceRequestStartParams, |
| | | TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams, |
| | | ) |
| | | |
| | | if TYPE_CHECKING: |
| | | # At runtime these are lazy-loaded at the bottom of the file. |
| | | from .worker import ( |
| | | GunicornUVLoopWebWorker as GunicornUVLoopWebWorker, |
| | | GunicornWebWorker as GunicornWebWorker, |
| | | ) |
| | | |
| | | __all__: Tuple[str, ...] = ( |
| | | "hdrs", |
| | | # client |
| | | "AddrInfoType", |
| | | "BaseConnector", |
| | | "ClientConnectionError", |
| | | "ClientConnectionResetError", |
| | | "ClientConnectorCertificateError", |
| | | "ClientConnectorDNSError", |
| | | "ClientConnectorError", |
| | | "ClientConnectorSSLError", |
| | | "ClientError", |
| | | "ClientHttpProxyError", |
| | | "ClientOSError", |
| | | "ClientPayloadError", |
| | | "ClientProxyConnectionError", |
| | | "ClientResponse", |
| | | "ClientRequest", |
| | | "ClientResponseError", |
| | | "ClientSSLError", |
| | | "ClientSession", |
| | | "ClientTimeout", |
| | | "ClientWebSocketResponse", |
| | | "ClientWSTimeout", |
| | | "ConnectionTimeoutError", |
| | | "ContentTypeError", |
| | | "Fingerprint", |
| | | "FlowControlDataQueue", |
| | | "InvalidURL", |
| | | "InvalidUrlClientError", |
| | | "InvalidUrlRedirectClientError", |
| | | "NonHttpUrlClientError", |
| | | "NonHttpUrlRedirectClientError", |
| | | "RedirectClientError", |
| | | "RequestInfo", |
| | | "ServerConnectionError", |
| | | "ServerDisconnectedError", |
| | | "ServerFingerprintMismatch", |
| | | "ServerTimeoutError", |
| | | "SocketFactoryType", |
| | | "SocketTimeoutError", |
| | | "TCPConnector", |
| | | "TooManyRedirects", |
| | | "UnixConnector", |
| | | "NamedPipeConnector", |
| | | "WSServerHandshakeError", |
| | | "request", |
| | | # client_middleware |
| | | "ClientMiddlewareType", |
| | | "ClientHandlerType", |
| | | # cookiejar |
| | | "CookieJar", |
| | | "DummyCookieJar", |
| | | # formdata |
| | | "FormData", |
| | | # helpers |
| | | "BasicAuth", |
| | | "ChainMapProxy", |
| | | "DigestAuthMiddleware", |
| | | "ETag", |
| | | "set_zlib_backend", |
| | | # http |
| | | "HttpVersion", |
| | | "HttpVersion10", |
| | | "HttpVersion11", |
| | | "WSMsgType", |
| | | "WSCloseCode", |
| | | "WSMessage", |
| | | "WebSocketError", |
| | | # multipart |
| | | "BadContentDispositionHeader", |
| | | "BadContentDispositionParam", |
| | | "BodyPartReader", |
| | | "MultipartReader", |
| | | "MultipartWriter", |
| | | "content_disposition_filename", |
| | | "parse_content_disposition", |
| | | # payload |
| | | "AsyncIterablePayload", |
| | | "BufferedReaderPayload", |
| | | "BytesIOPayload", |
| | | "BytesPayload", |
| | | "IOBasePayload", |
| | | "JsonPayload", |
| | | "PAYLOAD_REGISTRY", |
| | | "Payload", |
| | | "StringIOPayload", |
| | | "StringPayload", |
| | | "TextIOPayload", |
| | | "get_payload", |
| | | "payload_type", |
| | | # payload_streamer |
| | | "streamer", |
| | | # resolver |
| | | "AsyncResolver", |
| | | "DefaultResolver", |
| | | "ThreadedResolver", |
| | | # streams |
| | | "DataQueue", |
| | | "EMPTY_PAYLOAD", |
| | | "EofStream", |
| | | "StreamReader", |
| | | # tracing |
| | | "TraceConfig", |
| | | "TraceConnectionCreateEndParams", |
| | | "TraceConnectionCreateStartParams", |
| | | "TraceConnectionQueuedEndParams", |
| | | "TraceConnectionQueuedStartParams", |
| | | "TraceConnectionReuseconnParams", |
| | | "TraceDnsCacheHitParams", |
| | | "TraceDnsCacheMissParams", |
| | | "TraceDnsResolveHostEndParams", |
| | | "TraceDnsResolveHostStartParams", |
| | | "TraceRequestChunkSentParams", |
| | | "TraceRequestEndParams", |
| | | "TraceRequestExceptionParams", |
| | | "TraceRequestHeadersSentParams", |
| | | "TraceRequestRedirectParams", |
| | | "TraceRequestStartParams", |
| | | "TraceResponseChunkReceivedParams", |
| | | # workers (imported lazily with __getattr__) |
| | | "GunicornUVLoopWebWorker", |
| | | "GunicornWebWorker", |
| | | "WSMessageTypeError", |
| | | ) |
| | | |
| | | |
| | | def __dir__() -> Tuple[str, ...]: |
| | | return __all__ + ("__doc__",) |
| | | |
| | | |
| | | def __getattr__(name: str) -> object: |
| | | global GunicornUVLoopWebWorker, GunicornWebWorker |
| | | |
| | | # Importing gunicorn takes a long time (>100ms), so only import if actually needed. |
| | | if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"): |
| | | try: |
| | | from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw |
| | | except ImportError: |
| | | return None |
| | | |
| | | GunicornUVLoopWebWorker = guv # type: ignore[misc] |
| | | GunicornWebWorker = gw # type: ignore[misc] |
| | | return guv if name == "GunicornUVLoopWebWorker" else gw |
| | | |
| | | raise AttributeError(f"module {__name__} has no attribute {name}") |
| New file |
| | |
| | | """ |
| | | Internal cookie handling helpers. |
| | | |
| | | This module contains internal utilities for cookie parsing and manipulation. |
| | | These are not part of the public API and may change without notice. |
| | | """ |
| | | |
| | | import re |
| | | from http.cookies import Morsel |
| | | from typing import List, Optional, Sequence, Tuple, cast |
| | | |
| | | from .log import internal_logger |
| | | |
| | | __all__ = ( |
| | | "parse_set_cookie_headers", |
| | | "parse_cookie_header", |
| | | "preserve_morsel_with_coded_value", |
| | | ) |
| | | |
| | | # Cookie parsing constants |
| | | # Allow more characters in cookie names to handle real-world cookies |
| | | # that don't strictly follow RFC standards (fixes #2683) |
| | | # RFC 6265 defines cookie-name token as per RFC 2616 Section 2.2, |
| | | # but many servers send cookies with characters like {} [] () etc. |
| | | # This makes the cookie parser more tolerant of real-world cookies |
| | | # while still providing some validation to catch obviously malformed names. |
| | | _COOKIE_NAME_RE = re.compile(r"^[!#$%&\'()*+\-./0-9:<=>?@A-Z\[\]^_`a-z{|}~]+$") |
| | | _COOKIE_KNOWN_ATTRS = frozenset( # AKA Morsel._reserved |
| | | ( |
| | | "path", |
| | | "domain", |
| | | "max-age", |
| | | "expires", |
| | | "secure", |
| | | "httponly", |
| | | "samesite", |
| | | "partitioned", |
| | | "version", |
| | | "comment", |
| | | ) |
| | | ) |
| | | _COOKIE_BOOL_ATTRS = frozenset( # AKA Morsel._flags |
| | | ("secure", "httponly", "partitioned") |
| | | ) |
| | | |
| | | # SimpleCookie's pattern for parsing cookies with relaxed validation |
| | | # Based on http.cookies pattern but extended to allow more characters in cookie names |
| | | # to handle real-world cookies (fixes #2683) |
| | | _COOKIE_PATTERN = re.compile( |
| | | r""" |
| | | \s* # Optional whitespace at start of cookie |
| | | (?P<key> # Start of group 'key' |
| | | # aiohttp has extended to include [] for compatibility with real-world cookies |
| | | [\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\[\]]+ # Any word of at least one letter |
| | | ) # End of group 'key' |
| | | ( # Optional group: there may not be a value. |
| | | \s*=\s* # Equal Sign |
| | | (?P<val> # Start of group 'val' |
| | | "(?:[^\\"]|\\.)*" # Any double-quoted string (properly closed) |
| | | | # or |
| | | "[^";]* # Unmatched opening quote (differs from SimpleCookie - issue #7993) |
| | | | # or |
| | | # Special case for "expires" attr - RFC 822, RFC 850, RFC 1036, RFC 1123 |
| | | (\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day (with comma) |
| | | [\w\d\s-]{9,11}\s[\d:]{8}\s # Date and time in specific format |
| | | (GMT|[+-]\d{4}) # Timezone: GMT or RFC 2822 offset like -0000, +0100 |
| | | # NOTE: RFC 2822 timezone support is an aiohttp extension |
| | | # for issue #4493 - SimpleCookie does NOT support this |
| | | | # or |
| | | # ANSI C asctime() format: "Wed Jun 9 10:18:14 2021" |
| | | # NOTE: This is an aiohttp extension for issue #4327 - SimpleCookie does NOT support this format |
| | | \w{3}\s+\w{3}\s+[\s\d]\d\s+\d{2}:\d{2}:\d{2}\s+\d{4} |
| | | | # or |
| | | [\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]* # Any word or empty string |
| | | ) # End of group 'val' |
| | | )? # End of optional value group |
| | | \s* # Any number of spaces. |
| | | (\s+|;|$) # Ending either at space, semicolon, or EOS. |
| | | """, |
| | | re.VERBOSE | re.ASCII, |
| | | ) |
| | | |
| | | |
| | | def preserve_morsel_with_coded_value(cookie: Morsel[str]) -> Morsel[str]: |
| | | """ |
| | | Preserve a Morsel's coded_value exactly as received from the server. |
| | | |
| | | This function ensures that cookie encoding is preserved exactly as sent by |
| | | the server, which is critical for compatibility with old servers that have |
| | | strict requirements about cookie formats. |
| | | |
| | | This addresses the issue described in https://github.com/aio-libs/aiohttp/pull/1453 |
| | | where Python's SimpleCookie would re-encode cookies, breaking authentication |
| | | with certain servers. |
| | | |
| | | Args: |
| | | cookie: A Morsel object from SimpleCookie |
| | | |
| | | Returns: |
| | | A Morsel object with preserved coded_value |
| | | |
| | | """ |
| | | mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) |
| | | # We use __setstate__ instead of the public set() API because it allows us to |
| | | # bypass validation and set already validated state. This is more stable than |
| | | # setting protected attributes directly and unlikely to change since it would |
| | | # break pickling. |
| | | mrsl_val.__setstate__( # type: ignore[attr-defined] |
| | | {"key": cookie.key, "value": cookie.value, "coded_value": cookie.coded_value} |
| | | ) |
| | | return mrsl_val |
| | | |
| | | |
| | | _unquote_sub = re.compile(r"\\(?:([0-3][0-7][0-7])|(.))").sub |
| | | |
| | | |
| | | def _unquote_replace(m: re.Match[str]) -> str: |
| | | """ |
| | | Replace function for _unquote_sub regex substitution. |
| | | |
| | | Handles escaped characters in cookie values: |
| | | - Octal sequences are converted to their character representation |
| | | - Other escaped characters are unescaped by removing the backslash |
| | | """ |
| | | if m[1]: |
| | | return chr(int(m[1], 8)) |
| | | return m[2] |
| | | |
| | | |
| | | def _unquote(value: str) -> str: |
| | | """ |
| | | Unquote a cookie value. |
| | | |
| | | Vendored from http.cookies._unquote to ensure compatibility. |
| | | |
| | | Note: The original implementation checked for None, but we've removed |
| | | that check since all callers already ensure the value is not None. |
| | | """ |
| | | # If there aren't any doublequotes, |
| | | # then there can't be any special characters. See RFC 2109. |
| | | if len(value) < 2: |
| | | return value |
| | | if value[0] != '"' or value[-1] != '"': |
| | | return value |
| | | |
| | | # We have to assume that we must decode this string. |
| | | # Down to work. |
| | | |
| | | # Remove the "s |
| | | value = value[1:-1] |
| | | |
| | | # Check for special sequences. Examples: |
| | | # \012 --> \n |
| | | # \" --> " |
| | | # |
| | | return _unquote_sub(_unquote_replace, value) |
| | | |
| | | |
| | | def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]: |
| | | """ |
| | | Parse a Cookie header according to RFC 6265 Section 5.4. |
| | | |
| | | Cookie headers contain only name-value pairs separated by semicolons. |
| | | There are no attributes in Cookie headers - even names that match |
| | | attribute names (like 'path' or 'secure') should be treated as cookies. |
| | | |
| | | This parser uses the same regex-based approach as parse_set_cookie_headers |
| | | to properly handle quoted values that may contain semicolons. When the |
| | | regex fails to match a malformed cookie, it falls back to simple parsing |
| | | to ensure subsequent cookies are not lost |
| | | https://github.com/aio-libs/aiohttp/issues/11632 |
| | | |
| | | Args: |
| | | header: The Cookie header value to parse |
| | | |
| | | Returns: |
| | | List of (name, Morsel) tuples for compatibility with SimpleCookie.update() |
| | | """ |
| | | if not header: |
| | | return [] |
| | | |
| | | cookies: List[Tuple[str, Morsel[str]]] = [] |
| | | morsel: Morsel[str] |
| | | i = 0 |
| | | n = len(header) |
| | | |
| | | invalid_names = [] |
| | | while i < n: |
| | | # Use the same pattern as parse_set_cookie_headers to find cookies |
| | | match = _COOKIE_PATTERN.match(header, i) |
| | | if not match: |
| | | # Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632 |
| | | # Find next semicolon to skip or attempt simple key=value parsing |
| | | next_semi = header.find(";", i) |
| | | eq_pos = header.find("=", i) |
| | | |
| | | # Try to extract key=value if '=' comes before ';' |
| | | if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi): |
| | | end_pos = next_semi if next_semi != -1 else n |
| | | key = header[i:eq_pos].strip() |
| | | value = header[eq_pos + 1 : end_pos].strip() |
| | | |
| | | # Validate the name (same as regex path) |
| | | if not _COOKIE_NAME_RE.match(key): |
| | | invalid_names.append(key) |
| | | else: |
| | | morsel = Morsel() |
| | | morsel.__setstate__( # type: ignore[attr-defined] |
| | | {"key": key, "value": _unquote(value), "coded_value": value} |
| | | ) |
| | | cookies.append((key, morsel)) |
| | | |
| | | # Move to next cookie or end |
| | | i = next_semi + 1 if next_semi != -1 else n |
| | | continue |
| | | |
| | | key = match.group("key") |
| | | value = match.group("val") or "" |
| | | i = match.end(0) |
| | | |
| | | # Validate the name |
| | | if not key or not _COOKIE_NAME_RE.match(key): |
| | | invalid_names.append(key) |
| | | continue |
| | | |
| | | # Create new morsel |
| | | morsel = Morsel() |
| | | # Preserve the original value as coded_value (with quotes if present) |
| | | # We use __setstate__ instead of the public set() API because it allows us to |
| | | # bypass validation and set already validated state. This is more stable than |
| | | # setting protected attributes directly and unlikely to change since it would |
| | | # break pickling. |
| | | morsel.__setstate__( # type: ignore[attr-defined] |
| | | {"key": key, "value": _unquote(value), "coded_value": value} |
| | | ) |
| | | |
| | | cookies.append((key, morsel)) |
| | | |
| | | if invalid_names: |
| | | internal_logger.debug( |
| | | "Cannot load cookie. Illegal cookie names: %r", invalid_names |
| | | ) |
| | | |
| | | return cookies |
| | | |
| | | |
| | | def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[str]]]: |
| | | """ |
| | | Parse cookie headers using a vendored version of SimpleCookie parsing. |
| | | |
| | | This implementation is based on SimpleCookie.__parse_string to ensure |
| | | compatibility with how SimpleCookie parses cookies, including handling |
| | | of malformed cookies with missing semicolons. |
| | | |
| | | This function is used for both Cookie and Set-Cookie headers in order to be |
| | | forgiving. Ideally we would have followed RFC 6265 Section 5.2 (for Cookie |
| | | headers) and RFC 6265 Section 4.2.1 (for Set-Cookie headers), but the |
| | | real world data makes it impossible since we need to be a bit more forgiving. |
| | | |
| | | NOTE: This implementation differs from SimpleCookie in handling unmatched quotes. |
| | | SimpleCookie will stop parsing when it encounters a cookie value with an unmatched |
| | | quote (e.g., 'cookie="value'), causing subsequent cookies to be silently dropped. |
| | | This implementation handles unmatched quotes more gracefully to prevent cookie loss. |
| | | See https://github.com/aio-libs/aiohttp/issues/7993 |
| | | """ |
| | | parsed_cookies: List[Tuple[str, Morsel[str]]] = [] |
| | | |
| | | for header in headers: |
| | | if not header: |
| | | continue |
| | | |
| | | # Parse cookie string using SimpleCookie's algorithm |
| | | i = 0 |
| | | n = len(header) |
| | | current_morsel: Optional[Morsel[str]] = None |
| | | morsel_seen = False |
| | | |
| | | while 0 <= i < n: |
| | | # Start looking for a cookie |
| | | match = _COOKIE_PATTERN.match(header, i) |
| | | if not match: |
| | | # No more cookies |
| | | break |
| | | |
| | | key, value = match.group("key"), match.group("val") |
| | | i = match.end(0) |
| | | lower_key = key.lower() |
| | | |
| | | if key[0] == "$": |
| | | if not morsel_seen: |
| | | # We ignore attributes which pertain to the cookie |
| | | # mechanism as a whole, such as "$Version". |
| | | continue |
| | | # Process as attribute |
| | | if current_morsel is not None: |
| | | attr_lower_key = lower_key[1:] |
| | | if attr_lower_key in _COOKIE_KNOWN_ATTRS: |
| | | current_morsel[attr_lower_key] = value or "" |
| | | elif lower_key in _COOKIE_KNOWN_ATTRS: |
| | | if not morsel_seen: |
| | | # Invalid cookie string - attribute before cookie |
| | | break |
| | | if lower_key in _COOKIE_BOOL_ATTRS: |
| | | # Boolean attribute with any value should be True |
| | | if current_morsel is not None and current_morsel.isReservedKey(key): |
| | | current_morsel[lower_key] = True |
| | | elif value is None: |
| | | # Invalid cookie string - non-boolean attribute without value |
| | | break |
| | | elif current_morsel is not None: |
| | | # Regular attribute with value |
| | | current_morsel[lower_key] = _unquote(value) |
| | | elif value is not None: |
| | | # This is a cookie name=value pair |
| | | # Validate the name |
| | | if key in _COOKIE_KNOWN_ATTRS or not _COOKIE_NAME_RE.match(key): |
| | | internal_logger.warning( |
| | | "Can not load cookies: Illegal cookie name %r", key |
| | | ) |
| | | current_morsel = None |
| | | else: |
| | | # Create new morsel |
| | | current_morsel = Morsel() |
| | | # Preserve the original value as coded_value (with quotes if present) |
| | | # We use __setstate__ instead of the public set() API because it allows us to |
| | | # bypass validation and set already validated state. This is more stable than |
| | | # setting protected attributes directly and unlikely to change since it would |
| | | # break pickling. |
| | | current_morsel.__setstate__( # type: ignore[attr-defined] |
| | | {"key": key, "value": _unquote(value), "coded_value": value} |
| | | ) |
| | | parsed_cookies.append((key, current_morsel)) |
| | | morsel_seen = True |
| | | else: |
| | | # Invalid cookie string - no value for non-attribute |
| | | break |
| | | |
| | | return parsed_cookies |
| New file |
| | |
| | | from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t |
| | | |
| | | |
| | | cdef extern from "llhttp.h": |
| | | |
| | | struct llhttp__internal_s: |
| | | int32_t _index |
| | | void* _span_pos0 |
| | | void* _span_cb0 |
| | | int32_t error |
| | | const char* reason |
| | | const char* error_pos |
| | | void* data |
| | | void* _current |
| | | uint64_t content_length |
| | | uint8_t type |
| | | uint8_t method |
| | | uint8_t http_major |
| | | uint8_t http_minor |
| | | uint8_t header_state |
| | | uint8_t lenient_flags |
| | | uint8_t upgrade |
| | | uint8_t finish |
| | | uint16_t flags |
| | | uint16_t status_code |
| | | void* settings |
| | | |
| | | ctypedef llhttp__internal_s llhttp__internal_t |
| | | ctypedef llhttp__internal_t llhttp_t |
| | | |
| | | ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1 |
| | | ctypedef int (*llhttp_cb)(llhttp_t*) except -1 |
| | | |
| | | struct llhttp_settings_s: |
| | | llhttp_cb on_message_begin |
| | | llhttp_data_cb on_url |
| | | llhttp_data_cb on_status |
| | | llhttp_data_cb on_header_field |
| | | llhttp_data_cb on_header_value |
| | | llhttp_cb on_headers_complete |
| | | llhttp_data_cb on_body |
| | | llhttp_cb on_message_complete |
| | | llhttp_cb on_chunk_header |
| | | llhttp_cb on_chunk_complete |
| | | |
| | | llhttp_cb on_url_complete |
| | | llhttp_cb on_status_complete |
| | | llhttp_cb on_header_field_complete |
| | | llhttp_cb on_header_value_complete |
| | | |
| | | ctypedef llhttp_settings_s llhttp_settings_t |
| | | |
| | | enum llhttp_errno: |
| | | HPE_OK, |
| | | HPE_INTERNAL, |
| | | HPE_STRICT, |
| | | HPE_LF_EXPECTED, |
| | | HPE_UNEXPECTED_CONTENT_LENGTH, |
| | | HPE_CLOSED_CONNECTION, |
| | | HPE_INVALID_METHOD, |
| | | HPE_INVALID_URL, |
| | | HPE_INVALID_CONSTANT, |
| | | HPE_INVALID_VERSION, |
| | | HPE_INVALID_HEADER_TOKEN, |
| | | HPE_INVALID_CONTENT_LENGTH, |
| | | HPE_INVALID_CHUNK_SIZE, |
| | | HPE_INVALID_STATUS, |
| | | HPE_INVALID_EOF_STATE, |
| | | HPE_INVALID_TRANSFER_ENCODING, |
| | | HPE_CB_MESSAGE_BEGIN, |
| | | HPE_CB_HEADERS_COMPLETE, |
| | | HPE_CB_MESSAGE_COMPLETE, |
| | | HPE_CB_CHUNK_HEADER, |
| | | HPE_CB_CHUNK_COMPLETE, |
| | | HPE_PAUSED, |
| | | HPE_PAUSED_UPGRADE, |
| | | HPE_USER |
| | | |
| | | ctypedef llhttp_errno llhttp_errno_t |
| | | |
| | | enum llhttp_flags: |
| | | F_CHUNKED, |
| | | F_CONTENT_LENGTH |
| | | |
| | | enum llhttp_type: |
| | | HTTP_REQUEST, |
| | | HTTP_RESPONSE, |
| | | HTTP_BOTH |
| | | |
| | | enum llhttp_method: |
| | | HTTP_DELETE, |
| | | HTTP_GET, |
| | | HTTP_HEAD, |
| | | HTTP_POST, |
| | | HTTP_PUT, |
| | | HTTP_CONNECT, |
| | | HTTP_OPTIONS, |
| | | HTTP_TRACE, |
| | | HTTP_COPY, |
| | | HTTP_LOCK, |
| | | HTTP_MKCOL, |
| | | HTTP_MOVE, |
| | | HTTP_PROPFIND, |
| | | HTTP_PROPPATCH, |
| | | HTTP_SEARCH, |
| | | HTTP_UNLOCK, |
| | | HTTP_BIND, |
| | | HTTP_REBIND, |
| | | HTTP_UNBIND, |
| | | HTTP_ACL, |
| | | HTTP_REPORT, |
| | | HTTP_MKACTIVITY, |
| | | HTTP_CHECKOUT, |
| | | HTTP_MERGE, |
| | | HTTP_MSEARCH, |
| | | HTTP_NOTIFY, |
| | | HTTP_SUBSCRIBE, |
| | | HTTP_UNSUBSCRIBE, |
| | | HTTP_PATCH, |
| | | HTTP_PURGE, |
| | | HTTP_MKCALENDAR, |
| | | HTTP_LINK, |
| | | HTTP_UNLINK, |
| | | HTTP_SOURCE, |
| | | HTTP_PRI, |
| | | HTTP_DESCRIBE, |
| | | HTTP_ANNOUNCE, |
| | | HTTP_SETUP, |
| | | HTTP_PLAY, |
| | | HTTP_PAUSE, |
| | | HTTP_TEARDOWN, |
| | | HTTP_GET_PARAMETER, |
| | | HTTP_SET_PARAMETER, |
| | | HTTP_REDIRECT, |
| | | HTTP_RECORD, |
| | | HTTP_FLUSH |
| | | |
| | | ctypedef llhttp_method llhttp_method_t; |
| | | |
| | | void llhttp_settings_init(llhttp_settings_t* settings) |
| | | void llhttp_init(llhttp_t* parser, llhttp_type type, |
| | | const llhttp_settings_t* settings) |
| | | |
| | | llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len) |
| | | |
| | | int llhttp_should_keep_alive(const llhttp_t* parser) |
| | | |
| | | void llhttp_resume_after_upgrade(llhttp_t* parser) |
| | | |
| | | llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) |
| | | const char* llhttp_get_error_reason(const llhttp_t* parser) |
| | | const char* llhttp_get_error_pos(const llhttp_t* parser) |
| | | |
| | | const char* llhttp_method_name(llhttp_method_t method) |
| | | |
| | | void llhttp_set_lenient_headers(llhttp_t* parser, int enabled) |
| | | void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled) |
| | | void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled) |
| New file |
| | |
| | | cdef extern from "_find_header.h": |
| | | int find_header(char *, int) |
| New file |
| | |
| | | # The file is autogenerated from aiohttp/hdrs.py |
| | | # Run ./tools/gen.py to update it after the origin changing. |
| | | |
| | | from . import hdrs |
| | | cdef tuple headers = ( |
| | | hdrs.ACCEPT, |
| | | hdrs.ACCEPT_CHARSET, |
| | | hdrs.ACCEPT_ENCODING, |
| | | hdrs.ACCEPT_LANGUAGE, |
| | | hdrs.ACCEPT_RANGES, |
| | | hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, |
| | | hdrs.ACCESS_CONTROL_ALLOW_HEADERS, |
| | | hdrs.ACCESS_CONTROL_ALLOW_METHODS, |
| | | hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, |
| | | hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, |
| | | hdrs.ACCESS_CONTROL_MAX_AGE, |
| | | hdrs.ACCESS_CONTROL_REQUEST_HEADERS, |
| | | hdrs.ACCESS_CONTROL_REQUEST_METHOD, |
| | | hdrs.AGE, |
| | | hdrs.ALLOW, |
| | | hdrs.AUTHORIZATION, |
| | | hdrs.CACHE_CONTROL, |
| | | hdrs.CONNECTION, |
| | | hdrs.CONTENT_DISPOSITION, |
| | | hdrs.CONTENT_ENCODING, |
| | | hdrs.CONTENT_LANGUAGE, |
| | | hdrs.CONTENT_LENGTH, |
| | | hdrs.CONTENT_LOCATION, |
| | | hdrs.CONTENT_MD5, |
| | | hdrs.CONTENT_RANGE, |
| | | hdrs.CONTENT_TRANSFER_ENCODING, |
| | | hdrs.CONTENT_TYPE, |
| | | hdrs.COOKIE, |
| | | hdrs.DATE, |
| | | hdrs.DESTINATION, |
| | | hdrs.DIGEST, |
| | | hdrs.ETAG, |
| | | hdrs.EXPECT, |
| | | hdrs.EXPIRES, |
| | | hdrs.FORWARDED, |
| | | hdrs.FROM, |
| | | hdrs.HOST, |
| | | hdrs.IF_MATCH, |
| | | hdrs.IF_MODIFIED_SINCE, |
| | | hdrs.IF_NONE_MATCH, |
| | | hdrs.IF_RANGE, |
| | | hdrs.IF_UNMODIFIED_SINCE, |
| | | hdrs.KEEP_ALIVE, |
| | | hdrs.LAST_EVENT_ID, |
| | | hdrs.LAST_MODIFIED, |
| | | hdrs.LINK, |
| | | hdrs.LOCATION, |
| | | hdrs.MAX_FORWARDS, |
| | | hdrs.ORIGIN, |
| | | hdrs.PRAGMA, |
| | | hdrs.PROXY_AUTHENTICATE, |
| | | hdrs.PROXY_AUTHORIZATION, |
| | | hdrs.RANGE, |
| | | hdrs.REFERER, |
| | | hdrs.RETRY_AFTER, |
| | | hdrs.SEC_WEBSOCKET_ACCEPT, |
| | | hdrs.SEC_WEBSOCKET_EXTENSIONS, |
| | | hdrs.SEC_WEBSOCKET_KEY, |
| | | hdrs.SEC_WEBSOCKET_KEY1, |
| | | hdrs.SEC_WEBSOCKET_PROTOCOL, |
| | | hdrs.SEC_WEBSOCKET_VERSION, |
| | | hdrs.SERVER, |
| | | hdrs.SET_COOKIE, |
| | | hdrs.TE, |
| | | hdrs.TRAILER, |
| | | hdrs.TRANSFER_ENCODING, |
| | | hdrs.URI, |
| | | hdrs.UPGRADE, |
| | | hdrs.USER_AGENT, |
| | | hdrs.VARY, |
| | | hdrs.VIA, |
| | | hdrs.WWW_AUTHENTICATE, |
| | | hdrs.WANT_DIGEST, |
| | | hdrs.WARNING, |
| | | hdrs.X_FORWARDED_FOR, |
| | | hdrs.X_FORWARDED_HOST, |
| | | hdrs.X_FORWARDED_PROTO, |
| | | ) |
| New file |
| | |
| | | # Based on https://github.com/MagicStack/httptools |
| | | # |
| | | |
| | | from cpython cimport ( |
| | | Py_buffer, |
| | | PyBUF_SIMPLE, |
| | | PyBuffer_Release, |
| | | PyBytes_AsString, |
| | | PyBytes_AsStringAndSize, |
| | | PyObject_GetBuffer, |
| | | ) |
| | | from cpython.mem cimport PyMem_Free, PyMem_Malloc |
| | | from libc.limits cimport ULLONG_MAX |
| | | from libc.string cimport memcpy |
| | | |
| | | from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy |
| | | from yarl import URL as _URL |
| | | |
| | | from aiohttp import hdrs |
| | | from aiohttp.helpers import DEBUG, set_exception |
| | | |
| | | from .http_exceptions import ( |
| | | BadHttpMessage, |
| | | BadHttpMethod, |
| | | BadStatusLine, |
| | | ContentLengthError, |
| | | InvalidHeader, |
| | | InvalidURLError, |
| | | LineTooLong, |
| | | PayloadEncodingError, |
| | | TransferEncodingError, |
| | | ) |
| | | from .http_parser import DeflateBuffer as _DeflateBuffer |
| | | from .http_writer import ( |
| | | HttpVersion as _HttpVersion, |
| | | HttpVersion10 as _HttpVersion10, |
| | | HttpVersion11 as _HttpVersion11, |
| | | ) |
| | | from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader |
| | | |
| | | cimport cython |
| | | |
| | | from aiohttp cimport _cparser as cparser |
| | | |
| | | include "_headers.pxi" |
| | | |
| | | from aiohttp cimport _find_header |
| | | |
| | | ALLOWED_UPGRADES = frozenset({"websocket"}) |
| | | DEF DEFAULT_FREELIST_SIZE = 250 |
| | | |
| | | cdef extern from "Python.h": |
| | | int PyByteArray_Resize(object, Py_ssize_t) except -1 |
| | | Py_ssize_t PyByteArray_Size(object) except -1 |
| | | char* PyByteArray_AsString(object) |
| | | |
| | | __all__ = ('HttpRequestParser', 'HttpResponseParser', |
| | | 'RawRequestMessage', 'RawResponseMessage') |
| | | |
| | | cdef object URL = _URL |
| | | cdef object URL_build = URL.build |
| | | cdef object CIMultiDict = _CIMultiDict |
| | | cdef object CIMultiDictProxy = _CIMultiDictProxy |
| | | cdef object HttpVersion = _HttpVersion |
| | | cdef object HttpVersion10 = _HttpVersion10 |
| | | cdef object HttpVersion11 = _HttpVersion11 |
| | | cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1 |
| | | cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING |
| | | cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD |
| | | cdef object StreamReader = _StreamReader |
| | | cdef object DeflateBuffer = _DeflateBuffer |
| | | cdef bytes EMPTY_BYTES = b"" |
| | | |
| | | cdef inline object extend(object buf, const char* at, size_t length): |
| | | cdef Py_ssize_t s |
| | | cdef char* ptr |
| | | s = PyByteArray_Size(buf) |
| | | PyByteArray_Resize(buf, s + length) |
| | | ptr = PyByteArray_AsString(buf) |
| | | memcpy(ptr + s, at, length) |
| | | |
| | | |
| | | DEF METHODS_COUNT = 46; |
| | | |
| | | cdef list _http_method = [] |
| | | |
| | | for i in range(METHODS_COUNT): |
| | | _http_method.append( |
| | | cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii')) |
| | | |
| | | |
| | | cdef inline str http_method_str(int i): |
| | | if i < METHODS_COUNT: |
| | | return <str>_http_method[i] |
| | | else: |
| | | return "<unknown>" |
| | | |
| | | cdef inline object find_header(bytes raw_header): |
| | | cdef Py_ssize_t size |
| | | cdef char *buf |
| | | cdef int idx |
| | | PyBytes_AsStringAndSize(raw_header, &buf, &size) |
| | | idx = _find_header.find_header(buf, size) |
| | | if idx == -1: |
| | | return raw_header.decode('utf-8', 'surrogateescape') |
| | | return headers[idx] |
| | | |
| | | |
| | | @cython.freelist(DEFAULT_FREELIST_SIZE) |
| | | cdef class RawRequestMessage: |
| | | cdef readonly str method |
| | | cdef readonly str path |
| | | cdef readonly object version # HttpVersion |
| | | cdef readonly object headers # CIMultiDict |
| | | cdef readonly object raw_headers # tuple |
| | | cdef readonly object should_close |
| | | cdef readonly object compression |
| | | cdef readonly object upgrade |
| | | cdef readonly object chunked |
| | | cdef readonly object url # yarl.URL |
| | | |
| | | def __init__(self, method, path, version, headers, raw_headers, |
| | | should_close, compression, upgrade, chunked, url): |
| | | self.method = method |
| | | self.path = path |
| | | self.version = version |
| | | self.headers = headers |
| | | self.raw_headers = raw_headers |
| | | self.should_close = should_close |
| | | self.compression = compression |
| | | self.upgrade = upgrade |
| | | self.chunked = chunked |
| | | self.url = url |
| | | |
| | | def __repr__(self): |
| | | info = [] |
| | | info.append(("method", self.method)) |
| | | info.append(("path", self.path)) |
| | | info.append(("version", self.version)) |
| | | info.append(("headers", self.headers)) |
| | | info.append(("raw_headers", self.raw_headers)) |
| | | info.append(("should_close", self.should_close)) |
| | | info.append(("compression", self.compression)) |
| | | info.append(("upgrade", self.upgrade)) |
| | | info.append(("chunked", self.chunked)) |
| | | info.append(("url", self.url)) |
| | | sinfo = ', '.join(name + '=' + repr(val) for name, val in info) |
| | | return '<RawRequestMessage(' + sinfo + ')>' |
| | | |
| | | def _replace(self, **dct): |
| | | cdef RawRequestMessage ret |
| | | ret = _new_request_message(self.method, |
| | | self.path, |
| | | self.version, |
| | | self.headers, |
| | | self.raw_headers, |
| | | self.should_close, |
| | | self.compression, |
| | | self.upgrade, |
| | | self.chunked, |
| | | self.url) |
| | | if "method" in dct: |
| | | ret.method = dct["method"] |
| | | if "path" in dct: |
| | | ret.path = dct["path"] |
| | | if "version" in dct: |
| | | ret.version = dct["version"] |
| | | if "headers" in dct: |
| | | ret.headers = dct["headers"] |
| | | if "raw_headers" in dct: |
| | | ret.raw_headers = dct["raw_headers"] |
| | | if "should_close" in dct: |
| | | ret.should_close = dct["should_close"] |
| | | if "compression" in dct: |
| | | ret.compression = dct["compression"] |
| | | if "upgrade" in dct: |
| | | ret.upgrade = dct["upgrade"] |
| | | if "chunked" in dct: |
| | | ret.chunked = dct["chunked"] |
| | | if "url" in dct: |
| | | ret.url = dct["url"] |
| | | return ret |
| | | |
| | | cdef _new_request_message(str method, |
| | | str path, |
| | | object version, |
| | | object headers, |
| | | object raw_headers, |
| | | bint should_close, |
| | | object compression, |
| | | bint upgrade, |
| | | bint chunked, |
| | | object url): |
| | | cdef RawRequestMessage ret |
| | | ret = RawRequestMessage.__new__(RawRequestMessage) |
| | | ret.method = method |
| | | ret.path = path |
| | | ret.version = version |
| | | ret.headers = headers |
| | | ret.raw_headers = raw_headers |
| | | ret.should_close = should_close |
| | | ret.compression = compression |
| | | ret.upgrade = upgrade |
| | | ret.chunked = chunked |
| | | ret.url = url |
| | | return ret |
| | | |
| | | |
| | | @cython.freelist(DEFAULT_FREELIST_SIZE) |
| | | cdef class RawResponseMessage: |
| | | cdef readonly object version # HttpVersion |
| | | cdef readonly int code |
| | | cdef readonly str reason |
| | | cdef readonly object headers # CIMultiDict |
| | | cdef readonly object raw_headers # tuple |
| | | cdef readonly object should_close |
| | | cdef readonly object compression |
| | | cdef readonly object upgrade |
| | | cdef readonly object chunked |
| | | |
| | | def __init__(self, version, code, reason, headers, raw_headers, |
| | | should_close, compression, upgrade, chunked): |
| | | self.version = version |
| | | self.code = code |
| | | self.reason = reason |
| | | self.headers = headers |
| | | self.raw_headers = raw_headers |
| | | self.should_close = should_close |
| | | self.compression = compression |
| | | self.upgrade = upgrade |
| | | self.chunked = chunked |
| | | |
| | | def __repr__(self): |
| | | info = [] |
| | | info.append(("version", self.version)) |
| | | info.append(("code", self.code)) |
| | | info.append(("reason", self.reason)) |
| | | info.append(("headers", self.headers)) |
| | | info.append(("raw_headers", self.raw_headers)) |
| | | info.append(("should_close", self.should_close)) |
| | | info.append(("compression", self.compression)) |
| | | info.append(("upgrade", self.upgrade)) |
| | | info.append(("chunked", self.chunked)) |
| | | sinfo = ', '.join(name + '=' + repr(val) for name, val in info) |
| | | return '<RawResponseMessage(' + sinfo + ')>' |
| | | |
| | | |
| | | cdef _new_response_message(object version, |
| | | int code, |
| | | str reason, |
| | | object headers, |
| | | object raw_headers, |
| | | bint should_close, |
| | | object compression, |
| | | bint upgrade, |
| | | bint chunked): |
| | | cdef RawResponseMessage ret |
| | | ret = RawResponseMessage.__new__(RawResponseMessage) |
| | | ret.version = version |
| | | ret.code = code |
| | | ret.reason = reason |
| | | ret.headers = headers |
| | | ret.raw_headers = raw_headers |
| | | ret.should_close = should_close |
| | | ret.compression = compression |
| | | ret.upgrade = upgrade |
| | | ret.chunked = chunked |
| | | return ret |
| | | |
| | | |
| | | @cython.internal |
| | | cdef class HttpParser: |
| | | |
| | | cdef: |
| | | cparser.llhttp_t* _cparser |
| | | cparser.llhttp_settings_t* _csettings |
| | | |
| | | bytes _raw_name |
| | | object _name |
| | | bytes _raw_value |
| | | bint _has_value |
| | | |
| | | object _protocol |
| | | object _loop |
| | | object _timer |
| | | |
| | | size_t _max_line_size |
| | | size_t _max_field_size |
| | | size_t _max_headers |
| | | bint _response_with_body |
| | | bint _read_until_eof |
| | | |
| | | bint _started |
| | | object _url |
| | | bytearray _buf |
| | | str _path |
| | | str _reason |
| | | list _headers |
| | | list _raw_headers |
| | | bint _upgraded |
| | | list _messages |
| | | object _payload |
| | | bint _payload_error |
| | | object _payload_exception |
| | | object _last_error |
| | | bint _auto_decompress |
| | | int _limit |
| | | |
| | | str _content_encoding |
| | | |
| | | Py_buffer py_buf |
| | | |
| | | def __cinit__(self): |
| | | self._cparser = <cparser.llhttp_t*> \ |
| | | PyMem_Malloc(sizeof(cparser.llhttp_t)) |
| | | if self._cparser is NULL: |
| | | raise MemoryError() |
| | | |
| | | self._csettings = <cparser.llhttp_settings_t*> \ |
| | | PyMem_Malloc(sizeof(cparser.llhttp_settings_t)) |
| | | if self._csettings is NULL: |
| | | raise MemoryError() |
| | | |
| | | def __dealloc__(self): |
| | | PyMem_Free(self._cparser) |
| | | PyMem_Free(self._csettings) |
| | | |
| | | cdef _init( |
| | | self, cparser.llhttp_type mode, |
| | | object protocol, object loop, int limit, |
| | | object timer=None, |
| | | size_t max_line_size=8190, size_t max_headers=32768, |
| | | size_t max_field_size=8190, payload_exception=None, |
| | | bint response_with_body=True, bint read_until_eof=False, |
| | | bint auto_decompress=True, |
| | | ): |
| | | cparser.llhttp_settings_init(self._csettings) |
| | | cparser.llhttp_init(self._cparser, mode, self._csettings) |
| | | self._cparser.data = <void*>self |
| | | self._cparser.content_length = 0 |
| | | |
| | | self._protocol = protocol |
| | | self._loop = loop |
| | | self._timer = timer |
| | | |
| | | self._buf = bytearray() |
| | | self._payload = None |
| | | self._payload_error = 0 |
| | | self._payload_exception = payload_exception |
| | | self._messages = [] |
| | | |
| | | self._raw_name = EMPTY_BYTES |
| | | self._raw_value = EMPTY_BYTES |
| | | self._has_value = False |
| | | |
| | | self._max_line_size = max_line_size |
| | | self._max_headers = max_headers |
| | | self._max_field_size = max_field_size |
| | | self._response_with_body = response_with_body |
| | | self._read_until_eof = read_until_eof |
| | | self._upgraded = False |
| | | self._auto_decompress = auto_decompress |
| | | self._content_encoding = None |
| | | |
| | | self._csettings.on_url = cb_on_url |
| | | self._csettings.on_status = cb_on_status |
| | | self._csettings.on_header_field = cb_on_header_field |
| | | self._csettings.on_header_value = cb_on_header_value |
| | | self._csettings.on_headers_complete = cb_on_headers_complete |
| | | self._csettings.on_body = cb_on_body |
| | | self._csettings.on_message_begin = cb_on_message_begin |
| | | self._csettings.on_message_complete = cb_on_message_complete |
| | | self._csettings.on_chunk_header = cb_on_chunk_header |
| | | self._csettings.on_chunk_complete = cb_on_chunk_complete |
| | | |
| | | self._last_error = None |
| | | self._limit = limit |
| | | |
| | | cdef _process_header(self): |
| | | cdef str value |
| | | if self._raw_name is not EMPTY_BYTES: |
| | | name = find_header(self._raw_name) |
| | | value = self._raw_value.decode('utf-8', 'surrogateescape') |
| | | |
| | | self._headers.append((name, value)) |
| | | |
| | | if name is CONTENT_ENCODING: |
| | | self._content_encoding = value |
| | | |
| | | self._has_value = False |
| | | self._raw_headers.append((self._raw_name, self._raw_value)) |
| | | self._raw_name = EMPTY_BYTES |
| | | self._raw_value = EMPTY_BYTES |
| | | |
| | | cdef _on_header_field(self, char* at, size_t length): |
| | | if self._has_value: |
| | | self._process_header() |
| | | |
| | | if self._raw_name is EMPTY_BYTES: |
| | | self._raw_name = at[:length] |
| | | else: |
| | | self._raw_name += at[:length] |
| | | |
| | | cdef _on_header_value(self, char* at, size_t length): |
| | | if self._raw_value is EMPTY_BYTES: |
| | | self._raw_value = at[:length] |
| | | else: |
| | | self._raw_value += at[:length] |
| | | self._has_value = True |
| | | |
| | | cdef _on_headers_complete(self): |
| | | self._process_header() |
| | | |
| | | should_close = not cparser.llhttp_should_keep_alive(self._cparser) |
| | | upgrade = self._cparser.upgrade |
| | | chunked = self._cparser.flags & cparser.F_CHUNKED |
| | | |
| | | raw_headers = tuple(self._raw_headers) |
| | | headers = CIMultiDictProxy(CIMultiDict(self._headers)) |
| | | |
| | | if self._cparser.type == cparser.HTTP_REQUEST: |
| | | h_upg = headers.get("upgrade", "") |
| | | allowed = upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES |
| | | if allowed or self._cparser.method == cparser.HTTP_CONNECT: |
| | | self._upgraded = True |
| | | else: |
| | | if upgrade and self._cparser.status_code == 101: |
| | | self._upgraded = True |
| | | |
| | | # do not support old websocket spec |
| | | if SEC_WEBSOCKET_KEY1 in headers: |
| | | raise InvalidHeader(SEC_WEBSOCKET_KEY1) |
| | | |
| | | encoding = None |
| | | enc = self._content_encoding |
| | | if enc is not None: |
| | | self._content_encoding = None |
| | | if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}: |
| | | encoding = enc |
| | | |
| | | if self._cparser.type == cparser.HTTP_REQUEST: |
| | | method = http_method_str(self._cparser.method) |
| | | msg = _new_request_message( |
| | | method, self._path, |
| | | self.http_version(), headers, raw_headers, |
| | | should_close, encoding, upgrade, chunked, self._url) |
| | | else: |
| | | msg = _new_response_message( |
| | | self.http_version(), self._cparser.status_code, self._reason, |
| | | headers, raw_headers, should_close, encoding, |
| | | upgrade, chunked) |
| | | |
| | | if ( |
| | | ULLONG_MAX > self._cparser.content_length > 0 or chunked or |
| | | self._cparser.method == cparser.HTTP_CONNECT or |
| | | (self._cparser.status_code >= 199 and |
| | | self._cparser.content_length == 0 and |
| | | self._read_until_eof) |
| | | ): |
| | | payload = StreamReader( |
| | | self._protocol, timer=self._timer, loop=self._loop, |
| | | limit=self._limit) |
| | | else: |
| | | payload = EMPTY_PAYLOAD |
| | | |
| | | self._payload = payload |
| | | if encoding is not None and self._auto_decompress: |
| | | self._payload = DeflateBuffer(payload, encoding) |
| | | |
| | | if not self._response_with_body: |
| | | payload = EMPTY_PAYLOAD |
| | | |
| | | self._messages.append((msg, payload)) |
| | | |
| | | cdef _on_message_complete(self): |
| | | self._payload.feed_eof() |
| | | self._payload = None |
| | | |
| | | cdef _on_chunk_header(self): |
| | | self._payload.begin_http_chunk_receiving() |
| | | |
| | | cdef _on_chunk_complete(self): |
| | | self._payload.end_http_chunk_receiving() |
| | | |
| | | cdef object _on_status_complete(self): |
| | | pass |
| | | |
| | | cdef inline http_version(self): |
| | | cdef cparser.llhttp_t* parser = self._cparser |
| | | |
| | | if parser.http_major == 1: |
| | | if parser.http_minor == 0: |
| | | return HttpVersion10 |
| | | elif parser.http_minor == 1: |
| | | return HttpVersion11 |
| | | |
| | | return HttpVersion(parser.http_major, parser.http_minor) |
| | | |
| | | ### Public API ### |
| | | |
| | | def feed_eof(self): |
| | | cdef bytes desc |
| | | |
| | | if self._payload is not None: |
| | | if self._cparser.flags & cparser.F_CHUNKED: |
| | | raise TransferEncodingError( |
| | | "Not enough data to satisfy transfer length header.") |
| | | elif self._cparser.flags & cparser.F_CONTENT_LENGTH: |
| | | raise ContentLengthError( |
| | | "Not enough data to satisfy content length header.") |
| | | elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK: |
| | | desc = cparser.llhttp_get_error_reason(self._cparser) |
| | | raise PayloadEncodingError(desc.decode('latin-1')) |
| | | else: |
| | | self._payload.feed_eof() |
| | | elif self._started: |
| | | self._on_headers_complete() |
| | | if self._messages: |
| | | return self._messages[-1][0] |
| | | |
| | | def feed_data(self, data): |
| | | cdef: |
| | | size_t data_len |
| | | size_t nb |
| | | cdef cparser.llhttp_errno_t errno |
| | | |
| | | PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) |
| | | data_len = <size_t>self.py_buf.len |
| | | |
| | | errno = cparser.llhttp_execute( |
| | | self._cparser, |
| | | <char*>self.py_buf.buf, |
| | | data_len) |
| | | |
| | | if errno is cparser.HPE_PAUSED_UPGRADE: |
| | | cparser.llhttp_resume_after_upgrade(self._cparser) |
| | | |
| | | nb = cparser.llhttp_get_error_pos(self._cparser) - <char*>self.py_buf.buf |
| | | |
| | | PyBuffer_Release(&self.py_buf) |
| | | |
| | | if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): |
| | | if self._payload_error == 0: |
| | | if self._last_error is not None: |
| | | ex = self._last_error |
| | | self._last_error = None |
| | | else: |
| | | after = cparser.llhttp_get_error_pos(self._cparser) |
| | | before = data[:after - <char*>self.py_buf.buf] |
| | | after_b = after.split(b"\r\n", 1)[0] |
| | | before = before.rsplit(b"\r\n", 1)[-1] |
| | | data = before + after_b |
| | | pointer = " " * (len(repr(before))-1) + "^" |
| | | ex = parser_error_from_errno(self._cparser, data, pointer) |
| | | self._payload = None |
| | | raise ex |
| | | |
| | | if self._messages: |
| | | messages = self._messages |
| | | self._messages = [] |
| | | else: |
| | | messages = () |
| | | |
| | | if self._upgraded: |
| | | return messages, True, data[nb:] |
| | | else: |
| | | return messages, False, b"" |
| | | |
| | | def set_upgraded(self, val): |
| | | self._upgraded = val |
| | | |
| | | |
| | | cdef class HttpRequestParser(HttpParser): |
| | | |
| | | def __init__( |
| | | self, protocol, loop, int limit, timer=None, |
| | | size_t max_line_size=8190, size_t max_headers=32768, |
| | | size_t max_field_size=8190, payload_exception=None, |
| | | bint response_with_body=True, bint read_until_eof=False, |
| | | bint auto_decompress=True, |
| | | ): |
| | | self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer, |
| | | max_line_size, max_headers, max_field_size, |
| | | payload_exception, response_with_body, read_until_eof, |
| | | auto_decompress) |
| | | |
| | | cdef object _on_status_complete(self): |
| | | cdef int idx1, idx2 |
| | | if not self._buf: |
| | | return |
| | | self._path = self._buf.decode('utf-8', 'surrogateescape') |
| | | try: |
| | | idx3 = len(self._path) |
| | | if self._cparser.method == cparser.HTTP_CONNECT: |
| | | # authority-form, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 |
| | | self._url = URL.build(authority=self._path, encoded=True) |
| | | elif idx3 > 1 and self._path[0] == '/': |
| | | # origin-form, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 |
| | | idx1 = self._path.find("?") |
| | | if idx1 == -1: |
| | | query = "" |
| | | idx2 = self._path.find("#") |
| | | if idx2 == -1: |
| | | path = self._path |
| | | fragment = "" |
| | | else: |
| | | path = self._path[0: idx2] |
| | | fragment = self._path[idx2+1:] |
| | | |
| | | else: |
| | | path = self._path[0:idx1] |
| | | idx1 += 1 |
| | | idx2 = self._path.find("#", idx1+1) |
| | | if idx2 == -1: |
| | | query = self._path[idx1:] |
| | | fragment = "" |
| | | else: |
| | | query = self._path[idx1: idx2] |
| | | fragment = self._path[idx2+1:] |
| | | |
| | | self._url = URL.build( |
| | | path=path, |
| | | query_string=query, |
| | | fragment=fragment, |
| | | encoded=True, |
| | | ) |
| | | else: |
| | | # absolute-form for proxy maybe, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 |
| | | self._url = URL(self._path, encoded=True) |
| | | finally: |
| | | PyByteArray_Resize(self._buf, 0) |
| | | |
| | | |
| | | cdef class HttpResponseParser(HttpParser): |
| | | |
| | | def __init__( |
| | | self, protocol, loop, int limit, timer=None, |
| | | size_t max_line_size=8190, size_t max_headers=32768, |
| | | size_t max_field_size=8190, payload_exception=None, |
| | | bint response_with_body=True, bint read_until_eof=False, |
| | | bint auto_decompress=True |
| | | ): |
| | | self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer, |
| | | max_line_size, max_headers, max_field_size, |
| | | payload_exception, response_with_body, read_until_eof, |
| | | auto_decompress) |
| | | # Use strict parsing on dev mode, so users are warned about broken servers. |
| | | if not DEBUG: |
| | | cparser.llhttp_set_lenient_headers(self._cparser, 1) |
| | | cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1) |
| | | cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1) |
| | | |
| | | cdef object _on_status_complete(self): |
| | | if self._buf: |
| | | self._reason = self._buf.decode('utf-8', 'surrogateescape') |
| | | PyByteArray_Resize(self._buf, 0) |
| | | else: |
| | | self._reason = self._reason or '' |
| | | |
| | | cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | |
| | | pyparser._started = True |
| | | pyparser._headers = [] |
| | | pyparser._raw_headers = [] |
| | | PyByteArray_Resize(pyparser._buf, 0) |
| | | pyparser._path = None |
| | | pyparser._reason = None |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_url(cparser.llhttp_t* parser, |
| | | const char *at, size_t length) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | try: |
| | | if length > pyparser._max_line_size: |
| | | raise LineTooLong( |
| | | 'Status line is too long', pyparser._max_line_size, length) |
| | | extend(pyparser._buf, at, length) |
| | | except BaseException as ex: |
| | | pyparser._last_error = ex |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_status(cparser.llhttp_t* parser, |
| | | const char *at, size_t length) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | cdef str reason |
| | | try: |
| | | if length > pyparser._max_line_size: |
| | | raise LineTooLong( |
| | | 'Status line is too long', pyparser._max_line_size, length) |
| | | extend(pyparser._buf, at, length) |
| | | except BaseException as ex: |
| | | pyparser._last_error = ex |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_header_field(cparser.llhttp_t* parser, |
| | | const char *at, size_t length) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | cdef Py_ssize_t size |
| | | try: |
| | | pyparser._on_status_complete() |
| | | size = len(pyparser._raw_name) + length |
| | | if size > pyparser._max_field_size: |
| | | raise LineTooLong( |
| | | 'Header name is too long', pyparser._max_field_size, size) |
| | | pyparser._on_header_field(at, length) |
| | | except BaseException as ex: |
| | | pyparser._last_error = ex |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_header_value(cparser.llhttp_t* parser, |
| | | const char *at, size_t length) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | cdef Py_ssize_t size |
| | | try: |
| | | size = len(pyparser._raw_value) + length |
| | | if size > pyparser._max_field_size: |
| | | raise LineTooLong( |
| | | 'Header value is too long', pyparser._max_field_size, size) |
| | | pyparser._on_header_value(at, length) |
| | | except BaseException as ex: |
| | | pyparser._last_error = ex |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | try: |
| | | pyparser._on_status_complete() |
| | | pyparser._on_headers_complete() |
| | | except BaseException as exc: |
| | | pyparser._last_error = exc |
| | | return -1 |
| | | else: |
| | | if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT: |
| | | return 2 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_body(cparser.llhttp_t* parser, |
| | | const char *at, size_t length) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | cdef bytes body = at[:length] |
| | | try: |
| | | pyparser._payload.feed_data(body, length) |
| | | except BaseException as underlying_exc: |
| | | reraised_exc = underlying_exc |
| | | if pyparser._payload_exception is not None: |
| | | reraised_exc = pyparser._payload_exception(str(underlying_exc)) |
| | | |
| | | set_exception(pyparser._payload, reraised_exc, underlying_exc) |
| | | |
| | | pyparser._payload_error = 1 |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | try: |
| | | pyparser._started = False |
| | | pyparser._on_message_complete() |
| | | except BaseException as exc: |
| | | pyparser._last_error = exc |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | try: |
| | | pyparser._on_chunk_header() |
| | | except BaseException as exc: |
| | | pyparser._last_error = exc |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1: |
| | | cdef HttpParser pyparser = <HttpParser>parser.data |
| | | try: |
| | | pyparser._on_chunk_complete() |
| | | except BaseException as exc: |
| | | pyparser._last_error = exc |
| | | return -1 |
| | | else: |
| | | return 0 |
| | | |
| | | |
| | | cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer): |
| | | cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser) |
| | | cdef bytes desc = cparser.llhttp_get_error_reason(parser) |
| | | |
| | | err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer) |
| | | |
| | | if errno in {cparser.HPE_CB_MESSAGE_BEGIN, |
| | | cparser.HPE_CB_HEADERS_COMPLETE, |
| | | cparser.HPE_CB_MESSAGE_COMPLETE, |
| | | cparser.HPE_CB_CHUNK_HEADER, |
| | | cparser.HPE_CB_CHUNK_COMPLETE, |
| | | cparser.HPE_INVALID_CONSTANT, |
| | | cparser.HPE_INVALID_HEADER_TOKEN, |
| | | cparser.HPE_INVALID_CONTENT_LENGTH, |
| | | cparser.HPE_INVALID_CHUNK_SIZE, |
| | | cparser.HPE_INVALID_EOF_STATE, |
| | | cparser.HPE_INVALID_TRANSFER_ENCODING}: |
| | | return BadHttpMessage(err_msg) |
| | | elif errno == cparser.HPE_INVALID_METHOD: |
| | | return BadHttpMethod(error=err_msg) |
| | | elif errno in {cparser.HPE_INVALID_STATUS, |
| | | cparser.HPE_INVALID_VERSION}: |
| | | return BadStatusLine(error=err_msg) |
| | | elif errno == cparser.HPE_INVALID_URL: |
| | | return InvalidURLError(err_msg) |
| | | |
| | | return BadHttpMessage(err_msg) |
| New file |
| | |
| | | from cpython.bytes cimport PyBytes_FromStringAndSize |
| | | from cpython.exc cimport PyErr_NoMemory |
| | | from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc |
| | | from cpython.object cimport PyObject_Str |
| | | from libc.stdint cimport uint8_t, uint64_t |
| | | from libc.string cimport memcpy |
| | | |
| | | from multidict import istr |
| | | |
| | | DEF BUF_SIZE = 16 * 1024 # 16KiB |
| | | |
| | | cdef object _istr = istr |
| | | |
| | | |
| | | # ----------------- writer --------------------------- |
| | | |
| | | cdef struct Writer: |
| | | char *buf |
| | | Py_ssize_t size |
| | | Py_ssize_t pos |
| | | bint heap_allocated |
| | | |
| | | cdef inline void _init_writer(Writer* writer, char *buf): |
| | | writer.buf = buf |
| | | writer.size = BUF_SIZE |
| | | writer.pos = 0 |
| | | writer.heap_allocated = 0 |
| | | |
| | | |
| | | cdef inline void _release_writer(Writer* writer): |
| | | if writer.heap_allocated: |
| | | PyMem_Free(writer.buf) |
| | | |
| | | |
| | | cdef inline int _write_byte(Writer* writer, uint8_t ch): |
| | | cdef char * buf |
| | | cdef Py_ssize_t size |
| | | |
| | | if writer.pos == writer.size: |
| | | # reallocate |
| | | size = writer.size + BUF_SIZE |
| | | if not writer.heap_allocated: |
| | | buf = <char*>PyMem_Malloc(size) |
| | | if buf == NULL: |
| | | PyErr_NoMemory() |
| | | return -1 |
| | | memcpy(buf, writer.buf, writer.size) |
| | | else: |
| | | buf = <char*>PyMem_Realloc(writer.buf, size) |
| | | if buf == NULL: |
| | | PyErr_NoMemory() |
| | | return -1 |
| | | writer.buf = buf |
| | | writer.size = size |
| | | writer.heap_allocated = 1 |
| | | writer.buf[writer.pos] = <char>ch |
| | | writer.pos += 1 |
| | | return 0 |
| | | |
| | | |
| | | cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol): |
| | | cdef uint64_t utf = <uint64_t> symbol |
| | | |
| | | if utf < 0x80: |
| | | return _write_byte(writer, <uint8_t>utf) |
| | | elif utf < 0x800: |
| | | if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0: |
| | | return -1 |
| | | return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f))) |
| | | elif 0xD800 <= utf <= 0xDFFF: |
| | | # surogate pair, ignored |
| | | return 0 |
| | | elif utf < 0x10000: |
| | | if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0: |
| | | return -1 |
| | | if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0: |
| | | return -1 |
| | | return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f))) |
| | | elif utf > 0x10FFFF: |
| | | # symbol is too large |
| | | return 0 |
| | | else: |
| | | if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0: |
| | | return -1 |
| | | if _write_byte(writer, |
| | | <uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0: |
| | | return -1 |
| | | if _write_byte(writer, |
| | | <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0: |
| | | return -1 |
| | | return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f))) |
| | | |
| | | |
| | | cdef inline int _write_str(Writer* writer, str s): |
| | | cdef Py_UCS4 ch |
| | | for ch in s: |
| | | if _write_utf8(writer, ch) < 0: |
| | | return -1 |
| | | |
| | | |
| | | cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s): |
| | | cdef Py_UCS4 ch |
| | | cdef str out_str |
| | | if type(s) is str: |
| | | out_str = <str>s |
| | | elif type(s) is _istr: |
| | | out_str = PyObject_Str(s) |
| | | elif not isinstance(s, str): |
| | | raise TypeError("Cannot serialize non-str key {!r}".format(s)) |
| | | else: |
| | | out_str = str(s) |
| | | |
| | | for ch in out_str: |
| | | if ch == 0x0D or ch == 0x0A: |
| | | raise ValueError( |
| | | "Newline or carriage return detected in headers. " |
| | | "Potential header injection attack." |
| | | ) |
| | | if _write_utf8(writer, ch) < 0: |
| | | return -1 |
| | | |
| | | |
| | | # --------------- _serialize_headers ---------------------- |
| | | |
| | | def _serialize_headers(str status_line, headers): |
| | | cdef Writer writer |
| | | cdef object key |
| | | cdef object val |
| | | cdef char buf[BUF_SIZE] |
| | | |
| | | _init_writer(&writer, buf) |
| | | |
| | | try: |
| | | if _write_str(&writer, status_line) < 0: |
| | | raise |
| | | if _write_byte(&writer, b'\r') < 0: |
| | | raise |
| | | if _write_byte(&writer, b'\n') < 0: |
| | | raise |
| | | |
| | | for key, val in headers.items(): |
| | | if _write_str_raise_on_nlcr(&writer, key) < 0: |
| | | raise |
| | | if _write_byte(&writer, b':') < 0: |
| | | raise |
| | | if _write_byte(&writer, b' ') < 0: |
| | | raise |
| | | if _write_str_raise_on_nlcr(&writer, val) < 0: |
| | | raise |
| | | if _write_byte(&writer, b'\r') < 0: |
| | | raise |
| | | if _write_byte(&writer, b'\n') < 0: |
| | | raise |
| | | |
| | | if _write_byte(&writer, b'\r') < 0: |
| | | raise |
| | | if _write_byte(&writer, b'\n') < 0: |
| | | raise |
| | | |
| | | return PyBytes_FromStringAndSize(writer.buf, writer.pos) |
| | | finally: |
| | | _release_writer(&writer) |
| New file |
| | |
| | | e354dd499be171b6125bf56bc3b6c5e2bff2a28af69e3b5d699ddb9af2bafa3c *D:/a/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd |
| New file |
| | |
| | | 468edd38ebf8dc7000a8d333df1c82035d69a5c9febc0448be3c9c4ad4c4630c *D:/a/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx |
| New file |
| | |
| | | 1cd3a5e20456b4d04d11835b2bd3c639f14443052a2467b105b0ca07fdb4b25d *D:/a/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd |
| New file |
| | |
| | | """WebSocket protocol versions 13 and 8.""" |
| New file |
| | |
| | | """Helpers for WebSocket protocol versions 13 and 8.""" |
| | | |
| | | import functools |
| | | import re |
| | | from struct import Struct |
| | | from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple |
| | | |
| | | from ..helpers import NO_EXTENSIONS |
| | | from .models import WSHandshakeError |
| | | |
| | | UNPACK_LEN3 = Struct("!Q").unpack_from |
| | | UNPACK_CLOSE_CODE = Struct("!H").unpack |
| | | PACK_LEN1 = Struct("!BB").pack |
| | | PACK_LEN2 = Struct("!BBH").pack |
| | | PACK_LEN3 = Struct("!BBQ").pack |
| | | PACK_CLOSE_CODE = Struct("!H").pack |
| | | PACK_RANDBITS = Struct("!L").pack |
| | | MSG_SIZE: Final[int] = 2**14 |
| | | MASK_LEN: Final[int] = 4 |
| | | |
| | | WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" |
| | | |
| | | |
| | | # Used by _websocket_mask_python |
| | | @functools.lru_cache |
| | | def _xor_table() -> List[bytes]: |
| | | return [bytes(a ^ b for a in range(256)) for b in range(256)] |
| | | |
| | | |
| | | def _websocket_mask_python(mask: bytes, data: bytearray) -> None: |
| | | """Websocket masking function. |
| | | |
| | | `mask` is a `bytes` object of length 4; `data` is a `bytearray` |
| | | object of any length. The contents of `data` are masked with `mask`, |
| | | as specified in section 5.3 of RFC 6455. |
| | | |
| | | Note that this function mutates the `data` argument. |
| | | |
| | | This pure-python implementation may be replaced by an optimized |
| | | version when available. |
| | | |
| | | """ |
| | | assert isinstance(data, bytearray), data |
| | | assert len(mask) == 4, mask |
| | | |
| | | if data: |
| | | _XOR_TABLE = _xor_table() |
| | | a, b, c, d = (_XOR_TABLE[n] for n in mask) |
| | | data[::4] = data[::4].translate(a) |
| | | data[1::4] = data[1::4].translate(b) |
| | | data[2::4] = data[2::4].translate(c) |
| | | data[3::4] = data[3::4].translate(d) |
| | | |
| | | |
| | | if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover |
| | | websocket_mask = _websocket_mask_python |
| | | else: |
| | | try: |
| | | from .mask import _websocket_mask_cython # type: ignore[import-not-found] |
| | | |
| | | websocket_mask = _websocket_mask_cython |
| | | except ImportError: # pragma: no cover |
| | | websocket_mask = _websocket_mask_python |
| | | |
| | | |
| | | _WS_EXT_RE: Final[Pattern[str]] = re.compile( |
| | | r"^(?:;\s*(?:" |
| | | r"(server_no_context_takeover)|" |
| | | r"(client_no_context_takeover)|" |
| | | r"(server_max_window_bits(?:=(\d+))?)|" |
| | | r"(client_max_window_bits(?:=(\d+))?)))*$" |
| | | ) |
| | | |
| | | _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") |
| | | |
| | | |
| | | def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: |
| | | if not extstr: |
| | | return 0, False |
| | | |
| | | compress = 0 |
| | | notakeover = False |
| | | for ext in _WS_EXT_RE_SPLIT.finditer(extstr): |
| | | defext = ext.group(1) |
| | | # Return compress = 15 when get `permessage-deflate` |
| | | if not defext: |
| | | compress = 15 |
| | | break |
| | | match = _WS_EXT_RE.match(defext) |
| | | if match: |
| | | compress = 15 |
| | | if isserver: |
| | | # Server never fail to detect compress handshake. |
| | | # Server does not need to send max wbit to client |
| | | if match.group(4): |
| | | compress = int(match.group(4)) |
| | | # Group3 must match if group4 matches |
| | | # Compress wbit 8 does not support in zlib |
| | | # If compress level not support, |
| | | # CONTINUE to next extension |
| | | if compress > 15 or compress < 9: |
| | | compress = 0 |
| | | continue |
| | | if match.group(1): |
| | | notakeover = True |
| | | # Ignore regex group 5 & 6 for client_max_window_bits |
| | | break |
| | | else: |
| | | if match.group(6): |
| | | compress = int(match.group(6)) |
| | | # Group5 must match if group6 matches |
| | | # Compress wbit 8 does not support in zlib |
| | | # If compress level not support, |
| | | # FAIL the parse progress |
| | | if compress > 15 or compress < 9: |
| | | raise WSHandshakeError("Invalid window size") |
| | | if match.group(2): |
| | | notakeover = True |
| | | # Ignore regex group 5 & 6 for client_max_window_bits |
| | | break |
| | | # Return Fail if client side and not match |
| | | elif not isserver: |
| | | raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) |
| | | |
| | | return compress, notakeover |
| | | |
| | | |
| | | def ws_ext_gen( |
| | | compress: int = 15, isserver: bool = False, server_notakeover: bool = False |
| | | ) -> str: |
| | | # client_notakeover=False not used for server |
| | | # compress wbit 8 does not support in zlib |
| | | if compress < 9 or compress > 15: |
| | | raise ValueError( |
| | | "Compress wbits must between 9 and 15, zlib does not support wbits=8" |
| | | ) |
| | | enabledext = ["permessage-deflate"] |
| | | if not isserver: |
| | | enabledext.append("client_max_window_bits") |
| | | |
| | | if compress < 15: |
| | | enabledext.append("server_max_window_bits=" + str(compress)) |
| | | if server_notakeover: |
| | | enabledext.append("server_no_context_takeover") |
| | | # if client_notakeover: |
| | | # enabledext.append('client_no_context_takeover') |
| | | return "; ".join(enabledext) |
| New file |
| | |
| | | """Cython declarations for websocket masking.""" |
| | | |
| | | cpdef void _websocket_mask_cython(bytes mask, bytearray data) |
| New file |
| | |
| | | from cpython cimport PyBytes_AsString |
| | | |
| | | |
| | | #from cpython cimport PyByteArray_AsString # cython still not exports that |
| | | cdef extern from "Python.h": |
| | | char* PyByteArray_AsString(bytearray ba) except NULL |
| | | |
| | | from libc.stdint cimport uint32_t, uint64_t, uintmax_t |
| | | |
| | | |
| | | cpdef void _websocket_mask_cython(bytes mask, bytearray data): |
| | | """Note, this function mutates its `data` argument |
| | | """ |
| | | cdef: |
| | | Py_ssize_t data_len, i |
| | | # bit operations on signed integers are implementation-specific |
| | | unsigned char * in_buf |
| | | const unsigned char * mask_buf |
| | | uint32_t uint32_msk |
| | | uint64_t uint64_msk |
| | | |
| | | assert len(mask) == 4 |
| | | |
| | | data_len = len(data) |
| | | in_buf = <unsigned char*>PyByteArray_AsString(data) |
| | | mask_buf = <const unsigned char*>PyBytes_AsString(mask) |
| | | uint32_msk = (<uint32_t*>mask_buf)[0] |
| | | |
| | | # TODO: align in_data ptr to achieve even faster speeds |
| | | # does it need in python ?! malloc() always aligns to sizeof(long) bytes |
| | | |
| | | if sizeof(size_t) >= 8: |
| | | uint64_msk = uint32_msk |
| | | uint64_msk = (uint64_msk << 32) | uint32_msk |
| | | |
| | | while data_len >= 8: |
| | | (<uint64_t*>in_buf)[0] ^= uint64_msk |
| | | in_buf += 8 |
| | | data_len -= 8 |
| | | |
| | | |
| | | while data_len >= 4: |
| | | (<uint32_t*>in_buf)[0] ^= uint32_msk |
| | | in_buf += 4 |
| | | data_len -= 4 |
| | | |
| | | for i in range(0, data_len): |
| | | in_buf[i] ^= mask_buf[i] |
| New file |
| | |
| | | """Models for WebSocket protocol versions 13 and 8.""" |
| | | |
| | | import json |
| | | from enum import IntEnum |
| | | from typing import Any, Callable, Final, NamedTuple, Optional, cast |
| | | |
| | | WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) |
| | | |
| | | |
| | | class WSCloseCode(IntEnum): |
| | | OK = 1000 |
| | | GOING_AWAY = 1001 |
| | | PROTOCOL_ERROR = 1002 |
| | | UNSUPPORTED_DATA = 1003 |
| | | ABNORMAL_CLOSURE = 1006 |
| | | INVALID_TEXT = 1007 |
| | | POLICY_VIOLATION = 1008 |
| | | MESSAGE_TOO_BIG = 1009 |
| | | MANDATORY_EXTENSION = 1010 |
| | | INTERNAL_ERROR = 1011 |
| | | SERVICE_RESTART = 1012 |
| | | TRY_AGAIN_LATER = 1013 |
| | | BAD_GATEWAY = 1014 |
| | | |
| | | |
| | | class WSMsgType(IntEnum): |
| | | # websocket spec types |
| | | CONTINUATION = 0x0 |
| | | TEXT = 0x1 |
| | | BINARY = 0x2 |
| | | PING = 0x9 |
| | | PONG = 0xA |
| | | CLOSE = 0x8 |
| | | |
| | | # aiohttp specific types |
| | | CLOSING = 0x100 |
| | | CLOSED = 0x101 |
| | | ERROR = 0x102 |
| | | |
| | | text = TEXT |
| | | binary = BINARY |
| | | ping = PING |
| | | pong = PONG |
| | | close = CLOSE |
| | | closing = CLOSING |
| | | closed = CLOSED |
| | | error = ERROR |
| | | |
| | | |
| | | class WSMessage(NamedTuple): |
| | | type: WSMsgType |
| | | # To type correctly, this would need some kind of tagged union for each type. |
| | | data: Any |
| | | extra: Optional[str] |
| | | |
| | | def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any: |
| | | """Return parsed JSON data. |
| | | |
| | | .. versionadded:: 0.22 |
| | | """ |
| | | return loads(self.data) |
| | | |
| | | |
| | | # Constructing the tuple directly to avoid the overhead of |
| | | # the lambda and arg processing since NamedTuples are constructed |
| | | # with a run time built lambda |
| | | # https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441 |
| | | WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None)) |
| | | WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None)) |
| | | |
| | | |
| | | class WebSocketError(Exception): |
| | | """WebSocket protocol parser error.""" |
| | | |
| | | def __init__(self, code: int, message: str) -> None: |
| | | self.code = code |
| | | super().__init__(code, message) |
| | | |
| | | def __str__(self) -> str: |
| | | return cast(str, self.args[1]) |
| | | |
| | | |
| | | class WSHandshakeError(Exception): |
| | | """WebSocket protocol handshake error.""" |
| New file |
| | |
| | | """Reader for WebSocket protocol versions 13 and 8.""" |
| | | |
| | | from typing import TYPE_CHECKING |
| | | |
| | | from ..helpers import NO_EXTENSIONS |
| | | |
| | | if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover |
| | | from .reader_py import ( |
| | | WebSocketDataQueue as WebSocketDataQueuePython, |
| | | WebSocketReader as WebSocketReaderPython, |
| | | ) |
| | | |
| | | WebSocketReader = WebSocketReaderPython |
| | | WebSocketDataQueue = WebSocketDataQueuePython |
| | | else: |
| | | try: |
| | | from .reader_c import ( # type: ignore[import-not-found] |
| | | WebSocketDataQueue as WebSocketDataQueueCython, |
| | | WebSocketReader as WebSocketReaderCython, |
| | | ) |
| | | |
| | | WebSocketReader = WebSocketReaderCython |
| | | WebSocketDataQueue = WebSocketDataQueueCython |
| | | except ImportError: # pragma: no cover |
| | | from .reader_py import ( |
| | | WebSocketDataQueue as WebSocketDataQueuePython, |
| | | WebSocketReader as WebSocketReaderPython, |
| | | ) |
| | | |
| | | WebSocketReader = WebSocketReaderPython |
| | | WebSocketDataQueue = WebSocketDataQueuePython |
| New file |
| | |
| | | import cython |
| | | |
| | | from .mask cimport _websocket_mask_cython as websocket_mask |
| | | |
| | | |
| | | cdef unsigned int READ_HEADER |
| | | cdef unsigned int READ_PAYLOAD_LENGTH |
| | | cdef unsigned int READ_PAYLOAD_MASK |
| | | cdef unsigned int READ_PAYLOAD |
| | | |
| | | cdef int OP_CODE_NOT_SET |
| | | cdef int OP_CODE_CONTINUATION |
| | | cdef int OP_CODE_TEXT |
| | | cdef int OP_CODE_BINARY |
| | | cdef int OP_CODE_CLOSE |
| | | cdef int OP_CODE_PING |
| | | cdef int OP_CODE_PONG |
| | | |
| | | cdef int COMPRESSED_NOT_SET |
| | | cdef int COMPRESSED_FALSE |
| | | cdef int COMPRESSED_TRUE |
| | | |
| | | cdef object UNPACK_LEN3 |
| | | cdef object UNPACK_CLOSE_CODE |
| | | cdef object TUPLE_NEW |
| | | |
| | | cdef object WSMsgType |
| | | cdef object WSMessage |
| | | |
| | | cdef object WS_MSG_TYPE_TEXT |
| | | cdef object WS_MSG_TYPE_BINARY |
| | | |
| | | cdef set ALLOWED_CLOSE_CODES |
| | | cdef set MESSAGE_TYPES_WITH_CONTENT |
| | | |
| | | cdef tuple EMPTY_FRAME |
| | | cdef tuple EMPTY_FRAME_ERROR |
| | | |
| | | cdef class WebSocketDataQueue: |
| | | |
| | | cdef unsigned int _size |
| | | cdef public object _protocol |
| | | cdef unsigned int _limit |
| | | cdef object _loop |
| | | cdef bint _eof |
| | | cdef object _waiter |
| | | cdef object _exception |
| | | cdef public object _buffer |
| | | cdef object _get_buffer |
| | | cdef object _put_buffer |
| | | |
| | | cdef void _release_waiter(self) |
| | | |
| | | cpdef void feed_data(self, object data, unsigned int size) |
| | | |
| | | @cython.locals(size="unsigned int") |
| | | cdef _read_from_buffer(self) |
| | | |
| | | cdef class WebSocketReader: |
| | | |
| | | cdef WebSocketDataQueue queue |
| | | cdef unsigned int _max_msg_size |
| | | |
| | | cdef Exception _exc |
| | | cdef bytearray _partial |
| | | cdef unsigned int _state |
| | | |
| | | cdef int _opcode |
| | | cdef bint _frame_fin |
| | | cdef int _frame_opcode |
| | | cdef list _payload_fragments |
| | | cdef Py_ssize_t _frame_payload_len |
| | | |
| | | cdef bytes _tail |
| | | cdef bint _has_mask |
| | | cdef bytes _frame_mask |
| | | cdef Py_ssize_t _payload_bytes_to_read |
| | | cdef unsigned int _payload_len_flag |
| | | cdef int _compressed |
| | | cdef object _decompressobj |
| | | cdef bint _compress |
| | | |
| | | cpdef tuple feed_data(self, object data) |
| | | |
| | | @cython.locals( |
| | | is_continuation=bint, |
| | | fin=bint, |
| | | has_partial=bint, |
| | | payload_merged=bytes, |
| | | ) |
| | | cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except * |
| | | |
| | | @cython.locals( |
| | | start_pos=Py_ssize_t, |
| | | data_len=Py_ssize_t, |
| | | length=Py_ssize_t, |
| | | chunk_size=Py_ssize_t, |
| | | chunk_len=Py_ssize_t, |
| | | data_len=Py_ssize_t, |
| | | data_cstr="const unsigned char *", |
| | | first_byte="unsigned char", |
| | | second_byte="unsigned char", |
| | | f_start_pos=Py_ssize_t, |
| | | f_end_pos=Py_ssize_t, |
| | | has_mask=bint, |
| | | fin=bint, |
| | | had_fragments=Py_ssize_t, |
| | | payload_bytearray=bytearray, |
| | | ) |
| | | cpdef void _feed_data(self, bytes data) except * |
| New file |
| | |
| | | """Reader for WebSocket protocol versions 13 and 8.""" |
| | | |
| | | import asyncio |
| | | import builtins |
| | | from collections import deque |
| | | from typing import Deque, Final, Optional, Set, Tuple, Union |
| | | |
| | | from ..base_protocol import BaseProtocol |
| | | from ..compression_utils import ZLibDecompressor |
| | | from ..helpers import _EXC_SENTINEL, set_exception |
| | | from ..streams import EofStream |
| | | from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask |
| | | from .models import ( |
| | | WS_DEFLATE_TRAILING, |
| | | WebSocketError, |
| | | WSCloseCode, |
| | | WSMessage, |
| | | WSMsgType, |
| | | ) |
| | | |
| | | ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} |
| | | |
| | | # States for the reader, used to parse the WebSocket frame |
| | | # integer values are used so they can be cythonized |
| | | READ_HEADER = 1 |
| | | READ_PAYLOAD_LENGTH = 2 |
| | | READ_PAYLOAD_MASK = 3 |
| | | READ_PAYLOAD = 4 |
| | | |
| | | WS_MSG_TYPE_BINARY = WSMsgType.BINARY |
| | | WS_MSG_TYPE_TEXT = WSMsgType.TEXT |
| | | |
| | | # WSMsgType values unpacked so they can by cythonized to ints |
| | | OP_CODE_NOT_SET = -1 |
| | | OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value |
| | | OP_CODE_TEXT = WSMsgType.TEXT.value |
| | | OP_CODE_BINARY = WSMsgType.BINARY.value |
| | | OP_CODE_CLOSE = WSMsgType.CLOSE.value |
| | | OP_CODE_PING = WSMsgType.PING.value |
| | | OP_CODE_PONG = WSMsgType.PONG.value |
| | | |
| | | EMPTY_FRAME_ERROR = (True, b"") |
| | | EMPTY_FRAME = (False, b"") |
| | | |
| | | COMPRESSED_NOT_SET = -1 |
| | | COMPRESSED_FALSE = 0 |
| | | COMPRESSED_TRUE = 1 |
| | | |
| | | TUPLE_NEW = tuple.__new__ |
| | | |
| | | cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd |
| | | |
| | | |
| | | class WebSocketDataQueue: |
| | | """WebSocketDataQueue resumes and pauses an underlying stream. |
| | | |
| | | It is a destination for WebSocket data. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop |
| | | ) -> None: |
| | | self._size = 0 |
| | | self._protocol = protocol |
| | | self._limit = limit * 2 |
| | | self._loop = loop |
| | | self._eof = False |
| | | self._waiter: Optional[asyncio.Future[None]] = None |
| | | self._exception: Union[BaseException, None] = None |
| | | self._buffer: Deque[Tuple[WSMessage, int]] = deque() |
| | | self._get_buffer = self._buffer.popleft |
| | | self._put_buffer = self._buffer.append |
| | | |
| | | def is_eof(self) -> bool: |
| | | return self._eof |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return self._exception |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: builtins.BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | self._eof = True |
| | | self._exception = exc |
| | | if (waiter := self._waiter) is not None: |
| | | self._waiter = None |
| | | set_exception(waiter, exc, exc_cause) |
| | | |
| | | def _release_waiter(self) -> None: |
| | | if (waiter := self._waiter) is None: |
| | | return |
| | | self._waiter = None |
| | | if not waiter.done(): |
| | | waiter.set_result(None) |
| | | |
| | | def feed_eof(self) -> None: |
| | | self._eof = True |
| | | self._release_waiter() |
| | | self._exception = None # Break cyclic references |
| | | |
| | | def feed_data(self, data: "WSMessage", size: "cython_int") -> None: |
| | | self._size += size |
| | | self._put_buffer((data, size)) |
| | | self._release_waiter() |
| | | if self._size > self._limit and not self._protocol._reading_paused: |
| | | self._protocol.pause_reading() |
| | | |
| | | async def read(self) -> WSMessage: |
| | | if not self._buffer and not self._eof: |
| | | assert not self._waiter |
| | | self._waiter = self._loop.create_future() |
| | | try: |
| | | await self._waiter |
| | | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | | self._waiter = None |
| | | raise |
| | | return self._read_from_buffer() |
| | | |
| | | def _read_from_buffer(self) -> WSMessage: |
| | | if self._buffer: |
| | | data, size = self._get_buffer() |
| | | self._size -= size |
| | | if self._size < self._limit and self._protocol._reading_paused: |
| | | self._protocol.resume_reading() |
| | | return data |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | raise EofStream |
| | | |
| | | |
| | | class WebSocketReader: |
| | | def __init__( |
| | | self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True |
| | | ) -> None: |
| | | self.queue = queue |
| | | self._max_msg_size = max_msg_size |
| | | |
| | | self._exc: Optional[Exception] = None |
| | | self._partial = bytearray() |
| | | self._state = READ_HEADER |
| | | |
| | | self._opcode: int = OP_CODE_NOT_SET |
| | | self._frame_fin = False |
| | | self._frame_opcode: int = OP_CODE_NOT_SET |
| | | self._payload_fragments: list[bytes] = [] |
| | | self._frame_payload_len = 0 |
| | | |
| | | self._tail: bytes = b"" |
| | | self._has_mask = False |
| | | self._frame_mask: Optional[bytes] = None |
| | | self._payload_bytes_to_read = 0 |
| | | self._payload_len_flag = 0 |
| | | self._compressed: int = COMPRESSED_NOT_SET |
| | | self._decompressobj: Optional[ZLibDecompressor] = None |
| | | self._compress = compress |
| | | |
| | | def feed_eof(self) -> None: |
| | | self.queue.feed_eof() |
| | | |
| | | # data can be bytearray on Windows because proactor event loop uses bytearray |
| | | # and asyncio types this to Union[bytes, bytearray, memoryview] so we need |
| | | # coerce data to bytes if it is not |
| | | def feed_data( |
| | | self, data: Union[bytes, bytearray, memoryview] |
| | | ) -> Tuple[bool, bytes]: |
| | | if type(data) is not bytes: |
| | | data = bytes(data) |
| | | |
| | | if self._exc is not None: |
| | | return True, data |
| | | |
| | | try: |
| | | self._feed_data(data) |
| | | except Exception as exc: |
| | | self._exc = exc |
| | | set_exception(self.queue, exc) |
| | | return EMPTY_FRAME_ERROR |
| | | |
| | | return EMPTY_FRAME |
| | | |
| | | def _handle_frame( |
| | | self, |
| | | fin: bool, |
| | | opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int |
| | | payload: Union[bytes, bytearray], |
| | | compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int |
| | | ) -> None: |
| | | msg: WSMessage |
| | | if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: |
| | | # Validate continuation frames before processing |
| | | if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Continuation frame for non started message", |
| | | ) |
| | | |
| | | # load text/binary |
| | | if not fin: |
| | | # got partial frame payload |
| | | if opcode != OP_CODE_CONTINUATION: |
| | | self._opcode = opcode |
| | | self._partial += payload |
| | | if self._max_msg_size and len(self._partial) >= self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Message size {len(self._partial)} " |
| | | f"exceeds limit {self._max_msg_size}", |
| | | ) |
| | | return |
| | | |
| | | has_partial = bool(self._partial) |
| | | if opcode == OP_CODE_CONTINUATION: |
| | | opcode = self._opcode |
| | | self._opcode = OP_CODE_NOT_SET |
| | | # previous frame was non finished |
| | | # we should get continuation opcode |
| | | elif has_partial: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "The opcode in non-fin frame is expected " |
| | | f"to be zero, got {opcode!r}", |
| | | ) |
| | | |
| | | assembled_payload: Union[bytes, bytearray] |
| | | if has_partial: |
| | | assembled_payload = self._partial + payload |
| | | self._partial.clear() |
| | | else: |
| | | assembled_payload = payload |
| | | |
| | | if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Message size {len(assembled_payload)} " |
| | | f"exceeds limit {self._max_msg_size}", |
| | | ) |
| | | |
| | | # Decompress process must to be done after all packets |
| | | # received. |
| | | if compressed: |
| | | if not self._decompressobj: |
| | | self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) |
| | | # XXX: It's possible that the zlib backend (isal is known to |
| | | # do this, maybe others too?) will return max_length bytes, |
| | | # but internally buffer more data such that the payload is |
| | | # >max_length, so we return one extra byte and if we're able |
| | | # to do that, then the message is too big. |
| | | payload_merged = self._decompressobj.decompress_sync( |
| | | assembled_payload + WS_DEFLATE_TRAILING, |
| | | ( |
| | | self._max_msg_size + 1 |
| | | if self._max_msg_size |
| | | else self._max_msg_size |
| | | ), |
| | | ) |
| | | if self._max_msg_size and len(payload_merged) > self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Decompressed message exceeds size limit {self._max_msg_size}", |
| | | ) |
| | | elif type(assembled_payload) is bytes: |
| | | payload_merged = assembled_payload |
| | | else: |
| | | payload_merged = bytes(assembled_payload) |
| | | |
| | | if opcode == OP_CODE_TEXT: |
| | | try: |
| | | text = payload_merged.decode("utf-8") |
| | | except UnicodeDecodeError as exc: |
| | | raise WebSocketError( |
| | | WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
| | | ) from exc |
| | | |
| | | # XXX: The Text and Binary messages here can be a performance |
| | | # bottleneck, so we use tuple.__new__ to improve performance. |
| | | # This is not type safe, but many tests should fail in |
| | | # test_client_ws_functional.py if this is wrong. |
| | | self.queue.feed_data( |
| | | TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")), |
| | | len(payload_merged), |
| | | ) |
| | | else: |
| | | self.queue.feed_data( |
| | | TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")), |
| | | len(payload_merged), |
| | | ) |
| | | elif opcode == OP_CODE_CLOSE: |
| | | if len(payload) >= 2: |
| | | close_code = UNPACK_CLOSE_CODE(payload[:2])[0] |
| | | if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | f"Invalid close code: {close_code}", |
| | | ) |
| | | try: |
| | | close_message = payload[2:].decode("utf-8") |
| | | except UnicodeDecodeError as exc: |
| | | raise WebSocketError( |
| | | WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
| | | ) from exc |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message)) |
| | | elif payload: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | f"Invalid close frame: {fin} {opcode} {payload!r}", |
| | | ) |
| | | else: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, "")) |
| | | |
| | | self.queue.feed_data(msg, 0) |
| | | elif opcode == OP_CODE_PING: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, "")) |
| | | self.queue.feed_data(msg, len(payload)) |
| | | elif opcode == OP_CODE_PONG: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, "")) |
| | | self.queue.feed_data(msg, len(payload)) |
| | | else: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" |
| | | ) |
| | | |
| | | def _feed_data(self, data: bytes) -> None: |
| | | """Return the next frame from the socket.""" |
| | | if self._tail: |
| | | data, self._tail = self._tail + data, b"" |
| | | |
| | | start_pos: int = 0 |
| | | data_len = len(data) |
| | | data_cstr = data |
| | | |
| | | while True: |
| | | # read header |
| | | if self._state == READ_HEADER: |
| | | if data_len - start_pos < 2: |
| | | break |
| | | first_byte = data_cstr[start_pos] |
| | | second_byte = data_cstr[start_pos + 1] |
| | | start_pos += 2 |
| | | |
| | | fin = (first_byte >> 7) & 1 |
| | | rsv1 = (first_byte >> 6) & 1 |
| | | rsv2 = (first_byte >> 5) & 1 |
| | | rsv3 = (first_byte >> 4) & 1 |
| | | opcode = first_byte & 0xF |
| | | |
| | | # frame-fin = %x0 ; more frames of this message follow |
| | | # / %x1 ; final frame of this message |
| | | # frame-rsv1 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # frame-rsv2 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # frame-rsv3 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # |
| | | # Remove rsv1 from this test for deflate development |
| | | if rsv2 or rsv3 or (rsv1 and not self._compress): |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received frame with non-zero reserved bits", |
| | | ) |
| | | |
| | | if opcode > 0x7 and fin == 0: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received fragmented control frame", |
| | | ) |
| | | |
| | | has_mask = (second_byte >> 7) & 1 |
| | | length = second_byte & 0x7F |
| | | |
| | | # Control frames MUST have a payload |
| | | # length of 125 bytes or less |
| | | if opcode > 0x7 and length > 125: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Control frame payload cannot be larger than 125 bytes", |
| | | ) |
| | | |
| | | # Set compress status if last package is FIN |
| | | # OR set compress status if this is first fragment |
| | | # Raise error if not first fragment with rsv1 = 0x1 |
| | | if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: |
| | | self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE |
| | | elif rsv1: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received frame with non-zero reserved bits", |
| | | ) |
| | | |
| | | self._frame_fin = bool(fin) |
| | | self._frame_opcode = opcode |
| | | self._has_mask = bool(has_mask) |
| | | self._payload_len_flag = length |
| | | self._state = READ_PAYLOAD_LENGTH |
| | | |
| | | # read payload length |
| | | if self._state == READ_PAYLOAD_LENGTH: |
| | | len_flag = self._payload_len_flag |
| | | if len_flag == 126: |
| | | if data_len - start_pos < 2: |
| | | break |
| | | first_byte = data_cstr[start_pos] |
| | | second_byte = data_cstr[start_pos + 1] |
| | | start_pos += 2 |
| | | self._payload_bytes_to_read = first_byte << 8 | second_byte |
| | | elif len_flag > 126: |
| | | if data_len - start_pos < 8: |
| | | break |
| | | self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] |
| | | start_pos += 8 |
| | | else: |
| | | self._payload_bytes_to_read = len_flag |
| | | |
| | | self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD |
| | | |
| | | # read payload mask |
| | | if self._state == READ_PAYLOAD_MASK: |
| | | if data_len - start_pos < 4: |
| | | break |
| | | self._frame_mask = data_cstr[start_pos : start_pos + 4] |
| | | start_pos += 4 |
| | | self._state = READ_PAYLOAD |
| | | |
| | | if self._state == READ_PAYLOAD: |
| | | chunk_len = data_len - start_pos |
| | | if self._payload_bytes_to_read >= chunk_len: |
| | | f_end_pos = data_len |
| | | self._payload_bytes_to_read -= chunk_len |
| | | else: |
| | | f_end_pos = start_pos + self._payload_bytes_to_read |
| | | self._payload_bytes_to_read = 0 |
| | | |
| | | had_fragments = self._frame_payload_len |
| | | self._frame_payload_len += f_end_pos - start_pos |
| | | f_start_pos = start_pos |
| | | start_pos = f_end_pos |
| | | |
| | | if self._payload_bytes_to_read != 0: |
| | | # If we don't have a complete frame, we need to save the |
| | | # data for the next call to feed_data. |
| | | self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
| | | break |
| | | |
| | | payload: Union[bytes, bytearray] |
| | | if had_fragments: |
| | | # We have to join the payload fragments get the payload |
| | | self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
| | | if self._has_mask: |
| | | assert self._frame_mask is not None |
| | | payload_bytearray = bytearray(b"".join(self._payload_fragments)) |
| | | websocket_mask(self._frame_mask, payload_bytearray) |
| | | payload = payload_bytearray |
| | | else: |
| | | payload = b"".join(self._payload_fragments) |
| | | self._payload_fragments.clear() |
| | | elif self._has_mask: |
| | | assert self._frame_mask is not None |
| | | payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] |
| | | if type(payload_bytearray) is not bytearray: # pragma: no branch |
| | | # Cython will do the conversion for us |
| | | # but we need to do it for Python and we |
| | | # will always get here in Python |
| | | payload_bytearray = bytearray(payload_bytearray) |
| | | websocket_mask(self._frame_mask, payload_bytearray) |
| | | payload = payload_bytearray |
| | | else: |
| | | payload = data_cstr[f_start_pos:f_end_pos] |
| | | |
| | | self._handle_frame( |
| | | self._frame_fin, self._frame_opcode, payload, self._compressed |
| | | ) |
| | | self._frame_payload_len = 0 |
| | | self._state = READ_HEADER |
| | | |
| | | # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. |
| | | self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b"" |
| New file |
| | |
| | | """Reader for WebSocket protocol versions 13 and 8.""" |
| | | |
| | | import asyncio |
| | | import builtins |
| | | from collections import deque |
| | | from typing import Deque, Final, Optional, Set, Tuple, Union |
| | | |
| | | from ..base_protocol import BaseProtocol |
| | | from ..compression_utils import ZLibDecompressor |
| | | from ..helpers import _EXC_SENTINEL, set_exception |
| | | from ..streams import EofStream |
| | | from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask |
| | | from .models import ( |
| | | WS_DEFLATE_TRAILING, |
| | | WebSocketError, |
| | | WSCloseCode, |
| | | WSMessage, |
| | | WSMsgType, |
| | | ) |
| | | |
| | | ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} |
| | | |
| | | # States for the reader, used to parse the WebSocket frame |
| | | # integer values are used so they can be cythonized |
| | | READ_HEADER = 1 |
| | | READ_PAYLOAD_LENGTH = 2 |
| | | READ_PAYLOAD_MASK = 3 |
| | | READ_PAYLOAD = 4 |
| | | |
| | | WS_MSG_TYPE_BINARY = WSMsgType.BINARY |
| | | WS_MSG_TYPE_TEXT = WSMsgType.TEXT |
| | | |
| | | # WSMsgType values unpacked so they can by cythonized to ints |
| | | OP_CODE_NOT_SET = -1 |
| | | OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value |
| | | OP_CODE_TEXT = WSMsgType.TEXT.value |
| | | OP_CODE_BINARY = WSMsgType.BINARY.value |
| | | OP_CODE_CLOSE = WSMsgType.CLOSE.value |
| | | OP_CODE_PING = WSMsgType.PING.value |
| | | OP_CODE_PONG = WSMsgType.PONG.value |
| | | |
| | | EMPTY_FRAME_ERROR = (True, b"") |
| | | EMPTY_FRAME = (False, b"") |
| | | |
| | | COMPRESSED_NOT_SET = -1 |
| | | COMPRESSED_FALSE = 0 |
| | | COMPRESSED_TRUE = 1 |
| | | |
| | | TUPLE_NEW = tuple.__new__ |
| | | |
| | | cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd |
| | | |
| | | |
| | | class WebSocketDataQueue: |
| | | """WebSocketDataQueue resumes and pauses an underlying stream. |
| | | |
| | | It is a destination for WebSocket data. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop |
| | | ) -> None: |
| | | self._size = 0 |
| | | self._protocol = protocol |
| | | self._limit = limit * 2 |
| | | self._loop = loop |
| | | self._eof = False |
| | | self._waiter: Optional[asyncio.Future[None]] = None |
| | | self._exception: Union[BaseException, None] = None |
| | | self._buffer: Deque[Tuple[WSMessage, int]] = deque() |
| | | self._get_buffer = self._buffer.popleft |
| | | self._put_buffer = self._buffer.append |
| | | |
| | | def is_eof(self) -> bool: |
| | | return self._eof |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return self._exception |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: builtins.BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | self._eof = True |
| | | self._exception = exc |
| | | if (waiter := self._waiter) is not None: |
| | | self._waiter = None |
| | | set_exception(waiter, exc, exc_cause) |
| | | |
| | | def _release_waiter(self) -> None: |
| | | if (waiter := self._waiter) is None: |
| | | return |
| | | self._waiter = None |
| | | if not waiter.done(): |
| | | waiter.set_result(None) |
| | | |
| | | def feed_eof(self) -> None: |
| | | self._eof = True |
| | | self._release_waiter() |
| | | self._exception = None # Break cyclic references |
| | | |
| | | def feed_data(self, data: "WSMessage", size: "cython_int") -> None: |
| | | self._size += size |
| | | self._put_buffer((data, size)) |
| | | self._release_waiter() |
| | | if self._size > self._limit and not self._protocol._reading_paused: |
| | | self._protocol.pause_reading() |
| | | |
| | | async def read(self) -> WSMessage: |
| | | if not self._buffer and not self._eof: |
| | | assert not self._waiter |
| | | self._waiter = self._loop.create_future() |
| | | try: |
| | | await self._waiter |
| | | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | | self._waiter = None |
| | | raise |
| | | return self._read_from_buffer() |
| | | |
| | | def _read_from_buffer(self) -> WSMessage: |
| | | if self._buffer: |
| | | data, size = self._get_buffer() |
| | | self._size -= size |
| | | if self._size < self._limit and self._protocol._reading_paused: |
| | | self._protocol.resume_reading() |
| | | return data |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | raise EofStream |
| | | |
| | | |
| | | class WebSocketReader: |
| | | def __init__( |
| | | self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True |
| | | ) -> None: |
| | | self.queue = queue |
| | | self._max_msg_size = max_msg_size |
| | | |
| | | self._exc: Optional[Exception] = None |
| | | self._partial = bytearray() |
| | | self._state = READ_HEADER |
| | | |
| | | self._opcode: int = OP_CODE_NOT_SET |
| | | self._frame_fin = False |
| | | self._frame_opcode: int = OP_CODE_NOT_SET |
| | | self._payload_fragments: list[bytes] = [] |
| | | self._frame_payload_len = 0 |
| | | |
| | | self._tail: bytes = b"" |
| | | self._has_mask = False |
| | | self._frame_mask: Optional[bytes] = None |
| | | self._payload_bytes_to_read = 0 |
| | | self._payload_len_flag = 0 |
| | | self._compressed: int = COMPRESSED_NOT_SET |
| | | self._decompressobj: Optional[ZLibDecompressor] = None |
| | | self._compress = compress |
| | | |
| | | def feed_eof(self) -> None: |
| | | self.queue.feed_eof() |
| | | |
| | | # data can be bytearray on Windows because proactor event loop uses bytearray |
| | | # and asyncio types this to Union[bytes, bytearray, memoryview] so we need |
| | | # coerce data to bytes if it is not |
| | | def feed_data( |
| | | self, data: Union[bytes, bytearray, memoryview] |
| | | ) -> Tuple[bool, bytes]: |
| | | if type(data) is not bytes: |
| | | data = bytes(data) |
| | | |
| | | if self._exc is not None: |
| | | return True, data |
| | | |
| | | try: |
| | | self._feed_data(data) |
| | | except Exception as exc: |
| | | self._exc = exc |
| | | set_exception(self.queue, exc) |
| | | return EMPTY_FRAME_ERROR |
| | | |
| | | return EMPTY_FRAME |
| | | |
| | | def _handle_frame( |
| | | self, |
| | | fin: bool, |
| | | opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int |
| | | payload: Union[bytes, bytearray], |
| | | compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int |
| | | ) -> None: |
| | | msg: WSMessage |
| | | if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: |
| | | # Validate continuation frames before processing |
| | | if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Continuation frame for non started message", |
| | | ) |
| | | |
| | | # load text/binary |
| | | if not fin: |
| | | # got partial frame payload |
| | | if opcode != OP_CODE_CONTINUATION: |
| | | self._opcode = opcode |
| | | self._partial += payload |
| | | if self._max_msg_size and len(self._partial) >= self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Message size {len(self._partial)} " |
| | | f"exceeds limit {self._max_msg_size}", |
| | | ) |
| | | return |
| | | |
| | | has_partial = bool(self._partial) |
| | | if opcode == OP_CODE_CONTINUATION: |
| | | opcode = self._opcode |
| | | self._opcode = OP_CODE_NOT_SET |
| | | # previous frame was non finished |
| | | # we should get continuation opcode |
| | | elif has_partial: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "The opcode in non-fin frame is expected " |
| | | f"to be zero, got {opcode!r}", |
| | | ) |
| | | |
| | | assembled_payload: Union[bytes, bytearray] |
| | | if has_partial: |
| | | assembled_payload = self._partial + payload |
| | | self._partial.clear() |
| | | else: |
| | | assembled_payload = payload |
| | | |
| | | if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Message size {len(assembled_payload)} " |
| | | f"exceeds limit {self._max_msg_size}", |
| | | ) |
| | | |
| | | # Decompress process must to be done after all packets |
| | | # received. |
| | | if compressed: |
| | | if not self._decompressobj: |
| | | self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) |
| | | # XXX: It's possible that the zlib backend (isal is known to |
| | | # do this, maybe others too?) will return max_length bytes, |
| | | # but internally buffer more data such that the payload is |
| | | # >max_length, so we return one extra byte and if we're able |
| | | # to do that, then the message is too big. |
| | | payload_merged = self._decompressobj.decompress_sync( |
| | | assembled_payload + WS_DEFLATE_TRAILING, |
| | | ( |
| | | self._max_msg_size + 1 |
| | | if self._max_msg_size |
| | | else self._max_msg_size |
| | | ), |
| | | ) |
| | | if self._max_msg_size and len(payload_merged) > self._max_msg_size: |
| | | raise WebSocketError( |
| | | WSCloseCode.MESSAGE_TOO_BIG, |
| | | f"Decompressed message exceeds size limit {self._max_msg_size}", |
| | | ) |
| | | elif type(assembled_payload) is bytes: |
| | | payload_merged = assembled_payload |
| | | else: |
| | | payload_merged = bytes(assembled_payload) |
| | | |
| | | if opcode == OP_CODE_TEXT: |
| | | try: |
| | | text = payload_merged.decode("utf-8") |
| | | except UnicodeDecodeError as exc: |
| | | raise WebSocketError( |
| | | WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
| | | ) from exc |
| | | |
| | | # XXX: The Text and Binary messages here can be a performance |
| | | # bottleneck, so we use tuple.__new__ to improve performance. |
| | | # This is not type safe, but many tests should fail in |
| | | # test_client_ws_functional.py if this is wrong. |
| | | self.queue.feed_data( |
| | | TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")), |
| | | len(payload_merged), |
| | | ) |
| | | else: |
| | | self.queue.feed_data( |
| | | TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")), |
| | | len(payload_merged), |
| | | ) |
| | | elif opcode == OP_CODE_CLOSE: |
| | | if len(payload) >= 2: |
| | | close_code = UNPACK_CLOSE_CODE(payload[:2])[0] |
| | | if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | f"Invalid close code: {close_code}", |
| | | ) |
| | | try: |
| | | close_message = payload[2:].decode("utf-8") |
| | | except UnicodeDecodeError as exc: |
| | | raise WebSocketError( |
| | | WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
| | | ) from exc |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message)) |
| | | elif payload: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | f"Invalid close frame: {fin} {opcode} {payload!r}", |
| | | ) |
| | | else: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, "")) |
| | | |
| | | self.queue.feed_data(msg, 0) |
| | | elif opcode == OP_CODE_PING: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, "")) |
| | | self.queue.feed_data(msg, len(payload)) |
| | | elif opcode == OP_CODE_PONG: |
| | | msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, "")) |
| | | self.queue.feed_data(msg, len(payload)) |
| | | else: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" |
| | | ) |
| | | |
| | | def _feed_data(self, data: bytes) -> None: |
| | | """Return the next frame from the socket.""" |
| | | if self._tail: |
| | | data, self._tail = self._tail + data, b"" |
| | | |
| | | start_pos: int = 0 |
| | | data_len = len(data) |
| | | data_cstr = data |
| | | |
| | | while True: |
| | | # read header |
| | | if self._state == READ_HEADER: |
| | | if data_len - start_pos < 2: |
| | | break |
| | | first_byte = data_cstr[start_pos] |
| | | second_byte = data_cstr[start_pos + 1] |
| | | start_pos += 2 |
| | | |
| | | fin = (first_byte >> 7) & 1 |
| | | rsv1 = (first_byte >> 6) & 1 |
| | | rsv2 = (first_byte >> 5) & 1 |
| | | rsv3 = (first_byte >> 4) & 1 |
| | | opcode = first_byte & 0xF |
| | | |
| | | # frame-fin = %x0 ; more frames of this message follow |
| | | # / %x1 ; final frame of this message |
| | | # frame-rsv1 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # frame-rsv2 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # frame-rsv3 = %x0 ; |
| | | # 1 bit, MUST be 0 unless negotiated otherwise |
| | | # |
| | | # Remove rsv1 from this test for deflate development |
| | | if rsv2 or rsv3 or (rsv1 and not self._compress): |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received frame with non-zero reserved bits", |
| | | ) |
| | | |
| | | if opcode > 0x7 and fin == 0: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received fragmented control frame", |
| | | ) |
| | | |
| | | has_mask = (second_byte >> 7) & 1 |
| | | length = second_byte & 0x7F |
| | | |
| | | # Control frames MUST have a payload |
| | | # length of 125 bytes or less |
| | | if opcode > 0x7 and length > 125: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Control frame payload cannot be larger than 125 bytes", |
| | | ) |
| | | |
| | | # Set compress status if last package is FIN |
| | | # OR set compress status if this is first fragment |
| | | # Raise error if not first fragment with rsv1 = 0x1 |
| | | if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: |
| | | self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE |
| | | elif rsv1: |
| | | raise WebSocketError( |
| | | WSCloseCode.PROTOCOL_ERROR, |
| | | "Received frame with non-zero reserved bits", |
| | | ) |
| | | |
| | | self._frame_fin = bool(fin) |
| | | self._frame_opcode = opcode |
| | | self._has_mask = bool(has_mask) |
| | | self._payload_len_flag = length |
| | | self._state = READ_PAYLOAD_LENGTH |
| | | |
| | | # read payload length |
| | | if self._state == READ_PAYLOAD_LENGTH: |
| | | len_flag = self._payload_len_flag |
| | | if len_flag == 126: |
| | | if data_len - start_pos < 2: |
| | | break |
| | | first_byte = data_cstr[start_pos] |
| | | second_byte = data_cstr[start_pos + 1] |
| | | start_pos += 2 |
| | | self._payload_bytes_to_read = first_byte << 8 | second_byte |
| | | elif len_flag > 126: |
| | | if data_len - start_pos < 8: |
| | | break |
| | | self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] |
| | | start_pos += 8 |
| | | else: |
| | | self._payload_bytes_to_read = len_flag |
| | | |
| | | self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD |
| | | |
| | | # read payload mask |
| | | if self._state == READ_PAYLOAD_MASK: |
| | | if data_len - start_pos < 4: |
| | | break |
| | | self._frame_mask = data_cstr[start_pos : start_pos + 4] |
| | | start_pos += 4 |
| | | self._state = READ_PAYLOAD |
| | | |
| | | if self._state == READ_PAYLOAD: |
| | | chunk_len = data_len - start_pos |
| | | if self._payload_bytes_to_read >= chunk_len: |
| | | f_end_pos = data_len |
| | | self._payload_bytes_to_read -= chunk_len |
| | | else: |
| | | f_end_pos = start_pos + self._payload_bytes_to_read |
| | | self._payload_bytes_to_read = 0 |
| | | |
| | | had_fragments = self._frame_payload_len |
| | | self._frame_payload_len += f_end_pos - start_pos |
| | | f_start_pos = start_pos |
| | | start_pos = f_end_pos |
| | | |
| | | if self._payload_bytes_to_read != 0: |
| | | # If we don't have a complete frame, we need to save the |
| | | # data for the next call to feed_data. |
| | | self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
| | | break |
| | | |
| | | payload: Union[bytes, bytearray] |
| | | if had_fragments: |
| | | # We have to join the payload fragments get the payload |
| | | self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
| | | if self._has_mask: |
| | | assert self._frame_mask is not None |
| | | payload_bytearray = bytearray(b"".join(self._payload_fragments)) |
| | | websocket_mask(self._frame_mask, payload_bytearray) |
| | | payload = payload_bytearray |
| | | else: |
| | | payload = b"".join(self._payload_fragments) |
| | | self._payload_fragments.clear() |
| | | elif self._has_mask: |
| | | assert self._frame_mask is not None |
| | | payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] |
| | | if type(payload_bytearray) is not bytearray: # pragma: no branch |
| | | # Cython will do the conversion for us |
| | | # but we need to do it for Python and we |
| | | # will always get here in Python |
| | | payload_bytearray = bytearray(payload_bytearray) |
| | | websocket_mask(self._frame_mask, payload_bytearray) |
| | | payload = payload_bytearray |
| | | else: |
| | | payload = data_cstr[f_start_pos:f_end_pos] |
| | | |
| | | self._handle_frame( |
| | | self._frame_fin, self._frame_opcode, payload, self._compressed |
| | | ) |
| | | self._frame_payload_len = 0 |
| | | self._state = READ_HEADER |
| | | |
| | | # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. |
| | | self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b"" |
| New file |
| | |
| | | """WebSocket protocol versions 13 and 8.""" |
| | | |
| | | import asyncio |
| | | import random |
| | | import sys |
| | | from functools import partial |
| | | from typing import Final, Optional, Set, Union |
| | | |
| | | from ..base_protocol import BaseProtocol |
| | | from ..client_exceptions import ClientConnectionResetError |
| | | from ..compression_utils import ZLibBackend, ZLibCompressor |
| | | from .helpers import ( |
| | | MASK_LEN, |
| | | MSG_SIZE, |
| | | PACK_CLOSE_CODE, |
| | | PACK_LEN1, |
| | | PACK_LEN2, |
| | | PACK_LEN3, |
| | | PACK_RANDBITS, |
| | | websocket_mask, |
| | | ) |
| | | from .models import WS_DEFLATE_TRAILING, WSMsgType |
| | | |
| | | DEFAULT_LIMIT: Final[int] = 2**16 |
| | | |
| | | # WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames |
| | | # Control frames (ping, pong, close) are never compressed |
| | | WS_CONTROL_FRAME_OPCODE: Final[int] = 8 |
| | | |
| | | # For websockets, keeping latency low is extremely important as implementations |
| | | # generally expect to be able to send and receive messages quickly. We use a |
| | | # larger chunk size to reduce the number of executor calls and avoid task |
| | | # creation overhead, since both are significant sources of latency when chunks |
| | | # are small. A size of 16KiB was chosen as a balance between avoiding task |
| | | # overhead and not blocking the event loop too long with synchronous compression. |
| | | |
| | | WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 |
| | | |
| | | |
| | | class WebSocketWriter: |
| | | """WebSocket writer. |
| | | |
| | | The writer is responsible for sending messages to the client. It is |
| | | created by the protocol when a connection is established. The writer |
| | | should avoid implementing any application logic and should only be |
| | | concerned with the low-level details of the WebSocket protocol. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | protocol: BaseProtocol, |
| | | transport: asyncio.Transport, |
| | | *, |
| | | use_mask: bool = False, |
| | | limit: int = DEFAULT_LIMIT, |
| | | random: random.Random = random.Random(), |
| | | compress: int = 0, |
| | | notakeover: bool = False, |
| | | ) -> None: |
| | | """Initialize a WebSocket writer.""" |
| | | self.protocol = protocol |
| | | self.transport = transport |
| | | self.use_mask = use_mask |
| | | self.get_random_bits = partial(random.getrandbits, 32) |
| | | self.compress = compress |
| | | self.notakeover = notakeover |
| | | self._closing = False |
| | | self._limit = limit |
| | | self._output_size = 0 |
| | | self._compressobj: Optional[ZLibCompressor] = None |
| | | self._send_lock = asyncio.Lock() |
| | | self._background_tasks: Set[asyncio.Task[None]] = set() |
| | | |
| | | async def send_frame( |
| | | self, message: bytes, opcode: int, compress: Optional[int] = None |
| | | ) -> None: |
| | | """Send a frame over the websocket with message as its payload.""" |
| | | if self._closing and not (opcode & WSMsgType.CLOSE): |
| | | raise ClientConnectionResetError("Cannot write to closing transport") |
| | | |
| | | if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: |
| | | # Non-compressed frames don't need lock or shield |
| | | self._write_websocket_frame(message, opcode, 0) |
| | | elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: |
| | | # Small compressed payloads - compress synchronously in event loop |
| | | # We need the lock even though sync compression has no await points. |
| | | # This prevents small frames from interleaving with large frames that |
| | | # compress in the executor, avoiding compressor state corruption. |
| | | async with self._send_lock: |
| | | self._send_compressed_frame_sync(message, opcode, compress) |
| | | else: |
| | | # Large compressed frames need shield to prevent corruption |
| | | # For large compressed frames, the entire compress+send |
| | | # operation must be atomic. If cancelled after compression but |
| | | # before send, the compressor state would be advanced but data |
| | | # not sent, corrupting subsequent frames. |
| | | # Create a task to shield from cancellation |
| | | # The lock is acquired inside the shielded task so the entire |
| | | # operation (lock + compress + send) completes atomically. |
| | | # Use eager_start on Python 3.12+ to avoid scheduling overhead |
| | | loop = asyncio.get_running_loop() |
| | | coro = self._send_compressed_frame_async_locked(message, opcode, compress) |
| | | if sys.version_info >= (3, 12): |
| | | send_task = asyncio.Task(coro, loop=loop, eager_start=True) |
| | | else: |
| | | send_task = loop.create_task(coro) |
| | | # Keep a strong reference to prevent garbage collection |
| | | self._background_tasks.add(send_task) |
| | | send_task.add_done_callback(self._background_tasks.discard) |
| | | await asyncio.shield(send_task) |
| | | |
| | | # It is safe to return control to the event loop when using compression |
| | | # after this point as we have already sent or buffered all the data. |
| | | # Once we have written output_size up to the limit, we call the |
| | | # drain helper which waits for the transport to be ready to accept |
| | | # more data. This is a flow control mechanism to prevent the buffer |
| | | # from growing too large. The drain helper will return right away |
| | | # if the writer is not paused. |
| | | if self._output_size > self._limit: |
| | | self._output_size = 0 |
| | | if self.protocol._paused: |
| | | await self.protocol._drain_helper() |
| | | |
| | | def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: |
| | | """ |
| | | Write a websocket frame to the transport. |
| | | |
| | | This method handles frame header construction, masking, and writing to transport. |
| | | It does not handle compression or flow control - those are the responsibility |
| | | of the caller. |
| | | """ |
| | | msg_length = len(message) |
| | | |
| | | use_mask = self.use_mask |
| | | mask_bit = 0x80 if use_mask else 0 |
| | | |
| | | # Depending on the message length, the header is assembled differently. |
| | | # The first byte is reserved for the opcode and the RSV bits. |
| | | first_byte = 0x80 | rsv | opcode |
| | | if msg_length < 126: |
| | | header = PACK_LEN1(first_byte, msg_length | mask_bit) |
| | | header_len = 2 |
| | | elif msg_length < 65536: |
| | | header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) |
| | | header_len = 4 |
| | | else: |
| | | header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) |
| | | header_len = 10 |
| | | |
| | | if self.transport.is_closing(): |
| | | raise ClientConnectionResetError("Cannot write to closing transport") |
| | | |
| | | # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 |
| | | # If we are using a mask, we need to generate it randomly |
| | | # and apply it to the message before sending it. A mask is |
| | | # a 32-bit value that is applied to the message using a |
| | | # bitwise XOR operation. It is used to prevent certain types |
| | | # of attacks on the websocket protocol. The mask is only used |
| | | # when aiohttp is acting as a client. Servers do not use a mask. |
| | | if use_mask: |
| | | mask = PACK_RANDBITS(self.get_random_bits()) |
| | | message = bytearray(message) |
| | | websocket_mask(mask, message) |
| | | self.transport.write(header + mask + message) |
| | | self._output_size += MASK_LEN |
| | | elif msg_length > MSG_SIZE: |
| | | self.transport.write(header) |
| | | self.transport.write(message) |
| | | else: |
| | | self.transport.write(header + message) |
| | | |
| | | self._output_size += header_len + msg_length |
| | | |
| | | def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor: |
| | | """Get or create a compressor object for the given compression level.""" |
| | | if compress: |
| | | # Do not set self._compress if compressing is for this frame |
| | | return ZLibCompressor( |
| | | level=ZLibBackend.Z_BEST_SPEED, |
| | | wbits=-compress, |
| | | max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, |
| | | ) |
| | | if not self._compressobj: |
| | | self._compressobj = ZLibCompressor( |
| | | level=ZLibBackend.Z_BEST_SPEED, |
| | | wbits=-self.compress, |
| | | max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, |
| | | ) |
| | | return self._compressobj |
| | | |
| | | def _send_compressed_frame_sync( |
| | | self, message: bytes, opcode: int, compress: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Synchronous send for small compressed frames. |
| | | |
| | | This is used for small compressed payloads that compress synchronously in the event loop. |
| | | Since there are no await points, this is inherently cancellation-safe. |
| | | """ |
| | | # RSV are the reserved bits in the frame header. They are used to |
| | | # indicate that the frame is using an extension. |
| | | # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 |
| | | compressobj = self._get_compressor(compress) |
| | | # (0x40) RSV1 is set for compressed frames |
| | | # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 |
| | | self._write_websocket_frame( |
| | | ( |
| | | compressobj.compress_sync(message) |
| | | + compressobj.flush( |
| | | ZLibBackend.Z_FULL_FLUSH |
| | | if self.notakeover |
| | | else ZLibBackend.Z_SYNC_FLUSH |
| | | ) |
| | | ).removesuffix(WS_DEFLATE_TRAILING), |
| | | opcode, |
| | | 0x40, |
| | | ) |
| | | |
| | | async def _send_compressed_frame_async_locked( |
| | | self, message: bytes, opcode: int, compress: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Async send for large compressed frames with lock. |
| | | |
| | | Acquires the lock and compresses large payloads asynchronously in |
| | | the executor. The lock is held for the entire operation to ensure |
| | | the compressor state is not corrupted by concurrent sends. |
| | | |
| | | MUST be run shielded from cancellation. If cancelled after |
| | | compression but before sending, the compressor state would be |
| | | advanced but data not sent, corrupting subsequent frames. |
| | | """ |
| | | async with self._send_lock: |
| | | # RSV are the reserved bits in the frame header. They are used to |
| | | # indicate that the frame is using an extension. |
| | | # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 |
| | | compressobj = self._get_compressor(compress) |
| | | # (0x40) RSV1 is set for compressed frames |
| | | # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 |
| | | self._write_websocket_frame( |
| | | ( |
| | | await compressobj.compress(message) |
| | | + compressobj.flush( |
| | | ZLibBackend.Z_FULL_FLUSH |
| | | if self.notakeover |
| | | else ZLibBackend.Z_SYNC_FLUSH |
| | | ) |
| | | ).removesuffix(WS_DEFLATE_TRAILING), |
| | | opcode, |
| | | 0x40, |
| | | ) |
| | | |
| | | async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: |
| | | """Close the websocket, sending the specified code and message.""" |
| | | if isinstance(message, str): |
| | | message = message.encode("utf-8") |
| | | try: |
| | | await self.send_frame( |
| | | PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE |
| | | ) |
| | | finally: |
| | | self._closing = True |
| New file |
| | |
| | | import asyncio |
| | | import logging |
| | | import socket |
| | | from abc import ABC, abstractmethod |
| | | from collections.abc import Sized |
| | | from http.cookies import BaseCookie, Morsel |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | Dict, |
| | | Generator, |
| | | Iterable, |
| | | List, |
| | | Optional, |
| | | Sequence, |
| | | Tuple, |
| | | TypedDict, |
| | | Union, |
| | | ) |
| | | |
| | | from multidict import CIMultiDict |
| | | from yarl import URL |
| | | |
| | | from ._cookie_helpers import parse_set_cookie_headers |
| | | from .typedefs import LooseCookies |
| | | |
| | | if TYPE_CHECKING: |
| | | from .web_app import Application |
| | | from .web_exceptions import HTTPException |
| | | from .web_request import BaseRequest, Request |
| | | from .web_response import StreamResponse |
| | | else: |
| | | BaseRequest = Request = Application = StreamResponse = None |
| | | HTTPException = None |
| | | |
| | | |
| | | class AbstractRouter(ABC): |
| | | def __init__(self) -> None: |
| | | self._frozen = False |
| | | |
| | | def post_init(self, app: Application) -> None: |
| | | """Post init stage. |
| | | |
| | | Not an abstract method for sake of backward compatibility, |
| | | but if the router wants to be aware of the application |
| | | it can override this. |
| | | """ |
| | | |
| | | @property |
| | | def frozen(self) -> bool: |
| | | return self._frozen |
| | | |
| | | def freeze(self) -> None: |
| | | """Freeze router.""" |
| | | self._frozen = True |
| | | |
| | | @abstractmethod |
| | | async def resolve(self, request: Request) -> "AbstractMatchInfo": |
| | | """Return MATCH_INFO for given request""" |
| | | |
| | | |
| | | class AbstractMatchInfo(ABC): |
| | | |
| | | __slots__ = () |
| | | |
| | | @property # pragma: no branch |
| | | @abstractmethod |
| | | def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: |
| | | """Execute matched request handler""" |
| | | |
| | | @property |
| | | @abstractmethod |
| | | def expect_handler( |
| | | self, |
| | | ) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]: |
| | | """Expect handler for 100-continue processing""" |
| | | |
| | | @property # pragma: no branch |
| | | @abstractmethod |
| | | def http_exception(self) -> Optional[HTTPException]: |
| | | """HTTPException instance raised on router's resolving, or None""" |
| | | |
| | | @abstractmethod # pragma: no branch |
| | | def get_info(self) -> Dict[str, Any]: |
| | | """Return a dict with additional info useful for introspection""" |
| | | |
| | | @property # pragma: no branch |
| | | @abstractmethod |
| | | def apps(self) -> Tuple[Application, ...]: |
| | | """Stack of nested applications. |
| | | |
| | | Top level application is left-most element. |
| | | |
| | | """ |
| | | |
| | | @abstractmethod |
| | | def add_app(self, app: Application) -> None: |
| | | """Add application to the nested apps stack.""" |
| | | |
| | | @abstractmethod |
| | | def freeze(self) -> None: |
| | | """Freeze the match info. |
| | | |
| | | The method is called after route resolution. |
| | | |
| | | After the call .add_app() is forbidden. |
| | | |
| | | """ |
| | | |
| | | |
| | | class AbstractView(ABC): |
| | | """Abstract class based view.""" |
| | | |
| | | def __init__(self, request: Request) -> None: |
| | | self._request = request |
| | | |
| | | @property |
| | | def request(self) -> Request: |
| | | """Request instance.""" |
| | | return self._request |
| | | |
| | | @abstractmethod |
| | | def __await__(self) -> Generator[None, None, StreamResponse]: |
| | | """Execute the view handler.""" |
| | | |
| | | |
| | | class ResolveResult(TypedDict): |
| | | """Resolve result. |
| | | |
| | | This is the result returned from an AbstractResolver's |
| | | resolve method. |
| | | |
| | | :param hostname: The hostname that was provided. |
| | | :param host: The IP address that was resolved. |
| | | :param port: The port that was resolved. |
| | | :param family: The address family that was resolved. |
| | | :param proto: The protocol that was resolved. |
| | | :param flags: The flags that were resolved. |
| | | """ |
| | | |
| | | hostname: str |
| | | host: str |
| | | port: int |
| | | family: int |
| | | proto: int |
| | | flags: int |
| | | |
| | | |
| | | class AbstractResolver(ABC): |
| | | """Abstract DNS resolver.""" |
| | | |
| | | @abstractmethod |
| | | async def resolve( |
| | | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| | | ) -> List[ResolveResult]: |
| | | """Return IP address for given hostname""" |
| | | |
| | | @abstractmethod |
| | | async def close(self) -> None: |
| | | """Release resolver""" |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | IterableBase = Iterable[Morsel[str]] |
| | | else: |
| | | IterableBase = Iterable |
| | | |
| | | |
| | | ClearCookiePredicate = Callable[["Morsel[str]"], bool] |
| | | |
| | | |
| | | class AbstractCookieJar(Sized, IterableBase): |
| | | """Abstract Cookie Jar.""" |
| | | |
| | | def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
| | | self._loop = loop or asyncio.get_running_loop() |
| | | |
| | | @property |
| | | @abstractmethod |
| | | def quote_cookie(self) -> bool: |
| | | """Return True if cookies should be quoted.""" |
| | | |
| | | @abstractmethod |
| | | def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: |
| | | """Clear all cookies if no predicate is passed.""" |
| | | |
| | | @abstractmethod |
| | | def clear_domain(self, domain: str) -> None: |
| | | """Clear all cookies for domain and all subdomains.""" |
| | | |
| | | @abstractmethod |
| | | def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: |
| | | """Update cookies.""" |
| | | |
| | | def update_cookies_from_headers( |
| | | self, headers: Sequence[str], response_url: URL |
| | | ) -> None: |
| | | """Update cookies from raw Set-Cookie headers.""" |
| | | if headers and (cookies_to_update := parse_set_cookie_headers(headers)): |
| | | self.update_cookies(cookies_to_update, response_url) |
| | | |
| | | @abstractmethod |
| | | def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": |
| | | """Return the jar's cookies filtered by their attributes.""" |
| | | |
| | | |
| | | class AbstractStreamWriter(ABC): |
| | | """Abstract stream writer.""" |
| | | |
| | | buffer_size: int = 0 |
| | | output_size: int = 0 |
| | | length: Optional[int] = 0 |
| | | |
| | | @abstractmethod |
| | | async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None: |
| | | """Write chunk into stream.""" |
| | | |
| | | @abstractmethod |
| | | async def write_eof(self, chunk: bytes = b"") -> None: |
| | | """Write last chunk.""" |
| | | |
| | | @abstractmethod |
| | | async def drain(self) -> None: |
| | | """Flush the write buffer.""" |
| | | |
| | | @abstractmethod |
| | | def enable_compression( |
| | | self, encoding: str = "deflate", strategy: Optional[int] = None |
| | | ) -> None: |
| | | """Enable HTTP body compression""" |
| | | |
| | | @abstractmethod |
| | | def enable_chunking(self) -> None: |
| | | """Enable HTTP chunked mode""" |
| | | |
| | | @abstractmethod |
| | | async def write_headers( |
| | | self, status_line: str, headers: "CIMultiDict[str]" |
| | | ) -> None: |
| | | """Write HTTP headers""" |
| | | |
| | | def send_headers(self) -> None: |
| | | """Force sending buffered headers if not already sent. |
| | | |
| | | Required only if write_headers() buffers headers instead of sending immediately. |
| | | For backwards compatibility, this method does nothing by default. |
| | | """ |
| | | |
| | | |
| | | class AbstractAccessLogger(ABC): |
| | | """Abstract writer to access log.""" |
| | | |
| | | __slots__ = ("logger", "log_format") |
| | | |
| | | def __init__(self, logger: logging.Logger, log_format: str) -> None: |
| | | self.logger = logger |
| | | self.log_format = log_format |
| | | |
| | | @abstractmethod |
| | | def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: |
| | | """Emit log to logger.""" |
| | | |
| | | @property |
| | | def enabled(self) -> bool: |
| | | """Check if logger is enabled.""" |
| | | return True |
| New file |
| | |
| | | import asyncio |
| | | from typing import Optional, cast |
| | | |
| | | from .client_exceptions import ClientConnectionResetError |
| | | from .helpers import set_exception |
| | | from .tcp_helpers import tcp_nodelay |
| | | |
| | | |
| | | class BaseProtocol(asyncio.Protocol): |
| | | __slots__ = ( |
| | | "_loop", |
| | | "_paused", |
| | | "_drain_waiter", |
| | | "_connection_lost", |
| | | "_reading_paused", |
| | | "transport", |
| | | ) |
| | | |
| | | def __init__(self, loop: asyncio.AbstractEventLoop) -> None: |
| | | self._loop: asyncio.AbstractEventLoop = loop |
| | | self._paused = False |
| | | self._drain_waiter: Optional[asyncio.Future[None]] = None |
| | | self._reading_paused = False |
| | | |
| | | self.transport: Optional[asyncio.Transport] = None |
| | | |
| | | @property |
| | | def connected(self) -> bool: |
| | | """Return True if the connection is open.""" |
| | | return self.transport is not None |
| | | |
| | | @property |
| | | def writing_paused(self) -> bool: |
| | | return self._paused |
| | | |
| | | def pause_writing(self) -> None: |
| | | assert not self._paused |
| | | self._paused = True |
| | | |
| | | def resume_writing(self) -> None: |
| | | assert self._paused |
| | | self._paused = False |
| | | |
| | | waiter = self._drain_waiter |
| | | if waiter is not None: |
| | | self._drain_waiter = None |
| | | if not waiter.done(): |
| | | waiter.set_result(None) |
| | | |
| | | def pause_reading(self) -> None: |
| | | if not self._reading_paused and self.transport is not None: |
| | | try: |
| | | self.transport.pause_reading() |
| | | except (AttributeError, NotImplementedError, RuntimeError): |
| | | pass |
| | | self._reading_paused = True |
| | | |
| | | def resume_reading(self) -> None: |
| | | if self._reading_paused and self.transport is not None: |
| | | try: |
| | | self.transport.resume_reading() |
| | | except (AttributeError, NotImplementedError, RuntimeError): |
| | | pass |
| | | self._reading_paused = False |
| | | |
| | | def connection_made(self, transport: asyncio.BaseTransport) -> None: |
| | | tr = cast(asyncio.Transport, transport) |
| | | tcp_nodelay(tr, True) |
| | | self.transport = tr |
| | | |
| | | def connection_lost(self, exc: Optional[BaseException]) -> None: |
| | | # Wake up the writer if currently paused. |
| | | self.transport = None |
| | | if not self._paused: |
| | | return |
| | | waiter = self._drain_waiter |
| | | if waiter is None: |
| | | return |
| | | self._drain_waiter = None |
| | | if waiter.done(): |
| | | return |
| | | if exc is None: |
| | | waiter.set_result(None) |
| | | else: |
| | | set_exception( |
| | | waiter, |
| | | ConnectionError("Connection lost"), |
| | | exc, |
| | | ) |
| | | |
| | | async def _drain_helper(self) -> None: |
| | | if self.transport is None: |
| | | raise ClientConnectionResetError("Connection lost") |
| | | if not self._paused: |
| | | return |
| | | waiter = self._drain_waiter |
| | | if waiter is None: |
| | | waiter = self._loop.create_future() |
| | | self._drain_waiter = waiter |
| | | await asyncio.shield(waiter) |
| New file |
| | |
| | | """HTTP Client for asyncio.""" |
| | | |
| | | import asyncio |
| | | import base64 |
| | | import hashlib |
| | | import json |
| | | import os |
| | | import sys |
| | | import traceback |
| | | import warnings |
| | | from contextlib import suppress |
| | | from types import TracebackType |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | Coroutine, |
| | | Final, |
| | | FrozenSet, |
| | | Generator, |
| | | Generic, |
| | | Iterable, |
| | | List, |
| | | Mapping, |
| | | Optional, |
| | | Sequence, |
| | | Set, |
| | | Tuple, |
| | | Type, |
| | | TypedDict, |
| | | TypeVar, |
| | | Union, |
| | | ) |
| | | |
| | | import attr |
| | | from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr |
| | | from yarl import URL |
| | | |
| | | from . import hdrs, http, payload |
| | | from ._websocket.reader import WebSocketDataQueue |
| | | from .abc import AbstractCookieJar |
| | | from .client_exceptions import ( |
| | | ClientConnectionError, |
| | | ClientConnectionResetError, |
| | | ClientConnectorCertificateError, |
| | | ClientConnectorDNSError, |
| | | ClientConnectorError, |
| | | ClientConnectorSSLError, |
| | | ClientError, |
| | | ClientHttpProxyError, |
| | | ClientOSError, |
| | | ClientPayloadError, |
| | | ClientProxyConnectionError, |
| | | ClientResponseError, |
| | | ClientSSLError, |
| | | ConnectionTimeoutError, |
| | | ContentTypeError, |
| | | InvalidURL, |
| | | InvalidUrlClientError, |
| | | InvalidUrlRedirectClientError, |
| | | NonHttpUrlClientError, |
| | | NonHttpUrlRedirectClientError, |
| | | RedirectClientError, |
| | | ServerConnectionError, |
| | | ServerDisconnectedError, |
| | | ServerFingerprintMismatch, |
| | | ServerTimeoutError, |
| | | SocketTimeoutError, |
| | | TooManyRedirects, |
| | | WSMessageTypeError, |
| | | WSServerHandshakeError, |
| | | ) |
| | | from .client_middlewares import ClientMiddlewareType, build_client_middlewares |
| | | from .client_reqrep import ( |
| | | ClientRequest as ClientRequest, |
| | | ClientResponse as ClientResponse, |
| | | Fingerprint as Fingerprint, |
| | | RequestInfo as RequestInfo, |
| | | _merge_ssl_params, |
| | | ) |
| | | from .client_ws import ( |
| | | DEFAULT_WS_CLIENT_TIMEOUT, |
| | | ClientWebSocketResponse as ClientWebSocketResponse, |
| | | ClientWSTimeout as ClientWSTimeout, |
| | | ) |
| | | from .connector import ( |
| | | HTTP_AND_EMPTY_SCHEMA_SET, |
| | | BaseConnector as BaseConnector, |
| | | NamedPipeConnector as NamedPipeConnector, |
| | | TCPConnector as TCPConnector, |
| | | UnixConnector as UnixConnector, |
| | | ) |
| | | from .cookiejar import CookieJar |
| | | from .helpers import ( |
| | | _SENTINEL, |
| | | DEBUG, |
| | | EMPTY_BODY_METHODS, |
| | | BasicAuth, |
| | | TimeoutHandle, |
| | | basicauth_from_netrc, |
| | | get_env_proxy_for_url, |
| | | netrc_from_env, |
| | | sentinel, |
| | | strip_auth_from_url, |
| | | ) |
| | | from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter |
| | | from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse |
| | | from .tracing import Trace, TraceConfig |
| | | from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL |
| | | |
| | | __all__ = ( |
| | | # client_exceptions |
| | | "ClientConnectionError", |
| | | "ClientConnectionResetError", |
| | | "ClientConnectorCertificateError", |
| | | "ClientConnectorDNSError", |
| | | "ClientConnectorError", |
| | | "ClientConnectorSSLError", |
| | | "ClientError", |
| | | "ClientHttpProxyError", |
| | | "ClientOSError", |
| | | "ClientPayloadError", |
| | | "ClientProxyConnectionError", |
| | | "ClientResponseError", |
| | | "ClientSSLError", |
| | | "ConnectionTimeoutError", |
| | | "ContentTypeError", |
| | | "InvalidURL", |
| | | "InvalidUrlClientError", |
| | | "RedirectClientError", |
| | | "NonHttpUrlClientError", |
| | | "InvalidUrlRedirectClientError", |
| | | "NonHttpUrlRedirectClientError", |
| | | "ServerConnectionError", |
| | | "ServerDisconnectedError", |
| | | "ServerFingerprintMismatch", |
| | | "ServerTimeoutError", |
| | | "SocketTimeoutError", |
| | | "TooManyRedirects", |
| | | "WSServerHandshakeError", |
| | | # client_reqrep |
| | | "ClientRequest", |
| | | "ClientResponse", |
| | | "Fingerprint", |
| | | "RequestInfo", |
| | | # connector |
| | | "BaseConnector", |
| | | "TCPConnector", |
| | | "UnixConnector", |
| | | "NamedPipeConnector", |
| | | # client_ws |
| | | "ClientWebSocketResponse", |
| | | # client |
| | | "ClientSession", |
| | | "ClientTimeout", |
| | | "ClientWSTimeout", |
| | | "request", |
| | | "WSMessageTypeError", |
| | | ) |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from ssl import SSLContext |
| | | else: |
| | | SSLContext = None |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | from typing import Unpack |
| | | |
| | | |
| | | class _RequestOptions(TypedDict, total=False): |
| | | params: Query |
| | | data: Any |
| | | json: Any |
| | | cookies: Union[LooseCookies, None] |
| | | headers: Union[LooseHeaders, None] |
| | | skip_auto_headers: Union[Iterable[str], None] |
| | | auth: Union[BasicAuth, None] |
| | | allow_redirects: bool |
| | | max_redirects: int |
| | | compress: Union[str, bool, None] |
| | | chunked: Union[bool, None] |
| | | expect100: bool |
| | | raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] |
| | | read_until_eof: bool |
| | | proxy: Union[StrOrURL, None] |
| | | proxy_auth: Union[BasicAuth, None] |
| | | timeout: "Union[ClientTimeout, _SENTINEL, None]" |
| | | ssl: Union[SSLContext, bool, Fingerprint] |
| | | server_hostname: Union[str, None] |
| | | proxy_headers: Union[LooseHeaders, None] |
| | | trace_request_ctx: Union[Mapping[str, Any], None] |
| | | read_bufsize: Union[int, None] |
| | | auto_decompress: Union[bool, None] |
| | | max_line_size: Union[int, None] |
| | | max_field_size: Union[int, None] |
| | | middlewares: Optional[Sequence[ClientMiddlewareType]] |
| | | |
| | | |
| | | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | | class ClientTimeout: |
| | | total: Optional[float] = None |
| | | connect: Optional[float] = None |
| | | sock_read: Optional[float] = None |
| | | sock_connect: Optional[float] = None |
| | | ceil_threshold: float = 5 |
| | | |
| | | # pool_queue_timeout: Optional[float] = None |
| | | # dns_resolution_timeout: Optional[float] = None |
| | | # socket_connect_timeout: Optional[float] = None |
| | | # connection_acquiring_timeout: Optional[float] = None |
| | | # new_connection_timeout: Optional[float] = None |
| | | # http_header_timeout: Optional[float] = None |
| | | # response_body_timeout: Optional[float] = None |
| | | |
| | | # to create a timeout specific for a single request, either |
| | | # - create a completely new one to overwrite the default |
| | | # - or use http://www.attrs.org/en/stable/api.html#attr.evolve |
| | | # to overwrite the defaults |
| | | |
| | | |
| | | # 5 Minute default read timeout |
| | | DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60, sock_connect=30) |
| | | |
| | | # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 |
| | | IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) |
| | | |
| | | _RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse) |
| | | _CharsetResolver = Callable[[ClientResponse, bytes], str] |
| | | |
| | | |
| | | class ClientSession: |
| | | """First-class interface for making HTTP requests.""" |
| | | |
| | | ATTRS = frozenset( |
| | | [ |
| | | "_base_url", |
| | | "_base_url_origin", |
| | | "_source_traceback", |
| | | "_connector", |
| | | "_loop", |
| | | "_cookie_jar", |
| | | "_connector_owner", |
| | | "_default_auth", |
| | | "_version", |
| | | "_json_serialize", |
| | | "_requote_redirect_url", |
| | | "_timeout", |
| | | "_raise_for_status", |
| | | "_auto_decompress", |
| | | "_trust_env", |
| | | "_default_headers", |
| | | "_skip_auto_headers", |
| | | "_request_class", |
| | | "_response_class", |
| | | "_ws_response_class", |
| | | "_trace_configs", |
| | | "_read_bufsize", |
| | | "_max_line_size", |
| | | "_max_field_size", |
| | | "_resolve_charset", |
| | | "_default_proxy", |
| | | "_default_proxy_auth", |
| | | "_retry_connection", |
| | | "_middlewares", |
| | | "requote_redirect_url", |
| | | ] |
| | | ) |
| | | |
| | | _source_traceback: Optional[traceback.StackSummary] = None |
| | | _connector: Optional[BaseConnector] = None |
| | | |
| | | def __init__( |
| | | self, |
| | | base_url: Optional[StrOrURL] = None, |
| | | *, |
| | | connector: Optional[BaseConnector] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | cookies: Optional[LooseCookies] = None, |
| | | headers: Optional[LooseHeaders] = None, |
| | | proxy: Optional[StrOrURL] = None, |
| | | proxy_auth: Optional[BasicAuth] = None, |
| | | skip_auto_headers: Optional[Iterable[str]] = None, |
| | | auth: Optional[BasicAuth] = None, |
| | | json_serialize: JSONEncoder = json.dumps, |
| | | request_class: Type[ClientRequest] = ClientRequest, |
| | | response_class: Type[ClientResponse] = ClientResponse, |
| | | ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, |
| | | version: HttpVersion = http.HttpVersion11, |
| | | cookie_jar: Optional[AbstractCookieJar] = None, |
| | | connector_owner: bool = True, |
| | | raise_for_status: Union[ |
| | | bool, Callable[[ClientResponse], Awaitable[None]] |
| | | ] = False, |
| | | read_timeout: Union[float, _SENTINEL] = sentinel, |
| | | conn_timeout: Optional[float] = None, |
| | | timeout: Union[object, ClientTimeout] = sentinel, |
| | | auto_decompress: bool = True, |
| | | trust_env: bool = False, |
| | | requote_redirect_url: bool = True, |
| | | trace_configs: Optional[List[TraceConfig]] = None, |
| | | read_bufsize: int = 2**16, |
| | | max_line_size: int = 8190, |
| | | max_field_size: int = 8190, |
| | | fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", |
| | | middlewares: Sequence[ClientMiddlewareType] = (), |
| | | ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, |
| | | ) -> None: |
| | | # We initialise _connector to None immediately, as it's referenced in __del__() |
| | | # and could cause issues if an exception occurs during initialisation. |
| | | self._connector: Optional[BaseConnector] = None |
| | | |
| | | if loop is None: |
| | | if connector is not None: |
| | | loop = connector._loop |
| | | |
| | | loop = loop or asyncio.get_running_loop() |
| | | |
| | | if base_url is None or isinstance(base_url, URL): |
| | | self._base_url: Optional[URL] = base_url |
| | | self._base_url_origin = None if base_url is None else base_url.origin() |
| | | else: |
| | | self._base_url = URL(base_url) |
| | | self._base_url_origin = self._base_url.origin() |
| | | assert self._base_url.absolute, "Only absolute URLs are supported" |
| | | if self._base_url is not None and not self._base_url.path.endswith("/"): |
| | | raise ValueError("base_url must have a trailing '/'") |
| | | |
| | | if timeout is sentinel or timeout is None: |
| | | self._timeout = DEFAULT_TIMEOUT |
| | | if read_timeout is not sentinel: |
| | | warnings.warn( |
| | | "read_timeout is deprecated, use timeout argument instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | self._timeout = attr.evolve(self._timeout, total=read_timeout) |
| | | if conn_timeout is not None: |
| | | self._timeout = attr.evolve(self._timeout, connect=conn_timeout) |
| | | warnings.warn( |
| | | "conn_timeout is deprecated, use timeout argument instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | else: |
| | | if not isinstance(timeout, ClientTimeout): |
| | | raise ValueError( |
| | | f"timeout parameter cannot be of {type(timeout)} type, " |
| | | "please use 'timeout=ClientTimeout(...)'", |
| | | ) |
| | | self._timeout = timeout |
| | | if read_timeout is not sentinel: |
| | | raise ValueError( |
| | | "read_timeout and timeout parameters " |
| | | "conflict, please setup " |
| | | "timeout.read" |
| | | ) |
| | | if conn_timeout is not None: |
| | | raise ValueError( |
| | | "conn_timeout and timeout parameters " |
| | | "conflict, please setup " |
| | | "timeout.connect" |
| | | ) |
| | | |
| | | if ssl_shutdown_timeout is not sentinel: |
| | | warnings.warn( |
| | | "The ssl_shutdown_timeout parameter is deprecated and will be removed in aiohttp 4.0", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | |
| | | if connector is None: |
| | | connector = TCPConnector( |
| | | loop=loop, ssl_shutdown_timeout=ssl_shutdown_timeout |
| | | ) |
| | | |
| | | if connector._loop is not loop: |
| | | raise RuntimeError("Session and connector has to use same event loop") |
| | | |
| | | self._loop = loop |
| | | |
| | | if loop.get_debug(): |
| | | self._source_traceback = traceback.extract_stack(sys._getframe(1)) |
| | | |
| | | if cookie_jar is None: |
| | | cookie_jar = CookieJar(loop=loop) |
| | | self._cookie_jar = cookie_jar |
| | | |
| | | if cookies: |
| | | self._cookie_jar.update_cookies(cookies) |
| | | |
| | | self._connector = connector |
| | | self._connector_owner = connector_owner |
| | | self._default_auth = auth |
| | | self._version = version |
| | | self._json_serialize = json_serialize |
| | | self._raise_for_status = raise_for_status |
| | | self._auto_decompress = auto_decompress |
| | | self._trust_env = trust_env |
| | | self._requote_redirect_url = requote_redirect_url |
| | | self._read_bufsize = read_bufsize |
| | | self._max_line_size = max_line_size |
| | | self._max_field_size = max_field_size |
| | | |
| | | # Convert to list of tuples |
| | | if headers: |
| | | real_headers: CIMultiDict[str] = CIMultiDict(headers) |
| | | else: |
| | | real_headers = CIMultiDict() |
| | | self._default_headers: CIMultiDict[str] = real_headers |
| | | if skip_auto_headers is not None: |
| | | self._skip_auto_headers = frozenset(istr(i) for i in skip_auto_headers) |
| | | else: |
| | | self._skip_auto_headers = frozenset() |
| | | |
| | | self._request_class = request_class |
| | | self._response_class = response_class |
| | | self._ws_response_class = ws_response_class |
| | | |
| | | self._trace_configs = trace_configs or [] |
| | | for trace_config in self._trace_configs: |
| | | trace_config.freeze() |
| | | |
| | | self._resolve_charset = fallback_charset_resolver |
| | | |
| | | self._default_proxy = proxy |
| | | self._default_proxy_auth = proxy_auth |
| | | self._retry_connection: bool = True |
| | | self._middlewares = middlewares |
| | | |
| | | def __init_subclass__(cls: Type["ClientSession"]) -> None: |
| | | warnings.warn( |
| | | "Inheritance class {} from ClientSession " |
| | | "is discouraged".format(cls.__name__), |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | |
| | | if DEBUG: |
| | | |
| | | def __setattr__(self, name: str, val: Any) -> None: |
| | | if name not in self.ATTRS: |
| | | warnings.warn( |
| | | "Setting custom ClientSession.{} attribute " |
| | | "is discouraged".format(name), |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | super().__setattr__(name, val) |
| | | |
| | | def __del__(self, _warnings: Any = warnings) -> None: |
| | | if not self.closed: |
| | | kwargs = {"source": self} |
| | | _warnings.warn( |
| | | f"Unclosed client session {self!r}", ResourceWarning, **kwargs |
| | | ) |
| | | context = {"client_session": self, "message": "Unclosed client session"} |
| | | if self._source_traceback is not None: |
| | | context["source_traceback"] = self._source_traceback |
| | | self._loop.call_exception_handler(context) |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | |
| | | def request( |
| | | self, |
| | | method: str, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | else: |
| | | |
| | | def request( |
| | | self, method: str, url: StrOrURL, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP request.""" |
| | | return _RequestContextManager(self._request(method, url, **kwargs)) |
| | | |
| | | def _build_url(self, str_or_url: StrOrURL) -> URL: |
| | | url = URL(str_or_url) |
| | | if self._base_url and not url.absolute: |
| | | return self._base_url.join(url) |
| | | return url |
| | | |
| | | async def _request( |
| | | self, |
| | | method: str, |
| | | str_or_url: StrOrURL, |
| | | *, |
| | | params: Query = None, |
| | | data: Any = None, |
| | | json: Any = None, |
| | | cookies: Optional[LooseCookies] = None, |
| | | headers: Optional[LooseHeaders] = None, |
| | | skip_auto_headers: Optional[Iterable[str]] = None, |
| | | auth: Optional[BasicAuth] = None, |
| | | allow_redirects: bool = True, |
| | | max_redirects: int = 10, |
| | | compress: Union[str, bool, None] = None, |
| | | chunked: Optional[bool] = None, |
| | | expect100: bool = False, |
| | | raise_for_status: Union[ |
| | | None, bool, Callable[[ClientResponse], Awaitable[None]] |
| | | ] = None, |
| | | read_until_eof: bool = True, |
| | | proxy: Optional[StrOrURL] = None, |
| | | proxy_auth: Optional[BasicAuth] = None, |
| | | timeout: Union[ClientTimeout, _SENTINEL] = sentinel, |
| | | verify_ssl: Optional[bool] = None, |
| | | fingerprint: Optional[bytes] = None, |
| | | ssl_context: Optional[SSLContext] = None, |
| | | ssl: Union[SSLContext, bool, Fingerprint] = True, |
| | | server_hostname: Optional[str] = None, |
| | | proxy_headers: Optional[LooseHeaders] = None, |
| | | trace_request_ctx: Optional[Mapping[str, Any]] = None, |
| | | read_bufsize: Optional[int] = None, |
| | | auto_decompress: Optional[bool] = None, |
| | | max_line_size: Optional[int] = None, |
| | | max_field_size: Optional[int] = None, |
| | | middlewares: Optional[Sequence[ClientMiddlewareType]] = None, |
| | | ) -> ClientResponse: |
| | | |
| | | # NOTE: timeout clamps existing connect and read timeouts. We cannot |
| | | # set the default to None because we need to detect if the user wants |
| | | # to use the existing timeouts by setting timeout to None. |
| | | |
| | | if self.closed: |
| | | raise RuntimeError("Session is closed") |
| | | |
| | | ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) |
| | | |
| | | if data is not None and json is not None: |
| | | raise ValueError( |
| | | "data and json parameters can not be used at the same time" |
| | | ) |
| | | elif json is not None: |
| | | data = payload.JsonPayload(json, dumps=self._json_serialize) |
| | | |
| | | if not isinstance(chunked, bool) and chunked is not None: |
| | | warnings.warn("Chunk size is deprecated #1615", DeprecationWarning) |
| | | |
| | | redirects = 0 |
| | | history: List[ClientResponse] = [] |
| | | version = self._version |
| | | params = params or {} |
| | | |
| | | # Merge with default headers and transform to CIMultiDict |
| | | headers = self._prepare_headers(headers) |
| | | |
| | | try: |
| | | url = self._build_url(str_or_url) |
| | | except ValueError as e: |
| | | raise InvalidUrlClientError(str_or_url) from e |
| | | |
| | | assert self._connector is not None |
| | | if url.scheme not in self._connector.allowed_protocol_schema_set: |
| | | raise NonHttpUrlClientError(url) |
| | | |
| | | skip_headers: Optional[Iterable[istr]] |
| | | if skip_auto_headers is not None: |
| | | skip_headers = { |
| | | istr(i) for i in skip_auto_headers |
| | | } | self._skip_auto_headers |
| | | elif self._skip_auto_headers: |
| | | skip_headers = self._skip_auto_headers |
| | | else: |
| | | skip_headers = None |
| | | |
| | | if proxy is None: |
| | | proxy = self._default_proxy |
| | | if proxy_auth is None: |
| | | proxy_auth = self._default_proxy_auth |
| | | |
| | | if proxy is None: |
| | | proxy_headers = None |
| | | else: |
| | | proxy_headers = self._prepare_headers(proxy_headers) |
| | | try: |
| | | proxy = URL(proxy) |
| | | except ValueError as e: |
| | | raise InvalidURL(proxy) from e |
| | | |
| | | if timeout is sentinel: |
| | | real_timeout: ClientTimeout = self._timeout |
| | | else: |
| | | if not isinstance(timeout, ClientTimeout): |
| | | real_timeout = ClientTimeout(total=timeout) |
| | | else: |
| | | real_timeout = timeout |
| | | # timeout is cumulative for all request operations |
| | | # (request, redirects, responses, data consuming) |
| | | tm = TimeoutHandle( |
| | | self._loop, real_timeout.total, ceil_threshold=real_timeout.ceil_threshold |
| | | ) |
| | | handle = tm.start() |
| | | |
| | | if read_bufsize is None: |
| | | read_bufsize = self._read_bufsize |
| | | |
| | | if auto_decompress is None: |
| | | auto_decompress = self._auto_decompress |
| | | |
| | | if max_line_size is None: |
| | | max_line_size = self._max_line_size |
| | | |
| | | if max_field_size is None: |
| | | max_field_size = self._max_field_size |
| | | |
| | | traces = [ |
| | | Trace( |
| | | self, |
| | | trace_config, |
| | | trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), |
| | | ) |
| | | for trace_config in self._trace_configs |
| | | ] |
| | | |
| | | for trace in traces: |
| | | await trace.send_request_start(method, url.update_query(params), headers) |
| | | |
| | | timer = tm.timer() |
| | | try: |
| | | with timer: |
| | | # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests |
| | | retry_persistent_connection = ( |
| | | self._retry_connection and method in IDEMPOTENT_METHODS |
| | | ) |
| | | while True: |
| | | url, auth_from_url = strip_auth_from_url(url) |
| | | if not url.raw_host: |
| | | # NOTE: Bail early, otherwise, causes `InvalidURL` through |
| | | # NOTE: `self._request_class()` below. |
| | | err_exc_cls = ( |
| | | InvalidUrlRedirectClientError |
| | | if redirects |
| | | else InvalidUrlClientError |
| | | ) |
| | | raise err_exc_cls(url) |
| | | # If `auth` was passed for an already authenticated URL, |
| | | # disallow only if this is the initial URL; this is to avoid issues |
| | | # with sketchy redirects that are not the caller's responsibility |
| | | if not history and (auth and auth_from_url): |
| | | raise ValueError( |
| | | "Cannot combine AUTH argument with " |
| | | "credentials encoded in URL" |
| | | ) |
| | | |
| | | # Override the auth with the one from the URL only if we |
| | | # have no auth, or if we got an auth from a redirect URL |
| | | if auth is None or (history and auth_from_url is not None): |
| | | auth = auth_from_url |
| | | |
| | | if ( |
| | | auth is None |
| | | and self._default_auth |
| | | and ( |
| | | not self._base_url or self._base_url_origin == url.origin() |
| | | ) |
| | | ): |
| | | auth = self._default_auth |
| | | |
| | | # Try netrc if auth is still None and trust_env is enabled. |
| | | if auth is None and self._trust_env and url.host is not None: |
| | | auth = await self._loop.run_in_executor( |
| | | None, self._get_netrc_auth, url.host |
| | | ) |
| | | |
| | | # It would be confusing if we support explicit |
| | | # Authorization header with auth argument |
| | | if ( |
| | | headers is not None |
| | | and auth is not None |
| | | and hdrs.AUTHORIZATION in headers |
| | | ): |
| | | raise ValueError( |
| | | "Cannot combine AUTHORIZATION header " |
| | | "with AUTH argument or credentials " |
| | | "encoded in URL" |
| | | ) |
| | | |
| | | all_cookies = self._cookie_jar.filter_cookies(url) |
| | | |
| | | if cookies is not None: |
| | | tmp_cookie_jar = CookieJar( |
| | | quote_cookie=self._cookie_jar.quote_cookie |
| | | ) |
| | | tmp_cookie_jar.update_cookies(cookies) |
| | | req_cookies = tmp_cookie_jar.filter_cookies(url) |
| | | if req_cookies: |
| | | all_cookies.load(req_cookies) |
| | | |
| | | proxy_: Optional[URL] = None |
| | | if proxy is not None: |
| | | proxy_ = URL(proxy) |
| | | elif self._trust_env: |
| | | with suppress(LookupError): |
| | | proxy_, proxy_auth = await asyncio.to_thread( |
| | | get_env_proxy_for_url, url |
| | | ) |
| | | |
| | | req = self._request_class( |
| | | method, |
| | | url, |
| | | params=params, |
| | | headers=headers, |
| | | skip_auto_headers=skip_headers, |
| | | data=data, |
| | | cookies=all_cookies, |
| | | auth=auth, |
| | | version=version, |
| | | compress=compress, |
| | | chunked=chunked, |
| | | expect100=expect100, |
| | | loop=self._loop, |
| | | response_class=self._response_class, |
| | | proxy=proxy_, |
| | | proxy_auth=proxy_auth, |
| | | timer=timer, |
| | | session=self, |
| | | ssl=ssl if ssl is not None else True, |
| | | server_hostname=server_hostname, |
| | | proxy_headers=proxy_headers, |
| | | traces=traces, |
| | | trust_env=self.trust_env, |
| | | ) |
| | | |
| | | async def _connect_and_send_request( |
| | | req: ClientRequest, |
| | | ) -> ClientResponse: |
| | | # connection timeout |
| | | assert self._connector is not None |
| | | try: |
| | | conn = await self._connector.connect( |
| | | req, traces=traces, timeout=real_timeout |
| | | ) |
| | | except asyncio.TimeoutError as exc: |
| | | raise ConnectionTimeoutError( |
| | | f"Connection timeout to host {req.url}" |
| | | ) from exc |
| | | |
| | | assert conn.protocol is not None |
| | | conn.protocol.set_response_params( |
| | | timer=timer, |
| | | skip_payload=req.method in EMPTY_BODY_METHODS, |
| | | read_until_eof=read_until_eof, |
| | | auto_decompress=auto_decompress, |
| | | read_timeout=real_timeout.sock_read, |
| | | read_bufsize=read_bufsize, |
| | | timeout_ceil_threshold=self._connector._timeout_ceil_threshold, |
| | | max_line_size=max_line_size, |
| | | max_field_size=max_field_size, |
| | | ) |
| | | try: |
| | | resp = await req.send(conn) |
| | | try: |
| | | await resp.start(conn) |
| | | except BaseException: |
| | | resp.close() |
| | | raise |
| | | except BaseException: |
| | | conn.close() |
| | | raise |
| | | return resp |
| | | |
| | | # Apply middleware (if any) - per-request middleware overrides session middleware |
| | | effective_middlewares = ( |
| | | self._middlewares if middlewares is None else middlewares |
| | | ) |
| | | |
| | | if effective_middlewares: |
| | | handler = build_client_middlewares( |
| | | _connect_and_send_request, effective_middlewares |
| | | ) |
| | | else: |
| | | handler = _connect_and_send_request |
| | | |
| | | try: |
| | | resp = await handler(req) |
| | | # Client connector errors should not be retried |
| | | except ( |
| | | ConnectionTimeoutError, |
| | | ClientConnectorError, |
| | | ClientConnectorCertificateError, |
| | | ClientConnectorSSLError, |
| | | ): |
| | | raise |
| | | except (ClientOSError, ServerDisconnectedError): |
| | | if retry_persistent_connection: |
| | | retry_persistent_connection = False |
| | | continue |
| | | raise |
| | | except ClientError: |
| | | raise |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise ClientOSError(*exc.args) from exc |
| | | |
| | | # Update cookies from raw headers to preserve duplicates |
| | | if resp._raw_cookie_headers: |
| | | self._cookie_jar.update_cookies_from_headers( |
| | | resp._raw_cookie_headers, resp.url |
| | | ) |
| | | |
| | | # redirects |
| | | if resp.status in (301, 302, 303, 307, 308) and allow_redirects: |
| | | |
| | | for trace in traces: |
| | | await trace.send_request_redirect( |
| | | method, url.update_query(params), headers, resp |
| | | ) |
| | | |
| | | redirects += 1 |
| | | history.append(resp) |
| | | if max_redirects and redirects >= max_redirects: |
| | | if req._body is not None: |
| | | await req._body.close() |
| | | resp.close() |
| | | raise TooManyRedirects( |
| | | history[0].request_info, tuple(history) |
| | | ) |
| | | |
| | | # For 301 and 302, mimic IE, now changed in RFC |
| | | # https://github.com/kennethreitz/requests/pull/269 |
| | | if (resp.status == 303 and resp.method != hdrs.METH_HEAD) or ( |
| | | resp.status in (301, 302) and resp.method == hdrs.METH_POST |
| | | ): |
| | | method = hdrs.METH_GET |
| | | data = None |
| | | if headers.get(hdrs.CONTENT_LENGTH): |
| | | headers.pop(hdrs.CONTENT_LENGTH) |
| | | else: |
| | | # For 307/308, always preserve the request body |
| | | # For 301/302 with non-POST methods, preserve the request body |
| | | # https://www.rfc-editor.org/rfc/rfc9110#section-15.4.3-3.1 |
| | | # Use the existing payload to avoid recreating it from a potentially consumed file |
| | | data = req._body |
| | | |
| | | r_url = resp.headers.get(hdrs.LOCATION) or resp.headers.get( |
| | | hdrs.URI |
| | | ) |
| | | if r_url is None: |
| | | # see github.com/aio-libs/aiohttp/issues/2022 |
| | | break |
| | | else: |
| | | # reading from correct redirection |
| | | # response is forbidden |
| | | resp.release() |
| | | |
| | | try: |
| | | parsed_redirect_url = URL( |
| | | r_url, encoded=not self._requote_redirect_url |
| | | ) |
| | | except ValueError as e: |
| | | if req._body is not None: |
| | | await req._body.close() |
| | | resp.close() |
| | | raise InvalidUrlRedirectClientError( |
| | | r_url, |
| | | "Server attempted redirecting to a location that does not look like a URL", |
| | | ) from e |
| | | |
| | | scheme = parsed_redirect_url.scheme |
| | | if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: |
| | | if req._body is not None: |
| | | await req._body.close() |
| | | resp.close() |
| | | raise NonHttpUrlRedirectClientError(r_url) |
| | | elif not scheme: |
| | | parsed_redirect_url = url.join(parsed_redirect_url) |
| | | |
| | | try: |
| | | redirect_origin = parsed_redirect_url.origin() |
| | | except ValueError as origin_val_err: |
| | | if req._body is not None: |
| | | await req._body.close() |
| | | resp.close() |
| | | raise InvalidUrlRedirectClientError( |
| | | parsed_redirect_url, |
| | | "Invalid redirect URL origin", |
| | | ) from origin_val_err |
| | | |
| | | if url.origin() != redirect_origin: |
| | | auth = None |
| | | headers.pop(hdrs.AUTHORIZATION, None) |
| | | |
| | | url = parsed_redirect_url |
| | | params = {} |
| | | resp.release() |
| | | continue |
| | | |
| | | break |
| | | |
| | | if req._body is not None: |
| | | await req._body.close() |
| | | # check response status |
| | | if raise_for_status is None: |
| | | raise_for_status = self._raise_for_status |
| | | |
| | | if raise_for_status is None: |
| | | pass |
| | | elif callable(raise_for_status): |
| | | await raise_for_status(resp) |
| | | elif raise_for_status: |
| | | resp.raise_for_status() |
| | | |
| | | # register connection |
| | | if handle is not None: |
| | | if resp.connection is not None: |
| | | resp.connection.add_callback(handle.cancel) |
| | | else: |
| | | handle.cancel() |
| | | |
| | | resp._history = tuple(history) |
| | | |
| | | for trace in traces: |
| | | await trace.send_request_end( |
| | | method, url.update_query(params), headers, resp |
| | | ) |
| | | return resp |
| | | |
| | | except BaseException as e: |
| | | # cleanup timer |
| | | tm.close() |
| | | if handle: |
| | | handle.cancel() |
| | | handle = None |
| | | |
| | | for trace in traces: |
| | | await trace.send_request_exception( |
| | | method, url.update_query(params), headers, e |
| | | ) |
| | | raise |
| | | |
| | | def ws_connect( |
| | | self, |
| | | url: StrOrURL, |
| | | *, |
| | | method: str = hdrs.METH_GET, |
| | | protocols: Iterable[str] = (), |
| | | timeout: Union[ClientWSTimeout, _SENTINEL] = sentinel, |
| | | receive_timeout: Optional[float] = None, |
| | | autoclose: bool = True, |
| | | autoping: bool = True, |
| | | heartbeat: Optional[float] = None, |
| | | auth: Optional[BasicAuth] = None, |
| | | origin: Optional[str] = None, |
| | | params: Query = None, |
| | | headers: Optional[LooseHeaders] = None, |
| | | proxy: Optional[StrOrURL] = None, |
| | | proxy_auth: Optional[BasicAuth] = None, |
| | | ssl: Union[SSLContext, bool, Fingerprint] = True, |
| | | verify_ssl: Optional[bool] = None, |
| | | fingerprint: Optional[bytes] = None, |
| | | ssl_context: Optional[SSLContext] = None, |
| | | server_hostname: Optional[str] = None, |
| | | proxy_headers: Optional[LooseHeaders] = None, |
| | | compress: int = 0, |
| | | max_msg_size: int = 4 * 1024 * 1024, |
| | | ) -> "_WSRequestContextManager": |
| | | """Initiate websocket connection.""" |
| | | return _WSRequestContextManager( |
| | | self._ws_connect( |
| | | url, |
| | | method=method, |
| | | protocols=protocols, |
| | | timeout=timeout, |
| | | receive_timeout=receive_timeout, |
| | | autoclose=autoclose, |
| | | autoping=autoping, |
| | | heartbeat=heartbeat, |
| | | auth=auth, |
| | | origin=origin, |
| | | params=params, |
| | | headers=headers, |
| | | proxy=proxy, |
| | | proxy_auth=proxy_auth, |
| | | ssl=ssl, |
| | | verify_ssl=verify_ssl, |
| | | fingerprint=fingerprint, |
| | | ssl_context=ssl_context, |
| | | server_hostname=server_hostname, |
| | | proxy_headers=proxy_headers, |
| | | compress=compress, |
| | | max_msg_size=max_msg_size, |
| | | ) |
| | | ) |
| | | |
| | | async def _ws_connect( |
| | | self, |
| | | url: StrOrURL, |
| | | *, |
| | | method: str = hdrs.METH_GET, |
| | | protocols: Iterable[str] = (), |
| | | timeout: Union[ClientWSTimeout, _SENTINEL] = sentinel, |
| | | receive_timeout: Optional[float] = None, |
| | | autoclose: bool = True, |
| | | autoping: bool = True, |
| | | heartbeat: Optional[float] = None, |
| | | auth: Optional[BasicAuth] = None, |
| | | origin: Optional[str] = None, |
| | | params: Query = None, |
| | | headers: Optional[LooseHeaders] = None, |
| | | proxy: Optional[StrOrURL] = None, |
| | | proxy_auth: Optional[BasicAuth] = None, |
| | | ssl: Union[SSLContext, bool, Fingerprint] = True, |
| | | verify_ssl: Optional[bool] = None, |
| | | fingerprint: Optional[bytes] = None, |
| | | ssl_context: Optional[SSLContext] = None, |
| | | server_hostname: Optional[str] = None, |
| | | proxy_headers: Optional[LooseHeaders] = None, |
| | | compress: int = 0, |
| | | max_msg_size: int = 4 * 1024 * 1024, |
| | | ) -> ClientWebSocketResponse: |
| | | if timeout is not sentinel: |
| | | if isinstance(timeout, ClientWSTimeout): |
| | | ws_timeout = timeout |
| | | else: |
| | | warnings.warn( |
| | | "parameter 'timeout' of type 'float' " |
| | | "is deprecated, please use " |
| | | "'timeout=ClientWSTimeout(ws_close=...)'", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | ws_timeout = ClientWSTimeout(ws_close=timeout) |
| | | else: |
| | | ws_timeout = DEFAULT_WS_CLIENT_TIMEOUT |
| | | if receive_timeout is not None: |
| | | warnings.warn( |
| | | "float parameter 'receive_timeout' " |
| | | "is deprecated, please use parameter " |
| | | "'timeout=ClientWSTimeout(ws_receive=...)'", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | ws_timeout = attr.evolve(ws_timeout, ws_receive=receive_timeout) |
| | | |
| | | if headers is None: |
| | | real_headers: CIMultiDict[str] = CIMultiDict() |
| | | else: |
| | | real_headers = CIMultiDict(headers) |
| | | |
| | | default_headers = { |
| | | hdrs.UPGRADE: "websocket", |
| | | hdrs.CONNECTION: "Upgrade", |
| | | hdrs.SEC_WEBSOCKET_VERSION: "13", |
| | | } |
| | | |
| | | for key, value in default_headers.items(): |
| | | real_headers.setdefault(key, value) |
| | | |
| | | sec_key = base64.b64encode(os.urandom(16)) |
| | | real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() |
| | | |
| | | if protocols: |
| | | real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ",".join(protocols) |
| | | if origin is not None: |
| | | real_headers[hdrs.ORIGIN] = origin |
| | | if compress: |
| | | extstr = ws_ext_gen(compress=compress) |
| | | real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr |
| | | |
| | | # For the sake of backward compatibility, if user passes in None, convert it to True |
| | | if ssl is None: |
| | | warnings.warn( |
| | | "ssl=None is deprecated, please use ssl=True", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | ssl = True |
| | | ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) |
| | | |
| | | # send request |
| | | resp = await self.request( |
| | | method, |
| | | url, |
| | | params=params, |
| | | headers=real_headers, |
| | | read_until_eof=False, |
| | | auth=auth, |
| | | proxy=proxy, |
| | | proxy_auth=proxy_auth, |
| | | ssl=ssl, |
| | | server_hostname=server_hostname, |
| | | proxy_headers=proxy_headers, |
| | | ) |
| | | |
| | | try: |
| | | # check handshake |
| | | if resp.status != 101: |
| | | raise WSServerHandshakeError( |
| | | resp.request_info, |
| | | resp.history, |
| | | message="Invalid response status", |
| | | status=resp.status, |
| | | headers=resp.headers, |
| | | ) |
| | | |
| | | if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket": |
| | | raise WSServerHandshakeError( |
| | | resp.request_info, |
| | | resp.history, |
| | | message="Invalid upgrade header", |
| | | status=resp.status, |
| | | headers=resp.headers, |
| | | ) |
| | | |
| | | if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade": |
| | | raise WSServerHandshakeError( |
| | | resp.request_info, |
| | | resp.history, |
| | | message="Invalid connection header", |
| | | status=resp.status, |
| | | headers=resp.headers, |
| | | ) |
| | | |
| | | # key calculation |
| | | r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "") |
| | | match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() |
| | | if r_key != match: |
| | | raise WSServerHandshakeError( |
| | | resp.request_info, |
| | | resp.history, |
| | | message="Invalid challenge response", |
| | | status=resp.status, |
| | | headers=resp.headers, |
| | | ) |
| | | |
| | | # websocket protocol |
| | | protocol = None |
| | | if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: |
| | | resp_protocols = [ |
| | | proto.strip() |
| | | for proto in resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") |
| | | ] |
| | | |
| | | for proto in resp_protocols: |
| | | if proto in protocols: |
| | | protocol = proto |
| | | break |
| | | |
| | | # websocket compress |
| | | notakeover = False |
| | | if compress: |
| | | compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) |
| | | if compress_hdrs: |
| | | try: |
| | | compress, notakeover = ws_ext_parse(compress_hdrs) |
| | | except WSHandshakeError as exc: |
| | | raise WSServerHandshakeError( |
| | | resp.request_info, |
| | | resp.history, |
| | | message=exc.args[0], |
| | | status=resp.status, |
| | | headers=resp.headers, |
| | | ) from exc |
| | | else: |
| | | compress = 0 |
| | | notakeover = False |
| | | |
| | | conn = resp.connection |
| | | assert conn is not None |
| | | conn_proto = conn.protocol |
| | | assert conn_proto is not None |
| | | |
| | | # For WS connection the read_timeout must be either receive_timeout or greater |
| | | # None == no timeout, i.e. infinite timeout, so None is the max timeout possible |
| | | if ws_timeout.ws_receive is None: |
| | | # Reset regardless |
| | | conn_proto.read_timeout = None |
| | | elif conn_proto.read_timeout is not None: |
| | | conn_proto.read_timeout = max( |
| | | ws_timeout.ws_receive, conn_proto.read_timeout |
| | | ) |
| | | |
| | | transport = conn.transport |
| | | assert transport is not None |
| | | reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) |
| | | conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) |
| | | writer = WebSocketWriter( |
| | | conn_proto, |
| | | transport, |
| | | use_mask=True, |
| | | compress=compress, |
| | | notakeover=notakeover, |
| | | ) |
| | | except BaseException: |
| | | resp.close() |
| | | raise |
| | | else: |
| | | return self._ws_response_class( |
| | | reader, |
| | | writer, |
| | | protocol, |
| | | resp, |
| | | ws_timeout, |
| | | autoclose, |
| | | autoping, |
| | | self._loop, |
| | | heartbeat=heartbeat, |
| | | compress=compress, |
| | | client_notakeover=notakeover, |
| | | ) |
| | | |
| | | def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]": |
| | | """Add default headers and transform it to CIMultiDict""" |
| | | # Convert headers to MultiDict |
| | | result = CIMultiDict(self._default_headers) |
| | | if headers: |
| | | if not isinstance(headers, (MultiDictProxy, MultiDict)): |
| | | headers = CIMultiDict(headers) |
| | | added_names: Set[str] = set() |
| | | for key, value in headers.items(): |
| | | if key in added_names: |
| | | result.add(key, value) |
| | | else: |
| | | result[key] = value |
| | | added_names.add(key) |
| | | return result |
| | | |
| | | def _get_netrc_auth(self, host: str) -> Optional[BasicAuth]: |
| | | """ |
| | | Get auth from netrc for the given host. |
| | | |
| | | This method is designed to be called in an executor to avoid |
| | | blocking I/O in the event loop. |
| | | """ |
| | | netrc_obj = netrc_from_env() |
| | | try: |
| | | return basicauth_from_netrc(netrc_obj, host) |
| | | except LookupError: |
| | | return None |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | |
| | | def get( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def options( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def head( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def post( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def put( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def patch( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | def delete( |
| | | self, |
| | | url: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> "_RequestContextManager": ... |
| | | |
| | | else: |
| | | |
| | | def get( |
| | | self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP GET request.""" |
| | | return _RequestContextManager( |
| | | self._request( |
| | | hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs |
| | | ) |
| | | ) |
| | | |
| | | def options( |
| | | self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP OPTIONS request.""" |
| | | return _RequestContextManager( |
| | | self._request( |
| | | hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs |
| | | ) |
| | | ) |
| | | |
| | | def head( |
| | | self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP HEAD request.""" |
| | | return _RequestContextManager( |
| | | self._request( |
| | | hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs |
| | | ) |
| | | ) |
| | | |
| | | def post( |
| | | self, url: StrOrURL, *, data: Any = None, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP POST request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_POST, url, data=data, **kwargs) |
| | | ) |
| | | |
| | | def put( |
| | | self, url: StrOrURL, *, data: Any = None, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP PUT request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_PUT, url, data=data, **kwargs) |
| | | ) |
| | | |
| | | def patch( |
| | | self, url: StrOrURL, *, data: Any = None, **kwargs: Any |
| | | ) -> "_RequestContextManager": |
| | | """Perform HTTP PATCH request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_PATCH, url, data=data, **kwargs) |
| | | ) |
| | | |
| | | def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": |
| | | """Perform HTTP DELETE request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_DELETE, url, **kwargs) |
| | | ) |
| | | |
| | | async def close(self) -> None: |
| | | """Close underlying connector. |
| | | |
| | | Release all acquired resources. |
| | | """ |
| | | if not self.closed: |
| | | if self._connector is not None and self._connector_owner: |
| | | await self._connector.close() |
| | | self._connector = None |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | """Is client session closed. |
| | | |
| | | A readonly property. |
| | | """ |
| | | return self._connector is None or self._connector.closed |
| | | |
| | | @property |
| | | def connector(self) -> Optional[BaseConnector]: |
| | | """Connector instance used for the session.""" |
| | | return self._connector |
| | | |
| | | @property |
| | | def cookie_jar(self) -> AbstractCookieJar: |
| | | """The session cookies.""" |
| | | return self._cookie_jar |
| | | |
| | | @property |
| | | def version(self) -> Tuple[int, int]: |
| | | """The session HTTP protocol version.""" |
| | | return self._version |
| | | |
| | | @property |
| | | def requote_redirect_url(self) -> bool: |
| | | """Do URL requoting on redirection handling.""" |
| | | return self._requote_redirect_url |
| | | |
| | | @requote_redirect_url.setter |
| | | def requote_redirect_url(self, val: bool) -> None: |
| | | """Do URL requoting on redirection handling.""" |
| | | warnings.warn( |
| | | "session.requote_redirect_url modification is deprecated #2778", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | self._requote_redirect_url = val |
| | | |
| | | @property |
| | | def loop(self) -> asyncio.AbstractEventLoop: |
| | | """Session's loop.""" |
| | | warnings.warn( |
| | | "client.loop property is deprecated", DeprecationWarning, stacklevel=2 |
| | | ) |
| | | return self._loop |
| | | |
| | | @property |
| | | def timeout(self) -> ClientTimeout: |
| | | """Timeout for the session.""" |
| | | return self._timeout |
| | | |
| | | @property |
| | | def headers(self) -> "CIMultiDict[str]": |
| | | """The default headers of the client session.""" |
| | | return self._default_headers |
| | | |
| | | @property |
| | | def skip_auto_headers(self) -> FrozenSet[istr]: |
| | | """Headers for which autogeneration should be skipped""" |
| | | return self._skip_auto_headers |
| | | |
| | | @property |
| | | def auth(self) -> Optional[BasicAuth]: |
| | | """An object that represents HTTP Basic Authorization""" |
| | | return self._default_auth |
| | | |
| | | @property |
| | | def json_serialize(self) -> JSONEncoder: |
| | | """Json serializer callable""" |
| | | return self._json_serialize |
| | | |
| | | @property |
| | | def connector_owner(self) -> bool: |
| | | """Should connector be closed on session closing""" |
| | | return self._connector_owner |
| | | |
| | | @property |
| | | def raise_for_status( |
| | | self, |
| | | ) -> Union[bool, Callable[[ClientResponse], Awaitable[None]]]: |
| | | """Should `ClientResponse.raise_for_status()` be called for each response.""" |
| | | return self._raise_for_status |
| | | |
| | | @property |
| | | def auto_decompress(self) -> bool: |
| | | """Should the body response be automatically decompressed.""" |
| | | return self._auto_decompress |
| | | |
| | | @property |
| | | def trust_env(self) -> bool: |
| | | """ |
| | | Should proxies information from environment or netrc be trusted. |
| | | |
| | | Information is from HTTP_PROXY / HTTPS_PROXY environment variables |
| | | or ~/.netrc file if present. |
| | | """ |
| | | return self._trust_env |
| | | |
| | | @property |
| | | def trace_configs(self) -> List[TraceConfig]: |
| | | """A list of TraceConfig instances used for client tracing""" |
| | | return self._trace_configs |
| | | |
| | | def detach(self) -> None: |
| | | """Detach connector from session without closing the former. |
| | | |
| | | Session is switched to closed state anyway. |
| | | """ |
| | | self._connector = None |
| | | |
| | | def __enter__(self) -> None: |
| | | raise TypeError("Use async with instead") |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | # __exit__ should exist in pair with __enter__ but never executed |
| | | pass # pragma: no cover |
| | | |
| | | async def __aenter__(self) -> "ClientSession": |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | await self.close() |
| | | |
| | | |
| | | class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): |
| | | |
| | | __slots__ = ("_coro", "_resp") |
| | | |
| | | def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: |
| | | self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro |
| | | |
| | | def send(self, arg: None) -> "asyncio.Future[Any]": |
| | | return self._coro.send(arg) |
| | | |
| | | def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]": |
| | | return self._coro.throw(*args, **kwargs) |
| | | |
| | | def close(self) -> None: |
| | | return self._coro.close() |
| | | |
| | | def __await__(self) -> Generator[Any, None, _RetType]: |
| | | ret = self._coro.__await__() |
| | | return ret |
| | | |
| | | def __iter__(self) -> Generator[Any, None, _RetType]: |
| | | return self.__await__() |
| | | |
| | | async def __aenter__(self) -> _RetType: |
| | | self._resp: _RetType = await self._coro |
| | | return await self._resp.__aenter__() |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc: Optional[BaseException], |
| | | tb: Optional[TracebackType], |
| | | ) -> None: |
| | | await self._resp.__aexit__(exc_type, exc, tb) |
| | | |
| | | |
| | | _RequestContextManager = _BaseRequestContextManager[ClientResponse] |
| | | _WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse] |
| | | |
| | | |
| | | class _SessionRequestContextManager: |
| | | |
| | | __slots__ = ("_coro", "_resp", "_session") |
| | | |
| | | def __init__( |
| | | self, |
| | | coro: Coroutine["asyncio.Future[Any]", None, ClientResponse], |
| | | session: ClientSession, |
| | | ) -> None: |
| | | self._coro = coro |
| | | self._resp: Optional[ClientResponse] = None |
| | | self._session = session |
| | | |
| | | async def __aenter__(self) -> ClientResponse: |
| | | try: |
| | | self._resp = await self._coro |
| | | except BaseException: |
| | | await self._session.close() |
| | | raise |
| | | else: |
| | | return self._resp |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc: Optional[BaseException], |
| | | tb: Optional[TracebackType], |
| | | ) -> None: |
| | | assert self._resp is not None |
| | | self._resp.close() |
| | | await self._session.close() |
| | | |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | |
| | | def request( |
| | | method: str, |
| | | url: StrOrURL, |
| | | *, |
| | | version: HttpVersion = http.HttpVersion11, |
| | | connector: Optional[BaseConnector] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _SessionRequestContextManager: ... |
| | | |
| | | else: |
| | | |
| | | def request( |
| | | method: str, |
| | | url: StrOrURL, |
| | | *, |
| | | version: HttpVersion = http.HttpVersion11, |
| | | connector: Optional[BaseConnector] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | **kwargs: Any, |
| | | ) -> _SessionRequestContextManager: |
| | | """Constructs and sends a request. |
| | | |
| | | Returns response object. |
| | | method - HTTP method |
| | | url - request url |
| | | params - (optional) Dictionary or bytes to be sent in the query |
| | | string of the new request |
| | | data - (optional) Dictionary, bytes, or file-like object to |
| | | send in the body of the request |
| | | json - (optional) Any json compatible python object |
| | | headers - (optional) Dictionary of HTTP Headers to send with |
| | | the request |
| | | cookies - (optional) Dict object to send with the request |
| | | auth - (optional) BasicAuth named tuple represent HTTP Basic Auth |
| | | auth - aiohttp.helpers.BasicAuth |
| | | allow_redirects - (optional) If set to False, do not follow |
| | | redirects |
| | | version - Request HTTP version. |
| | | compress - Set to True if request has to be compressed |
| | | with deflate encoding. |
| | | chunked - Set to chunk size for chunked transfer encoding. |
| | | expect100 - Expect 100-continue response from server. |
| | | connector - BaseConnector sub-class instance to support |
| | | connection pooling. |
| | | read_until_eof - Read response until eof if response |
| | | does not have Content-Length header. |
| | | loop - Optional event loop. |
| | | timeout - Optional ClientTimeout settings structure, 5min |
| | | total timeout by default. |
| | | Usage:: |
| | | >>> import aiohttp |
| | | >>> async with aiohttp.request('GET', 'http://python.org/') as resp: |
| | | ... print(resp) |
| | | ... data = await resp.read() |
| | | <ClientResponse(https://www.python.org/) [200 OK]> |
| | | """ |
| | | connector_owner = False |
| | | if connector is None: |
| | | connector_owner = True |
| | | connector = TCPConnector(loop=loop, force_close=True) |
| | | |
| | | session = ClientSession( |
| | | loop=loop, |
| | | cookies=kwargs.pop("cookies", None), |
| | | version=version, |
| | | timeout=kwargs.pop("timeout", sentinel), |
| | | connector=connector, |
| | | connector_owner=connector_owner, |
| | | ) |
| | | |
| | | return _SessionRequestContextManager( |
| | | session._request(method, url, **kwargs), |
| | | session, |
| | | ) |
| New file |
| | |
| | | """HTTP related errors.""" |
| | | |
| | | import asyncio |
| | | import warnings |
| | | from typing import TYPE_CHECKING, Optional, Tuple, Union |
| | | |
| | | from multidict import MultiMapping |
| | | |
| | | from .typedefs import StrOrURL |
| | | |
| | | if TYPE_CHECKING: |
| | | import ssl |
| | | |
| | | SSLContext = ssl.SSLContext |
| | | else: |
| | | try: |
| | | import ssl |
| | | |
| | | SSLContext = ssl.SSLContext |
| | | except ImportError: # pragma: no cover |
| | | ssl = SSLContext = None # type: ignore[assignment] |
| | | |
| | | if TYPE_CHECKING: |
| | | from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo |
| | | from .http_parser import RawResponseMessage |
| | | else: |
| | | RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None |
| | | |
| | | __all__ = ( |
| | | "ClientError", |
| | | "ClientConnectionError", |
| | | "ClientConnectionResetError", |
| | | "ClientOSError", |
| | | "ClientConnectorError", |
| | | "ClientProxyConnectionError", |
| | | "ClientSSLError", |
| | | "ClientConnectorDNSError", |
| | | "ClientConnectorSSLError", |
| | | "ClientConnectorCertificateError", |
| | | "ConnectionTimeoutError", |
| | | "SocketTimeoutError", |
| | | "ServerConnectionError", |
| | | "ServerTimeoutError", |
| | | "ServerDisconnectedError", |
| | | "ServerFingerprintMismatch", |
| | | "ClientResponseError", |
| | | "ClientHttpProxyError", |
| | | "WSServerHandshakeError", |
| | | "ContentTypeError", |
| | | "ClientPayloadError", |
| | | "InvalidURL", |
| | | "InvalidUrlClientError", |
| | | "RedirectClientError", |
| | | "NonHttpUrlClientError", |
| | | "InvalidUrlRedirectClientError", |
| | | "NonHttpUrlRedirectClientError", |
| | | "WSMessageTypeError", |
| | | ) |
| | | |
| | | |
| | | class ClientError(Exception): |
| | | """Base class for client connection errors.""" |
| | | |
| | | |
| | | class ClientResponseError(ClientError): |
| | | """Base class for exceptions that occur after getting a response. |
| | | |
| | | request_info: An instance of RequestInfo. |
| | | history: A sequence of responses, if redirects occurred. |
| | | status: HTTP status code. |
| | | message: Error message. |
| | | headers: Response headers. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | request_info: RequestInfo, |
| | | history: Tuple[ClientResponse, ...], |
| | | *, |
| | | code: Optional[int] = None, |
| | | status: Optional[int] = None, |
| | | message: str = "", |
| | | headers: Optional[MultiMapping[str]] = None, |
| | | ) -> None: |
| | | self.request_info = request_info |
| | | if code is not None: |
| | | if status is not None: |
| | | raise ValueError( |
| | | "Both code and status arguments are provided; " |
| | | "code is deprecated, use status instead" |
| | | ) |
| | | warnings.warn( |
| | | "code argument is deprecated, use status instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | if status is not None: |
| | | self.status = status |
| | | elif code is not None: |
| | | self.status = code |
| | | else: |
| | | self.status = 0 |
| | | self.message = message |
| | | self.headers = headers |
| | | self.history = history |
| | | self.args = (request_info, history) |
| | | |
| | | def __str__(self) -> str: |
| | | return "{}, message={!r}, url={!r}".format( |
| | | self.status, |
| | | self.message, |
| | | str(self.request_info.real_url), |
| | | ) |
| | | |
| | | def __repr__(self) -> str: |
| | | args = f"{self.request_info!r}, {self.history!r}" |
| | | if self.status != 0: |
| | | args += f", status={self.status!r}" |
| | | if self.message != "": |
| | | args += f", message={self.message!r}" |
| | | if self.headers is not None: |
| | | args += f", headers={self.headers!r}" |
| | | return f"{type(self).__name__}({args})" |
| | | |
| | | @property |
| | | def code(self) -> int: |
| | | warnings.warn( |
| | | "code property is deprecated, use status instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return self.status |
| | | |
| | | @code.setter |
| | | def code(self, value: int) -> None: |
| | | warnings.warn( |
| | | "code property is deprecated, use status instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | self.status = value |
| | | |
| | | |
| | | class ContentTypeError(ClientResponseError): |
| | | """ContentType found is not valid.""" |
| | | |
| | | |
| | | class WSServerHandshakeError(ClientResponseError): |
| | | """websocket server handshake error.""" |
| | | |
| | | |
| | | class ClientHttpProxyError(ClientResponseError): |
| | | """HTTP proxy error. |
| | | |
| | | Raised in :class:`aiohttp.connector.TCPConnector` if |
| | | proxy responds with status other than ``200 OK`` |
| | | on ``CONNECT`` request. |
| | | """ |
| | | |
| | | |
| | | class TooManyRedirects(ClientResponseError): |
| | | """Client was redirected too many times.""" |
| | | |
| | | |
| | | class ClientConnectionError(ClientError): |
| | | """Base class for client socket errors.""" |
| | | |
| | | |
| | | class ClientConnectionResetError(ClientConnectionError, ConnectionResetError): |
| | | """ConnectionResetError""" |
| | | |
| | | |
| | | class ClientOSError(ClientConnectionError, OSError): |
| | | """OSError error.""" |
| | | |
| | | |
| | | class ClientConnectorError(ClientOSError): |
| | | """Client connector error. |
| | | |
| | | Raised in :class:`aiohttp.connector.TCPConnector` if |
| | | a connection can not be established. |
| | | """ |
| | | |
| | | def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None: |
| | | self._conn_key = connection_key |
| | | self._os_error = os_error |
| | | super().__init__(os_error.errno, os_error.strerror) |
| | | self.args = (connection_key, os_error) |
| | | |
| | | @property |
| | | def os_error(self) -> OSError: |
| | | return self._os_error |
| | | |
| | | @property |
| | | def host(self) -> str: |
| | | return self._conn_key.host |
| | | |
| | | @property |
| | | def port(self) -> Optional[int]: |
| | | return self._conn_key.port |
| | | |
| | | @property |
| | | def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]: |
| | | return self._conn_key.ssl |
| | | |
| | | def __str__(self) -> str: |
| | | return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( |
| | | self, "default" if self.ssl is True else self.ssl, self.strerror |
| | | ) |
| | | |
| | | # OSError.__reduce__ does too much black magick |
| | | __reduce__ = BaseException.__reduce__ |
| | | |
| | | |
| | | class ClientConnectorDNSError(ClientConnectorError): |
| | | """DNS resolution failed during client connection. |
| | | |
| | | Raised in :class:`aiohttp.connector.TCPConnector` if |
| | | DNS resolution fails. |
| | | """ |
| | | |
| | | |
| | | class ClientProxyConnectionError(ClientConnectorError): |
| | | """Proxy connection error. |
| | | |
| | | Raised in :class:`aiohttp.connector.TCPConnector` if |
| | | connection to proxy can not be established. |
| | | """ |
| | | |
| | | |
| | | class UnixClientConnectorError(ClientConnectorError): |
| | | """Unix connector error. |
| | | |
| | | Raised in :py:class:`aiohttp.connector.UnixConnector` |
| | | if connection to unix socket can not be established. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, path: str, connection_key: ConnectionKey, os_error: OSError |
| | | ) -> None: |
| | | self._path = path |
| | | super().__init__(connection_key, os_error) |
| | | |
| | | @property |
| | | def path(self) -> str: |
| | | return self._path |
| | | |
| | | def __str__(self) -> str: |
| | | return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format( |
| | | self, "default" if self.ssl is True else self.ssl, self.strerror |
| | | ) |
| | | |
| | | |
| | | class ServerConnectionError(ClientConnectionError): |
| | | """Server connection errors.""" |
| | | |
| | | |
| | | class ServerDisconnectedError(ServerConnectionError): |
| | | """Server disconnected.""" |
| | | |
| | | def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None: |
| | | if message is None: |
| | | message = "Server disconnected" |
| | | |
| | | self.args = (message,) |
| | | self.message = message |
| | | |
| | | |
| | | class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): |
| | | """Server timeout error.""" |
| | | |
| | | |
| | | class ConnectionTimeoutError(ServerTimeoutError): |
| | | """Connection timeout error.""" |
| | | |
| | | |
| | | class SocketTimeoutError(ServerTimeoutError): |
| | | """Socket timeout error.""" |
| | | |
| | | |
| | | class ServerFingerprintMismatch(ServerConnectionError): |
| | | """SSL certificate does not match expected fingerprint.""" |
| | | |
| | | def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None: |
| | | self.expected = expected |
| | | self.got = got |
| | | self.host = host |
| | | self.port = port |
| | | self.args = (expected, got, host, port) |
| | | |
| | | def __repr__(self) -> str: |
| | | return "<{} expected={!r} got={!r} host={!r} port={!r}>".format( |
| | | self.__class__.__name__, self.expected, self.got, self.host, self.port |
| | | ) |
| | | |
| | | |
| | | class ClientPayloadError(ClientError): |
| | | """Response payload error.""" |
| | | |
| | | |
| | | class InvalidURL(ClientError, ValueError): |
| | | """Invalid URL. |
| | | |
| | | URL used for fetching is malformed, e.g. it doesn't contains host |
| | | part. |
| | | """ |
| | | |
| | | # Derive from ValueError for backward compatibility |
| | | |
| | | def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None: |
| | | # The type of url is not yarl.URL because the exception can be raised |
| | | # on URL(url) call |
| | | self._url = url |
| | | self._description = description |
| | | |
| | | if description: |
| | | super().__init__(url, description) |
| | | else: |
| | | super().__init__(url) |
| | | |
| | | @property |
| | | def url(self) -> StrOrURL: |
| | | return self._url |
| | | |
| | | @property |
| | | def description(self) -> "str | None": |
| | | return self._description |
| | | |
| | | def __repr__(self) -> str: |
| | | return f"<{self.__class__.__name__} {self}>" |
| | | |
| | | def __str__(self) -> str: |
| | | if self._description: |
| | | return f"{self._url} - {self._description}" |
| | | return str(self._url) |
| | | |
| | | |
| | | class InvalidUrlClientError(InvalidURL): |
| | | """Invalid URL client error.""" |
| | | |
| | | |
| | | class RedirectClientError(ClientError): |
| | | """Client redirect error.""" |
| | | |
| | | |
| | | class NonHttpUrlClientError(ClientError): |
| | | """Non http URL client error.""" |
| | | |
| | | |
| | | class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError): |
| | | """Invalid URL redirect client error.""" |
| | | |
| | | |
| | | class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError): |
| | | """Non http URL redirect client error.""" |
| | | |
| | | |
| | | class ClientSSLError(ClientConnectorError): |
| | | """Base error for ssl.*Errors.""" |
| | | |
| | | |
| | | if ssl is not None: |
| | | cert_errors = (ssl.CertificateError,) |
| | | cert_errors_bases = ( |
| | | ClientSSLError, |
| | | ssl.CertificateError, |
| | | ) |
| | | |
| | | ssl_errors = (ssl.SSLError,) |
| | | ssl_error_bases = (ClientSSLError, ssl.SSLError) |
| | | else: # pragma: no cover |
| | | cert_errors = tuple() |
| | | cert_errors_bases = ( |
| | | ClientSSLError, |
| | | ValueError, |
| | | ) |
| | | |
| | | ssl_errors = tuple() |
| | | ssl_error_bases = (ClientSSLError,) |
| | | |
| | | |
| | | class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc] |
| | | """Response ssl error.""" |
| | | |
| | | |
| | | class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc] |
| | | """Response certificate error.""" |
| | | |
| | | def __init__( |
| | | self, connection_key: ConnectionKey, certificate_error: Exception |
| | | ) -> None: |
| | | self._conn_key = connection_key |
| | | self._certificate_error = certificate_error |
| | | self.args = (connection_key, certificate_error) |
| | | |
| | | @property |
| | | def certificate_error(self) -> Exception: |
| | | return self._certificate_error |
| | | |
| | | @property |
| | | def host(self) -> str: |
| | | return self._conn_key.host |
| | | |
| | | @property |
| | | def port(self) -> Optional[int]: |
| | | return self._conn_key.port |
| | | |
| | | @property |
| | | def ssl(self) -> bool: |
| | | return self._conn_key.is_ssl |
| | | |
| | | def __str__(self) -> str: |
| | | return ( |
| | | "Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} " |
| | | "[{0.certificate_error.__class__.__name__}: " |
| | | "{0.certificate_error.args}]".format(self) |
| | | ) |
| | | |
| | | |
| | | class WSMessageTypeError(TypeError): |
| | | """WebSocket message type is not valid.""" |
| New file |
| | |
| | | """ |
| | | Digest authentication middleware for aiohttp client. |
| | | |
| | | This middleware implements HTTP Digest Authentication according to RFC 7616, |
| | | providing a more secure alternative to Basic Authentication. It supports all |
| | | standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session |
| | | variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options. |
| | | """ |
| | | |
| | | import hashlib |
| | | import os |
| | | import re |
| | | import sys |
| | | import time |
| | | from typing import ( |
| | | Callable, |
| | | Dict, |
| | | Final, |
| | | FrozenSet, |
| | | List, |
| | | Literal, |
| | | Tuple, |
| | | TypedDict, |
| | | Union, |
| | | ) |
| | | |
| | | from yarl import URL |
| | | |
| | | from . import hdrs |
| | | from .client_exceptions import ClientError |
| | | from .client_middlewares import ClientHandlerType |
| | | from .client_reqrep import ClientRequest, ClientResponse |
| | | from .payload import Payload |
| | | |
| | | |
| | | class DigestAuthChallenge(TypedDict, total=False): |
| | | realm: str |
| | | nonce: str |
| | | qop: str |
| | | algorithm: str |
| | | opaque: str |
| | | domain: str |
| | | stale: str |
| | | |
| | | |
| | | DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = { |
| | | "MD5": hashlib.md5, |
| | | "MD5-SESS": hashlib.md5, |
| | | "SHA": hashlib.sha1, |
| | | "SHA-SESS": hashlib.sha1, |
| | | "SHA256": hashlib.sha256, |
| | | "SHA256-SESS": hashlib.sha256, |
| | | "SHA-256": hashlib.sha256, |
| | | "SHA-256-SESS": hashlib.sha256, |
| | | "SHA512": hashlib.sha512, |
| | | "SHA512-SESS": hashlib.sha512, |
| | | "SHA-512": hashlib.sha512, |
| | | "SHA-512-SESS": hashlib.sha512, |
| | | } |
| | | |
| | | |
| | | # Compile the regex pattern once at module level for performance |
| | | _HEADER_PAIRS_PATTERN = re.compile( |
| | | r'(?:^|\s|,\s*)(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' |
| | | if sys.version_info < (3, 11) |
| | | else r'(?:^|\s|,\s*)((?>\w+))\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' |
| | | # +------------|--------|--|-|-|--|----|------|----|--||-----|-> Match valid start/sep |
| | | # +--------|--|-|-|--|----|------|----|--||-----|-> alphanumeric key (atomic |
| | | # | | | | | | | | || | group reduces backtracking) |
| | | # +--|-|-|--|----|------|----|--||-----|-> maybe whitespace |
| | | # | | | | | | | || | |
| | | # +-|-|--|----|------|----|--||-----|-> = (delimiter) |
| | | # +-|--|----|------|----|--||-----|-> maybe whitespace |
| | | # | | | | | || | |
| | | # +--|----|------|----|--||-----|-> group quoted or unquoted |
| | | # | | | | || | |
| | | # +----|------|----|--||-----|-> if quoted... |
| | | # +------|----|--||-----|-> anything but " or \ |
| | | # +----|--||-----|-> escaped characters allowed |
| | | # +--||-----|-> or can be empty string |
| | | # || | |
| | | # +|-----|-> if unquoted... |
| | | # +-----|-> anything but , or <space> |
| | | # +-> at least one char req'd |
| | | ) |
| | | |
| | | |
| | | # RFC 7616: Challenge parameters to extract |
| | | CHALLENGE_FIELDS: Final[ |
| | | Tuple[ |
| | | Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ... |
| | | ] |
| | | ] = ( |
| | | "realm", |
| | | "nonce", |
| | | "qop", |
| | | "algorithm", |
| | | "opaque", |
| | | "domain", |
| | | "stale", |
| | | ) |
| | | |
| | | # Supported digest authentication algorithms |
| | | # Use a tuple of sorted keys for predictable documentation and error messages |
| | | SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys())) |
| | | |
| | | # RFC 7616: Fields that require quoting in the Digest auth header |
| | | # These fields must be enclosed in double quotes in the Authorization header. |
| | | # Algorithm, qop, and nc are never quoted per RFC specifications. |
| | | # This frozen set is used by the template-based header construction to |
| | | # automatically determine which fields need quotes. |
| | | QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset( |
| | | {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"} |
| | | ) |
| | | |
| | | |
| | | def escape_quotes(value: str) -> str: |
| | | """Escape double quotes for HTTP header values.""" |
| | | return value.replace('"', '\\"') |
| | | |
| | | |
| | | def unescape_quotes(value: str) -> str: |
| | | """Unescape double quotes in HTTP header values.""" |
| | | return value.replace('\\"', '"') |
| | | |
| | | |
| | | def parse_header_pairs(header: str) -> Dict[str, str]: |
| | | """ |
| | | Parse key-value pairs from WWW-Authenticate or similar HTTP headers. |
| | | |
| | | This function handles the complex format of WWW-Authenticate header values, |
| | | supporting both quoted and unquoted values, proper handling of commas in |
| | | quoted values, and whitespace variations per RFC 7616. |
| | | |
| | | Examples of supported formats: |
| | | - key1="value1", key2=value2 |
| | | - key1 = "value1" , key2="value, with, commas" |
| | | - key1=value1,key2="value2" |
| | | - realm="example.com", nonce="12345", qop="auth" |
| | | |
| | | Args: |
| | | header: The header value string to parse |
| | | |
| | | Returns: |
| | | Dictionary mapping parameter names to their values |
| | | """ |
| | | return { |
| | | stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val |
| | | for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header) |
| | | if (stripped_key := key.strip()) |
| | | } |
| | | |
| | | |
| | | class DigestAuthMiddleware: |
| | | """ |
| | | HTTP digest authentication middleware for aiohttp client. |
| | | |
| | | This middleware intercepts 401 Unauthorized responses containing a Digest |
| | | authentication challenge, calculates the appropriate digest credentials, |
| | | and automatically retries the request with the proper Authorization header. |
| | | |
| | | Features: |
| | | - Handles all aspects of Digest authentication handshake automatically |
| | | - Supports all standard hash algorithms: |
| | | - MD5, MD5-SESS |
| | | - SHA, SHA-SESS |
| | | - SHA256, SHA256-SESS, SHA-256, SHA-256-SESS |
| | | - SHA512, SHA512-SESS, SHA-512, SHA-512-SESS |
| | | - Supports 'auth' and 'auth-int' quality of protection modes |
| | | - Properly handles quoted strings and parameter parsing |
| | | - Includes replay attack protection with client nonce count tracking |
| | | - Supports preemptive authentication per RFC 7616 Section 3.6 |
| | | |
| | | Standards compliance: |
| | | - RFC 7616: HTTP Digest Access Authentication (primary reference) |
| | | - RFC 2617: HTTP Authentication (deprecated by RFC 7616) |
| | | - RFC 1945: Section 11.1 (username restrictions) |
| | | |
| | | Implementation notes: |
| | | The core digest calculation is inspired by the implementation in |
| | | https://github.com/requests/requests/blob/v2.18.4/requests/auth.py |
| | | with added support for modern digest auth features and error handling. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | login: str, |
| | | password: str, |
| | | preemptive: bool = True, |
| | | ) -> None: |
| | | if login is None: |
| | | raise ValueError("None is not allowed as login value") |
| | | |
| | | if password is None: |
| | | raise ValueError("None is not allowed as password value") |
| | | |
| | | if ":" in login: |
| | | raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)') |
| | | |
| | | self._login_str: Final[str] = login |
| | | self._login_bytes: Final[bytes] = login.encode("utf-8") |
| | | self._password_bytes: Final[bytes] = password.encode("utf-8") |
| | | |
| | | self._last_nonce_bytes = b"" |
| | | self._nonce_count = 0 |
| | | self._challenge: DigestAuthChallenge = {} |
| | | self._preemptive: bool = preemptive |
| | | # Set of URLs defining the protection space |
| | | self._protection_space: List[str] = [] |
| | | |
| | | async def _encode( |
| | | self, method: str, url: URL, body: Union[Payload, Literal[b""]] |
| | | ) -> str: |
| | | """ |
| | | Build digest authorization header for the current challenge. |
| | | |
| | | Args: |
| | | method: The HTTP method (GET, POST, etc.) |
| | | url: The request URL |
| | | body: The request body (used for qop=auth-int) |
| | | |
| | | Returns: |
| | | A fully formatted Digest authorization header string |
| | | |
| | | Raises: |
| | | ClientError: If the challenge is missing required parameters or |
| | | contains unsupported values |
| | | |
| | | """ |
| | | challenge = self._challenge |
| | | if "realm" not in challenge: |
| | | raise ClientError( |
| | | "Malformed Digest auth challenge: Missing 'realm' parameter" |
| | | ) |
| | | |
| | | if "nonce" not in challenge: |
| | | raise ClientError( |
| | | "Malformed Digest auth challenge: Missing 'nonce' parameter" |
| | | ) |
| | | |
| | | # Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name) |
| | | realm = challenge["realm"] |
| | | nonce = challenge["nonce"] |
| | | |
| | | # Empty nonce values are not allowed as they are security-critical for replay protection |
| | | if not nonce: |
| | | raise ClientError( |
| | | "Security issue: Digest auth challenge contains empty 'nonce' value" |
| | | ) |
| | | |
| | | qop_raw = challenge.get("qop", "") |
| | | # Preserve original algorithm case for response while using uppercase for processing |
| | | algorithm_original = challenge.get("algorithm", "MD5") |
| | | algorithm = algorithm_original.upper() |
| | | opaque = challenge.get("opaque", "") |
| | | |
| | | # Convert string values to bytes once |
| | | nonce_bytes = nonce.encode("utf-8") |
| | | realm_bytes = realm.encode("utf-8") |
| | | path = URL(url).path_qs |
| | | |
| | | # Process QoP |
| | | qop = "" |
| | | qop_bytes = b"" |
| | | if qop_raw: |
| | | valid_qops = {"auth", "auth-int"}.intersection( |
| | | {q.strip() for q in qop_raw.split(",") if q.strip()} |
| | | ) |
| | | if not valid_qops: |
| | | raise ClientError( |
| | | f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}" |
| | | ) |
| | | |
| | | qop = "auth-int" if "auth-int" in valid_qops else "auth" |
| | | qop_bytes = qop.encode("utf-8") |
| | | |
| | | if algorithm not in DigestFunctions: |
| | | raise ClientError( |
| | | f"Digest auth error: Unsupported hash algorithm: {algorithm}. " |
| | | f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}" |
| | | ) |
| | | hash_fn: Final = DigestFunctions[algorithm] |
| | | |
| | | def H(x: bytes) -> bytes: |
| | | """RFC 7616 Section 3: Hash function H(data) = hex(hash(data)).""" |
| | | return hash_fn(x).hexdigest().encode() |
| | | |
| | | def KD(s: bytes, d: bytes) -> bytes: |
| | | """RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data)).""" |
| | | return H(b":".join((s, d))) |
| | | |
| | | # Calculate A1 and A2 |
| | | A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes)) |
| | | A2 = f"{method.upper()}:{path}".encode() |
| | | if qop == "auth-int": |
| | | if isinstance(body, Payload): # will always be empty bytes unless Payload |
| | | entity_bytes = await body.as_bytes() # Get bytes from Payload |
| | | else: |
| | | entity_bytes = body |
| | | entity_hash = H(entity_bytes) |
| | | A2 = b":".join((A2, entity_hash)) |
| | | |
| | | HA1 = H(A1) |
| | | HA2 = H(A2) |
| | | |
| | | # Nonce count handling |
| | | if nonce_bytes == self._last_nonce_bytes: |
| | | self._nonce_count += 1 |
| | | else: |
| | | self._nonce_count = 1 |
| | | |
| | | self._last_nonce_bytes = nonce_bytes |
| | | ncvalue = f"{self._nonce_count:08x}" |
| | | ncvalue_bytes = ncvalue.encode("utf-8") |
| | | |
| | | # Generate client nonce |
| | | cnonce = hashlib.sha1( |
| | | b"".join( |
| | | [ |
| | | str(self._nonce_count).encode("utf-8"), |
| | | nonce_bytes, |
| | | time.ctime().encode("utf-8"), |
| | | os.urandom(8), |
| | | ] |
| | | ) |
| | | ).hexdigest()[:16] |
| | | cnonce_bytes = cnonce.encode("utf-8") |
| | | |
| | | # Special handling for session-based algorithms |
| | | if algorithm.upper().endswith("-SESS"): |
| | | HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes))) |
| | | |
| | | # Calculate the response digest |
| | | if qop: |
| | | noncebit = b":".join( |
| | | (nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2) |
| | | ) |
| | | response_digest = KD(HA1, noncebit) |
| | | else: |
| | | response_digest = KD(HA1, b":".join((nonce_bytes, HA2))) |
| | | |
| | | # Define a dict mapping of header fields to their values |
| | | # Group fields into always-present, optional, and qop-dependent |
| | | header_fields = { |
| | | # Always present fields |
| | | "username": escape_quotes(self._login_str), |
| | | "realm": escape_quotes(realm), |
| | | "nonce": escape_quotes(nonce), |
| | | "uri": path, |
| | | "response": response_digest.decode(), |
| | | "algorithm": algorithm_original, |
| | | } |
| | | |
| | | # Optional fields |
| | | if opaque: |
| | | header_fields["opaque"] = escape_quotes(opaque) |
| | | |
| | | # QoP-dependent fields |
| | | if qop: |
| | | header_fields["qop"] = qop |
| | | header_fields["nc"] = ncvalue |
| | | header_fields["cnonce"] = cnonce |
| | | |
| | | # Build header using templates for each field type |
| | | pairs: List[str] = [] |
| | | for field, value in header_fields.items(): |
| | | if field in QUOTED_AUTH_FIELDS: |
| | | pairs.append(f'{field}="{value}"') |
| | | else: |
| | | pairs.append(f"{field}={value}") |
| | | |
| | | return f"Digest {', '.join(pairs)}" |
| | | |
| | | def _in_protection_space(self, url: URL) -> bool: |
| | | """ |
| | | Check if the given URL is within the current protection space. |
| | | |
| | | According to RFC 7616, a URI is in the protection space if any URI |
| | | in the protection space is a prefix of it (after both have been made absolute). |
| | | """ |
| | | request_str = str(url) |
| | | for space_str in self._protection_space: |
| | | # Check if request starts with space URL |
| | | if not request_str.startswith(space_str): |
| | | continue |
| | | # Exact match or space ends with / (proper directory prefix) |
| | | if len(request_str) == len(space_str) or space_str[-1] == "/": |
| | | return True |
| | | # Check next char is / to ensure proper path boundary |
| | | if request_str[len(space_str)] == "/": |
| | | return True |
| | | return False |
| | | |
| | | def _authenticate(self, response: ClientResponse) -> bool: |
| | | """ |
| | | Takes the given response and tries digest-auth, if needed. |
| | | |
| | | Returns true if the original request must be resent. |
| | | """ |
| | | if response.status != 401: |
| | | return False |
| | | |
| | | auth_header = response.headers.get("www-authenticate", "") |
| | | if not auth_header: |
| | | return False # No authentication header present |
| | | |
| | | method, sep, headers = auth_header.partition(" ") |
| | | if not sep: |
| | | # No space found in www-authenticate header |
| | | return False # Malformed auth header, missing scheme separator |
| | | |
| | | if method.lower() != "digest": |
| | | # Not a digest auth challenge (could be Basic, Bearer, etc.) |
| | | return False |
| | | |
| | | if not headers: |
| | | # We have a digest scheme but no parameters |
| | | return False # Malformed digest header, missing parameters |
| | | |
| | | # We have a digest auth header with content |
| | | if not (header_pairs := parse_header_pairs(headers)): |
| | | # Failed to parse any key-value pairs |
| | | return False # Malformed digest header, no valid parameters |
| | | |
| | | # Extract challenge parameters |
| | | self._challenge = {} |
| | | for field in CHALLENGE_FIELDS: |
| | | if value := header_pairs.get(field): |
| | | self._challenge[field] = value |
| | | |
| | | # Update protection space based on domain parameter or default to origin |
| | | origin = response.url.origin() |
| | | |
| | | if domain := self._challenge.get("domain"): |
| | | # Parse space-separated list of URIs |
| | | self._protection_space = [] |
| | | for uri in domain.split(): |
| | | # Remove quotes if present |
| | | uri = uri.strip('"') |
| | | if uri.startswith("/"): |
| | | # Path-absolute, relative to origin |
| | | self._protection_space.append(str(origin.join(URL(uri)))) |
| | | else: |
| | | # Absolute URI |
| | | self._protection_space.append(str(URL(uri))) |
| | | else: |
| | | # No domain specified, protection space is entire origin |
| | | self._protection_space = [str(origin)] |
| | | |
| | | # Return True only if we found at least one challenge parameter |
| | | return bool(self._challenge) |
| | | |
| | | async def __call__( |
| | | self, request: ClientRequest, handler: ClientHandlerType |
| | | ) -> ClientResponse: |
| | | """Run the digest auth middleware.""" |
| | | response = None |
| | | for retry_count in range(2): |
| | | # Apply authorization header if: |
| | | # 1. This is a retry after 401 (retry_count > 0), OR |
| | | # 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space |
| | | if retry_count > 0 or ( |
| | | self._preemptive |
| | | and self._challenge |
| | | and self._in_protection_space(request.url) |
| | | ): |
| | | request.headers[hdrs.AUTHORIZATION] = await self._encode( |
| | | request.method, request.url, request.body |
| | | ) |
| | | |
| | | # Send the request |
| | | response = await handler(request) |
| | | |
| | | # Check if we need to authenticate |
| | | if not self._authenticate(response): |
| | | break |
| | | |
| | | # At this point, response is guaranteed to be defined |
| | | assert response is not None |
| | | return response |
| New file |
| | |
| | | """Client middleware support.""" |
| | | |
| | | from collections.abc import Awaitable, Callable, Sequence |
| | | |
| | | from .client_reqrep import ClientRequest, ClientResponse |
| | | |
| | | __all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares") |
| | | |
| | | # Type alias for client request handlers - functions that process requests and return responses |
| | | ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]] |
| | | |
| | | # Type for client middleware - similar to server but uses ClientRequest/ClientResponse |
| | | ClientMiddlewareType = Callable[ |
| | | [ClientRequest, ClientHandlerType], Awaitable[ClientResponse] |
| | | ] |
| | | |
| | | |
| | | def build_client_middlewares( |
| | | handler: ClientHandlerType, |
| | | middlewares: Sequence[ClientMiddlewareType], |
| | | ) -> ClientHandlerType: |
| | | """ |
| | | Apply middlewares to request handler. |
| | | |
| | | The middlewares are applied in reverse order, so the first middleware |
| | | in the list wraps all subsequent middlewares and the handler. |
| | | |
| | | This implementation avoids using partial/update_wrapper to minimize overhead |
| | | and doesn't cache to avoid holding references to stateful middleware. |
| | | """ |
| | | # Optimize for single middleware case |
| | | if len(middlewares) == 1: |
| | | middleware = middlewares[0] |
| | | |
| | | async def single_middleware_handler(req: ClientRequest) -> ClientResponse: |
| | | return await middleware(req, handler) |
| | | |
| | | return single_middleware_handler |
| | | |
| | | # Build the chain for multiple middlewares |
| | | current_handler = handler |
| | | |
| | | for middleware in reversed(middlewares): |
| | | # Create a new closure that captures the current state |
| | | def make_wrapper( |
| | | mw: ClientMiddlewareType, next_h: ClientHandlerType |
| | | ) -> ClientHandlerType: |
| | | async def wrapped(req: ClientRequest) -> ClientResponse: |
| | | return await mw(req, next_h) |
| | | |
| | | return wrapped |
| | | |
| | | current_handler = make_wrapper(middleware, current_handler) |
| | | |
| | | return current_handler |
| New file |
| | |
| | | import asyncio |
| | | from contextlib import suppress |
| | | from typing import Any, Optional, Tuple, Union |
| | | |
| | | from .base_protocol import BaseProtocol |
| | | from .client_exceptions import ( |
| | | ClientConnectionError, |
| | | ClientOSError, |
| | | ClientPayloadError, |
| | | ServerDisconnectedError, |
| | | SocketTimeoutError, |
| | | ) |
| | | from .helpers import ( |
| | | _EXC_SENTINEL, |
| | | EMPTY_BODY_STATUS_CODES, |
| | | BaseTimerContext, |
| | | set_exception, |
| | | set_result, |
| | | ) |
| | | from .http import HttpResponseParser, RawResponseMessage |
| | | from .http_exceptions import HttpProcessingError |
| | | from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader |
| | | |
| | | |
| | | class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]): |
| | | """Helper class to adapt between Protocol and StreamReader.""" |
| | | |
| | | def __init__(self, loop: asyncio.AbstractEventLoop) -> None: |
| | | BaseProtocol.__init__(self, loop=loop) |
| | | DataQueue.__init__(self, loop) |
| | | |
| | | self._should_close = False |
| | | |
| | | self._payload: Optional[StreamReader] = None |
| | | self._skip_payload = False |
| | | self._payload_parser = None |
| | | |
| | | self._timer = None |
| | | |
| | | self._tail = b"" |
| | | self._upgraded = False |
| | | self._parser: Optional[HttpResponseParser] = None |
| | | |
| | | self._read_timeout: Optional[float] = None |
| | | self._read_timeout_handle: Optional[asyncio.TimerHandle] = None |
| | | |
| | | self._timeout_ceil_threshold: Optional[float] = 5 |
| | | |
| | | self._closed: Union[None, asyncio.Future[None]] = None |
| | | self._connection_lost_called = False |
| | | |
| | | @property |
| | | def closed(self) -> Union[None, asyncio.Future[None]]: |
| | | """Future that is set when the connection is closed. |
| | | |
| | | This property returns a Future that will be completed when the connection |
| | | is closed. The Future is created lazily on first access to avoid creating |
| | | futures that will never be awaited. |
| | | |
| | | Returns: |
| | | - A Future[None] if the connection is still open or was closed after |
| | | this property was accessed |
| | | - None if connection_lost() was already called before this property |
| | | was ever accessed (indicating no one is waiting for the closure) |
| | | """ |
| | | if self._closed is None and not self._connection_lost_called: |
| | | self._closed = self._loop.create_future() |
| | | return self._closed |
| | | |
| | | @property |
| | | def upgraded(self) -> bool: |
| | | return self._upgraded |
| | | |
| | | @property |
| | | def should_close(self) -> bool: |
| | | return bool( |
| | | self._should_close |
| | | or (self._payload is not None and not self._payload.is_eof()) |
| | | or self._upgraded |
| | | or self._exception is not None |
| | | or self._payload_parser is not None |
| | | or self._buffer |
| | | or self._tail |
| | | ) |
| | | |
| | | def force_close(self) -> None: |
| | | self._should_close = True |
| | | |
| | | def close(self) -> None: |
| | | self._exception = None # Break cyclic references |
| | | transport = self.transport |
| | | if transport is not None: |
| | | transport.close() |
| | | self.transport = None |
| | | self._payload = None |
| | | self._drop_timeout() |
| | | |
| | | def abort(self) -> None: |
| | | self._exception = None # Break cyclic references |
| | | transport = self.transport |
| | | if transport is not None: |
| | | transport.abort() |
| | | self.transport = None |
| | | self._payload = None |
| | | self._drop_timeout() |
| | | |
| | | def is_connected(self) -> bool: |
| | | return self.transport is not None and not self.transport.is_closing() |
| | | |
| | | def connection_lost(self, exc: Optional[BaseException]) -> None: |
| | | self._connection_lost_called = True |
| | | self._drop_timeout() |
| | | |
| | | original_connection_error = exc |
| | | reraised_exc = original_connection_error |
| | | |
| | | connection_closed_cleanly = original_connection_error is None |
| | | |
| | | if self._closed is not None: |
| | | # If someone is waiting for the closed future, |
| | | # we should set it to None or an exception. If |
| | | # self._closed is None, it means that |
| | | # connection_lost() was called already |
| | | # or nobody is waiting for it. |
| | | if connection_closed_cleanly: |
| | | set_result(self._closed, None) |
| | | else: |
| | | assert original_connection_error is not None |
| | | set_exception( |
| | | self._closed, |
| | | ClientConnectionError( |
| | | f"Connection lost: {original_connection_error !s}", |
| | | ), |
| | | original_connection_error, |
| | | ) |
| | | |
| | | if self._payload_parser is not None: |
| | | with suppress(Exception): # FIXME: log this somehow? |
| | | self._payload_parser.feed_eof() |
| | | |
| | | uncompleted = None |
| | | if self._parser is not None: |
| | | try: |
| | | uncompleted = self._parser.feed_eof() |
| | | except Exception as underlying_exc: |
| | | if self._payload is not None: |
| | | client_payload_exc_msg = ( |
| | | f"Response payload is not completed: {underlying_exc !r}" |
| | | ) |
| | | if not connection_closed_cleanly: |
| | | client_payload_exc_msg = ( |
| | | f"{client_payload_exc_msg !s}. " |
| | | f"{original_connection_error !r}" |
| | | ) |
| | | set_exception( |
| | | self._payload, |
| | | ClientPayloadError(client_payload_exc_msg), |
| | | underlying_exc, |
| | | ) |
| | | |
| | | if not self.is_eof(): |
| | | if isinstance(original_connection_error, OSError): |
| | | reraised_exc = ClientOSError(*original_connection_error.args) |
| | | if connection_closed_cleanly: |
| | | reraised_exc = ServerDisconnectedError(uncompleted) |
| | | # assigns self._should_close to True as side effect, |
| | | # we do it anyway below |
| | | underlying_non_eof_exc = ( |
| | | _EXC_SENTINEL |
| | | if connection_closed_cleanly |
| | | else original_connection_error |
| | | ) |
| | | assert underlying_non_eof_exc is not None |
| | | assert reraised_exc is not None |
| | | self.set_exception(reraised_exc, underlying_non_eof_exc) |
| | | |
| | | self._should_close = True |
| | | self._parser = None |
| | | self._payload = None |
| | | self._payload_parser = None |
| | | self._reading_paused = False |
| | | |
| | | super().connection_lost(reraised_exc) |
| | | |
| | | def eof_received(self) -> None: |
| | | # should call parser.feed_eof() most likely |
| | | self._drop_timeout() |
| | | |
| | | def pause_reading(self) -> None: |
| | | super().pause_reading() |
| | | self._drop_timeout() |
| | | |
| | | def resume_reading(self) -> None: |
| | | super().resume_reading() |
| | | self._reschedule_timeout() |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | self._should_close = True |
| | | self._drop_timeout() |
| | | super().set_exception(exc, exc_cause) |
| | | |
| | | def set_parser(self, parser: Any, payload: Any) -> None: |
| | | # TODO: actual types are: |
| | | # parser: WebSocketReader |
| | | # payload: WebSocketDataQueue |
| | | # but they are not generi enough |
| | | # Need an ABC for both types |
| | | self._payload = payload |
| | | self._payload_parser = parser |
| | | |
| | | self._drop_timeout() |
| | | |
| | | if self._tail: |
| | | data, self._tail = self._tail, b"" |
| | | self.data_received(data) |
| | | |
| | | def set_response_params( |
| | | self, |
| | | *, |
| | | timer: Optional[BaseTimerContext] = None, |
| | | skip_payload: bool = False, |
| | | read_until_eof: bool = False, |
| | | auto_decompress: bool = True, |
| | | read_timeout: Optional[float] = None, |
| | | read_bufsize: int = 2**16, |
| | | timeout_ceil_threshold: float = 5, |
| | | max_line_size: int = 8190, |
| | | max_field_size: int = 8190, |
| | | ) -> None: |
| | | self._skip_payload = skip_payload |
| | | |
| | | self._read_timeout = read_timeout |
| | | |
| | | self._timeout_ceil_threshold = timeout_ceil_threshold |
| | | |
| | | self._parser = HttpResponseParser( |
| | | self, |
| | | self._loop, |
| | | read_bufsize, |
| | | timer=timer, |
| | | payload_exception=ClientPayloadError, |
| | | response_with_body=not skip_payload, |
| | | read_until_eof=read_until_eof, |
| | | auto_decompress=auto_decompress, |
| | | max_line_size=max_line_size, |
| | | max_field_size=max_field_size, |
| | | ) |
| | | |
| | | if self._tail: |
| | | data, self._tail = self._tail, b"" |
| | | self.data_received(data) |
| | | |
| | | def _drop_timeout(self) -> None: |
| | | if self._read_timeout_handle is not None: |
| | | self._read_timeout_handle.cancel() |
| | | self._read_timeout_handle = None |
| | | |
| | | def _reschedule_timeout(self) -> None: |
| | | timeout = self._read_timeout |
| | | if self._read_timeout_handle is not None: |
| | | self._read_timeout_handle.cancel() |
| | | |
| | | if timeout: |
| | | self._read_timeout_handle = self._loop.call_later( |
| | | timeout, self._on_read_timeout |
| | | ) |
| | | else: |
| | | self._read_timeout_handle = None |
| | | |
| | | def start_timeout(self) -> None: |
| | | self._reschedule_timeout() |
| | | |
| | | @property |
| | | def read_timeout(self) -> Optional[float]: |
| | | return self._read_timeout |
| | | |
| | | @read_timeout.setter |
| | | def read_timeout(self, read_timeout: Optional[float]) -> None: |
| | | self._read_timeout = read_timeout |
| | | |
| | | def _on_read_timeout(self) -> None: |
| | | exc = SocketTimeoutError("Timeout on reading data from socket") |
| | | self.set_exception(exc) |
| | | if self._payload is not None: |
| | | set_exception(self._payload, exc) |
| | | |
| | | def data_received(self, data: bytes) -> None: |
| | | self._reschedule_timeout() |
| | | |
| | | if not data: |
| | | return |
| | | |
| | | # custom payload parser - currently always WebSocketReader |
| | | if self._payload_parser is not None: |
| | | eof, tail = self._payload_parser.feed_data(data) |
| | | if eof: |
| | | self._payload = None |
| | | self._payload_parser = None |
| | | |
| | | if tail: |
| | | self.data_received(tail) |
| | | return |
| | | |
| | | if self._upgraded or self._parser is None: |
| | | # i.e. websocket connection, websocket parser is not set yet |
| | | self._tail += data |
| | | return |
| | | |
| | | # parse http messages |
| | | try: |
| | | messages, upgraded, tail = self._parser.feed_data(data) |
| | | except BaseException as underlying_exc: |
| | | if self.transport is not None: |
| | | # connection.release() could be called BEFORE |
| | | # data_received(), the transport is already |
| | | # closed in this case |
| | | self.transport.close() |
| | | # should_close is True after the call |
| | | if isinstance(underlying_exc, HttpProcessingError): |
| | | exc = HttpProcessingError( |
| | | code=underlying_exc.code, |
| | | message=underlying_exc.message, |
| | | headers=underlying_exc.headers, |
| | | ) |
| | | else: |
| | | exc = HttpProcessingError() |
| | | self.set_exception(exc, underlying_exc) |
| | | return |
| | | |
| | | self._upgraded = upgraded |
| | | |
| | | payload: Optional[StreamReader] = None |
| | | for message, payload in messages: |
| | | if message.should_close: |
| | | self._should_close = True |
| | | |
| | | self._payload = payload |
| | | |
| | | if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES: |
| | | self.feed_data((message, EMPTY_PAYLOAD), 0) |
| | | else: |
| | | self.feed_data((message, payload), 0) |
| | | |
| | | if payload is not None: |
| | | # new message(s) was processed |
| | | # register timeout handler unsubscribing |
| | | # either on end-of-stream or immediately for |
| | | # EMPTY_PAYLOAD |
| | | if payload is not EMPTY_PAYLOAD: |
| | | payload.on_eof(self._drop_timeout) |
| | | else: |
| | | self._drop_timeout() |
| | | |
| | | if upgraded and tail: |
| | | self.data_received(tail) |
| New file |
| | |
| | | import asyncio |
| | | import codecs |
| | | import contextlib |
| | | import functools |
| | | import io |
| | | import re |
| | | import sys |
| | | import traceback |
| | | import warnings |
| | | from collections.abc import Mapping |
| | | from hashlib import md5, sha1, sha256 |
| | | from http.cookies import Morsel, SimpleCookie |
| | | from types import MappingProxyType, TracebackType |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Callable, |
| | | Dict, |
| | | Iterable, |
| | | List, |
| | | Literal, |
| | | NamedTuple, |
| | | Optional, |
| | | Tuple, |
| | | Type, |
| | | Union, |
| | | ) |
| | | |
| | | import attr |
| | | from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy |
| | | from yarl import URL |
| | | |
| | | from . import hdrs, helpers, http, multipart, payload |
| | | from ._cookie_helpers import ( |
| | | parse_cookie_header, |
| | | parse_set_cookie_headers, |
| | | preserve_morsel_with_coded_value, |
| | | ) |
| | | from .abc import AbstractStreamWriter |
| | | from .client_exceptions import ( |
| | | ClientConnectionError, |
| | | ClientOSError, |
| | | ClientResponseError, |
| | | ContentTypeError, |
| | | InvalidURL, |
| | | ServerFingerprintMismatch, |
| | | ) |
| | | from .compression_utils import HAS_BROTLI, HAS_ZSTD |
| | | from .formdata import FormData |
| | | from .helpers import ( |
| | | _SENTINEL, |
| | | BaseTimerContext, |
| | | BasicAuth, |
| | | HeadersMixin, |
| | | TimerNoop, |
| | | noop, |
| | | reify, |
| | | sentinel, |
| | | set_exception, |
| | | set_result, |
| | | ) |
| | | from .http import ( |
| | | SERVER_SOFTWARE, |
| | | HttpVersion, |
| | | HttpVersion10, |
| | | HttpVersion11, |
| | | StreamWriter, |
| | | ) |
| | | from .streams import StreamReader |
| | | from .typedefs import ( |
| | | DEFAULT_JSON_DECODER, |
| | | JSONDecoder, |
| | | LooseCookies, |
| | | LooseHeaders, |
| | | Query, |
| | | RawHeaders, |
| | | ) |
| | | |
| | | if TYPE_CHECKING: |
| | | import ssl |
| | | from ssl import SSLContext |
| | | else: |
| | | try: |
| | | import ssl |
| | | from ssl import SSLContext |
| | | except ImportError: # pragma: no cover |
| | | ssl = None # type: ignore[assignment] |
| | | SSLContext = object # type: ignore[misc,assignment] |
| | | |
| | | |
| | | __all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint") |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from .client import ClientSession |
| | | from .connector import Connection |
| | | from .tracing import Trace |
| | | |
| | | |
| | | _CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") |
| | | _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") |
| | | json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json") |
| | | |
| | | |
| | | def _gen_default_accept_encoding() -> str: |
| | | encodings = [ |
| | | "gzip", |
| | | "deflate", |
| | | ] |
| | | if HAS_BROTLI: |
| | | encodings.append("br") |
| | | if HAS_ZSTD: |
| | | encodings.append("zstd") |
| | | return ", ".join(encodings) |
| | | |
| | | |
| | | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | | class ContentDisposition: |
| | | type: Optional[str] |
| | | parameters: "MappingProxyType[str, str]" |
| | | filename: Optional[str] |
| | | |
| | | |
| | | class _RequestInfo(NamedTuple): |
| | | url: URL |
| | | method: str |
| | | headers: "CIMultiDictProxy[str]" |
| | | real_url: URL |
| | | |
| | | |
| | | class RequestInfo(_RequestInfo): |
| | | |
| | | def __new__( |
| | | cls, |
| | | url: URL, |
| | | method: str, |
| | | headers: "CIMultiDictProxy[str]", |
| | | real_url: Union[URL, _SENTINEL] = sentinel, |
| | | ) -> "RequestInfo": |
| | | """Create a new RequestInfo instance. |
| | | |
| | | For backwards compatibility, the real_url parameter is optional. |
| | | """ |
| | | return tuple.__new__( |
| | | cls, (url, method, headers, url if real_url is sentinel else real_url) |
| | | ) |
| | | |
| | | |
| | | class Fingerprint: |
| | | HASHFUNC_BY_DIGESTLEN = { |
| | | 16: md5, |
| | | 20: sha1, |
| | | 32: sha256, |
| | | } |
| | | |
| | | def __init__(self, fingerprint: bytes) -> None: |
| | | digestlen = len(fingerprint) |
| | | hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen) |
| | | if not hashfunc: |
| | | raise ValueError("fingerprint has invalid length") |
| | | elif hashfunc is md5 or hashfunc is sha1: |
| | | raise ValueError("md5 and sha1 are insecure and not supported. Use sha256.") |
| | | self._hashfunc = hashfunc |
| | | self._fingerprint = fingerprint |
| | | |
| | | @property |
| | | def fingerprint(self) -> bytes: |
| | | return self._fingerprint |
| | | |
| | | def check(self, transport: asyncio.Transport) -> None: |
| | | if not transport.get_extra_info("sslcontext"): |
| | | return |
| | | sslobj = transport.get_extra_info("ssl_object") |
| | | cert = sslobj.getpeercert(binary_form=True) |
| | | got = self._hashfunc(cert).digest() |
| | | if got != self._fingerprint: |
| | | host, port, *_ = transport.get_extra_info("peername") |
| | | raise ServerFingerprintMismatch(self._fingerprint, got, host, port) |
| | | |
| | | |
| | | if ssl is not None: |
| | | SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None)) |
| | | else: # pragma: no cover |
| | | SSL_ALLOWED_TYPES = (bool, type(None)) |
| | | |
| | | |
| | | def _merge_ssl_params( |
| | | ssl: Union["SSLContext", bool, Fingerprint], |
| | | verify_ssl: Optional[bool], |
| | | ssl_context: Optional["SSLContext"], |
| | | fingerprint: Optional[bytes], |
| | | ) -> Union["SSLContext", bool, Fingerprint]: |
| | | if ssl is None: |
| | | ssl = True # Double check for backwards compatibility |
| | | if verify_ssl is not None and not verify_ssl: |
| | | warnings.warn( |
| | | "verify_ssl is deprecated, use ssl=False instead", |
| | | DeprecationWarning, |
| | | stacklevel=3, |
| | | ) |
| | | if ssl is not True: |
| | | raise ValueError( |
| | | "verify_ssl, ssl_context, fingerprint and ssl " |
| | | "parameters are mutually exclusive" |
| | | ) |
| | | else: |
| | | ssl = False |
| | | if ssl_context is not None: |
| | | warnings.warn( |
| | | "ssl_context is deprecated, use ssl=context instead", |
| | | DeprecationWarning, |
| | | stacklevel=3, |
| | | ) |
| | | if ssl is not True: |
| | | raise ValueError( |
| | | "verify_ssl, ssl_context, fingerprint and ssl " |
| | | "parameters are mutually exclusive" |
| | | ) |
| | | else: |
| | | ssl = ssl_context |
| | | if fingerprint is not None: |
| | | warnings.warn( |
| | | "fingerprint is deprecated, use ssl=Fingerprint(fingerprint) instead", |
| | | DeprecationWarning, |
| | | stacklevel=3, |
| | | ) |
| | | if ssl is not True: |
| | | raise ValueError( |
| | | "verify_ssl, ssl_context, fingerprint and ssl " |
| | | "parameters are mutually exclusive" |
| | | ) |
| | | else: |
| | | ssl = Fingerprint(fingerprint) |
| | | if not isinstance(ssl, SSL_ALLOWED_TYPES): |
| | | raise TypeError( |
| | | "ssl should be SSLContext, bool, Fingerprint or None, " |
| | | "got {!r} instead.".format(ssl) |
| | | ) |
| | | return ssl |
| | | |
| | | |
| | | _SSL_SCHEMES = frozenset(("https", "wss")) |
| | | |
| | | |
| | | # ConnectionKey is a NamedTuple because it is used as a key in a dict |
| | | # and a set in the connector. Since a NamedTuple is a tuple it uses |
| | | # the fast native tuple __hash__ and __eq__ implementation in CPython. |
| | | class ConnectionKey(NamedTuple): |
| | | # the key should contain an information about used proxy / TLS |
| | | # to prevent reusing wrong connections from a pool |
| | | host: str |
| | | port: Optional[int] |
| | | is_ssl: bool |
| | | ssl: Union[SSLContext, bool, Fingerprint] |
| | | proxy: Optional[URL] |
| | | proxy_auth: Optional[BasicAuth] |
| | | proxy_headers_hash: Optional[int] # hash(CIMultiDict) |
| | | |
| | | |
| | | def _is_expected_content_type( |
| | | response_content_type: str, expected_content_type: str |
| | | ) -> bool: |
| | | if expected_content_type == "application/json": |
| | | return json_re.match(response_content_type) is not None |
| | | return expected_content_type in response_content_type |
| | | |
| | | |
| | | def _warn_if_unclosed_payload(payload: payload.Payload, stacklevel: int = 2) -> None: |
| | | """Warn if the payload is not closed. |
| | | |
| | | Callers must check that the body is a Payload before calling this method. |
| | | |
| | | Args: |
| | | payload: The payload to check |
| | | stacklevel: Stack level for the warning (default 2 for direct callers) |
| | | """ |
| | | if not payload.autoclose and not payload.consumed: |
| | | warnings.warn( |
| | | "The previous request body contains unclosed resources. " |
| | | "Use await request.update_body() instead of setting request.body " |
| | | "directly to properly close resources and avoid leaks.", |
| | | ResourceWarning, |
| | | stacklevel=stacklevel, |
| | | ) |
| | | |
| | | |
| | | class ClientResponse(HeadersMixin): |
| | | |
| | | # Some of these attributes are None when created, |
| | | # but will be set by the start() method. |
| | | # As the end user will likely never see the None values, we cheat the types below. |
| | | # from the Status-Line of the response |
| | | version: Optional[HttpVersion] = None # HTTP-Version |
| | | status: int = None # type: ignore[assignment] # Status-Code |
| | | reason: Optional[str] = None # Reason-Phrase |
| | | |
| | | content: StreamReader = None # type: ignore[assignment] # Payload stream |
| | | _body: Optional[bytes] = None |
| | | _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] |
| | | _history: Tuple["ClientResponse", ...] = () |
| | | _raw_headers: RawHeaders = None # type: ignore[assignment] |
| | | |
| | | _connection: Optional["Connection"] = None # current connection |
| | | _cookies: Optional[SimpleCookie] = None |
| | | _raw_cookie_headers: Optional[Tuple[str, ...]] = None |
| | | _continue: Optional["asyncio.Future[bool]"] = None |
| | | _source_traceback: Optional[traceback.StackSummary] = None |
| | | _session: Optional["ClientSession"] = None |
| | | # set up by ClientRequest after ClientResponse object creation |
| | | # post-init stage allows to not change ctor signature |
| | | _closed = True # to allow __del__ for non-initialized properly response |
| | | _released = False |
| | | _in_context = False |
| | | |
| | | _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" |
| | | |
| | | __writer: Optional["asyncio.Task[None]"] = None |
| | | |
| | | def __init__( |
| | | self, |
| | | method: str, |
| | | url: URL, |
| | | *, |
| | | writer: "Optional[asyncio.Task[None]]", |
| | | continue100: Optional["asyncio.Future[bool]"], |
| | | timer: BaseTimerContext, |
| | | request_info: RequestInfo, |
| | | traces: List["Trace"], |
| | | loop: asyncio.AbstractEventLoop, |
| | | session: "ClientSession", |
| | | ) -> None: |
| | | # URL forbids subclasses, so a simple type check is enough. |
| | | assert type(url) is URL |
| | | |
| | | self.method = method |
| | | |
| | | self._real_url = url |
| | | self._url = url.with_fragment(None) if url.raw_fragment else url |
| | | if writer is not None: |
| | | self._writer = writer |
| | | if continue100 is not None: |
| | | self._continue = continue100 |
| | | self._request_info = request_info |
| | | self._timer = timer if timer is not None else TimerNoop() |
| | | self._cache: Dict[str, Any] = {} |
| | | self._traces = traces |
| | | self._loop = loop |
| | | # Save reference to _resolve_charset, so that get_encoding() will still |
| | | # work after the response has finished reading the body. |
| | | # TODO: Fix session=None in tests (see ClientRequest.__init__). |
| | | if session is not None: |
| | | # store a reference to session #1985 |
| | | self._session = session |
| | | self._resolve_charset = session._resolve_charset |
| | | if loop.get_debug(): |
| | | self._source_traceback = traceback.extract_stack(sys._getframe(1)) |
| | | |
| | | def __reset_writer(self, _: object = None) -> None: |
| | | self.__writer = None |
| | | |
| | | @property |
| | | def _writer(self) -> Optional["asyncio.Task[None]"]: |
| | | """The writer task for streaming data. |
| | | |
| | | _writer is only provided for backwards compatibility |
| | | for subclasses that may need to access it. |
| | | """ |
| | | return self.__writer |
| | | |
| | | @_writer.setter |
| | | def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: |
| | | """Set the writer task for streaming data.""" |
| | | if self.__writer is not None: |
| | | self.__writer.remove_done_callback(self.__reset_writer) |
| | | self.__writer = writer |
| | | if writer is None: |
| | | return |
| | | if writer.done(): |
| | | # The writer is already done, so we can clear it immediately. |
| | | self.__writer = None |
| | | else: |
| | | writer.add_done_callback(self.__reset_writer) |
| | | |
| | | @property |
| | | def cookies(self) -> SimpleCookie: |
| | | if self._cookies is None: |
| | | if self._raw_cookie_headers is not None: |
| | | # Parse cookies for response.cookies (SimpleCookie for backward compatibility) |
| | | cookies = SimpleCookie() |
| | | # Use parse_set_cookie_headers for more lenient parsing that handles |
| | | # malformed cookies better than SimpleCookie.load |
| | | cookies.update(parse_set_cookie_headers(self._raw_cookie_headers)) |
| | | self._cookies = cookies |
| | | else: |
| | | self._cookies = SimpleCookie() |
| | | return self._cookies |
| | | |
| | | @cookies.setter |
| | | def cookies(self, cookies: SimpleCookie) -> None: |
| | | self._cookies = cookies |
| | | # Generate raw cookie headers from the SimpleCookie |
| | | if cookies: |
| | | self._raw_cookie_headers = tuple( |
| | | morsel.OutputString() for morsel in cookies.values() |
| | | ) |
| | | else: |
| | | self._raw_cookie_headers = None |
| | | |
| | | @reify |
| | | def url(self) -> URL: |
| | | return self._url |
| | | |
| | | @reify |
| | | def url_obj(self) -> URL: |
| | | warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2) |
| | | return self._url |
| | | |
| | | @reify |
| | | def real_url(self) -> URL: |
| | | return self._real_url |
| | | |
| | | @reify |
| | | def host(self) -> str: |
| | | assert self._url.host is not None |
| | | return self._url.host |
| | | |
| | | @reify |
| | | def headers(self) -> "CIMultiDictProxy[str]": |
| | | return self._headers |
| | | |
| | | @reify |
| | | def raw_headers(self) -> RawHeaders: |
| | | return self._raw_headers |
| | | |
| | | @reify |
| | | def request_info(self) -> RequestInfo: |
| | | return self._request_info |
| | | |
| | | @reify |
| | | def content_disposition(self) -> Optional[ContentDisposition]: |
| | | raw = self._headers.get(hdrs.CONTENT_DISPOSITION) |
| | | if raw is None: |
| | | return None |
| | | disposition_type, params_dct = multipart.parse_content_disposition(raw) |
| | | params = MappingProxyType(params_dct) |
| | | filename = multipart.content_disposition_filename(params) |
| | | return ContentDisposition(disposition_type, params, filename) |
| | | |
| | | def __del__(self, _warnings: Any = warnings) -> None: |
| | | if self._closed: |
| | | return |
| | | |
| | | if self._connection is not None: |
| | | self._connection.release() |
| | | self._cleanup_writer() |
| | | |
| | | if self._loop.get_debug(): |
| | | kwargs = {"source": self} |
| | | _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs) |
| | | context = {"client_response": self, "message": "Unclosed response"} |
| | | if self._source_traceback: |
| | | context["source_traceback"] = self._source_traceback |
| | | self._loop.call_exception_handler(context) |
| | | |
| | | def __repr__(self) -> str: |
| | | out = io.StringIO() |
| | | ascii_encodable_url = str(self.url) |
| | | if self.reason: |
| | | ascii_encodable_reason = self.reason.encode( |
| | | "ascii", "backslashreplace" |
| | | ).decode("ascii") |
| | | else: |
| | | ascii_encodable_reason = "None" |
| | | print( |
| | | "<ClientResponse({}) [{} {}]>".format( |
| | | ascii_encodable_url, self.status, ascii_encodable_reason |
| | | ), |
| | | file=out, |
| | | ) |
| | | print(self.headers, file=out) |
| | | return out.getvalue() |
| | | |
| | | @property |
| | | def connection(self) -> Optional["Connection"]: |
| | | return self._connection |
| | | |
| | | @reify |
| | | def history(self) -> Tuple["ClientResponse", ...]: |
| | | """A sequence of of responses, if redirects occurred.""" |
| | | return self._history |
| | | |
| | | @reify |
| | | def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": |
| | | links_str = ", ".join(self.headers.getall("link", [])) |
| | | |
| | | if not links_str: |
| | | return MultiDictProxy(MultiDict()) |
| | | |
| | | links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict() |
| | | |
| | | for val in re.split(r",(?=\s*<)", links_str): |
| | | match = re.match(r"\s*<(.*)>(.*)", val) |
| | | if match is None: # pragma: no cover |
| | | # the check exists to suppress mypy error |
| | | continue |
| | | url, params_str = match.groups() |
| | | params = params_str.split(";")[1:] |
| | | |
| | | link: MultiDict[Union[str, URL]] = MultiDict() |
| | | |
| | | for param in params: |
| | | match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) |
| | | if match is None: # pragma: no cover |
| | | # the check exists to suppress mypy error |
| | | continue |
| | | key, _, value, _ = match.groups() |
| | | |
| | | link.add(key, value) |
| | | |
| | | key = link.get("rel", url) |
| | | |
| | | link.add("url", self.url.join(URL(url))) |
| | | |
| | | links.add(str(key), MultiDictProxy(link)) |
| | | |
| | | return MultiDictProxy(links) |
| | | |
| | | async def start(self, connection: "Connection") -> "ClientResponse": |
| | | """Start response processing.""" |
| | | self._closed = False |
| | | self._protocol = connection.protocol |
| | | self._connection = connection |
| | | |
| | | with self._timer: |
| | | while True: |
| | | # read response |
| | | try: |
| | | protocol = self._protocol |
| | | message, payload = await protocol.read() # type: ignore[union-attr] |
| | | except http.HttpProcessingError as exc: |
| | | raise ClientResponseError( |
| | | self.request_info, |
| | | self.history, |
| | | status=exc.code, |
| | | message=exc.message, |
| | | headers=exc.headers, |
| | | ) from exc |
| | | |
| | | if message.code < 100 or message.code > 199 or message.code == 101: |
| | | break |
| | | |
| | | if self._continue is not None: |
| | | set_result(self._continue, True) |
| | | self._continue = None |
| | | |
| | | # payload eof handler |
| | | payload.on_eof(self._response_eof) |
| | | |
| | | # response status |
| | | self.version = message.version |
| | | self.status = message.code |
| | | self.reason = message.reason |
| | | |
| | | # headers |
| | | self._headers = message.headers # type is CIMultiDictProxy |
| | | self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] |
| | | |
| | | # payload |
| | | self.content = payload |
| | | |
| | | # cookies |
| | | if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()): |
| | | # Store raw cookie headers for CookieJar |
| | | self._raw_cookie_headers = tuple(cookie_hdrs) |
| | | return self |
| | | |
| | | def _response_eof(self) -> None: |
| | | if self._closed: |
| | | return |
| | | |
| | | # protocol could be None because connection could be detached |
| | | protocol = self._connection and self._connection.protocol |
| | | if protocol is not None and protocol.upgraded: |
| | | return |
| | | |
| | | self._closed = True |
| | | self._cleanup_writer() |
| | | self._release_connection() |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | return self._closed |
| | | |
| | | def close(self) -> None: |
| | | if not self._released: |
| | | self._notify_content() |
| | | |
| | | self._closed = True |
| | | if self._loop is None or self._loop.is_closed(): |
| | | return |
| | | |
| | | self._cleanup_writer() |
| | | if self._connection is not None: |
| | | self._connection.close() |
| | | self._connection = None |
| | | |
| | | def release(self) -> Any: |
| | | if not self._released: |
| | | self._notify_content() |
| | | |
| | | self._closed = True |
| | | |
| | | self._cleanup_writer() |
| | | self._release_connection() |
| | | return noop() |
| | | |
| | | @property |
| | | def ok(self) -> bool: |
| | | """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. |
| | | |
| | | This is **not** a check for ``200 OK`` but a check that the response |
| | | status is under 400. |
| | | """ |
| | | return 400 > self.status |
| | | |
| | | def raise_for_status(self) -> None: |
| | | if not self.ok: |
| | | # reason should always be not None for a started response |
| | | assert self.reason is not None |
| | | |
| | | # If we're in a context we can rely on __aexit__() to release as the |
| | | # exception propagates. |
| | | if not self._in_context: |
| | | self.release() |
| | | |
| | | raise ClientResponseError( |
| | | self.request_info, |
| | | self.history, |
| | | status=self.status, |
| | | message=self.reason, |
| | | headers=self.headers, |
| | | ) |
| | | |
| | | def _release_connection(self) -> None: |
| | | if self._connection is not None: |
| | | if self.__writer is None: |
| | | self._connection.release() |
| | | self._connection = None |
| | | else: |
| | | self.__writer.add_done_callback(lambda f: self._release_connection()) |
| | | |
| | | async def _wait_released(self) -> None: |
| | | if self.__writer is not None: |
| | | try: |
| | | await self.__writer |
| | | except asyncio.CancelledError: |
| | | if ( |
| | | sys.version_info >= (3, 11) |
| | | and (task := asyncio.current_task()) |
| | | and task.cancelling() |
| | | ): |
| | | raise |
| | | self._release_connection() |
| | | |
| | | def _cleanup_writer(self) -> None: |
| | | if self.__writer is not None: |
| | | self.__writer.cancel() |
| | | self._session = None |
| | | |
| | | def _notify_content(self) -> None: |
| | | content = self.content |
| | | if content and content.exception() is None: |
| | | set_exception(content, _CONNECTION_CLOSED_EXCEPTION) |
| | | self._released = True |
| | | |
| | | async def wait_for_close(self) -> None: |
| | | if self.__writer is not None: |
| | | try: |
| | | await self.__writer |
| | | except asyncio.CancelledError: |
| | | if ( |
| | | sys.version_info >= (3, 11) |
| | | and (task := asyncio.current_task()) |
| | | and task.cancelling() |
| | | ): |
| | | raise |
| | | self.release() |
| | | |
| | | async def read(self) -> bytes: |
| | | """Read response payload.""" |
| | | if self._body is None: |
| | | try: |
| | | self._body = await self.content.read() |
| | | for trace in self._traces: |
| | | await trace.send_response_chunk_received( |
| | | self.method, self.url, self._body |
| | | ) |
| | | except BaseException: |
| | | self.close() |
| | | raise |
| | | elif self._released: # Response explicitly released |
| | | raise ClientConnectionError("Connection closed") |
| | | |
| | | protocol = self._connection and self._connection.protocol |
| | | if protocol is None or not protocol.upgraded: |
| | | await self._wait_released() # Underlying connection released |
| | | return self._body |
| | | |
| | | def get_encoding(self) -> str: |
| | | ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() |
| | | mimetype = helpers.parse_mimetype(ctype) |
| | | |
| | | encoding = mimetype.parameters.get("charset") |
| | | if encoding: |
| | | with contextlib.suppress(LookupError, ValueError): |
| | | return codecs.lookup(encoding).name |
| | | |
| | | if mimetype.type == "application" and ( |
| | | mimetype.subtype == "json" or mimetype.subtype == "rdap" |
| | | ): |
| | | # RFC 7159 states that the default encoding is UTF-8. |
| | | # RFC 7483 defines application/rdap+json |
| | | return "utf-8" |
| | | |
| | | if self._body is None: |
| | | raise RuntimeError( |
| | | "Cannot compute fallback encoding of a not yet read body" |
| | | ) |
| | | |
| | | return self._resolve_charset(self, self._body) |
| | | |
| | | async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: |
| | | """Read response payload and decode.""" |
| | | if self._body is None: |
| | | await self.read() |
| | | |
| | | if encoding is None: |
| | | encoding = self.get_encoding() |
| | | |
| | | return self._body.decode(encoding, errors=errors) # type: ignore[union-attr] |
| | | |
| | | async def json( |
| | | self, |
| | | *, |
| | | encoding: Optional[str] = None, |
| | | loads: JSONDecoder = DEFAULT_JSON_DECODER, |
| | | content_type: Optional[str] = "application/json", |
| | | ) -> Any: |
| | | """Read and decodes JSON response.""" |
| | | if self._body is None: |
| | | await self.read() |
| | | |
| | | if content_type: |
| | | ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() |
| | | if not _is_expected_content_type(ctype, content_type): |
| | | raise ContentTypeError( |
| | | self.request_info, |
| | | self.history, |
| | | status=self.status, |
| | | message=( |
| | | "Attempt to decode JSON with unexpected mimetype: %s" % ctype |
| | | ), |
| | | headers=self.headers, |
| | | ) |
| | | |
| | | stripped = self._body.strip() # type: ignore[union-attr] |
| | | if not stripped: |
| | | return None |
| | | |
| | | if encoding is None: |
| | | encoding = self.get_encoding() |
| | | |
| | | return loads(stripped.decode(encoding)) |
| | | |
| | | async def __aenter__(self) -> "ClientResponse": |
| | | self._in_context = True |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | self._in_context = False |
| | | # similar to _RequestContextManager, we do not need to check |
| | | # for exceptions, response object can close connection |
| | | # if state is broken |
| | | self.release() |
| | | await self.wait_for_close() |
| | | |
| | | |
| | | class ClientRequest: |
| | | GET_METHODS = { |
| | | hdrs.METH_GET, |
| | | hdrs.METH_HEAD, |
| | | hdrs.METH_OPTIONS, |
| | | hdrs.METH_TRACE, |
| | | } |
| | | POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} |
| | | ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE}) |
| | | |
| | | DEFAULT_HEADERS = { |
| | | hdrs.ACCEPT: "*/*", |
| | | hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), |
| | | } |
| | | |
| | | # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. |
| | | _body: Union[None, payload.Payload] = None |
| | | auth = None |
| | | response = None |
| | | |
| | | __writer: Optional["asyncio.Task[None]"] = None # async task for streaming data |
| | | |
| | | # These class defaults help create_autospec() work correctly. |
| | | # If autospec is improved in future, maybe these can be removed. |
| | | url = URL() |
| | | method = "GET" |
| | | |
| | | _continue = None # waiter future for '100 Continue' response |
| | | |
| | | _skip_auto_headers: Optional["CIMultiDict[None]"] = None |
| | | |
| | | # N.B. |
| | | # Adding __del__ method with self._writer closing doesn't make sense |
| | | # because _writer is instance method, thus it keeps a reference to self. |
| | | # Until writer has finished finalizer will not be called. |
| | | |
| | | def __init__( |
| | | self, |
| | | method: str, |
| | | url: URL, |
| | | *, |
| | | params: Query = None, |
| | | headers: Optional[LooseHeaders] = None, |
| | | skip_auto_headers: Optional[Iterable[str]] = None, |
| | | data: Any = None, |
| | | cookies: Optional[LooseCookies] = None, |
| | | auth: Optional[BasicAuth] = None, |
| | | version: http.HttpVersion = http.HttpVersion11, |
| | | compress: Union[str, bool, None] = None, |
| | | chunked: Optional[bool] = None, |
| | | expect100: bool = False, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | response_class: Optional[Type["ClientResponse"]] = None, |
| | | proxy: Optional[URL] = None, |
| | | proxy_auth: Optional[BasicAuth] = None, |
| | | timer: Optional[BaseTimerContext] = None, |
| | | session: Optional["ClientSession"] = None, |
| | | ssl: Union[SSLContext, bool, Fingerprint] = True, |
| | | proxy_headers: Optional[LooseHeaders] = None, |
| | | traces: Optional[List["Trace"]] = None, |
| | | trust_env: bool = False, |
| | | server_hostname: Optional[str] = None, |
| | | ): |
| | | if loop is None: |
| | | loop = asyncio.get_event_loop() |
| | | if match := _CONTAINS_CONTROL_CHAR_RE.search(method): |
| | | raise ValueError( |
| | | f"Method cannot contain non-token characters {method!r} " |
| | | f"(found at least {match.group()!r})" |
| | | ) |
| | | # URL forbids subclasses, so a simple type check is enough. |
| | | assert type(url) is URL, url |
| | | if proxy is not None: |
| | | assert type(proxy) is URL, proxy |
| | | # FIXME: session is None in tests only, need to fix tests |
| | | # assert session is not None |
| | | if TYPE_CHECKING: |
| | | assert session is not None |
| | | self._session = session |
| | | if params: |
| | | url = url.extend_query(params) |
| | | self.original_url = url |
| | | self.url = url.with_fragment(None) if url.raw_fragment else url |
| | | self.method = method.upper() |
| | | self.chunked = chunked |
| | | self.compress = compress |
| | | self.loop = loop |
| | | self.length = None |
| | | if response_class is None: |
| | | real_response_class = ClientResponse |
| | | else: |
| | | real_response_class = response_class |
| | | self.response_class: Type[ClientResponse] = real_response_class |
| | | self._timer = timer if timer is not None else TimerNoop() |
| | | self._ssl = ssl if ssl is not None else True |
| | | self.server_hostname = server_hostname |
| | | |
| | | if loop.get_debug(): |
| | | self._source_traceback = traceback.extract_stack(sys._getframe(1)) |
| | | |
| | | self.update_version(version) |
| | | self.update_host(url) |
| | | self.update_headers(headers) |
| | | self.update_auto_headers(skip_auto_headers) |
| | | self.update_cookies(cookies) |
| | | self.update_content_encoding(data) |
| | | self.update_auth(auth, trust_env) |
| | | self.update_proxy(proxy, proxy_auth, proxy_headers) |
| | | |
| | | self.update_body_from_data(data) |
| | | if data is not None or self.method not in self.GET_METHODS: |
| | | self.update_transfer_encoding() |
| | | self.update_expect_continue(expect100) |
| | | self._traces = [] if traces is None else traces |
| | | |
| | | def __reset_writer(self, _: object = None) -> None: |
| | | self.__writer = None |
| | | |
| | | def _get_content_length(self) -> Optional[int]: |
| | | """Extract and validate Content-Length header value. |
| | | |
| | | Returns parsed Content-Length value or None if not set. |
| | | Raises ValueError if header exists but cannot be parsed as an integer. |
| | | """ |
| | | if hdrs.CONTENT_LENGTH not in self.headers: |
| | | return None |
| | | |
| | | content_length_hdr = self.headers[hdrs.CONTENT_LENGTH] |
| | | try: |
| | | return int(content_length_hdr) |
| | | except ValueError: |
| | | raise ValueError( |
| | | f"Invalid Content-Length header: {content_length_hdr}" |
| | | ) from None |
| | | |
| | | @property |
| | | def skip_auto_headers(self) -> CIMultiDict[None]: |
| | | return self._skip_auto_headers or CIMultiDict() |
| | | |
| | | @property |
| | | def _writer(self) -> Optional["asyncio.Task[None]"]: |
| | | return self.__writer |
| | | |
| | | @_writer.setter |
| | | def _writer(self, writer: "asyncio.Task[None]") -> None: |
| | | if self.__writer is not None: |
| | | self.__writer.remove_done_callback(self.__reset_writer) |
| | | self.__writer = writer |
| | | writer.add_done_callback(self.__reset_writer) |
| | | |
| | | def is_ssl(self) -> bool: |
| | | return self.url.scheme in _SSL_SCHEMES |
| | | |
| | | @property |
| | | def ssl(self) -> Union["SSLContext", bool, Fingerprint]: |
| | | return self._ssl |
| | | |
| | | @property |
| | | def connection_key(self) -> ConnectionKey: |
| | | if proxy_headers := self.proxy_headers: |
| | | h: Optional[int] = hash(tuple(proxy_headers.items())) |
| | | else: |
| | | h = None |
| | | url = self.url |
| | | return tuple.__new__( |
| | | ConnectionKey, |
| | | ( |
| | | url.raw_host or "", |
| | | url.port, |
| | | url.scheme in _SSL_SCHEMES, |
| | | self._ssl, |
| | | self.proxy, |
| | | self.proxy_auth, |
| | | h, |
| | | ), |
| | | ) |
| | | |
| | | @property |
| | | def host(self) -> str: |
| | | ret = self.url.raw_host |
| | | assert ret is not None |
| | | return ret |
| | | |
| | | @property |
| | | def port(self) -> Optional[int]: |
| | | return self.url.port |
| | | |
| | | @property |
| | | def body(self) -> Union[payload.Payload, Literal[b""]]: |
| | | """Request body.""" |
| | | # empty body is represented as bytes for backwards compatibility |
| | | return self._body or b"" |
| | | |
| | | @body.setter |
| | | def body(self, value: Any) -> None: |
| | | """Set request body with warning for non-autoclose payloads. |
| | | |
| | | WARNING: This setter must be called from within an event loop and is not |
| | | thread-safe. Setting body outside of an event loop may raise RuntimeError |
| | | when closing file-based payloads. |
| | | |
| | | DEPRECATED: Direct assignment to body is deprecated and will be removed |
| | | in a future version. Use await update_body() instead for proper resource |
| | | management. |
| | | """ |
| | | # Close existing payload if present |
| | | if self._body is not None: |
| | | # Warn if the payload needs manual closing |
| | | # stacklevel=3: user code -> body setter -> _warn_if_unclosed_payload |
| | | _warn_if_unclosed_payload(self._body, stacklevel=3) |
| | | # NOTE: In the future, when we remove sync close support, |
| | | # this setter will need to be removed and only the async |
| | | # update_body() method will be available. For now, we call |
| | | # _close() for backwards compatibility. |
| | | self._body._close() |
| | | self._update_body(value) |
| | | |
| | | @property |
| | | def request_info(self) -> RequestInfo: |
| | | headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) |
| | | # These are created on every request, so we use a NamedTuple |
| | | # for performance reasons. We don't use the RequestInfo.__new__ |
| | | # method because it has a different signature which is provided |
| | | # for backwards compatibility only. |
| | | return tuple.__new__( |
| | | RequestInfo, (self.url, self.method, headers, self.original_url) |
| | | ) |
| | | |
| | | @property |
| | | def session(self) -> "ClientSession": |
| | | """Return the ClientSession instance. |
| | | |
| | | This property provides access to the ClientSession that initiated |
| | | this request, allowing middleware to make additional requests |
| | | using the same session. |
| | | """ |
| | | return self._session |
| | | |
| | | def update_host(self, url: URL) -> None: |
| | | """Update destination host, port and connection type (ssl).""" |
| | | # get host/port |
| | | if not url.raw_host: |
| | | raise InvalidURL(url) |
| | | |
| | | # basic auth info |
| | | if url.raw_user or url.raw_password: |
| | | self.auth = helpers.BasicAuth(url.user or "", url.password or "") |
| | | |
| | | def update_version(self, version: Union[http.HttpVersion, str]) -> None: |
| | | """Convert request version to two elements tuple. |
| | | |
| | | parser HTTP version '1.1' => (1, 1) |
| | | """ |
| | | if isinstance(version, str): |
| | | v = [part.strip() for part in version.split(".", 1)] |
| | | try: |
| | | version = http.HttpVersion(int(v[0]), int(v[1])) |
| | | except ValueError: |
| | | raise ValueError( |
| | | f"Can not parse http version number: {version}" |
| | | ) from None |
| | | self.version = version |
| | | |
| | | def update_headers(self, headers: Optional[LooseHeaders]) -> None: |
| | | """Update request headers.""" |
| | | self.headers: CIMultiDict[str] = CIMultiDict() |
| | | |
| | | # Build the host header |
| | | host = self.url.host_port_subcomponent |
| | | |
| | | # host_port_subcomponent is None when the URL is a relative URL. |
| | | # but we know we do not have a relative URL here. |
| | | assert host is not None |
| | | self.headers[hdrs.HOST] = host |
| | | |
| | | if not headers: |
| | | return |
| | | |
| | | if isinstance(headers, (dict, MultiDictProxy, MultiDict)): |
| | | headers = headers.items() |
| | | |
| | | for key, value in headers: # type: ignore[misc] |
| | | # A special case for Host header |
| | | if key in hdrs.HOST_ALL: |
| | | self.headers[key] = value |
| | | else: |
| | | self.headers.add(key, value) |
| | | |
| | | def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None: |
| | | if skip_auto_headers is not None: |
| | | self._skip_auto_headers = CIMultiDict( |
| | | (hdr, None) for hdr in sorted(skip_auto_headers) |
| | | ) |
| | | used_headers = self.headers.copy() |
| | | used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type] |
| | | else: |
| | | # Fast path when there are no headers to skip |
| | | # which is the most common case. |
| | | used_headers = self.headers |
| | | |
| | | for hdr, val in self.DEFAULT_HEADERS.items(): |
| | | if hdr not in used_headers: |
| | | self.headers[hdr] = val |
| | | |
| | | if hdrs.USER_AGENT not in used_headers: |
| | | self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE |
| | | |
| | | def update_cookies(self, cookies: Optional[LooseCookies]) -> None: |
| | | """Update request cookies header.""" |
| | | if not cookies: |
| | | return |
| | | |
| | | c = SimpleCookie() |
| | | if hdrs.COOKIE in self.headers: |
| | | # parse_cookie_header for RFC 6265 compliant Cookie header parsing |
| | | c.update(parse_cookie_header(self.headers.get(hdrs.COOKIE, ""))) |
| | | del self.headers[hdrs.COOKIE] |
| | | |
| | | if isinstance(cookies, Mapping): |
| | | iter_cookies = cookies.items() |
| | | else: |
| | | iter_cookies = cookies # type: ignore[assignment] |
| | | for name, value in iter_cookies: |
| | | if isinstance(value, Morsel): |
| | | # Use helper to preserve coded_value exactly as sent by server |
| | | c[name] = preserve_morsel_with_coded_value(value) |
| | | else: |
| | | c[name] = value # type: ignore[assignment] |
| | | |
| | | self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() |
| | | |
| | | def update_content_encoding(self, data: Any) -> None: |
| | | """Set request content encoding.""" |
| | | if not data: |
| | | # Don't compress an empty body. |
| | | self.compress = None |
| | | return |
| | | |
| | | if self.headers.get(hdrs.CONTENT_ENCODING): |
| | | if self.compress: |
| | | raise ValueError( |
| | | "compress can not be set if Content-Encoding header is set" |
| | | ) |
| | | elif self.compress: |
| | | if not isinstance(self.compress, str): |
| | | self.compress = "deflate" |
| | | self.headers[hdrs.CONTENT_ENCODING] = self.compress |
| | | self.chunked = True # enable chunked, no need to deal with length |
| | | |
| | | def update_transfer_encoding(self) -> None: |
| | | """Analyze transfer-encoding header.""" |
| | | te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() |
| | | |
| | | if "chunked" in te: |
| | | if self.chunked: |
| | | raise ValueError( |
| | | "chunked can not be set " |
| | | 'if "Transfer-Encoding: chunked" header is set' |
| | | ) |
| | | |
| | | elif self.chunked: |
| | | if hdrs.CONTENT_LENGTH in self.headers: |
| | | raise ValueError( |
| | | "chunked can not be set if Content-Length header is set" |
| | | ) |
| | | |
| | | self.headers[hdrs.TRANSFER_ENCODING] = "chunked" |
| | | |
| | | def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: |
| | | """Set basic auth.""" |
| | | if auth is None: |
| | | auth = self.auth |
| | | if auth is None: |
| | | return |
| | | |
| | | if not isinstance(auth, helpers.BasicAuth): |
| | | raise TypeError("BasicAuth() tuple is required instead") |
| | | |
| | | self.headers[hdrs.AUTHORIZATION] = auth.encode() |
| | | |
| | | def update_body_from_data(self, body: Any, _stacklevel: int = 3) -> None: |
| | | """Update request body from data.""" |
| | | if self._body is not None: |
| | | _warn_if_unclosed_payload(self._body, stacklevel=_stacklevel) |
| | | |
| | | if body is None: |
| | | self._body = None |
| | | # Set Content-Length to 0 when body is None for methods that expect a body |
| | | if ( |
| | | self.method not in self.GET_METHODS |
| | | and not self.chunked |
| | | and hdrs.CONTENT_LENGTH not in self.headers |
| | | ): |
| | | self.headers[hdrs.CONTENT_LENGTH] = "0" |
| | | return |
| | | |
| | | # FormData |
| | | maybe_payload = body() if isinstance(body, FormData) else body |
| | | |
| | | try: |
| | | body_payload = payload.PAYLOAD_REGISTRY.get(maybe_payload, disposition=None) |
| | | except payload.LookupError: |
| | | body_payload = FormData(maybe_payload)() # type: ignore[arg-type] |
| | | |
| | | self._body = body_payload |
| | | # enable chunked encoding if needed |
| | | if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: |
| | | if (size := body_payload.size) is not None: |
| | | self.headers[hdrs.CONTENT_LENGTH] = str(size) |
| | | else: |
| | | self.chunked = True |
| | | |
| | | # copy payload headers |
| | | assert body_payload.headers |
| | | headers = self.headers |
| | | skip_headers = self._skip_auto_headers |
| | | for key, value in body_payload.headers.items(): |
| | | if key in headers or (skip_headers is not None and key in skip_headers): |
| | | continue |
| | | headers[key] = value |
| | | |
| | | def _update_body(self, body: Any) -> None: |
| | | """Update request body after its already been set.""" |
| | | # Remove existing Content-Length header since body is changing |
| | | if hdrs.CONTENT_LENGTH in self.headers: |
| | | del self.headers[hdrs.CONTENT_LENGTH] |
| | | |
| | | # Remove existing Transfer-Encoding header to avoid conflicts |
| | | if self.chunked and hdrs.TRANSFER_ENCODING in self.headers: |
| | | del self.headers[hdrs.TRANSFER_ENCODING] |
| | | |
| | | # Now update the body using the existing method |
| | | # Called from _update_body, add 1 to stacklevel from caller |
| | | self.update_body_from_data(body, _stacklevel=4) |
| | | |
| | | # Update transfer encoding headers if needed (same logic as __init__) |
| | | if body is not None or self.method not in self.GET_METHODS: |
| | | self.update_transfer_encoding() |
| | | |
| | | async def update_body(self, body: Any) -> None: |
| | | """ |
| | | Update request body and close previous payload if needed. |
| | | |
| | | This method safely updates the request body by first closing any existing |
| | | payload to prevent resource leaks, then setting the new body. |
| | | |
| | | IMPORTANT: Always use this method instead of setting request.body directly. |
| | | Direct assignment to request.body will leak resources if the previous body |
| | | contains file handles, streams, or other resources that need cleanup. |
| | | |
| | | Args: |
| | | body: The new body content. Can be: |
| | | - bytes/bytearray: Raw binary data |
| | | - str: Text data (will be encoded using charset from Content-Type) |
| | | - FormData: Form data that will be encoded as multipart/form-data |
| | | - Payload: A pre-configured payload object |
| | | - AsyncIterable: An async iterable of bytes chunks |
| | | - File-like object: Will be read and sent as binary data |
| | | - None: Clears the body |
| | | |
| | | Usage: |
| | | # CORRECT: Use update_body |
| | | await request.update_body(b"new request data") |
| | | |
| | | # WRONG: Don't set body directly |
| | | # request.body = b"new request data" # This will leak resources! |
| | | |
| | | # Update with form data |
| | | form_data = FormData() |
| | | form_data.add_field('field', 'value') |
| | | await request.update_body(form_data) |
| | | |
| | | # Clear body |
| | | await request.update_body(None) |
| | | |
| | | Note: |
| | | This method is async because it may need to close file handles or |
| | | other resources associated with the previous payload. Always await |
| | | this method to ensure proper cleanup. |
| | | |
| | | Warning: |
| | | Setting request.body directly is highly discouraged and can lead to: |
| | | - Resource leaks (unclosed file handles, streams) |
| | | - Memory leaks (unreleased buffers) |
| | | - Unexpected behavior with streaming payloads |
| | | |
| | | It is not recommended to change the payload type in middleware. If the |
| | | body was already set (e.g., as bytes), it's best to keep the same type |
| | | rather than converting it (e.g., to str) as this may result in unexpected |
| | | behavior. |
| | | |
| | | See Also: |
| | | - update_body_from_data: Synchronous body update without cleanup |
| | | - body property: Direct body access (STRONGLY DISCOURAGED) |
| | | |
| | | """ |
| | | # Close existing payload if it exists and needs closing |
| | | if self._body is not None: |
| | | await self._body.close() |
| | | self._update_body(body) |
| | | |
| | | def update_expect_continue(self, expect: bool = False) -> None: |
| | | if expect: |
| | | self.headers[hdrs.EXPECT] = "100-continue" |
| | | elif ( |
| | | hdrs.EXPECT in self.headers |
| | | and self.headers[hdrs.EXPECT].lower() == "100-continue" |
| | | ): |
| | | expect = True |
| | | |
| | | if expect: |
| | | self._continue = self.loop.create_future() |
| | | |
| | | def update_proxy( |
| | | self, |
| | | proxy: Optional[URL], |
| | | proxy_auth: Optional[BasicAuth], |
| | | proxy_headers: Optional[LooseHeaders], |
| | | ) -> None: |
| | | self.proxy = proxy |
| | | if proxy is None: |
| | | self.proxy_auth = None |
| | | self.proxy_headers = None |
| | | return |
| | | |
| | | if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): |
| | | raise ValueError("proxy_auth must be None or BasicAuth() tuple") |
| | | self.proxy_auth = proxy_auth |
| | | |
| | | if proxy_headers is not None and not isinstance( |
| | | proxy_headers, (MultiDict, MultiDictProxy) |
| | | ): |
| | | proxy_headers = CIMultiDict(proxy_headers) |
| | | self.proxy_headers = proxy_headers |
| | | |
| | | async def write_bytes( |
| | | self, |
| | | writer: AbstractStreamWriter, |
| | | conn: "Connection", |
| | | content_length: Optional[int] = None, |
| | | ) -> None: |
| | | """ |
| | | Write the request body to the connection stream. |
| | | |
| | | This method handles writing different types of request bodies: |
| | | 1. Payload objects (using their specialized write_with_length method) |
| | | 2. Bytes/bytearray objects |
| | | 3. Iterable body content |
| | | |
| | | Args: |
| | | writer: The stream writer to write the body to |
| | | conn: The connection being used for this request |
| | | content_length: Optional maximum number of bytes to write from the body |
| | | (None means write the entire body) |
| | | |
| | | The method properly handles: |
| | | - Waiting for 100-Continue responses if required |
| | | - Content length constraints for chunked encoding |
| | | - Error handling for network issues, cancellation, and other exceptions |
| | | - Signaling EOF and timeout management |
| | | |
| | | Raises: |
| | | ClientOSError: When there's an OS-level error writing the body |
| | | ClientConnectionError: When there's a general connection error |
| | | asyncio.CancelledError: When the operation is cancelled |
| | | |
| | | """ |
| | | # 100 response |
| | | if self._continue is not None: |
| | | # Force headers to be sent before waiting for 100-continue |
| | | writer.send_headers() |
| | | await writer.drain() |
| | | await self._continue |
| | | |
| | | protocol = conn.protocol |
| | | assert protocol is not None |
| | | try: |
| | | # This should be a rare case but the |
| | | # self._body can be set to None while |
| | | # the task is being started or we wait above |
| | | # for the 100-continue response. |
| | | # The more likely case is we have an empty |
| | | # payload, but 100-continue is still expected. |
| | | if self._body is not None: |
| | | await self._body.write_with_length(writer, content_length) |
| | | except OSError as underlying_exc: |
| | | reraised_exc = underlying_exc |
| | | |
| | | # Distinguish between timeout and other OS errors for better error reporting |
| | | exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( |
| | | underlying_exc, asyncio.TimeoutError |
| | | ) |
| | | if exc_is_not_timeout: |
| | | reraised_exc = ClientOSError( |
| | | underlying_exc.errno, |
| | | f"Can not write request body for {self.url !s}", |
| | | ) |
| | | |
| | | set_exception(protocol, reraised_exc, underlying_exc) |
| | | except asyncio.CancelledError: |
| | | # Body hasn't been fully sent, so connection can't be reused |
| | | conn.close() |
| | | raise |
| | | except Exception as underlying_exc: |
| | | set_exception( |
| | | protocol, |
| | | ClientConnectionError( |
| | | "Failed to send bytes into the underlying connection " |
| | | f"{conn !s}: {underlying_exc!r}", |
| | | ), |
| | | underlying_exc, |
| | | ) |
| | | else: |
| | | # Successfully wrote the body, signal EOF and start response timeout |
| | | await writer.write_eof() |
| | | protocol.start_timeout() |
| | | |
| | | async def send(self, conn: "Connection") -> "ClientResponse": |
| | | # Specify request target: |
| | | # - CONNECT request must send authority form URI |
| | | # - not CONNECT proxy must send absolute form URI |
| | | # - most common is origin form URI |
| | | if self.method == hdrs.METH_CONNECT: |
| | | connect_host = self.url.host_subcomponent |
| | | assert connect_host is not None |
| | | path = f"{connect_host}:{self.url.port}" |
| | | elif self.proxy and not self.is_ssl(): |
| | | path = str(self.url) |
| | | else: |
| | | path = self.url.raw_path_qs |
| | | |
| | | protocol = conn.protocol |
| | | assert protocol is not None |
| | | writer = StreamWriter( |
| | | protocol, |
| | | self.loop, |
| | | on_chunk_sent=( |
| | | functools.partial(self._on_chunk_request_sent, self.method, self.url) |
| | | if self._traces |
| | | else None |
| | | ), |
| | | on_headers_sent=( |
| | | functools.partial(self._on_headers_request_sent, self.method, self.url) |
| | | if self._traces |
| | | else None |
| | | ), |
| | | ) |
| | | |
| | | if self.compress: |
| | | writer.enable_compression(self.compress) # type: ignore[arg-type] |
| | | |
| | | if self.chunked is not None: |
| | | writer.enable_chunking() |
| | | |
| | | # set default content-type |
| | | if ( |
| | | self.method in self.POST_METHODS |
| | | and ( |
| | | self._skip_auto_headers is None |
| | | or hdrs.CONTENT_TYPE not in self._skip_auto_headers |
| | | ) |
| | | and hdrs.CONTENT_TYPE not in self.headers |
| | | ): |
| | | self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" |
| | | |
| | | v = self.version |
| | | if hdrs.CONNECTION not in self.headers: |
| | | if conn._connector.force_close: |
| | | if v == HttpVersion11: |
| | | self.headers[hdrs.CONNECTION] = "close" |
| | | elif v == HttpVersion10: |
| | | self.headers[hdrs.CONNECTION] = "keep-alive" |
| | | |
| | | # status + headers |
| | | status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" |
| | | |
| | | # Buffer headers for potential coalescing with body |
| | | await writer.write_headers(status_line, self.headers) |
| | | |
| | | task: Optional["asyncio.Task[None]"] |
| | | if self._body or self._continue is not None or protocol.writing_paused: |
| | | coro = self.write_bytes(writer, conn, self._get_content_length()) |
| | | if sys.version_info >= (3, 12): |
| | | # Optimization for Python 3.12, try to write |
| | | # bytes immediately to avoid having to schedule |
| | | # the task on the event loop. |
| | | task = asyncio.Task(coro, loop=self.loop, eager_start=True) |
| | | else: |
| | | task = self.loop.create_task(coro) |
| | | if task.done(): |
| | | task = None |
| | | else: |
| | | self._writer = task |
| | | else: |
| | | # We have nothing to write because |
| | | # - there is no body |
| | | # - the protocol does not have writing paused |
| | | # - we are not waiting for a 100-continue response |
| | | protocol.start_timeout() |
| | | writer.set_eof() |
| | | task = None |
| | | response_class = self.response_class |
| | | assert response_class is not None |
| | | self.response = response_class( |
| | | self.method, |
| | | self.original_url, |
| | | writer=task, |
| | | continue100=self._continue, |
| | | timer=self._timer, |
| | | request_info=self.request_info, |
| | | traces=self._traces, |
| | | loop=self.loop, |
| | | session=self._session, |
| | | ) |
| | | return self.response |
| | | |
| | | async def close(self) -> None: |
| | | if self.__writer is not None: |
| | | try: |
| | | await self.__writer |
| | | except asyncio.CancelledError: |
| | | if ( |
| | | sys.version_info >= (3, 11) |
| | | and (task := asyncio.current_task()) |
| | | and task.cancelling() |
| | | ): |
| | | raise |
| | | |
| | | def terminate(self) -> None: |
| | | if self.__writer is not None: |
| | | if not self.loop.is_closed(): |
| | | self.__writer.cancel() |
| | | self.__writer.remove_done_callback(self.__reset_writer) |
| | | self.__writer = None |
| | | |
| | | async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: |
| | | for trace in self._traces: |
| | | await trace.send_request_chunk_sent(method, url, chunk) |
| | | |
| | | async def _on_headers_request_sent( |
| | | self, method: str, url: URL, headers: "CIMultiDict[str]" |
| | | ) -> None: |
| | | for trace in self._traces: |
| | | await trace.send_request_headers(method, url, headers) |
| New file |
| | |
| | | """WebSocket client for asyncio.""" |
| | | |
| | | import asyncio |
| | | import sys |
| | | from types import TracebackType |
| | | from typing import Any, Optional, Type, cast |
| | | |
| | | import attr |
| | | |
| | | from ._websocket.reader import WebSocketDataQueue |
| | | from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError |
| | | from .client_reqrep import ClientResponse |
| | | from .helpers import calculate_timeout_when, set_result |
| | | from .http import ( |
| | | WS_CLOSED_MESSAGE, |
| | | WS_CLOSING_MESSAGE, |
| | | WebSocketError, |
| | | WSCloseCode, |
| | | WSMessage, |
| | | WSMsgType, |
| | | ) |
| | | from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter |
| | | from .streams import EofStream |
| | | from .typedefs import ( |
| | | DEFAULT_JSON_DECODER, |
| | | DEFAULT_JSON_ENCODER, |
| | | JSONDecoder, |
| | | JSONEncoder, |
| | | ) |
| | | |
| | | if sys.version_info >= (3, 11): |
| | | import asyncio as async_timeout |
| | | else: |
| | | import async_timeout |
| | | |
| | | |
| | | @attr.s(frozen=True, slots=True) |
| | | class ClientWSTimeout: |
| | | ws_receive = attr.ib(type=Optional[float], default=None) |
| | | ws_close = attr.ib(type=Optional[float], default=None) |
| | | |
| | | |
| | | DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0) |
| | | |
| | | |
| | | class ClientWebSocketResponse: |
| | | def __init__( |
| | | self, |
| | | reader: WebSocketDataQueue, |
| | | writer: WebSocketWriter, |
| | | protocol: Optional[str], |
| | | response: ClientResponse, |
| | | timeout: ClientWSTimeout, |
| | | autoclose: bool, |
| | | autoping: bool, |
| | | loop: asyncio.AbstractEventLoop, |
| | | *, |
| | | heartbeat: Optional[float] = None, |
| | | compress: int = 0, |
| | | client_notakeover: bool = False, |
| | | ) -> None: |
| | | self._response = response |
| | | self._conn = response.connection |
| | | |
| | | self._writer = writer |
| | | self._reader = reader |
| | | self._protocol = protocol |
| | | self._closed = False |
| | | self._closing = False |
| | | self._close_code: Optional[int] = None |
| | | self._timeout = timeout |
| | | self._autoclose = autoclose |
| | | self._autoping = autoping |
| | | self._heartbeat = heartbeat |
| | | self._heartbeat_cb: Optional[asyncio.TimerHandle] = None |
| | | self._heartbeat_when: float = 0.0 |
| | | if heartbeat is not None: |
| | | self._pong_heartbeat = heartbeat / 2.0 |
| | | self._pong_response_cb: Optional[asyncio.TimerHandle] = None |
| | | self._loop = loop |
| | | self._waiting: bool = False |
| | | self._close_wait: Optional[asyncio.Future[None]] = None |
| | | self._exception: Optional[BaseException] = None |
| | | self._compress = compress |
| | | self._client_notakeover = client_notakeover |
| | | self._ping_task: Optional[asyncio.Task[None]] = None |
| | | |
| | | self._reset_heartbeat() |
| | | |
| | | def _cancel_heartbeat(self) -> None: |
| | | self._cancel_pong_response_cb() |
| | | if self._heartbeat_cb is not None: |
| | | self._heartbeat_cb.cancel() |
| | | self._heartbeat_cb = None |
| | | if self._ping_task is not None: |
| | | self._ping_task.cancel() |
| | | self._ping_task = None |
| | | |
| | | def _cancel_pong_response_cb(self) -> None: |
| | | if self._pong_response_cb is not None: |
| | | self._pong_response_cb.cancel() |
| | | self._pong_response_cb = None |
| | | |
| | | def _reset_heartbeat(self) -> None: |
| | | if self._heartbeat is None: |
| | | return |
| | | self._cancel_pong_response_cb() |
| | | loop = self._loop |
| | | assert loop is not None |
| | | conn = self._conn |
| | | timeout_ceil_threshold = ( |
| | | conn._connector._timeout_ceil_threshold if conn is not None else 5 |
| | | ) |
| | | now = loop.time() |
| | | when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) |
| | | self._heartbeat_when = when |
| | | if self._heartbeat_cb is None: |
| | | # We do not cancel the previous heartbeat_cb here because |
| | | # it generates a significant amount of TimerHandle churn |
| | | # which causes asyncio to rebuild the heap frequently. |
| | | # Instead _send_heartbeat() will reschedule the next |
| | | # heartbeat if it fires too early. |
| | | self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) |
| | | |
| | | def _send_heartbeat(self) -> None: |
| | | self._heartbeat_cb = None |
| | | loop = self._loop |
| | | now = loop.time() |
| | | if now < self._heartbeat_when: |
| | | # Heartbeat fired too early, reschedule |
| | | self._heartbeat_cb = loop.call_at( |
| | | self._heartbeat_when, self._send_heartbeat |
| | | ) |
| | | return |
| | | |
| | | conn = self._conn |
| | | timeout_ceil_threshold = ( |
| | | conn._connector._timeout_ceil_threshold if conn is not None else 5 |
| | | ) |
| | | when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) |
| | | self._cancel_pong_response_cb() |
| | | self._pong_response_cb = loop.call_at(when, self._pong_not_received) |
| | | |
| | | coro = self._writer.send_frame(b"", WSMsgType.PING) |
| | | if sys.version_info >= (3, 12): |
| | | # Optimization for Python 3.12, try to send the ping |
| | | # immediately to avoid having to schedule |
| | | # the task on the event loop. |
| | | ping_task = asyncio.Task(coro, loop=loop, eager_start=True) |
| | | else: |
| | | ping_task = loop.create_task(coro) |
| | | |
| | | if not ping_task.done(): |
| | | self._ping_task = ping_task |
| | | ping_task.add_done_callback(self._ping_task_done) |
| | | else: |
| | | self._ping_task_done(ping_task) |
| | | |
| | | def _ping_task_done(self, task: "asyncio.Task[None]") -> None: |
| | | """Callback for when the ping task completes.""" |
| | | if not task.cancelled() and (exc := task.exception()): |
| | | self._handle_ping_pong_exception(exc) |
| | | self._ping_task = None |
| | | |
| | | def _pong_not_received(self) -> None: |
| | | self._handle_ping_pong_exception( |
| | | ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") |
| | | ) |
| | | |
| | | def _handle_ping_pong_exception(self, exc: BaseException) -> None: |
| | | """Handle exceptions raised during ping/pong processing.""" |
| | | if self._closed: |
| | | return |
| | | self._set_closed() |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | self._exception = exc |
| | | self._response.close() |
| | | if self._waiting and not self._closing: |
| | | self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0) |
| | | |
| | | def _set_closed(self) -> None: |
| | | """Set the connection to closed. |
| | | |
| | | Cancel any heartbeat timers and set the closed flag. |
| | | """ |
| | | self._closed = True |
| | | self._cancel_heartbeat() |
| | | |
| | | def _set_closing(self) -> None: |
| | | """Set the connection to closing. |
| | | |
| | | Cancel any heartbeat timers and set the closing flag. |
| | | """ |
| | | self._closing = True |
| | | self._cancel_heartbeat() |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | return self._closed |
| | | |
| | | @property |
| | | def close_code(self) -> Optional[int]: |
| | | return self._close_code |
| | | |
| | | @property |
| | | def protocol(self) -> Optional[str]: |
| | | return self._protocol |
| | | |
| | | @property |
| | | def compress(self) -> int: |
| | | return self._compress |
| | | |
| | | @property |
| | | def client_notakeover(self) -> bool: |
| | | return self._client_notakeover |
| | | |
| | | def get_extra_info(self, name: str, default: Any = None) -> Any: |
| | | """extra info from connection transport""" |
| | | conn = self._response.connection |
| | | if conn is None: |
| | | return default |
| | | transport = conn.transport |
| | | if transport is None: |
| | | return default |
| | | return transport.get_extra_info(name, default) |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return self._exception |
| | | |
| | | async def ping(self, message: bytes = b"") -> None: |
| | | await self._writer.send_frame(message, WSMsgType.PING) |
| | | |
| | | async def pong(self, message: bytes = b"") -> None: |
| | | await self._writer.send_frame(message, WSMsgType.PONG) |
| | | |
| | | async def send_frame( |
| | | self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None |
| | | ) -> None: |
| | | """Send a frame over the websocket.""" |
| | | await self._writer.send_frame(message, opcode, compress) |
| | | |
| | | async def send_str(self, data: str, compress: Optional[int] = None) -> None: |
| | | if not isinstance(data, str): |
| | | raise TypeError("data argument must be str (%r)" % type(data)) |
| | | await self._writer.send_frame( |
| | | data.encode("utf-8"), WSMsgType.TEXT, compress=compress |
| | | ) |
| | | |
| | | async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: |
| | | if not isinstance(data, (bytes, bytearray, memoryview)): |
| | | raise TypeError("data argument must be byte-ish (%r)" % type(data)) |
| | | await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) |
| | | |
| | | async def send_json( |
| | | self, |
| | | data: Any, |
| | | compress: Optional[int] = None, |
| | | *, |
| | | dumps: JSONEncoder = DEFAULT_JSON_ENCODER, |
| | | ) -> None: |
| | | await self.send_str(dumps(data), compress=compress) |
| | | |
| | | async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: |
| | | # we need to break `receive()` cycle first, |
| | | # `close()` may be called from different task |
| | | if self._waiting and not self._closing: |
| | | assert self._loop is not None |
| | | self._close_wait = self._loop.create_future() |
| | | self._set_closing() |
| | | self._reader.feed_data(WS_CLOSING_MESSAGE, 0) |
| | | await self._close_wait |
| | | |
| | | if self._closed: |
| | | return False |
| | | |
| | | self._set_closed() |
| | | try: |
| | | await self._writer.close(code, message) |
| | | except asyncio.CancelledError: |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | self._response.close() |
| | | raise |
| | | except Exception as exc: |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | self._exception = exc |
| | | self._response.close() |
| | | return True |
| | | |
| | | if self._close_code: |
| | | self._response.close() |
| | | return True |
| | | |
| | | while True: |
| | | try: |
| | | async with async_timeout.timeout(self._timeout.ws_close): |
| | | msg = await self._reader.read() |
| | | except asyncio.CancelledError: |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | self._response.close() |
| | | raise |
| | | except Exception as exc: |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | self._exception = exc |
| | | self._response.close() |
| | | return True |
| | | |
| | | if msg.type is WSMsgType.CLOSE: |
| | | self._close_code = msg.data |
| | | self._response.close() |
| | | return True |
| | | |
| | | async def receive(self, timeout: Optional[float] = None) -> WSMessage: |
| | | receive_timeout = timeout or self._timeout.ws_receive |
| | | |
| | | while True: |
| | | if self._waiting: |
| | | raise RuntimeError("Concurrent call to receive() is not allowed") |
| | | |
| | | if self._closed: |
| | | return WS_CLOSED_MESSAGE |
| | | elif self._closing: |
| | | await self.close() |
| | | return WS_CLOSED_MESSAGE |
| | | |
| | | try: |
| | | self._waiting = True |
| | | try: |
| | | if receive_timeout: |
| | | # Entering the context manager and creating |
| | | # Timeout() object can take almost 50% of the |
| | | # run time in this loop so we avoid it if |
| | | # there is no read timeout. |
| | | async with async_timeout.timeout(receive_timeout): |
| | | msg = await self._reader.read() |
| | | else: |
| | | msg = await self._reader.read() |
| | | self._reset_heartbeat() |
| | | finally: |
| | | self._waiting = False |
| | | if self._close_wait: |
| | | set_result(self._close_wait, None) |
| | | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | raise |
| | | except EofStream: |
| | | self._close_code = WSCloseCode.OK |
| | | await self.close() |
| | | return WSMessage(WSMsgType.CLOSED, None, None) |
| | | except ClientError: |
| | | # Likely ServerDisconnectedError when connection is lost |
| | | self._set_closed() |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | return WS_CLOSED_MESSAGE |
| | | except WebSocketError as exc: |
| | | self._close_code = exc.code |
| | | await self.close(code=exc.code) |
| | | return WSMessage(WSMsgType.ERROR, exc, None) |
| | | except Exception as exc: |
| | | self._exception = exc |
| | | self._set_closing() |
| | | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | | await self.close() |
| | | return WSMessage(WSMsgType.ERROR, exc, None) |
| | | |
| | | if msg.type not in _INTERNAL_RECEIVE_TYPES: |
| | | # If its not a close/closing/ping/pong message |
| | | # we can return it immediately |
| | | return msg |
| | | |
| | | if msg.type is WSMsgType.CLOSE: |
| | | self._set_closing() |
| | | self._close_code = msg.data |
| | | if not self._closed and self._autoclose: |
| | | await self.close() |
| | | elif msg.type is WSMsgType.CLOSING: |
| | | self._set_closing() |
| | | elif msg.type is WSMsgType.PING and self._autoping: |
| | | await self.pong(msg.data) |
| | | continue |
| | | elif msg.type is WSMsgType.PONG and self._autoping: |
| | | continue |
| | | |
| | | return msg |
| | | |
| | | async def receive_str(self, *, timeout: Optional[float] = None) -> str: |
| | | msg = await self.receive(timeout) |
| | | if msg.type is not WSMsgType.TEXT: |
| | | raise WSMessageTypeError( |
| | | f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" |
| | | ) |
| | | return cast(str, msg.data) |
| | | |
| | | async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: |
| | | msg = await self.receive(timeout) |
| | | if msg.type is not WSMsgType.BINARY: |
| | | raise WSMessageTypeError( |
| | | f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" |
| | | ) |
| | | return cast(bytes, msg.data) |
| | | |
| | | async def receive_json( |
| | | self, |
| | | *, |
| | | loads: JSONDecoder = DEFAULT_JSON_DECODER, |
| | | timeout: Optional[float] = None, |
| | | ) -> Any: |
| | | data = await self.receive_str(timeout=timeout) |
| | | return loads(data) |
| | | |
| | | def __aiter__(self) -> "ClientWebSocketResponse": |
| | | return self |
| | | |
| | | async def __anext__(self) -> WSMessage: |
| | | msg = await self.receive() |
| | | if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): |
| | | raise StopAsyncIteration |
| | | return msg |
| | | |
| | | async def __aenter__(self) -> "ClientWebSocketResponse": |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | await self.close() |
| New file |
| | |
| | | import asyncio |
| | | import sys |
| | | import zlib |
| | | from abc import ABC, abstractmethod |
| | | from concurrent.futures import Executor |
| | | from typing import Any, Final, Optional, Protocol, TypedDict, cast |
| | | |
| | | if sys.version_info >= (3, 12): |
| | | from collections.abc import Buffer |
| | | else: |
| | | from typing import Union |
| | | |
| | | Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] |
| | | |
| | | try: |
| | | try: |
| | | import brotlicffi as brotli |
| | | except ImportError: |
| | | import brotli |
| | | |
| | | HAS_BROTLI = True |
| | | except ImportError: # pragma: no cover |
| | | HAS_BROTLI = False |
| | | |
| | | try: |
| | | if sys.version_info >= (3, 14): |
| | | from compression.zstd import ZstdDecompressor # noqa: I900 |
| | | else: # TODO(PY314): Remove mentions of backports.zstd across codebase |
| | | from backports.zstd import ZstdDecompressor |
| | | |
| | | HAS_ZSTD = True |
| | | except ImportError: |
| | | HAS_ZSTD = False |
| | | |
| | | |
| | | MAX_SYNC_CHUNK_SIZE = 4096 |
| | | DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB |
| | | |
| | | # Unlimited decompression constants - different libraries use different conventions |
| | | ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited |
| | | ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited |
| | | |
| | | |
| | | class ZLibCompressObjProtocol(Protocol): |
| | | def compress(self, data: Buffer) -> bytes: ... |
| | | def flush(self, mode: int = ..., /) -> bytes: ... |
| | | |
| | | |
| | | class ZLibDecompressObjProtocol(Protocol): |
| | | def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ... |
| | | def flush(self, length: int = ..., /) -> bytes: ... |
| | | |
| | | @property |
| | | def eof(self) -> bool: ... |
| | | |
| | | |
| | | class ZLibBackendProtocol(Protocol): |
| | | MAX_WBITS: int |
| | | Z_FULL_FLUSH: int |
| | | Z_SYNC_FLUSH: int |
| | | Z_BEST_SPEED: int |
| | | Z_FINISH: int |
| | | |
| | | def compressobj( |
| | | self, |
| | | level: int = ..., |
| | | method: int = ..., |
| | | wbits: int = ..., |
| | | memLevel: int = ..., |
| | | strategy: int = ..., |
| | | zdict: Optional[Buffer] = ..., |
| | | ) -> ZLibCompressObjProtocol: ... |
| | | def decompressobj( |
| | | self, wbits: int = ..., zdict: Buffer = ... |
| | | ) -> ZLibDecompressObjProtocol: ... |
| | | |
| | | def compress( |
| | | self, data: Buffer, /, level: int = ..., wbits: int = ... |
| | | ) -> bytes: ... |
| | | def decompress( |
| | | self, data: Buffer, /, wbits: int = ..., bufsize: int = ... |
| | | ) -> bytes: ... |
| | | |
| | | |
| | | class CompressObjArgs(TypedDict, total=False): |
| | | wbits: int |
| | | strategy: int |
| | | level: int |
| | | |
| | | |
| | | class ZLibBackendWrapper: |
| | | def __init__(self, _zlib_backend: ZLibBackendProtocol): |
| | | self._zlib_backend: ZLibBackendProtocol = _zlib_backend |
| | | |
| | | @property |
| | | def name(self) -> str: |
| | | return getattr(self._zlib_backend, "__name__", "undefined") |
| | | |
| | | @property |
| | | def MAX_WBITS(self) -> int: |
| | | return self._zlib_backend.MAX_WBITS |
| | | |
| | | @property |
| | | def Z_FULL_FLUSH(self) -> int: |
| | | return self._zlib_backend.Z_FULL_FLUSH |
| | | |
| | | @property |
| | | def Z_SYNC_FLUSH(self) -> int: |
| | | return self._zlib_backend.Z_SYNC_FLUSH |
| | | |
| | | @property |
| | | def Z_BEST_SPEED(self) -> int: |
| | | return self._zlib_backend.Z_BEST_SPEED |
| | | |
| | | @property |
| | | def Z_FINISH(self) -> int: |
| | | return self._zlib_backend.Z_FINISH |
| | | |
| | | def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol: |
| | | return self._zlib_backend.compressobj(*args, **kwargs) |
| | | |
| | | def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol: |
| | | return self._zlib_backend.decompressobj(*args, **kwargs) |
| | | |
| | | def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: |
| | | return self._zlib_backend.compress(data, *args, **kwargs) |
| | | |
| | | def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: |
| | | return self._zlib_backend.decompress(data, *args, **kwargs) |
| | | |
| | | # Everything not explicitly listed in the Protocol we just pass through |
| | | def __getattr__(self, attrname: str) -> Any: |
| | | return getattr(self._zlib_backend, attrname) |
| | | |
| | | |
| | | ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib) |
| | | |
| | | |
| | | def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: |
| | | ZLibBackend._zlib_backend = new_zlib_backend |
| | | |
| | | |
| | | def encoding_to_mode( |
| | | encoding: Optional[str] = None, |
| | | suppress_deflate_header: bool = False, |
| | | ) -> int: |
| | | if encoding == "gzip": |
| | | return 16 + ZLibBackend.MAX_WBITS |
| | | |
| | | return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS |
| | | |
| | | |
| | | class DecompressionBaseHandler(ABC): |
| | | def __init__( |
| | | self, |
| | | executor: Optional[Executor] = None, |
| | | max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, |
| | | ): |
| | | """Base class for decompression handlers.""" |
| | | self._executor = executor |
| | | self._max_sync_chunk_size = max_sync_chunk_size |
| | | |
| | | @abstractmethod |
| | | def decompress_sync( |
| | | self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED |
| | | ) -> bytes: |
| | | """Decompress the given data.""" |
| | | |
| | | async def decompress( |
| | | self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED |
| | | ) -> bytes: |
| | | """Decompress the given data.""" |
| | | if ( |
| | | self._max_sync_chunk_size is not None |
| | | and len(data) > self._max_sync_chunk_size |
| | | ): |
| | | return await asyncio.get_event_loop().run_in_executor( |
| | | self._executor, self.decompress_sync, data, max_length |
| | | ) |
| | | return self.decompress_sync(data, max_length) |
| | | |
| | | |
| | | class ZLibCompressor: |
| | | def __init__( |
| | | self, |
| | | encoding: Optional[str] = None, |
| | | suppress_deflate_header: bool = False, |
| | | level: Optional[int] = None, |
| | | wbits: Optional[int] = None, |
| | | strategy: Optional[int] = None, |
| | | executor: Optional[Executor] = None, |
| | | max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, |
| | | ): |
| | | self._executor = executor |
| | | self._max_sync_chunk_size = max_sync_chunk_size |
| | | self._mode = ( |
| | | encoding_to_mode(encoding, suppress_deflate_header) |
| | | if wbits is None |
| | | else wbits |
| | | ) |
| | | self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) |
| | | |
| | | kwargs: CompressObjArgs = {} |
| | | kwargs["wbits"] = self._mode |
| | | if strategy is not None: |
| | | kwargs["strategy"] = strategy |
| | | if level is not None: |
| | | kwargs["level"] = level |
| | | self._compressor = self._zlib_backend.compressobj(**kwargs) |
| | | |
| | | def compress_sync(self, data: bytes) -> bytes: |
| | | return self._compressor.compress(data) |
| | | |
| | | async def compress(self, data: bytes) -> bytes: |
| | | """Compress the data and returned the compressed bytes. |
| | | |
| | | Note that flush() must be called after the last call to compress() |
| | | |
| | | If the data size is large than the max_sync_chunk_size, the compression |
| | | will be done in the executor. Otherwise, the compression will be done |
| | | in the event loop. |
| | | |
| | | **WARNING: This method is NOT cancellation-safe when used with flush().** |
| | | If this operation is cancelled, the compressor state may be corrupted. |
| | | The connection MUST be closed after cancellation to avoid data corruption |
| | | in subsequent compress operations. |
| | | |
| | | For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap |
| | | compress() + flush() + send operations in a shield and lock to ensure atomicity. |
| | | """ |
| | | # For large payloads, offload compression to executor to avoid blocking event loop |
| | | should_use_executor = ( |
| | | self._max_sync_chunk_size is not None |
| | | and len(data) > self._max_sync_chunk_size |
| | | ) |
| | | if should_use_executor: |
| | | return await asyncio.get_running_loop().run_in_executor( |
| | | self._executor, self._compressor.compress, data |
| | | ) |
| | | return self.compress_sync(data) |
| | | |
| | | def flush(self, mode: Optional[int] = None) -> bytes: |
| | | """Flush the compressor synchronously. |
| | | |
| | | **WARNING: This method is NOT cancellation-safe when called after compress().** |
| | | The flush() operation accesses shared compressor state. If compress() was |
| | | cancelled, calling flush() may result in corrupted data. The connection MUST |
| | | be closed after compress() cancellation. |
| | | |
| | | For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap |
| | | compress() + flush() + send operations in a shield and lock to ensure atomicity. |
| | | """ |
| | | return self._compressor.flush( |
| | | mode if mode is not None else self._zlib_backend.Z_FINISH |
| | | ) |
| | | |
| | | |
| | | class ZLibDecompressor(DecompressionBaseHandler): |
| | | def __init__( |
| | | self, |
| | | encoding: Optional[str] = None, |
| | | suppress_deflate_header: bool = False, |
| | | executor: Optional[Executor] = None, |
| | | max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, |
| | | ): |
| | | super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) |
| | | self._mode = encoding_to_mode(encoding, suppress_deflate_header) |
| | | self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) |
| | | self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) |
| | | |
| | | def decompress_sync( |
| | | self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED |
| | | ) -> bytes: |
| | | return self._decompressor.decompress(data, max_length) |
| | | |
| | | def flush(self, length: int = 0) -> bytes: |
| | | return ( |
| | | self._decompressor.flush(length) |
| | | if length > 0 |
| | | else self._decompressor.flush() |
| | | ) |
| | | |
| | | @property |
| | | def eof(self) -> bool: |
| | | return self._decompressor.eof |
| | | |
| | | |
| | | class BrotliDecompressor(DecompressionBaseHandler): |
| | | # Supports both 'brotlipy' and 'Brotli' packages |
| | | # since they share an import name. The top branches |
| | | # are for 'brotlipy' and bottom branches for 'Brotli' |
| | | def __init__( |
| | | self, |
| | | executor: Optional[Executor] = None, |
| | | max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, |
| | | ) -> None: |
| | | """Decompress data using the Brotli library.""" |
| | | if not HAS_BROTLI: |
| | | raise RuntimeError( |
| | | "The brotli decompression is not available. " |
| | | "Please install `Brotli` module" |
| | | ) |
| | | self._obj = brotli.Decompressor() |
| | | super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) |
| | | |
| | | def decompress_sync( |
| | | self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED |
| | | ) -> bytes: |
| | | """Decompress the given data.""" |
| | | if hasattr(self._obj, "decompress"): |
| | | return cast(bytes, self._obj.decompress(data, max_length)) |
| | | return cast(bytes, self._obj.process(data, max_length)) |
| | | |
| | | def flush(self) -> bytes: |
| | | """Flush the decompressor.""" |
| | | if hasattr(self._obj, "flush"): |
| | | return cast(bytes, self._obj.flush()) |
| | | return b"" |
| | | |
| | | |
| | | class ZSTDDecompressor(DecompressionBaseHandler): |
| | | def __init__( |
| | | self, |
| | | executor: Optional[Executor] = None, |
| | | max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, |
| | | ) -> None: |
| | | if not HAS_ZSTD: |
| | | raise RuntimeError( |
| | | "The zstd decompression is not available. " |
| | | "Please install `backports.zstd` module" |
| | | ) |
| | | self._obj = ZstdDecompressor() |
| | | super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) |
| | | |
| | | def decompress_sync( |
| | | self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED |
| | | ) -> bytes: |
| | | # zstd uses -1 for unlimited, while zlib uses 0 for unlimited |
| | | # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited) |
| | | zstd_max_length = ( |
| | | ZSTD_MAX_LENGTH_UNLIMITED |
| | | if max_length == ZLIB_MAX_LENGTH_UNLIMITED |
| | | else max_length |
| | | ) |
| | | return self._obj.decompress(data, zstd_max_length) |
| | | |
| | | def flush(self) -> bytes: |
| | | return b"" |
| New file |
| | |
| | | import asyncio |
| | | import functools |
| | | import random |
| | | import socket |
| | | import sys |
| | | import traceback |
| | | import warnings |
| | | from collections import OrderedDict, defaultdict, deque |
| | | from contextlib import suppress |
| | | from http import HTTPStatus |
| | | from itertools import chain, cycle, islice |
| | | from time import monotonic |
| | | from types import TracebackType |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | DefaultDict, |
| | | Deque, |
| | | Dict, |
| | | Iterator, |
| | | List, |
| | | Literal, |
| | | Optional, |
| | | Sequence, |
| | | Set, |
| | | Tuple, |
| | | Type, |
| | | Union, |
| | | cast, |
| | | ) |
| | | |
| | | import aiohappyeyeballs |
| | | from aiohappyeyeballs import AddrInfoType, SocketFactoryType |
| | | |
| | | from . import hdrs, helpers |
| | | from .abc import AbstractResolver, ResolveResult |
| | | from .client_exceptions import ( |
| | | ClientConnectionError, |
| | | ClientConnectorCertificateError, |
| | | ClientConnectorDNSError, |
| | | ClientConnectorError, |
| | | ClientConnectorSSLError, |
| | | ClientHttpProxyError, |
| | | ClientProxyConnectionError, |
| | | ServerFingerprintMismatch, |
| | | UnixClientConnectorError, |
| | | cert_errors, |
| | | ssl_errors, |
| | | ) |
| | | from .client_proto import ResponseHandler |
| | | from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params |
| | | from .helpers import ( |
| | | _SENTINEL, |
| | | ceil_timeout, |
| | | is_ip_address, |
| | | noop, |
| | | sentinel, |
| | | set_exception, |
| | | set_result, |
| | | ) |
| | | from .log import client_logger |
| | | from .resolver import DefaultResolver |
| | | |
| | | if sys.version_info >= (3, 12): |
| | | from collections.abc import Buffer |
| | | else: |
| | | Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] |
| | | |
| | | if TYPE_CHECKING: |
| | | import ssl |
| | | |
| | | SSLContext = ssl.SSLContext |
| | | else: |
| | | try: |
| | | import ssl |
| | | |
| | | SSLContext = ssl.SSLContext |
| | | except ImportError: # pragma: no cover |
| | | ssl = None # type: ignore[assignment] |
| | | SSLContext = object # type: ignore[misc,assignment] |
| | | |
| | | EMPTY_SCHEMA_SET = frozenset({""}) |
| | | HTTP_SCHEMA_SET = frozenset({"http", "https"}) |
| | | WS_SCHEMA_SET = frozenset({"ws", "wss"}) |
| | | |
| | | HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET |
| | | HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET |
| | | |
| | | NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < ( |
| | | 3, |
| | | 13, |
| | | 1, |
| | | ) or sys.version_info < (3, 12, 7) |
| | | # Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960 |
| | | # which first appeared in Python 3.12.7 and 3.13.1 |
| | | |
| | | |
| | | __all__ = ( |
| | | "BaseConnector", |
| | | "TCPConnector", |
| | | "UnixConnector", |
| | | "NamedPipeConnector", |
| | | "AddrInfoType", |
| | | "SocketFactoryType", |
| | | ) |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from .client import ClientTimeout |
| | | from .client_reqrep import ConnectionKey |
| | | from .tracing import Trace |
| | | |
| | | |
| | | class _DeprecationWaiter: |
| | | __slots__ = ("_awaitable", "_awaited") |
| | | |
| | | def __init__(self, awaitable: Awaitable[Any]) -> None: |
| | | self._awaitable = awaitable |
| | | self._awaited = False |
| | | |
| | | def __await__(self) -> Any: |
| | | self._awaited = True |
| | | return self._awaitable.__await__() |
| | | |
| | | def __del__(self) -> None: |
| | | if not self._awaited: |
| | | warnings.warn( |
| | | "Connector.close() is a coroutine, " |
| | | "please use await connector.close()", |
| | | DeprecationWarning, |
| | | ) |
| | | |
| | | |
| | | async def _wait_for_close(waiters: List[Awaitable[object]]) -> None: |
| | | """Wait for all waiters to finish closing.""" |
| | | results = await asyncio.gather(*waiters, return_exceptions=True) |
| | | for res in results: |
| | | if isinstance(res, Exception): |
| | | client_logger.debug("Error while closing connector: %r", res) |
| | | |
| | | |
| | | class Connection: |
| | | |
| | | _source_traceback = None |
| | | |
| | | def __init__( |
| | | self, |
| | | connector: "BaseConnector", |
| | | key: "ConnectionKey", |
| | | protocol: ResponseHandler, |
| | | loop: asyncio.AbstractEventLoop, |
| | | ) -> None: |
| | | self._key = key |
| | | self._connector = connector |
| | | self._loop = loop |
| | | self._protocol: Optional[ResponseHandler] = protocol |
| | | self._callbacks: List[Callable[[], None]] = [] |
| | | |
| | | if loop.get_debug(): |
| | | self._source_traceback = traceback.extract_stack(sys._getframe(1)) |
| | | |
| | | def __repr__(self) -> str: |
| | | return f"Connection<{self._key}>" |
| | | |
| | | def __del__(self, _warnings: Any = warnings) -> None: |
| | | if self._protocol is not None: |
| | | kwargs = {"source": self} |
| | | _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) |
| | | if self._loop.is_closed(): |
| | | return |
| | | |
| | | self._connector._release(self._key, self._protocol, should_close=True) |
| | | |
| | | context = {"client_connection": self, "message": "Unclosed connection"} |
| | | if self._source_traceback is not None: |
| | | context["source_traceback"] = self._source_traceback |
| | | self._loop.call_exception_handler(context) |
| | | |
| | | def __bool__(self) -> Literal[True]: |
| | | """Force subclasses to not be falsy, to make checks simpler.""" |
| | | return True |
| | | |
| | | @property |
| | | def loop(self) -> asyncio.AbstractEventLoop: |
| | | warnings.warn( |
| | | "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 |
| | | ) |
| | | return self._loop |
| | | |
| | | @property |
| | | def transport(self) -> Optional[asyncio.Transport]: |
| | | if self._protocol is None: |
| | | return None |
| | | return self._protocol.transport |
| | | |
| | | @property |
| | | def protocol(self) -> Optional[ResponseHandler]: |
| | | return self._protocol |
| | | |
| | | def add_callback(self, callback: Callable[[], None]) -> None: |
| | | if callback is not None: |
| | | self._callbacks.append(callback) |
| | | |
| | | def _notify_release(self) -> None: |
| | | callbacks, self._callbacks = self._callbacks[:], [] |
| | | |
| | | for cb in callbacks: |
| | | with suppress(Exception): |
| | | cb() |
| | | |
| | | def close(self) -> None: |
| | | self._notify_release() |
| | | |
| | | if self._protocol is not None: |
| | | self._connector._release(self._key, self._protocol, should_close=True) |
| | | self._protocol = None |
| | | |
| | | def release(self) -> None: |
| | | self._notify_release() |
| | | |
| | | if self._protocol is not None: |
| | | self._connector._release(self._key, self._protocol) |
| | | self._protocol = None |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | return self._protocol is None or not self._protocol.is_connected() |
| | | |
| | | |
| | | class _ConnectTunnelConnection(Connection): |
| | | """Special connection wrapper for CONNECT tunnels that must never be pooled. |
| | | |
| | | This connection wraps the proxy connection that will be upgraded with TLS. |
| | | It must never be released to the pool because: |
| | | 1. Its 'closed' future will never complete, causing session.close() to hang |
| | | 2. It represents an intermediate state, not a reusable connection |
| | | 3. The real connection (with TLS) will be created separately |
| | | """ |
| | | |
| | | def release(self) -> None: |
| | | """Do nothing - don't pool or close the connection. |
| | | |
| | | These connections are an intermediate state during the CONNECT tunnel |
| | | setup and will be cleaned up naturally after the TLS upgrade. If they |
| | | were to be pooled, they would never be properly closed, causing |
| | | session.close() to wait forever for their 'closed' future. |
| | | """ |
| | | |
| | | |
| | | class _TransportPlaceholder: |
| | | """placeholder for BaseConnector.connect function""" |
| | | |
| | | __slots__ = ("closed", "transport") |
| | | |
| | | def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None: |
| | | """Initialize a placeholder for a transport.""" |
| | | self.closed = closed_future |
| | | self.transport = None |
| | | |
| | | def close(self) -> None: |
| | | """Close the placeholder.""" |
| | | |
| | | def abort(self) -> None: |
| | | """Abort the placeholder (does nothing).""" |
| | | |
| | | |
| | | class BaseConnector: |
| | | """Base connector class. |
| | | |
| | | keepalive_timeout - (optional) Keep-alive timeout. |
| | | force_close - Set to True to force close and do reconnect |
| | | after each request (and between redirects). |
| | | limit - The total number of simultaneous connections. |
| | | limit_per_host - Number of simultaneous connections to one host. |
| | | enable_cleanup_closed - Enables clean-up closed ssl transports. |
| | | Disabled by default. |
| | | timeout_ceil_threshold - Trigger ceiling of timeout values when |
| | | it's above timeout_ceil_threshold. |
| | | loop - Optional event loop. |
| | | """ |
| | | |
| | | _closed = True # prevent AttributeError in __del__ if ctor was failed |
| | | _source_traceback = None |
| | | |
| | | # abort transport after 2 seconds (cleanup broken connections) |
| | | _cleanup_closed_period = 2.0 |
| | | |
| | | allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET |
| | | |
| | | def __init__( |
| | | self, |
| | | *, |
| | | keepalive_timeout: Union[object, None, float] = sentinel, |
| | | force_close: bool = False, |
| | | limit: int = 100, |
| | | limit_per_host: int = 0, |
| | | enable_cleanup_closed: bool = False, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | timeout_ceil_threshold: float = 5, |
| | | ) -> None: |
| | | |
| | | if force_close: |
| | | if keepalive_timeout is not None and keepalive_timeout is not sentinel: |
| | | raise ValueError( |
| | | "keepalive_timeout cannot be set if force_close is True" |
| | | ) |
| | | else: |
| | | if keepalive_timeout is sentinel: |
| | | keepalive_timeout = 15.0 |
| | | |
| | | loop = loop or asyncio.get_running_loop() |
| | | self._timeout_ceil_threshold = timeout_ceil_threshold |
| | | |
| | | self._closed = False |
| | | if loop.get_debug(): |
| | | self._source_traceback = traceback.extract_stack(sys._getframe(1)) |
| | | |
| | | # Connection pool of reusable connections. |
| | | # We use a deque to store connections because it has O(1) popleft() |
| | | # and O(1) append() operations to implement a FIFO queue. |
| | | self._conns: DefaultDict[ |
| | | ConnectionKey, Deque[Tuple[ResponseHandler, float]] |
| | | ] = defaultdict(deque) |
| | | self._limit = limit |
| | | self._limit_per_host = limit_per_host |
| | | self._acquired: Set[ResponseHandler] = set() |
| | | self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( |
| | | defaultdict(set) |
| | | ) |
| | | self._keepalive_timeout = cast(float, keepalive_timeout) |
| | | self._force_close = force_close |
| | | |
| | | # {host_key: FIFO list of waiters} |
| | | # The FIFO is implemented with an OrderedDict with None keys because |
| | | # python does not have an ordered set. |
| | | self._waiters: DefaultDict[ |
| | | ConnectionKey, OrderedDict[asyncio.Future[None], None] |
| | | ] = defaultdict(OrderedDict) |
| | | |
| | | self._loop = loop |
| | | self._factory = functools.partial(ResponseHandler, loop=loop) |
| | | |
| | | # start keep-alive connection cleanup task |
| | | self._cleanup_handle: Optional[asyncio.TimerHandle] = None |
| | | |
| | | # start cleanup closed transports task |
| | | self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None |
| | | |
| | | if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: |
| | | warnings.warn( |
| | | "enable_cleanup_closed ignored because " |
| | | "https://github.com/python/cpython/pull/118960 is fixed " |
| | | f"in Python version {sys.version_info}", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | enable_cleanup_closed = False |
| | | |
| | | self._cleanup_closed_disabled = not enable_cleanup_closed |
| | | self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] |
| | | self._placeholder_future: asyncio.Future[Optional[Exception]] = ( |
| | | loop.create_future() |
| | | ) |
| | | self._placeholder_future.set_result(None) |
| | | self._cleanup_closed() |
| | | |
| | | def __del__(self, _warnings: Any = warnings) -> None: |
| | | if self._closed: |
| | | return |
| | | if not self._conns: |
| | | return |
| | | |
| | | conns = [repr(c) for c in self._conns.values()] |
| | | |
| | | self._close() |
| | | |
| | | kwargs = {"source": self} |
| | | _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) |
| | | context = { |
| | | "connector": self, |
| | | "connections": conns, |
| | | "message": "Unclosed connector", |
| | | } |
| | | if self._source_traceback is not None: |
| | | context["source_traceback"] = self._source_traceback |
| | | self._loop.call_exception_handler(context) |
| | | |
| | | def __enter__(self) -> "BaseConnector": |
| | | warnings.warn( |
| | | '"with Connector():" is deprecated, ' |
| | | 'use "async with Connector():" instead', |
| | | DeprecationWarning, |
| | | ) |
| | | return self |
| | | |
| | | def __exit__(self, *exc: Any) -> None: |
| | | self._close() |
| | | |
| | | async def __aenter__(self) -> "BaseConnector": |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]] = None, |
| | | exc_value: Optional[BaseException] = None, |
| | | exc_traceback: Optional[TracebackType] = None, |
| | | ) -> None: |
| | | await self.close() |
| | | |
| | | @property |
| | | def force_close(self) -> bool: |
| | | """Ultimately close connection on releasing if True.""" |
| | | return self._force_close |
| | | |
| | | @property |
| | | def limit(self) -> int: |
| | | """The total number for simultaneous connections. |
| | | |
| | | If limit is 0 the connector has no limit. |
| | | The default limit size is 100. |
| | | """ |
| | | return self._limit |
| | | |
| | | @property |
| | | def limit_per_host(self) -> int: |
| | | """The limit for simultaneous connections to the same endpoint. |
| | | |
| | | Endpoints are the same if they are have equal |
| | | (host, port, is_ssl) triple. |
| | | """ |
| | | return self._limit_per_host |
| | | |
| | | def _cleanup(self) -> None: |
| | | """Cleanup unused transports.""" |
| | | if self._cleanup_handle: |
| | | self._cleanup_handle.cancel() |
| | | # _cleanup_handle should be unset, otherwise _release() will not |
| | | # recreate it ever! |
| | | self._cleanup_handle = None |
| | | |
| | | now = monotonic() |
| | | timeout = self._keepalive_timeout |
| | | |
| | | if self._conns: |
| | | connections = defaultdict(deque) |
| | | deadline = now - timeout |
| | | for key, conns in self._conns.items(): |
| | | alive: Deque[Tuple[ResponseHandler, float]] = deque() |
| | | for proto, use_time in conns: |
| | | if proto.is_connected() and use_time - deadline >= 0: |
| | | alive.append((proto, use_time)) |
| | | continue |
| | | transport = proto.transport |
| | | proto.close() |
| | | if not self._cleanup_closed_disabled and key.is_ssl: |
| | | self._cleanup_closed_transports.append(transport) |
| | | |
| | | if alive: |
| | | connections[key] = alive |
| | | |
| | | self._conns = connections |
| | | |
| | | if self._conns: |
| | | self._cleanup_handle = helpers.weakref_handle( |
| | | self, |
| | | "_cleanup", |
| | | timeout, |
| | | self._loop, |
| | | timeout_ceil_threshold=self._timeout_ceil_threshold, |
| | | ) |
| | | |
| | | def _cleanup_closed(self) -> None: |
| | | """Double confirmation for transport close. |
| | | |
| | | Some broken ssl servers may leave socket open without proper close. |
| | | """ |
| | | if self._cleanup_closed_handle: |
| | | self._cleanup_closed_handle.cancel() |
| | | |
| | | for transport in self._cleanup_closed_transports: |
| | | if transport is not None: |
| | | transport.abort() |
| | | |
| | | self._cleanup_closed_transports = [] |
| | | |
| | | if not self._cleanup_closed_disabled: |
| | | self._cleanup_closed_handle = helpers.weakref_handle( |
| | | self, |
| | | "_cleanup_closed", |
| | | self._cleanup_closed_period, |
| | | self._loop, |
| | | timeout_ceil_threshold=self._timeout_ceil_threshold, |
| | | ) |
| | | |
| | | def close(self, *, abort_ssl: bool = False) -> Awaitable[None]: |
| | | """Close all opened transports. |
| | | |
| | | :param abort_ssl: If True, SSL connections will be aborted immediately |
| | | without performing the shutdown handshake. This provides |
| | | faster cleanup at the cost of less graceful disconnection. |
| | | """ |
| | | if not (waiters := self._close(abort_ssl=abort_ssl)): |
| | | # If there are no connections to close, we can return a noop |
| | | # awaitable to avoid scheduling a task on the event loop. |
| | | return _DeprecationWaiter(noop()) |
| | | coro = _wait_for_close(waiters) |
| | | if sys.version_info >= (3, 12): |
| | | # Optimization for Python 3.12, try to close connections |
| | | # immediately to avoid having to schedule the task on the event loop. |
| | | task = asyncio.Task(coro, loop=self._loop, eager_start=True) |
| | | else: |
| | | task = self._loop.create_task(coro) |
| | | return _DeprecationWaiter(task) |
| | | |
| | | def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: |
| | | waiters: List[Awaitable[object]] = [] |
| | | |
| | | if self._closed: |
| | | return waiters |
| | | |
| | | self._closed = True |
| | | |
| | | try: |
| | | if self._loop.is_closed(): |
| | | return waiters |
| | | |
| | | # cancel cleanup task |
| | | if self._cleanup_handle: |
| | | self._cleanup_handle.cancel() |
| | | |
| | | # cancel cleanup close task |
| | | if self._cleanup_closed_handle: |
| | | self._cleanup_closed_handle.cancel() |
| | | |
| | | for data in self._conns.values(): |
| | | for proto, _ in data: |
| | | if ( |
| | | abort_ssl |
| | | and proto.transport |
| | | and proto.transport.get_extra_info("sslcontext") is not None |
| | | ): |
| | | proto.abort() |
| | | else: |
| | | proto.close() |
| | | if closed := proto.closed: |
| | | waiters.append(closed) |
| | | |
| | | for proto in self._acquired: |
| | | if ( |
| | | abort_ssl |
| | | and proto.transport |
| | | and proto.transport.get_extra_info("sslcontext") is not None |
| | | ): |
| | | proto.abort() |
| | | else: |
| | | proto.close() |
| | | if closed := proto.closed: |
| | | waiters.append(closed) |
| | | |
| | | for transport in self._cleanup_closed_transports: |
| | | if transport is not None: |
| | | transport.abort() |
| | | |
| | | return waiters |
| | | |
| | | finally: |
| | | self._conns.clear() |
| | | self._acquired.clear() |
| | | for keyed_waiters in self._waiters.values(): |
| | | for keyed_waiter in keyed_waiters: |
| | | keyed_waiter.cancel() |
| | | self._waiters.clear() |
| | | self._cleanup_handle = None |
| | | self._cleanup_closed_transports.clear() |
| | | self._cleanup_closed_handle = None |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | """Is connector closed. |
| | | |
| | | A readonly property. |
| | | """ |
| | | return self._closed |
| | | |
| | | def _available_connections(self, key: "ConnectionKey") -> int: |
| | | """ |
| | | Return number of available connections. |
| | | |
| | | The limit, limit_per_host and the connection key are taken into account. |
| | | |
| | | If it returns less than 1 means that there are no connections |
| | | available. |
| | | """ |
| | | # check total available connections |
| | | # If there are no limits, this will always return 1 |
| | | total_remain = 1 |
| | | |
| | | if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0: |
| | | return total_remain |
| | | |
| | | # check limit per host |
| | | if host_remain := self._limit_per_host: |
| | | if acquired := self._acquired_per_host.get(key): |
| | | host_remain -= len(acquired) |
| | | if total_remain > host_remain: |
| | | return host_remain |
| | | |
| | | return total_remain |
| | | |
| | | def _update_proxy_auth_header_and_build_proxy_req( |
| | | self, req: ClientRequest |
| | | ) -> ClientRequest: |
| | | """Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests.""" |
| | | url = req.proxy |
| | | assert url is not None |
| | | headers: Dict[str, str] = {} |
| | | if req.proxy_headers is not None: |
| | | headers = req.proxy_headers # type: ignore[assignment] |
| | | headers[hdrs.HOST] = req.headers[hdrs.HOST] |
| | | proxy_req = ClientRequest( |
| | | hdrs.METH_GET, |
| | | url, |
| | | headers=headers, |
| | | auth=req.proxy_auth, |
| | | loop=self._loop, |
| | | ssl=req.ssl, |
| | | ) |
| | | auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) |
| | | if auth is not None: |
| | | if not req.is_ssl(): |
| | | req.headers[hdrs.PROXY_AUTHORIZATION] = auth |
| | | else: |
| | | proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth |
| | | return proxy_req |
| | | |
| | | async def connect( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> Connection: |
| | | """Get from pool or create new connection.""" |
| | | key = req.connection_key |
| | | if (conn := await self._get(key, traces)) is not None: |
| | | # If we do not have to wait and we can get a connection from the pool |
| | | # we can avoid the timeout ceil logic and directly return the connection |
| | | if req.proxy: |
| | | self._update_proxy_auth_header_and_build_proxy_req(req) |
| | | return conn |
| | | |
| | | async with ceil_timeout(timeout.connect, timeout.ceil_threshold): |
| | | if self._available_connections(key) <= 0: |
| | | await self._wait_for_available_connection(key, traces) |
| | | if (conn := await self._get(key, traces)) is not None: |
| | | if req.proxy: |
| | | self._update_proxy_auth_header_and_build_proxy_req(req) |
| | | return conn |
| | | |
| | | placeholder = cast( |
| | | ResponseHandler, _TransportPlaceholder(self._placeholder_future) |
| | | ) |
| | | self._acquired.add(placeholder) |
| | | if self._limit_per_host: |
| | | self._acquired_per_host[key].add(placeholder) |
| | | |
| | | try: |
| | | # Traces are done inside the try block to ensure that the |
| | | # that the placeholder is still cleaned up if an exception |
| | | # is raised. |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_connection_create_start() |
| | | proto = await self._create_connection(req, traces, timeout) |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_connection_create_end() |
| | | except BaseException: |
| | | self._release_acquired(key, placeholder) |
| | | raise |
| | | else: |
| | | if self._closed: |
| | | proto.close() |
| | | raise ClientConnectionError("Connector is closed.") |
| | | |
| | | # The connection was successfully created, drop the placeholder |
| | | # and add the real connection to the acquired set. There should |
| | | # be no awaits after the proto is added to the acquired set |
| | | # to ensure that the connection is not left in the acquired set |
| | | # on cancellation. |
| | | self._acquired.remove(placeholder) |
| | | self._acquired.add(proto) |
| | | if self._limit_per_host: |
| | | acquired_per_host = self._acquired_per_host[key] |
| | | acquired_per_host.remove(placeholder) |
| | | acquired_per_host.add(proto) |
| | | return Connection(self, key, proto, self._loop) |
| | | |
| | | async def _wait_for_available_connection( |
| | | self, key: "ConnectionKey", traces: List["Trace"] |
| | | ) -> None: |
| | | """Wait for an available connection slot.""" |
| | | # We loop here because there is a race between |
| | | # the connection limit check and the connection |
| | | # being acquired. If the connection is acquired |
| | | # between the check and the await statement, we |
| | | # need to loop again to check if the connection |
| | | # slot is still available. |
| | | attempts = 0 |
| | | while True: |
| | | fut: asyncio.Future[None] = self._loop.create_future() |
| | | keyed_waiters = self._waiters[key] |
| | | keyed_waiters[fut] = None |
| | | if attempts: |
| | | # If we have waited before, we need to move the waiter |
| | | # to the front of the queue as otherwise we might get |
| | | # starved and hit the timeout. |
| | | keyed_waiters.move_to_end(fut, last=False) |
| | | |
| | | try: |
| | | # Traces happen in the try block to ensure that the |
| | | # the waiter is still cleaned up if an exception is raised. |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_connection_queued_start() |
| | | await fut |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_connection_queued_end() |
| | | finally: |
| | | # pop the waiter from the queue if its still |
| | | # there and not already removed by _release_waiter |
| | | keyed_waiters.pop(fut, None) |
| | | if not self._waiters.get(key, True): |
| | | del self._waiters[key] |
| | | |
| | | if self._available_connections(key) > 0: |
| | | break |
| | | attempts += 1 |
| | | |
| | | async def _get( |
| | | self, key: "ConnectionKey", traces: List["Trace"] |
| | | ) -> Optional[Connection]: |
| | | """Get next reusable connection for the key or None. |
| | | |
| | | The connection will be marked as acquired. |
| | | """ |
| | | if (conns := self._conns.get(key)) is None: |
| | | return None |
| | | |
| | | t1 = monotonic() |
| | | while conns: |
| | | proto, t0 = conns.popleft() |
| | | # We will we reuse the connection if its connected and |
| | | # the keepalive timeout has not been exceeded |
| | | if proto.is_connected() and t1 - t0 <= self._keepalive_timeout: |
| | | if not conns: |
| | | # The very last connection was reclaimed: drop the key |
| | | del self._conns[key] |
| | | self._acquired.add(proto) |
| | | if self._limit_per_host: |
| | | self._acquired_per_host[key].add(proto) |
| | | if traces: |
| | | for trace in traces: |
| | | try: |
| | | await trace.send_connection_reuseconn() |
| | | except BaseException: |
| | | self._release_acquired(key, proto) |
| | | raise |
| | | return Connection(self, key, proto, self._loop) |
| | | |
| | | # Connection cannot be reused, close it |
| | | transport = proto.transport |
| | | proto.close() |
| | | # only for SSL transports |
| | | if not self._cleanup_closed_disabled and key.is_ssl: |
| | | self._cleanup_closed_transports.append(transport) |
| | | |
| | | # No more connections: drop the key |
| | | del self._conns[key] |
| | | return None |
| | | |
| | | def _release_waiter(self) -> None: |
| | | """ |
| | | Iterates over all waiters until one to be released is found. |
| | | |
| | | The one to be released is not finished and |
| | | belongs to a host that has available connections. |
| | | """ |
| | | if not self._waiters: |
| | | return |
| | | |
| | | # Having the dict keys ordered this avoids to iterate |
| | | # at the same order at each call. |
| | | queues = list(self._waiters) |
| | | random.shuffle(queues) |
| | | |
| | | for key in queues: |
| | | if self._available_connections(key) < 1: |
| | | continue |
| | | |
| | | waiters = self._waiters[key] |
| | | while waiters: |
| | | waiter, _ = waiters.popitem(last=False) |
| | | if not waiter.done(): |
| | | waiter.set_result(None) |
| | | return |
| | | |
| | | def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None: |
| | | """Release acquired connection.""" |
| | | if self._closed: |
| | | # acquired connection is already released on connector closing |
| | | return |
| | | |
| | | self._acquired.discard(proto) |
| | | if self._limit_per_host and (conns := self._acquired_per_host.get(key)): |
| | | conns.discard(proto) |
| | | if not conns: |
| | | del self._acquired_per_host[key] |
| | | self._release_waiter() |
| | | |
| | | def _release( |
| | | self, |
| | | key: "ConnectionKey", |
| | | protocol: ResponseHandler, |
| | | *, |
| | | should_close: bool = False, |
| | | ) -> None: |
| | | if self._closed: |
| | | # acquired connection is already released on connector closing |
| | | return |
| | | |
| | | self._release_acquired(key, protocol) |
| | | |
| | | if self._force_close or should_close or protocol.should_close: |
| | | transport = protocol.transport |
| | | protocol.close() |
| | | |
| | | if key.is_ssl and not self._cleanup_closed_disabled: |
| | | self._cleanup_closed_transports.append(transport) |
| | | return |
| | | |
| | | self._conns[key].append((protocol, monotonic())) |
| | | |
| | | if self._cleanup_handle is None: |
| | | self._cleanup_handle = helpers.weakref_handle( |
| | | self, |
| | | "_cleanup", |
| | | self._keepalive_timeout, |
| | | self._loop, |
| | | timeout_ceil_threshold=self._timeout_ceil_threshold, |
| | | ) |
| | | |
| | | async def _create_connection( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> ResponseHandler: |
| | | raise NotImplementedError() |
| | | |
| | | |
| | | class _DNSCacheTable: |
| | | def __init__(self, ttl: Optional[float] = None) -> None: |
| | | self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} |
| | | self._timestamps: Dict[Tuple[str, int], float] = {} |
| | | self._ttl = ttl |
| | | |
| | | def __contains__(self, host: object) -> bool: |
| | | return host in self._addrs_rr |
| | | |
| | | def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: |
| | | self._addrs_rr[key] = (cycle(addrs), len(addrs)) |
| | | |
| | | if self._ttl is not None: |
| | | self._timestamps[key] = monotonic() |
| | | |
| | | def remove(self, key: Tuple[str, int]) -> None: |
| | | self._addrs_rr.pop(key, None) |
| | | |
| | | if self._ttl is not None: |
| | | self._timestamps.pop(key, None) |
| | | |
| | | def clear(self) -> None: |
| | | self._addrs_rr.clear() |
| | | self._timestamps.clear() |
| | | |
| | | def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: |
| | | loop, length = self._addrs_rr[key] |
| | | addrs = list(islice(loop, length)) |
| | | # Consume one more element to shift internal state of `cycle` |
| | | next(loop) |
| | | return addrs |
| | | |
| | | def expired(self, key: Tuple[str, int]) -> bool: |
| | | if self._ttl is None: |
| | | return False |
| | | |
| | | return self._timestamps[key] + self._ttl < monotonic() |
| | | |
| | | |
| | | def _make_ssl_context(verified: bool) -> SSLContext: |
| | | """Create SSL context. |
| | | |
| | | This method is not async-friendly and should be called from a thread |
| | | because it will load certificates from disk and do other blocking I/O. |
| | | """ |
| | | if ssl is None: |
| | | # No ssl support |
| | | return None |
| | | if verified: |
| | | sslcontext = ssl.create_default_context() |
| | | else: |
| | | sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) |
| | | sslcontext.options |= ssl.OP_NO_SSLv2 |
| | | sslcontext.options |= ssl.OP_NO_SSLv3 |
| | | sslcontext.check_hostname = False |
| | | sslcontext.verify_mode = ssl.CERT_NONE |
| | | sslcontext.options |= ssl.OP_NO_COMPRESSION |
| | | sslcontext.set_default_verify_paths() |
| | | sslcontext.set_alpn_protocols(("http/1.1",)) |
| | | return sslcontext |
| | | |
| | | |
| | | # The default SSLContext objects are created at import time |
| | | # since they do blocking I/O to load certificates from disk, |
| | | # and imports should always be done before the event loop starts |
| | | # or in a thread. |
| | | _SSL_CONTEXT_VERIFIED = _make_ssl_context(True) |
| | | _SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) |
| | | |
| | | |
| | | class TCPConnector(BaseConnector): |
| | | """TCP connector. |
| | | |
| | | verify_ssl - Set to True to check ssl certifications. |
| | | fingerprint - Pass the binary sha256 |
| | | digest of the expected certificate in DER format to verify |
| | | that the certificate the server presents matches. See also |
| | | https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning |
| | | resolver - Enable DNS lookups and use this |
| | | resolver |
| | | use_dns_cache - Use memory cache for DNS lookups. |
| | | ttl_dns_cache - Max seconds having cached a DNS entry, None forever. |
| | | family - socket address family |
| | | local_addr - local tuple of (host, port) to bind socket to |
| | | |
| | | keepalive_timeout - (optional) Keep-alive timeout. |
| | | force_close - Set to True to force close and do reconnect |
| | | after each request (and between redirects). |
| | | limit - The total number of simultaneous connections. |
| | | limit_per_host - Number of simultaneous connections to one host. |
| | | enable_cleanup_closed - Enables clean-up closed ssl transports. |
| | | Disabled by default. |
| | | happy_eyeballs_delay - This is the “Connection Attempt Delay” |
| | | as defined in RFC 8305. To disable |
| | | the happy eyeballs algorithm, set to None. |
| | | interleave - “First Address Family Count” as defined in RFC 8305 |
| | | loop - Optional event loop. |
| | | socket_factory - A SocketFactoryType function that, if supplied, |
| | | will be used to create sockets given an |
| | | AddrInfoType. |
| | | ssl_shutdown_timeout - DEPRECATED. Will be removed in aiohttp 4.0. |
| | | Grace period for SSL shutdown handshake on TLS |
| | | connections. Default is 0 seconds (immediate abort). |
| | | This parameter allowed for a clean SSL shutdown by |
| | | notifying the remote peer of connection closure, |
| | | while avoiding excessive delays during connector cleanup. |
| | | Note: Only takes effect on Python 3.11+. |
| | | """ |
| | | |
| | | allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) |
| | | |
| | | def __init__( |
| | | self, |
| | | *, |
| | | verify_ssl: bool = True, |
| | | fingerprint: Optional[bytes] = None, |
| | | use_dns_cache: bool = True, |
| | | ttl_dns_cache: Optional[int] = 10, |
| | | family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, |
| | | ssl_context: Optional[SSLContext] = None, |
| | | ssl: Union[bool, Fingerprint, SSLContext] = True, |
| | | local_addr: Optional[Tuple[str, int]] = None, |
| | | resolver: Optional[AbstractResolver] = None, |
| | | keepalive_timeout: Union[None, float, object] = sentinel, |
| | | force_close: bool = False, |
| | | limit: int = 100, |
| | | limit_per_host: int = 0, |
| | | enable_cleanup_closed: bool = False, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | timeout_ceil_threshold: float = 5, |
| | | happy_eyeballs_delay: Optional[float] = 0.25, |
| | | interleave: Optional[int] = None, |
| | | socket_factory: Optional[SocketFactoryType] = None, |
| | | ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, |
| | | ): |
| | | super().__init__( |
| | | keepalive_timeout=keepalive_timeout, |
| | | force_close=force_close, |
| | | limit=limit, |
| | | limit_per_host=limit_per_host, |
| | | enable_cleanup_closed=enable_cleanup_closed, |
| | | loop=loop, |
| | | timeout_ceil_threshold=timeout_ceil_threshold, |
| | | ) |
| | | |
| | | self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) |
| | | |
| | | self._resolver: AbstractResolver |
| | | if resolver is None: |
| | | self._resolver = DefaultResolver(loop=self._loop) |
| | | self._resolver_owner = True |
| | | else: |
| | | self._resolver = resolver |
| | | self._resolver_owner = False |
| | | |
| | | self._use_dns_cache = use_dns_cache |
| | | self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) |
| | | self._throttle_dns_futures: Dict[ |
| | | Tuple[str, int], Set["asyncio.Future[None]"] |
| | | ] = {} |
| | | self._family = family |
| | | self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) |
| | | self._happy_eyeballs_delay = happy_eyeballs_delay |
| | | self._interleave = interleave |
| | | self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() |
| | | self._socket_factory = socket_factory |
| | | self._ssl_shutdown_timeout: Optional[float] |
| | | # Handle ssl_shutdown_timeout with warning for Python < 3.11 |
| | | if ssl_shutdown_timeout is sentinel: |
| | | self._ssl_shutdown_timeout = 0 |
| | | else: |
| | | # Deprecation warning for ssl_shutdown_timeout parameter |
| | | warnings.warn( |
| | | "The ssl_shutdown_timeout parameter is deprecated and will be removed in aiohttp 4.0", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | if ( |
| | | sys.version_info < (3, 11) |
| | | and ssl_shutdown_timeout is not None |
| | | and ssl_shutdown_timeout != 0 |
| | | ): |
| | | warnings.warn( |
| | | f"ssl_shutdown_timeout={ssl_shutdown_timeout} is ignored on Python < 3.11; " |
| | | "only ssl_shutdown_timeout=0 is supported. The timeout will be ignored.", |
| | | RuntimeWarning, |
| | | stacklevel=2, |
| | | ) |
| | | self._ssl_shutdown_timeout = ssl_shutdown_timeout |
| | | |
| | | def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: |
| | | """Close all ongoing DNS calls.""" |
| | | for fut in chain.from_iterable(self._throttle_dns_futures.values()): |
| | | fut.cancel() |
| | | |
| | | waiters = super()._close(abort_ssl=abort_ssl) |
| | | |
| | | for t in self._resolve_host_tasks: |
| | | t.cancel() |
| | | waiters.append(t) |
| | | |
| | | return waiters |
| | | |
| | | async def close(self, *, abort_ssl: bool = False) -> None: |
| | | """ |
| | | Close all opened transports. |
| | | |
| | | :param abort_ssl: If True, SSL connections will be aborted immediately |
| | | without performing the shutdown handshake. If False (default), |
| | | the behavior is determined by ssl_shutdown_timeout: |
| | | - If ssl_shutdown_timeout=0: connections are aborted |
| | | - If ssl_shutdown_timeout>0: graceful shutdown is performed |
| | | """ |
| | | if self._resolver_owner: |
| | | await self._resolver.close() |
| | | # Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default |
| | | await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0) |
| | | |
| | | @property |
| | | def family(self) -> int: |
| | | """Socket family like AF_INET.""" |
| | | return self._family |
| | | |
| | | @property |
| | | def use_dns_cache(self) -> bool: |
| | | """True if local DNS caching is enabled.""" |
| | | return self._use_dns_cache |
| | | |
| | | def clear_dns_cache( |
| | | self, host: Optional[str] = None, port: Optional[int] = None |
| | | ) -> None: |
| | | """Remove specified host/port or clear all dns local cache.""" |
| | | if host is not None and port is not None: |
| | | self._cached_hosts.remove((host, port)) |
| | | elif host is not None or port is not None: |
| | | raise ValueError("either both host and port or none of them are allowed") |
| | | else: |
| | | self._cached_hosts.clear() |
| | | |
| | | async def _resolve_host( |
| | | self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None |
| | | ) -> List[ResolveResult]: |
| | | """Resolve host and return list of addresses.""" |
| | | if is_ip_address(host): |
| | | return [ |
| | | { |
| | | "hostname": host, |
| | | "host": host, |
| | | "port": port, |
| | | "family": self._family, |
| | | "proto": 0, |
| | | "flags": 0, |
| | | } |
| | | ] |
| | | |
| | | if not self._use_dns_cache: |
| | | |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_resolvehost_start(host) |
| | | |
| | | res = await self._resolver.resolve(host, port, family=self._family) |
| | | |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_resolvehost_end(host) |
| | | |
| | | return res |
| | | |
| | | key = (host, port) |
| | | if key in self._cached_hosts and not self._cached_hosts.expired(key): |
| | | # get result early, before any await (#4014) |
| | | result = self._cached_hosts.next_addrs(key) |
| | | |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_cache_hit(host) |
| | | return result |
| | | |
| | | futures: Set["asyncio.Future[None]"] |
| | | # |
| | | # If multiple connectors are resolving the same host, we wait |
| | | # for the first one to resolve and then use the result for all of them. |
| | | # We use a throttle to ensure that we only resolve the host once |
| | | # and then use the result for all the waiters. |
| | | # |
| | | if key in self._throttle_dns_futures: |
| | | # get futures early, before any await (#4014) |
| | | futures = self._throttle_dns_futures[key] |
| | | future: asyncio.Future[None] = self._loop.create_future() |
| | | futures.add(future) |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_cache_hit(host) |
| | | try: |
| | | await future |
| | | finally: |
| | | futures.discard(future) |
| | | return self._cached_hosts.next_addrs(key) |
| | | |
| | | # update dict early, before any await (#4014) |
| | | self._throttle_dns_futures[key] = futures = set() |
| | | # In this case we need to create a task to ensure that we can shield |
| | | # the task from cancellation as cancelling this lookup should not cancel |
| | | # the underlying lookup or else the cancel event will get broadcast to |
| | | # all the waiters across all connections. |
| | | # |
| | | coro = self._resolve_host_with_throttle(key, host, port, futures, traces) |
| | | loop = asyncio.get_running_loop() |
| | | if sys.version_info >= (3, 12): |
| | | # Optimization for Python 3.12, try to send immediately |
| | | resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True) |
| | | else: |
| | | resolved_host_task = loop.create_task(coro) |
| | | |
| | | if not resolved_host_task.done(): |
| | | self._resolve_host_tasks.add(resolved_host_task) |
| | | resolved_host_task.add_done_callback(self._resolve_host_tasks.discard) |
| | | |
| | | try: |
| | | return await asyncio.shield(resolved_host_task) |
| | | except asyncio.CancelledError: |
| | | |
| | | def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: |
| | | with suppress(Exception, asyncio.CancelledError): |
| | | fut.result() |
| | | |
| | | resolved_host_task.add_done_callback(drop_exception) |
| | | raise |
| | | |
| | | async def _resolve_host_with_throttle( |
| | | self, |
| | | key: Tuple[str, int], |
| | | host: str, |
| | | port: int, |
| | | futures: Set["asyncio.Future[None]"], |
| | | traces: Optional[Sequence["Trace"]], |
| | | ) -> List[ResolveResult]: |
| | | """Resolve host and set result for all waiters. |
| | | |
| | | This method must be run in a task and shielded from cancellation |
| | | to avoid cancelling the underlying lookup. |
| | | """ |
| | | try: |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_cache_miss(host) |
| | | |
| | | for trace in traces: |
| | | await trace.send_dns_resolvehost_start(host) |
| | | |
| | | addrs = await self._resolver.resolve(host, port, family=self._family) |
| | | if traces: |
| | | for trace in traces: |
| | | await trace.send_dns_resolvehost_end(host) |
| | | |
| | | self._cached_hosts.add(key, addrs) |
| | | for fut in futures: |
| | | set_result(fut, None) |
| | | except BaseException as e: |
| | | # any DNS exception is set for the waiters to raise the same exception. |
| | | # This coro is always run in task that is shielded from cancellation so |
| | | # we should never be propagating cancellation here. |
| | | for fut in futures: |
| | | set_exception(fut, e) |
| | | raise |
| | | finally: |
| | | self._throttle_dns_futures.pop(key) |
| | | |
| | | return self._cached_hosts.next_addrs(key) |
| | | |
| | | async def _create_connection( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> ResponseHandler: |
| | | """Create connection. |
| | | |
| | | Has same keyword arguments as BaseEventLoop.create_connection. |
| | | """ |
| | | if req.proxy: |
| | | _, proto = await self._create_proxy_connection(req, traces, timeout) |
| | | else: |
| | | _, proto = await self._create_direct_connection(req, traces, timeout) |
| | | |
| | | return proto |
| | | |
| | | def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: |
| | | """Logic to get the correct SSL context |
| | | |
| | | 0. if req.ssl is false, return None |
| | | |
| | | 1. if ssl_context is specified in req, use it |
| | | 2. if _ssl_context is specified in self, use it |
| | | 3. otherwise: |
| | | 1. if verify_ssl is not specified in req, use self.ssl_context |
| | | (will generate a default context according to self.verify_ssl) |
| | | 2. if verify_ssl is True in req, generate a default SSL context |
| | | 3. if verify_ssl is False in req, generate a SSL context that |
| | | won't verify |
| | | """ |
| | | if not req.is_ssl(): |
| | | return None |
| | | |
| | | if ssl is None: # pragma: no cover |
| | | raise RuntimeError("SSL is not supported.") |
| | | sslcontext = req.ssl |
| | | if isinstance(sslcontext, ssl.SSLContext): |
| | | return sslcontext |
| | | if sslcontext is not True: |
| | | # not verified or fingerprinted |
| | | return _SSL_CONTEXT_UNVERIFIED |
| | | sslcontext = self._ssl |
| | | if isinstance(sslcontext, ssl.SSLContext): |
| | | return sslcontext |
| | | if sslcontext is not True: |
| | | # not verified or fingerprinted |
| | | return _SSL_CONTEXT_UNVERIFIED |
| | | return _SSL_CONTEXT_VERIFIED |
| | | |
| | | def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: |
| | | ret = req.ssl |
| | | if isinstance(ret, Fingerprint): |
| | | return ret |
| | | ret = self._ssl |
| | | if isinstance(ret, Fingerprint): |
| | | return ret |
| | | return None |
| | | |
| | | async def _wrap_create_connection( |
| | | self, |
| | | *args: Any, |
| | | addr_infos: List[AddrInfoType], |
| | | req: ClientRequest, |
| | | timeout: "ClientTimeout", |
| | | client_error: Type[Exception] = ClientConnectorError, |
| | | **kwargs: Any, |
| | | ) -> Tuple[asyncio.Transport, ResponseHandler]: |
| | | try: |
| | | async with ceil_timeout( |
| | | timeout.sock_connect, ceil_threshold=timeout.ceil_threshold |
| | | ): |
| | | sock = await aiohappyeyeballs.start_connection( |
| | | addr_infos=addr_infos, |
| | | local_addr_infos=self._local_addr_infos, |
| | | happy_eyeballs_delay=self._happy_eyeballs_delay, |
| | | interleave=self._interleave, |
| | | loop=self._loop, |
| | | socket_factory=self._socket_factory, |
| | | ) |
| | | # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used |
| | | if ( |
| | | kwargs.get("ssl") |
| | | and self._ssl_shutdown_timeout |
| | | and sys.version_info >= (3, 11) |
| | | ): |
| | | kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout |
| | | return await self._loop.create_connection(*args, **kwargs, sock=sock) |
| | | except cert_errors as exc: |
| | | raise ClientConnectorCertificateError(req.connection_key, exc) from exc |
| | | except ssl_errors as exc: |
| | | raise ClientConnectorSSLError(req.connection_key, exc) from exc |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise client_error(req.connection_key, exc) from exc |
| | | |
| | | async def _wrap_existing_connection( |
| | | self, |
| | | *args: Any, |
| | | req: ClientRequest, |
| | | timeout: "ClientTimeout", |
| | | client_error: Type[Exception] = ClientConnectorError, |
| | | **kwargs: Any, |
| | | ) -> Tuple[asyncio.Transport, ResponseHandler]: |
| | | try: |
| | | async with ceil_timeout( |
| | | timeout.sock_connect, ceil_threshold=timeout.ceil_threshold |
| | | ): |
| | | return await self._loop.create_connection(*args, **kwargs) |
| | | except cert_errors as exc: |
| | | raise ClientConnectorCertificateError(req.connection_key, exc) from exc |
| | | except ssl_errors as exc: |
| | | raise ClientConnectorSSLError(req.connection_key, exc) from exc |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise client_error(req.connection_key, exc) from exc |
| | | |
| | | def _fail_on_no_start_tls(self, req: "ClientRequest") -> None: |
| | | """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``. |
| | | |
| | | It is necessary for TLS-in-TLS so that it is possible to |
| | | send HTTPS queries through HTTPS proxies. |
| | | |
| | | This doesn't affect regular HTTP requests, though. |
| | | """ |
| | | if not req.is_ssl(): |
| | | return |
| | | |
| | | proxy_url = req.proxy |
| | | assert proxy_url is not None |
| | | if proxy_url.scheme != "https": |
| | | return |
| | | |
| | | self._check_loop_for_start_tls() |
| | | |
| | | def _check_loop_for_start_tls(self) -> None: |
| | | try: |
| | | self._loop.start_tls |
| | | except AttributeError as attr_exc: |
| | | raise RuntimeError( |
| | | "An HTTPS request is being sent through an HTTPS proxy. " |
| | | "This needs support for TLS in TLS but it is not implemented " |
| | | "in your runtime for the stdlib asyncio.\n\n" |
| | | "Please upgrade to Python 3.11 or higher. For more details, " |
| | | "please see:\n" |
| | | "* https://bugs.python.org/issue37179\n" |
| | | "* https://github.com/python/cpython/pull/28073\n" |
| | | "* https://docs.aiohttp.org/en/stable/" |
| | | "client_advanced.html#proxy-support\n" |
| | | "* https://github.com/aio-libs/aiohttp/discussions/6044\n", |
| | | ) from attr_exc |
| | | |
| | | def _loop_supports_start_tls(self) -> bool: |
| | | try: |
| | | self._check_loop_for_start_tls() |
| | | except RuntimeError: |
| | | return False |
| | | else: |
| | | return True |
| | | |
| | | def _warn_about_tls_in_tls( |
| | | self, |
| | | underlying_transport: asyncio.Transport, |
| | | req: ClientRequest, |
| | | ) -> None: |
| | | """Issue a warning if the requested URL has HTTPS scheme.""" |
| | | if req.request_info.url.scheme != "https": |
| | | return |
| | | |
| | | # Check if uvloop is being used, which supports TLS in TLS, |
| | | # otherwise assume that asyncio's native transport is being used. |
| | | if type(underlying_transport).__module__.startswith("uvloop"): |
| | | return |
| | | |
| | | # Support in asyncio was added in Python 3.11 (bpo-44011) |
| | | asyncio_supports_tls_in_tls = sys.version_info >= (3, 11) or getattr( |
| | | underlying_transport, |
| | | "_start_tls_compatible", |
| | | False, |
| | | ) |
| | | |
| | | if asyncio_supports_tls_in_tls: |
| | | return |
| | | |
| | | warnings.warn( |
| | | "An HTTPS request is being sent through an HTTPS proxy. " |
| | | "This support for TLS in TLS is known to be disabled " |
| | | "in the stdlib asyncio (Python <3.11). This is why you'll probably see " |
| | | "an error in the log below.\n\n" |
| | | "It is possible to enable it via monkeypatching. " |
| | | "For more details, see:\n" |
| | | "* https://bugs.python.org/issue37179\n" |
| | | "* https://github.com/python/cpython/pull/28073\n\n" |
| | | "You can temporarily patch this as follows:\n" |
| | | "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n" |
| | | "* https://github.com/aio-libs/aiohttp/discussions/6044\n", |
| | | RuntimeWarning, |
| | | source=self, |
| | | # Why `4`? At least 3 of the calls in the stack originate |
| | | # from the methods in this class. |
| | | stacklevel=3, |
| | | ) |
| | | |
| | | async def _start_tls_connection( |
| | | self, |
| | | underlying_transport: asyncio.Transport, |
| | | req: ClientRequest, |
| | | timeout: "ClientTimeout", |
| | | client_error: Type[Exception] = ClientConnectorError, |
| | | ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: |
| | | """Wrap the raw TCP transport with TLS.""" |
| | | tls_proto = self._factory() # Create a brand new proto for TLS |
| | | sslcontext = self._get_ssl_context(req) |
| | | if TYPE_CHECKING: |
| | | # _start_tls_connection is unreachable in the current code path |
| | | # if sslcontext is None. |
| | | assert sslcontext is not None |
| | | |
| | | try: |
| | | async with ceil_timeout( |
| | | timeout.sock_connect, ceil_threshold=timeout.ceil_threshold |
| | | ): |
| | | try: |
| | | # ssl_shutdown_timeout is only available in Python 3.11+ |
| | | if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout: |
| | | tls_transport = await self._loop.start_tls( |
| | | underlying_transport, |
| | | tls_proto, |
| | | sslcontext, |
| | | server_hostname=req.server_hostname or req.host, |
| | | ssl_handshake_timeout=timeout.total, |
| | | ssl_shutdown_timeout=self._ssl_shutdown_timeout, |
| | | ) |
| | | else: |
| | | tls_transport = await self._loop.start_tls( |
| | | underlying_transport, |
| | | tls_proto, |
| | | sslcontext, |
| | | server_hostname=req.server_hostname or req.host, |
| | | ssl_handshake_timeout=timeout.total, |
| | | ) |
| | | except BaseException: |
| | | # We need to close the underlying transport since |
| | | # `start_tls()` probably failed before it had a |
| | | # chance to do this: |
| | | if self._ssl_shutdown_timeout == 0: |
| | | underlying_transport.abort() |
| | | else: |
| | | underlying_transport.close() |
| | | raise |
| | | if isinstance(tls_transport, asyncio.Transport): |
| | | fingerprint = self._get_fingerprint(req) |
| | | if fingerprint: |
| | | try: |
| | | fingerprint.check(tls_transport) |
| | | except ServerFingerprintMismatch: |
| | | tls_transport.close() |
| | | if not self._cleanup_closed_disabled: |
| | | self._cleanup_closed_transports.append(tls_transport) |
| | | raise |
| | | except cert_errors as exc: |
| | | raise ClientConnectorCertificateError(req.connection_key, exc) from exc |
| | | except ssl_errors as exc: |
| | | raise ClientConnectorSSLError(req.connection_key, exc) from exc |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise client_error(req.connection_key, exc) from exc |
| | | except TypeError as type_err: |
| | | # Example cause looks like this: |
| | | # TypeError: transport <asyncio.sslproto._SSLProtocolTransport |
| | | # object at 0x7f760615e460> is not supported by start_tls() |
| | | |
| | | raise ClientConnectionError( |
| | | "Cannot initialize a TLS-in-TLS connection to host " |
| | | f"{req.host!s}:{req.port:d} through an underlying connection " |
| | | f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " |
| | | f"[{type_err!s}]" |
| | | ) from type_err |
| | | else: |
| | | if tls_transport is None: |
| | | msg = "Failed to start TLS (possibly caused by closing transport)" |
| | | raise client_error(req.connection_key, OSError(msg)) |
| | | tls_proto.connection_made( |
| | | tls_transport |
| | | ) # Kick the state machine of the new TLS protocol |
| | | |
| | | return tls_transport, tls_proto |
| | | |
| | | def _convert_hosts_to_addr_infos( |
| | | self, hosts: List[ResolveResult] |
| | | ) -> List[AddrInfoType]: |
| | | """Converts the list of hosts to a list of addr_infos. |
| | | |
| | | The list of hosts is the result of a DNS lookup. The list of |
| | | addr_infos is the result of a call to `socket.getaddrinfo()`. |
| | | """ |
| | | addr_infos: List[AddrInfoType] = [] |
| | | for hinfo in hosts: |
| | | host = hinfo["host"] |
| | | is_ipv6 = ":" in host |
| | | family = socket.AF_INET6 if is_ipv6 else socket.AF_INET |
| | | if self._family and self._family != family: |
| | | continue |
| | | addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) |
| | | addr_infos.append( |
| | | (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) |
| | | ) |
| | | return addr_infos |
| | | |
| | | async def _create_direct_connection( |
| | | self, |
| | | req: ClientRequest, |
| | | traces: List["Trace"], |
| | | timeout: "ClientTimeout", |
| | | *, |
| | | client_error: Type[Exception] = ClientConnectorError, |
| | | ) -> Tuple[asyncio.Transport, ResponseHandler]: |
| | | sslcontext = self._get_ssl_context(req) |
| | | fingerprint = self._get_fingerprint(req) |
| | | |
| | | host = req.url.raw_host |
| | | assert host is not None |
| | | # Replace multiple trailing dots with a single one. |
| | | # A trailing dot is only present for fully-qualified domain names. |
| | | # See https://github.com/aio-libs/aiohttp/pull/7364. |
| | | if host.endswith(".."): |
| | | host = host.rstrip(".") + "." |
| | | port = req.port |
| | | assert port is not None |
| | | try: |
| | | # Cancelling this lookup should not cancel the underlying lookup |
| | | # or else the cancel event will get broadcast to all the waiters |
| | | # across all connections. |
| | | hosts = await self._resolve_host(host, port, traces=traces) |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | # in case of proxy it is not ClientProxyConnectionError |
| | | # it is problem of resolving proxy ip itself |
| | | raise ClientConnectorDNSError(req.connection_key, exc) from exc |
| | | |
| | | last_exc: Optional[Exception] = None |
| | | addr_infos = self._convert_hosts_to_addr_infos(hosts) |
| | | while addr_infos: |
| | | # Strip trailing dots, certificates contain FQDN without dots. |
| | | # See https://github.com/aio-libs/aiohttp/issues/3636 |
| | | server_hostname = ( |
| | | (req.server_hostname or host).rstrip(".") if sslcontext else None |
| | | ) |
| | | |
| | | try: |
| | | transp, proto = await self._wrap_create_connection( |
| | | self._factory, |
| | | timeout=timeout, |
| | | ssl=sslcontext, |
| | | addr_infos=addr_infos, |
| | | server_hostname=server_hostname, |
| | | req=req, |
| | | client_error=client_error, |
| | | ) |
| | | except (ClientConnectorError, asyncio.TimeoutError) as exc: |
| | | last_exc = exc |
| | | aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) |
| | | continue |
| | | |
| | | if req.is_ssl() and fingerprint: |
| | | try: |
| | | fingerprint.check(transp) |
| | | except ServerFingerprintMismatch as exc: |
| | | transp.close() |
| | | if not self._cleanup_closed_disabled: |
| | | self._cleanup_closed_transports.append(transp) |
| | | last_exc = exc |
| | | # Remove the bad peer from the list of addr_infos |
| | | sock: socket.socket = transp.get_extra_info("socket") |
| | | bad_peer = sock.getpeername() |
| | | aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) |
| | | continue |
| | | |
| | | return transp, proto |
| | | else: |
| | | assert last_exc is not None |
| | | raise last_exc |
| | | |
| | | async def _create_proxy_connection( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: |
| | | self._fail_on_no_start_tls(req) |
| | | runtime_has_start_tls = self._loop_supports_start_tls() |
| | | proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req) |
| | | |
| | | # create connection to proxy server |
| | | transport, proto = await self._create_direct_connection( |
| | | proxy_req, [], timeout, client_error=ClientProxyConnectionError |
| | | ) |
| | | |
| | | if req.is_ssl(): |
| | | if runtime_has_start_tls: |
| | | self._warn_about_tls_in_tls(transport, req) |
| | | |
| | | # For HTTPS requests over HTTP proxy |
| | | # we must notify proxy to tunnel connection |
| | | # so we send CONNECT command: |
| | | # CONNECT www.python.org:443 HTTP/1.1 |
| | | # Host: www.python.org |
| | | # |
| | | # next we must do TLS handshake and so on |
| | | # to do this we must wrap raw socket into secure one |
| | | # asyncio handles this perfectly |
| | | proxy_req.method = hdrs.METH_CONNECT |
| | | proxy_req.url = req.url |
| | | key = req.connection_key._replace( |
| | | proxy=None, proxy_auth=None, proxy_headers_hash=None |
| | | ) |
| | | conn = _ConnectTunnelConnection(self, key, proto, self._loop) |
| | | proxy_resp = await proxy_req.send(conn) |
| | | try: |
| | | protocol = conn._protocol |
| | | assert protocol is not None |
| | | |
| | | # read_until_eof=True will ensure the connection isn't closed |
| | | # once the response is received and processed allowing |
| | | # START_TLS to work on the connection below. |
| | | protocol.set_response_params( |
| | | read_until_eof=runtime_has_start_tls, |
| | | timeout_ceil_threshold=self._timeout_ceil_threshold, |
| | | ) |
| | | resp = await proxy_resp.start(conn) |
| | | except BaseException: |
| | | proxy_resp.close() |
| | | conn.close() |
| | | raise |
| | | else: |
| | | conn._protocol = None |
| | | try: |
| | | if resp.status != 200: |
| | | message = resp.reason |
| | | if message is None: |
| | | message = HTTPStatus(resp.status).phrase |
| | | raise ClientHttpProxyError( |
| | | proxy_resp.request_info, |
| | | resp.history, |
| | | status=resp.status, |
| | | message=message, |
| | | headers=resp.headers, |
| | | ) |
| | | if not runtime_has_start_tls: |
| | | rawsock = transport.get_extra_info("socket", default=None) |
| | | if rawsock is None: |
| | | raise RuntimeError( |
| | | "Transport does not expose socket instance" |
| | | ) |
| | | # Duplicate the socket, so now we can close proxy transport |
| | | rawsock = rawsock.dup() |
| | | except BaseException: |
| | | # It shouldn't be closed in `finally` because it's fed to |
| | | # `loop.start_tls()` and the docs say not to touch it after |
| | | # passing there. |
| | | transport.close() |
| | | raise |
| | | finally: |
| | | if not runtime_has_start_tls: |
| | | transport.close() |
| | | |
| | | if not runtime_has_start_tls: |
| | | # HTTP proxy with support for upgrade to HTTPS |
| | | sslcontext = self._get_ssl_context(req) |
| | | return await self._wrap_existing_connection( |
| | | self._factory, |
| | | timeout=timeout, |
| | | ssl=sslcontext, |
| | | sock=rawsock, |
| | | server_hostname=req.host, |
| | | req=req, |
| | | ) |
| | | |
| | | return await self._start_tls_connection( |
| | | # Access the old transport for the last time before it's |
| | | # closed and forgotten forever: |
| | | transport, |
| | | req=req, |
| | | timeout=timeout, |
| | | ) |
| | | finally: |
| | | proxy_resp.close() |
| | | |
| | | return transport, proto |
| | | |
| | | |
| | | class UnixConnector(BaseConnector): |
| | | """Unix socket connector. |
| | | |
| | | path - Unix socket path. |
| | | keepalive_timeout - (optional) Keep-alive timeout. |
| | | force_close - Set to True to force close and do reconnect |
| | | after each request (and between redirects). |
| | | limit - The total number of simultaneous connections. |
| | | limit_per_host - Number of simultaneous connections to one host. |
| | | loop - Optional event loop. |
| | | """ |
| | | |
| | | allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) |
| | | |
| | | def __init__( |
| | | self, |
| | | path: str, |
| | | force_close: bool = False, |
| | | keepalive_timeout: Union[object, float, None] = sentinel, |
| | | limit: int = 100, |
| | | limit_per_host: int = 0, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | ) -> None: |
| | | super().__init__( |
| | | force_close=force_close, |
| | | keepalive_timeout=keepalive_timeout, |
| | | limit=limit, |
| | | limit_per_host=limit_per_host, |
| | | loop=loop, |
| | | ) |
| | | self._path = path |
| | | |
| | | @property |
| | | def path(self) -> str: |
| | | """Path to unix socket.""" |
| | | return self._path |
| | | |
| | | async def _create_connection( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> ResponseHandler: |
| | | try: |
| | | async with ceil_timeout( |
| | | timeout.sock_connect, ceil_threshold=timeout.ceil_threshold |
| | | ): |
| | | _, proto = await self._loop.create_unix_connection( |
| | | self._factory, self._path |
| | | ) |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc |
| | | |
| | | return proto |
| | | |
| | | |
| | | class NamedPipeConnector(BaseConnector): |
| | | """Named pipe connector. |
| | | |
| | | Only supported by the proactor event loop. |
| | | See also: https://docs.python.org/3/library/asyncio-eventloop.html |
| | | |
| | | path - Windows named pipe path. |
| | | keepalive_timeout - (optional) Keep-alive timeout. |
| | | force_close - Set to True to force close and do reconnect |
| | | after each request (and between redirects). |
| | | limit - The total number of simultaneous connections. |
| | | limit_per_host - Number of simultaneous connections to one host. |
| | | loop - Optional event loop. |
| | | """ |
| | | |
| | | allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) |
| | | |
| | | def __init__( |
| | | self, |
| | | path: str, |
| | | force_close: bool = False, |
| | | keepalive_timeout: Union[object, float, None] = sentinel, |
| | | limit: int = 100, |
| | | limit_per_host: int = 0, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | ) -> None: |
| | | super().__init__( |
| | | force_close=force_close, |
| | | keepalive_timeout=keepalive_timeout, |
| | | limit=limit, |
| | | limit_per_host=limit_per_host, |
| | | loop=loop, |
| | | ) |
| | | if not isinstance( |
| | | self._loop, |
| | | asyncio.ProactorEventLoop, # type: ignore[attr-defined] |
| | | ): |
| | | raise RuntimeError( |
| | | "Named Pipes only available in proactor loop under windows" |
| | | ) |
| | | self._path = path |
| | | |
| | | @property |
| | | def path(self) -> str: |
| | | """Path to the named pipe.""" |
| | | return self._path |
| | | |
| | | async def _create_connection( |
| | | self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" |
| | | ) -> ResponseHandler: |
| | | try: |
| | | async with ceil_timeout( |
| | | timeout.sock_connect, ceil_threshold=timeout.ceil_threshold |
| | | ): |
| | | _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined] |
| | | self._factory, self._path |
| | | ) |
| | | # the drain is required so that the connection_made is called |
| | | # and transport is set otherwise it is not set before the |
| | | # `assert conn.transport is not None` |
| | | # in client.py's _request method |
| | | await asyncio.sleep(0) |
| | | # other option is to manually set transport like |
| | | # `proto.transport = trans` |
| | | except OSError as exc: |
| | | if exc.errno is None and isinstance(exc, asyncio.TimeoutError): |
| | | raise |
| | | raise ClientConnectorError(req.connection_key, exc) from exc |
| | | |
| | | return cast(ResponseHandler, proto) |
| New file |
| | |
| | | import asyncio |
| | | import calendar |
| | | import contextlib |
| | | import datetime |
| | | import heapq |
| | | import itertools |
| | | import os # noqa |
| | | import pathlib |
| | | import pickle |
| | | import re |
| | | import time |
| | | import warnings |
| | | from collections import defaultdict |
| | | from collections.abc import Mapping |
| | | from http.cookies import BaseCookie, Morsel, SimpleCookie |
| | | from typing import ( |
| | | DefaultDict, |
| | | Dict, |
| | | Iterable, |
| | | Iterator, |
| | | List, |
| | | Optional, |
| | | Set, |
| | | Tuple, |
| | | Union, |
| | | ) |
| | | |
| | | from yarl import URL |
| | | |
| | | from ._cookie_helpers import preserve_morsel_with_coded_value |
| | | from .abc import AbstractCookieJar, ClearCookiePredicate |
| | | from .helpers import is_ip_address |
| | | from .typedefs import LooseCookies, PathLike, StrOrURL |
| | | |
| | | __all__ = ("CookieJar", "DummyCookieJar") |
| | | |
| | | |
| | | CookieItem = Union[str, "Morsel[str]"] |
| | | |
| | | # We cache these string methods here as their use is in performance critical code. |
| | | _FORMAT_PATH = "{}/{}".format |
| | | _FORMAT_DOMAIN_REVERSED = "{1}.{0}".format |
| | | |
| | | # The minimum number of scheduled cookie expirations before we start cleaning up |
| | | # the expiration heap. This is a performance optimization to avoid cleaning up the |
| | | # heap too often when there are only a few scheduled expirations. |
| | | _MIN_SCHEDULED_COOKIE_EXPIRATION = 100 |
| | | _SIMPLE_COOKIE = SimpleCookie() |
| | | |
| | | |
| | | class CookieJar(AbstractCookieJar): |
| | | """Implements cookie storage adhering to RFC 6265.""" |
| | | |
| | | DATE_TOKENS_RE = re.compile( |
| | | r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" |
| | | r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" |
| | | ) |
| | | |
| | | DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") |
| | | |
| | | DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") |
| | | |
| | | DATE_MONTH_RE = re.compile( |
| | | "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)", |
| | | re.I, |
| | | ) |
| | | |
| | | DATE_YEAR_RE = re.compile(r"(\d{2,4})") |
| | | |
| | | # calendar.timegm() fails for timestamps after datetime.datetime.max |
| | | # Minus one as a loss of precision occurs when timestamp() is called. |
| | | MAX_TIME = ( |
| | | int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1 |
| | | ) |
| | | try: |
| | | calendar.timegm(time.gmtime(MAX_TIME)) |
| | | except (OSError, ValueError): |
| | | # Hit the maximum representable time on Windows |
| | | # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 |
| | | # Throws ValueError on PyPy 3.9, OSError elsewhere |
| | | MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) |
| | | except OverflowError: |
| | | # #4515: datetime.max may not be representable on 32-bit platforms |
| | | MAX_TIME = 2**31 - 1 |
| | | # Avoid minuses in the future, 3x faster |
| | | SUB_MAX_TIME = MAX_TIME - 1 |
| | | |
| | | def __init__( |
| | | self, |
| | | *, |
| | | unsafe: bool = False, |
| | | quote_cookie: bool = True, |
| | | treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | ) -> None: |
| | | super().__init__(loop=loop) |
| | | self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( |
| | | SimpleCookie |
| | | ) |
| | | self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( |
| | | defaultdict(dict) |
| | | ) |
| | | self._host_only_cookies: Set[Tuple[str, str]] = set() |
| | | self._unsafe = unsafe |
| | | self._quote_cookie = quote_cookie |
| | | if treat_as_secure_origin is None: |
| | | treat_as_secure_origin = [] |
| | | elif isinstance(treat_as_secure_origin, URL): |
| | | treat_as_secure_origin = [treat_as_secure_origin.origin()] |
| | | elif isinstance(treat_as_secure_origin, str): |
| | | treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] |
| | | else: |
| | | treat_as_secure_origin = [ |
| | | URL(url).origin() if isinstance(url, str) else url.origin() |
| | | for url in treat_as_secure_origin |
| | | ] |
| | | self._treat_as_secure_origin = treat_as_secure_origin |
| | | self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = [] |
| | | self._expirations: Dict[Tuple[str, str, str], float] = {} |
| | | |
| | | @property |
| | | def quote_cookie(self) -> bool: |
| | | return self._quote_cookie |
| | | |
| | | def save(self, file_path: PathLike) -> None: |
| | | file_path = pathlib.Path(file_path) |
| | | with file_path.open(mode="wb") as f: |
| | | pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) |
| | | |
| | | def load(self, file_path: PathLike) -> None: |
| | | file_path = pathlib.Path(file_path) |
| | | with file_path.open(mode="rb") as f: |
| | | self._cookies = pickle.load(f) |
| | | |
| | | def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: |
| | | if predicate is None: |
| | | self._expire_heap.clear() |
| | | self._cookies.clear() |
| | | self._morsel_cache.clear() |
| | | self._host_only_cookies.clear() |
| | | self._expirations.clear() |
| | | return |
| | | |
| | | now = time.time() |
| | | to_del = [ |
| | | key |
| | | for (domain, path), cookie in self._cookies.items() |
| | | for name, morsel in cookie.items() |
| | | if ( |
| | | (key := (domain, path, name)) in self._expirations |
| | | and self._expirations[key] <= now |
| | | ) |
| | | or predicate(morsel) |
| | | ] |
| | | if to_del: |
| | | self._delete_cookies(to_del) |
| | | |
| | | def clear_domain(self, domain: str) -> None: |
| | | self.clear(lambda x: self._is_domain_match(domain, x["domain"])) |
| | | |
| | | def __iter__(self) -> "Iterator[Morsel[str]]": |
| | | self._do_expiration() |
| | | for val in self._cookies.values(): |
| | | yield from val.values() |
| | | |
| | | def __len__(self) -> int: |
| | | """Return number of cookies. |
| | | |
| | | This function does not iterate self to avoid unnecessary expiration |
| | | checks. |
| | | """ |
| | | return sum(len(cookie.values()) for cookie in self._cookies.values()) |
| | | |
| | | def _do_expiration(self) -> None: |
| | | """Remove expired cookies.""" |
| | | if not (expire_heap_len := len(self._expire_heap)): |
| | | return |
| | | |
| | | # If the expiration heap grows larger than the number expirations |
| | | # times two, we clean it up to avoid keeping expired entries in |
| | | # the heap and consuming memory. We guard this with a minimum |
| | | # threshold to avoid cleaning up the heap too often when there are |
| | | # only a few scheduled expirations. |
| | | if ( |
| | | expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION |
| | | and expire_heap_len > len(self._expirations) * 2 |
| | | ): |
| | | # Remove any expired entries from the expiration heap |
| | | # that do not match the expiration time in the expirations |
| | | # as it means the cookie has been re-added to the heap |
| | | # with a different expiration time. |
| | | self._expire_heap = [ |
| | | entry |
| | | for entry in self._expire_heap |
| | | if self._expirations.get(entry[1]) == entry[0] |
| | | ] |
| | | heapq.heapify(self._expire_heap) |
| | | |
| | | now = time.time() |
| | | to_del: List[Tuple[str, str, str]] = [] |
| | | # Find any expired cookies and add them to the to-delete list |
| | | while self._expire_heap: |
| | | when, cookie_key = self._expire_heap[0] |
| | | if when > now: |
| | | break |
| | | heapq.heappop(self._expire_heap) |
| | | # Check if the cookie hasn't been re-added to the heap |
| | | # with a different expiration time as it will be removed |
| | | # later when it reaches the top of the heap and its |
| | | # expiration time is met. |
| | | if self._expirations.get(cookie_key) == when: |
| | | to_del.append(cookie_key) |
| | | |
| | | if to_del: |
| | | self._delete_cookies(to_del) |
| | | |
| | | def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: |
| | | for domain, path, name in to_del: |
| | | self._host_only_cookies.discard((domain, name)) |
| | | self._cookies[(domain, path)].pop(name, None) |
| | | self._morsel_cache[(domain, path)].pop(name, None) |
| | | self._expirations.pop((domain, path, name), None) |
| | | |
| | | def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: |
| | | cookie_key = (domain, path, name) |
| | | if self._expirations.get(cookie_key) == when: |
| | | # Avoid adding duplicates to the heap |
| | | return |
| | | heapq.heappush(self._expire_heap, (when, cookie_key)) |
| | | self._expirations[cookie_key] = when |
| | | |
| | | def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: |
| | | """Update cookies.""" |
| | | hostname = response_url.raw_host |
| | | |
| | | if not self._unsafe and is_ip_address(hostname): |
| | | # Don't accept cookies from IPs |
| | | return |
| | | |
| | | if isinstance(cookies, Mapping): |
| | | cookies = cookies.items() |
| | | |
| | | for name, cookie in cookies: |
| | | if not isinstance(cookie, Morsel): |
| | | tmp = SimpleCookie() |
| | | tmp[name] = cookie # type: ignore[assignment] |
| | | cookie = tmp[name] |
| | | |
| | | domain = cookie["domain"] |
| | | |
| | | # ignore domains with trailing dots |
| | | if domain and domain[-1] == ".": |
| | | domain = "" |
| | | del cookie["domain"] |
| | | |
| | | if not domain and hostname is not None: |
| | | # Set the cookie's domain to the response hostname |
| | | # and set its host-only-flag |
| | | self._host_only_cookies.add((hostname, name)) |
| | | domain = cookie["domain"] = hostname |
| | | |
| | | if domain and domain[0] == ".": |
| | | # Remove leading dot |
| | | domain = domain[1:] |
| | | cookie["domain"] = domain |
| | | |
| | | if hostname and not self._is_domain_match(domain, hostname): |
| | | # Setting cookies for different domains is not allowed |
| | | continue |
| | | |
| | | path = cookie["path"] |
| | | if not path or path[0] != "/": |
| | | # Set the cookie's path to the response path |
| | | path = response_url.path |
| | | if not path.startswith("/"): |
| | | path = "/" |
| | | else: |
| | | # Cut everything from the last slash to the end |
| | | path = "/" + path[1 : path.rfind("/")] |
| | | cookie["path"] = path |
| | | path = path.rstrip("/") |
| | | |
| | | if max_age := cookie["max-age"]: |
| | | try: |
| | | delta_seconds = int(max_age) |
| | | max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) |
| | | self._expire_cookie(max_age_expiration, domain, path, name) |
| | | except ValueError: |
| | | cookie["max-age"] = "" |
| | | |
| | | elif expires := cookie["expires"]: |
| | | if expire_time := self._parse_date(expires): |
| | | self._expire_cookie(expire_time, domain, path, name) |
| | | else: |
| | | cookie["expires"] = "" |
| | | |
| | | key = (domain, path) |
| | | if self._cookies[key].get(name) != cookie: |
| | | # Don't blow away the cache if the same |
| | | # cookie gets set again |
| | | self._cookies[key][name] = cookie |
| | | self._morsel_cache[key].pop(name, None) |
| | | |
| | | self._do_expiration() |
| | | |
| | | def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]": |
| | | """Returns this jar's cookies filtered by their attributes.""" |
| | | # We always use BaseCookie now since all |
| | | # cookies set on on filtered are fully constructed |
| | | # Morsels, not just names and values. |
| | | filtered: BaseCookie[str] = BaseCookie() |
| | | if not self._cookies: |
| | | # Skip do_expiration() if there are no cookies. |
| | | return filtered |
| | | self._do_expiration() |
| | | if not self._cookies: |
| | | # Skip rest of function if no non-expired cookies. |
| | | return filtered |
| | | if type(request_url) is not URL: |
| | | warnings.warn( |
| | | "filter_cookies expects yarl.URL instances only," |
| | | f"and will stop working in 4.x, got {type(request_url)}", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | request_url = URL(request_url) |
| | | hostname = request_url.raw_host or "" |
| | | |
| | | is_not_secure = request_url.scheme not in ("https", "wss") |
| | | if is_not_secure and self._treat_as_secure_origin: |
| | | request_origin = URL() |
| | | with contextlib.suppress(ValueError): |
| | | request_origin = request_url.origin() |
| | | is_not_secure = request_origin not in self._treat_as_secure_origin |
| | | |
| | | # Send shared cookie |
| | | key = ("", "") |
| | | for c in self._cookies[key].values(): |
| | | # Check cache first |
| | | if c.key in self._morsel_cache[key]: |
| | | filtered[c.key] = self._morsel_cache[key][c.key] |
| | | continue |
| | | |
| | | # Build and cache the morsel |
| | | mrsl_val = self._build_morsel(c) |
| | | self._morsel_cache[key][c.key] = mrsl_val |
| | | filtered[c.key] = mrsl_val |
| | | |
| | | if is_ip_address(hostname): |
| | | if not self._unsafe: |
| | | return filtered |
| | | domains: Iterable[str] = (hostname,) |
| | | else: |
| | | # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") |
| | | domains = itertools.accumulate( |
| | | reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED |
| | | ) |
| | | |
| | | # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") |
| | | paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) |
| | | # Create every combination of (domain, path) pairs. |
| | | pairs = itertools.product(domains, paths) |
| | | |
| | | path_len = len(request_url.path) |
| | | # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 |
| | | for p in pairs: |
| | | if p not in self._cookies: |
| | | continue |
| | | for name, cookie in self._cookies[p].items(): |
| | | domain = cookie["domain"] |
| | | |
| | | if (domain, name) in self._host_only_cookies and domain != hostname: |
| | | continue |
| | | |
| | | # Skip edge case when the cookie has a trailing slash but request doesn't. |
| | | if len(cookie["path"]) > path_len: |
| | | continue |
| | | |
| | | if is_not_secure and cookie["secure"]: |
| | | continue |
| | | |
| | | # We already built the Morsel so reuse it here |
| | | if name in self._morsel_cache[p]: |
| | | filtered[name] = self._morsel_cache[p][name] |
| | | continue |
| | | |
| | | # Build and cache the morsel |
| | | mrsl_val = self._build_morsel(cookie) |
| | | self._morsel_cache[p][name] = mrsl_val |
| | | filtered[name] = mrsl_val |
| | | |
| | | return filtered |
| | | |
| | | def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]: |
| | | """Build a morsel for sending, respecting quote_cookie setting.""" |
| | | if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"': |
| | | return preserve_morsel_with_coded_value(cookie) |
| | | morsel: Morsel[str] = Morsel() |
| | | if self._quote_cookie: |
| | | value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value) |
| | | else: |
| | | coded_value = value = cookie.value |
| | | # We use __setstate__ instead of the public set() API because it allows us to |
| | | # bypass validation and set already validated state. This is more stable than |
| | | # setting protected attributes directly and unlikely to change since it would |
| | | # break pickling. |
| | | morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined] |
| | | return morsel |
| | | |
| | | @staticmethod |
| | | def _is_domain_match(domain: str, hostname: str) -> bool: |
| | | """Implements domain matching adhering to RFC 6265.""" |
| | | if hostname == domain: |
| | | return True |
| | | |
| | | if not hostname.endswith(domain): |
| | | return False |
| | | |
| | | non_matching = hostname[: -len(domain)] |
| | | |
| | | if not non_matching.endswith("."): |
| | | return False |
| | | |
| | | return not is_ip_address(hostname) |
| | | |
| | | @classmethod |
| | | def _parse_date(cls, date_str: str) -> Optional[int]: |
| | | """Implements date string parsing adhering to RFC 6265.""" |
| | | if not date_str: |
| | | return None |
| | | |
| | | found_time = False |
| | | found_day = False |
| | | found_month = False |
| | | found_year = False |
| | | |
| | | hour = minute = second = 0 |
| | | day = 0 |
| | | month = 0 |
| | | year = 0 |
| | | |
| | | for token_match in cls.DATE_TOKENS_RE.finditer(date_str): |
| | | |
| | | token = token_match.group("token") |
| | | |
| | | if not found_time: |
| | | time_match = cls.DATE_HMS_TIME_RE.match(token) |
| | | if time_match: |
| | | found_time = True |
| | | hour, minute, second = (int(s) for s in time_match.groups()) |
| | | continue |
| | | |
| | | if not found_day: |
| | | day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) |
| | | if day_match: |
| | | found_day = True |
| | | day = int(day_match.group()) |
| | | continue |
| | | |
| | | if not found_month: |
| | | month_match = cls.DATE_MONTH_RE.match(token) |
| | | if month_match: |
| | | found_month = True |
| | | assert month_match.lastindex is not None |
| | | month = month_match.lastindex |
| | | continue |
| | | |
| | | if not found_year: |
| | | year_match = cls.DATE_YEAR_RE.match(token) |
| | | if year_match: |
| | | found_year = True |
| | | year = int(year_match.group()) |
| | | |
| | | if 70 <= year <= 99: |
| | | year += 1900 |
| | | elif 0 <= year <= 69: |
| | | year += 2000 |
| | | |
| | | if False in (found_day, found_month, found_year, found_time): |
| | | return None |
| | | |
| | | if not 1 <= day <= 31: |
| | | return None |
| | | |
| | | if year < 1601 or hour > 23 or minute > 59 or second > 59: |
| | | return None |
| | | |
| | | return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1)) |
| | | |
| | | |
| | | class DummyCookieJar(AbstractCookieJar): |
| | | """Implements a dummy cookie storage. |
| | | |
| | | It can be used with the ClientSession when no cookie processing is needed. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
| | | super().__init__(loop=loop) |
| | | |
| | | def __iter__(self) -> "Iterator[Morsel[str]]": |
| | | while False: |
| | | yield None |
| | | |
| | | def __len__(self) -> int: |
| | | return 0 |
| | | |
| | | @property |
| | | def quote_cookie(self) -> bool: |
| | | return True |
| | | |
| | | def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: |
| | | pass |
| | | |
| | | def clear_domain(self, domain: str) -> None: |
| | | pass |
| | | |
| | | def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: |
| | | pass |
| | | |
| | | def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": |
| | | return SimpleCookie() |
| New file |
| | |
| | | import io |
| | | import warnings |
| | | from typing import Any, Iterable, List, Optional |
| | | from urllib.parse import urlencode |
| | | |
| | | from multidict import MultiDict, MultiDictProxy |
| | | |
| | | from . import hdrs, multipart, payload |
| | | from .helpers import guess_filename |
| | | from .payload import Payload |
| | | |
| | | __all__ = ("FormData",) |
| | | |
| | | |
| | | class FormData: |
| | | """Helper class for form body generation. |
| | | |
| | | Supports multipart/form-data and application/x-www-form-urlencoded. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | fields: Iterable[Any] = (), |
| | | quote_fields: bool = True, |
| | | charset: Optional[str] = None, |
| | | *, |
| | | default_to_multipart: bool = False, |
| | | ) -> None: |
| | | self._writer = multipart.MultipartWriter("form-data") |
| | | self._fields: List[Any] = [] |
| | | self._is_multipart = default_to_multipart |
| | | self._quote_fields = quote_fields |
| | | self._charset = charset |
| | | |
| | | if isinstance(fields, dict): |
| | | fields = list(fields.items()) |
| | | elif not isinstance(fields, (list, tuple)): |
| | | fields = (fields,) |
| | | self.add_fields(*fields) |
| | | |
| | | @property |
| | | def is_multipart(self) -> bool: |
| | | return self._is_multipart |
| | | |
| | | def add_field( |
| | | self, |
| | | name: str, |
| | | value: Any, |
| | | *, |
| | | content_type: Optional[str] = None, |
| | | filename: Optional[str] = None, |
| | | content_transfer_encoding: Optional[str] = None, |
| | | ) -> None: |
| | | |
| | | if isinstance(value, io.IOBase): |
| | | self._is_multipart = True |
| | | elif isinstance(value, (bytes, bytearray, memoryview)): |
| | | msg = ( |
| | | "In v4, passing bytes will no longer create a file field. " |
| | | "Please explicitly use the filename parameter or pass a BytesIO object." |
| | | ) |
| | | if filename is None and content_transfer_encoding is None: |
| | | warnings.warn(msg, DeprecationWarning) |
| | | filename = name |
| | | |
| | | type_options: MultiDict[str] = MultiDict({"name": name}) |
| | | if filename is not None and not isinstance(filename, str): |
| | | raise TypeError("filename must be an instance of str. Got: %s" % filename) |
| | | if filename is None and isinstance(value, io.IOBase): |
| | | filename = guess_filename(value, name) |
| | | if filename is not None: |
| | | type_options["filename"] = filename |
| | | self._is_multipart = True |
| | | |
| | | headers = {} |
| | | if content_type is not None: |
| | | if not isinstance(content_type, str): |
| | | raise TypeError( |
| | | "content_type must be an instance of str. Got: %s" % content_type |
| | | ) |
| | | headers[hdrs.CONTENT_TYPE] = content_type |
| | | self._is_multipart = True |
| | | if content_transfer_encoding is not None: |
| | | if not isinstance(content_transfer_encoding, str): |
| | | raise TypeError( |
| | | "content_transfer_encoding must be an instance" |
| | | " of str. Got: %s" % content_transfer_encoding |
| | | ) |
| | | msg = ( |
| | | "content_transfer_encoding is deprecated. " |
| | | "To maintain compatibility with v4 please pass a BytesPayload." |
| | | ) |
| | | warnings.warn(msg, DeprecationWarning) |
| | | self._is_multipart = True |
| | | |
| | | self._fields.append((type_options, headers, value)) |
| | | |
| | | def add_fields(self, *fields: Any) -> None: |
| | | to_add = list(fields) |
| | | |
| | | while to_add: |
| | | rec = to_add.pop(0) |
| | | |
| | | if isinstance(rec, io.IOBase): |
| | | k = guess_filename(rec, "unknown") |
| | | self.add_field(k, rec) # type: ignore[arg-type] |
| | | |
| | | elif isinstance(rec, (MultiDictProxy, MultiDict)): |
| | | to_add.extend(rec.items()) |
| | | |
| | | elif isinstance(rec, (list, tuple)) and len(rec) == 2: |
| | | k, fp = rec |
| | | self.add_field(k, fp) |
| | | |
| | | else: |
| | | raise TypeError( |
| | | "Only io.IOBase, multidict and (name, file) " |
| | | "pairs allowed, use .add_field() for passing " |
| | | "more complex parameters, got {!r}".format(rec) |
| | | ) |
| | | |
| | | def _gen_form_urlencoded(self) -> payload.BytesPayload: |
| | | # form data (x-www-form-urlencoded) |
| | | data = [] |
| | | for type_options, _, value in self._fields: |
| | | data.append((type_options["name"], value)) |
| | | |
| | | charset = self._charset if self._charset is not None else "utf-8" |
| | | |
| | | if charset == "utf-8": |
| | | content_type = "application/x-www-form-urlencoded" |
| | | else: |
| | | content_type = "application/x-www-form-urlencoded; charset=%s" % charset |
| | | |
| | | return payload.BytesPayload( |
| | | urlencode(data, doseq=True, encoding=charset).encode(), |
| | | content_type=content_type, |
| | | ) |
| | | |
| | | def _gen_form_data(self) -> multipart.MultipartWriter: |
| | | """Encode a list of fields using the multipart/form-data MIME format""" |
| | | for dispparams, headers, value in self._fields: |
| | | try: |
| | | if hdrs.CONTENT_TYPE in headers: |
| | | part = payload.get_payload( |
| | | value, |
| | | content_type=headers[hdrs.CONTENT_TYPE], |
| | | headers=headers, |
| | | encoding=self._charset, |
| | | ) |
| | | else: |
| | | part = payload.get_payload( |
| | | value, headers=headers, encoding=self._charset |
| | | ) |
| | | except Exception as exc: |
| | | raise TypeError( |
| | | "Can not serialize value type: %r\n " |
| | | "headers: %r\n value: %r" % (type(value), headers, value) |
| | | ) from exc |
| | | |
| | | if dispparams: |
| | | part.set_content_disposition( |
| | | "form-data", quote_fields=self._quote_fields, **dispparams |
| | | ) |
| | | # FIXME cgi.FieldStorage doesn't likes body parts with |
| | | # Content-Length which were sent via chunked transfer encoding |
| | | assert part.headers is not None |
| | | part.headers.popall(hdrs.CONTENT_LENGTH, None) |
| | | |
| | | self._writer.append_payload(part) |
| | | |
| | | self._fields.clear() |
| | | return self._writer |
| | | |
| | | def __call__(self) -> Payload: |
| | | if self._is_multipart: |
| | | return self._gen_form_data() |
| | | else: |
| | | return self._gen_form_urlencoded() |
| New file |
| | |
| | | """HTTP Headers constants.""" |
| | | |
| | | # After changing the file content call ./tools/gen.py |
| | | # to regenerate the headers parser |
| | | import itertools |
| | | from typing import Final, Set |
| | | |
| | | from multidict import istr |
| | | |
| | | METH_ANY: Final[str] = "*" |
| | | METH_CONNECT: Final[str] = "CONNECT" |
| | | METH_HEAD: Final[str] = "HEAD" |
| | | METH_GET: Final[str] = "GET" |
| | | METH_DELETE: Final[str] = "DELETE" |
| | | METH_OPTIONS: Final[str] = "OPTIONS" |
| | | METH_PATCH: Final[str] = "PATCH" |
| | | METH_POST: Final[str] = "POST" |
| | | METH_PUT: Final[str] = "PUT" |
| | | METH_TRACE: Final[str] = "TRACE" |
| | | |
| | | METH_ALL: Final[Set[str]] = { |
| | | METH_CONNECT, |
| | | METH_HEAD, |
| | | METH_GET, |
| | | METH_DELETE, |
| | | METH_OPTIONS, |
| | | METH_PATCH, |
| | | METH_POST, |
| | | METH_PUT, |
| | | METH_TRACE, |
| | | } |
| | | |
| | | ACCEPT: Final[istr] = istr("Accept") |
| | | ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset") |
| | | ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding") |
| | | ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language") |
| | | ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges") |
| | | ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age") |
| | | ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials") |
| | | ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers") |
| | | ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods") |
| | | ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin") |
| | | ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers") |
| | | ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers") |
| | | ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method") |
| | | AGE: Final[istr] = istr("Age") |
| | | ALLOW: Final[istr] = istr("Allow") |
| | | AUTHORIZATION: Final[istr] = istr("Authorization") |
| | | CACHE_CONTROL: Final[istr] = istr("Cache-Control") |
| | | CONNECTION: Final[istr] = istr("Connection") |
| | | CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition") |
| | | CONTENT_ENCODING: Final[istr] = istr("Content-Encoding") |
| | | CONTENT_LANGUAGE: Final[istr] = istr("Content-Language") |
| | | CONTENT_LENGTH: Final[istr] = istr("Content-Length") |
| | | CONTENT_LOCATION: Final[istr] = istr("Content-Location") |
| | | CONTENT_MD5: Final[istr] = istr("Content-MD5") |
| | | CONTENT_RANGE: Final[istr] = istr("Content-Range") |
| | | CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding") |
| | | CONTENT_TYPE: Final[istr] = istr("Content-Type") |
| | | COOKIE: Final[istr] = istr("Cookie") |
| | | DATE: Final[istr] = istr("Date") |
| | | DESTINATION: Final[istr] = istr("Destination") |
| | | DIGEST: Final[istr] = istr("Digest") |
| | | ETAG: Final[istr] = istr("Etag") |
| | | EXPECT: Final[istr] = istr("Expect") |
| | | EXPIRES: Final[istr] = istr("Expires") |
| | | FORWARDED: Final[istr] = istr("Forwarded") |
| | | FROM: Final[istr] = istr("From") |
| | | HOST: Final[istr] = istr("Host") |
| | | IF_MATCH: Final[istr] = istr("If-Match") |
| | | IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since") |
| | | IF_NONE_MATCH: Final[istr] = istr("If-None-Match") |
| | | IF_RANGE: Final[istr] = istr("If-Range") |
| | | IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since") |
| | | KEEP_ALIVE: Final[istr] = istr("Keep-Alive") |
| | | LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID") |
| | | LAST_MODIFIED: Final[istr] = istr("Last-Modified") |
| | | LINK: Final[istr] = istr("Link") |
| | | LOCATION: Final[istr] = istr("Location") |
| | | MAX_FORWARDS: Final[istr] = istr("Max-Forwards") |
| | | ORIGIN: Final[istr] = istr("Origin") |
| | | PRAGMA: Final[istr] = istr("Pragma") |
| | | PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate") |
| | | PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization") |
| | | RANGE: Final[istr] = istr("Range") |
| | | REFERER: Final[istr] = istr("Referer") |
| | | RETRY_AFTER: Final[istr] = istr("Retry-After") |
| | | SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept") |
| | | SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version") |
| | | SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol") |
| | | SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions") |
| | | SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key") |
| | | SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1") |
| | | SERVER: Final[istr] = istr("Server") |
| | | SET_COOKIE: Final[istr] = istr("Set-Cookie") |
| | | TE: Final[istr] = istr("TE") |
| | | TRAILER: Final[istr] = istr("Trailer") |
| | | TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding") |
| | | UPGRADE: Final[istr] = istr("Upgrade") |
| | | URI: Final[istr] = istr("URI") |
| | | USER_AGENT: Final[istr] = istr("User-Agent") |
| | | VARY: Final[istr] = istr("Vary") |
| | | VIA: Final[istr] = istr("Via") |
| | | WANT_DIGEST: Final[istr] = istr("Want-Digest") |
| | | WARNING: Final[istr] = istr("Warning") |
| | | WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate") |
| | | X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For") |
| | | X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host") |
| | | X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto") |
| | | |
| | | # These are the upper/lower case variants of the headers/methods |
| | | # Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...} |
| | | METH_HEAD_ALL: Final = frozenset( |
| | | map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower()))) |
| | | ) |
| | | METH_CONNECT_ALL: Final = frozenset( |
| | | map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower()))) |
| | | ) |
| | | HOST_ALL: Final = frozenset( |
| | | map("".join, itertools.product(*zip(HOST.upper(), HOST.lower()))) |
| | | ) |
| New file |
| | |
| | | """Various helper functions""" |
| | | |
| | | import asyncio |
| | | import base64 |
| | | import binascii |
| | | import contextlib |
| | | import datetime |
| | | import enum |
| | | import functools |
| | | import inspect |
| | | import netrc |
| | | import os |
| | | import platform |
| | | import re |
| | | import sys |
| | | import time |
| | | import weakref |
| | | from collections import namedtuple |
| | | from contextlib import suppress |
| | | from email.message import EmailMessage |
| | | from email.parser import HeaderParser |
| | | from email.policy import HTTP |
| | | from email.utils import parsedate |
| | | from math import ceil |
| | | from pathlib import Path |
| | | from types import MappingProxyType, TracebackType |
| | | from typing import ( |
| | | Any, |
| | | Callable, |
| | | ContextManager, |
| | | Dict, |
| | | Generator, |
| | | Generic, |
| | | Iterable, |
| | | Iterator, |
| | | List, |
| | | Mapping, |
| | | Optional, |
| | | Protocol, |
| | | Tuple, |
| | | Type, |
| | | TypeVar, |
| | | Union, |
| | | get_args, |
| | | overload, |
| | | ) |
| | | from urllib.parse import quote |
| | | from urllib.request import getproxies, proxy_bypass |
| | | |
| | | import attr |
| | | from multidict import MultiDict, MultiDictProxy, MultiMapping |
| | | from propcache.api import under_cached_property as reify |
| | | from yarl import URL |
| | | |
| | | from . import hdrs |
| | | from .log import client_logger |
| | | |
| | | if sys.version_info >= (3, 11): |
| | | import asyncio as async_timeout |
| | | else: |
| | | import async_timeout |
| | | |
| | | __all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify") |
| | | |
| | | IS_MACOS = platform.system() == "Darwin" |
| | | IS_WINDOWS = platform.system() == "Windows" |
| | | |
| | | PY_310 = sys.version_info >= (3, 10) |
| | | PY_311 = sys.version_info >= (3, 11) |
| | | |
| | | |
| | | _T = TypeVar("_T") |
| | | _S = TypeVar("_S") |
| | | |
| | | _SENTINEL = enum.Enum("_SENTINEL", "sentinel") |
| | | sentinel = _SENTINEL.sentinel |
| | | |
| | | NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) |
| | | |
| | | # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 |
| | | EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) |
| | | # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 |
| | | # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 |
| | | EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL |
| | | |
| | | DEBUG = sys.flags.dev_mode or ( |
| | | not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) |
| | | ) |
| | | |
| | | |
| | | CHAR = {chr(i) for i in range(0, 128)} |
| | | CTL = {chr(i) for i in range(0, 32)} | { |
| | | chr(127), |
| | | } |
| | | SEPARATORS = { |
| | | "(", |
| | | ")", |
| | | "<", |
| | | ">", |
| | | "@", |
| | | ",", |
| | | ";", |
| | | ":", |
| | | "\\", |
| | | '"', |
| | | "/", |
| | | "[", |
| | | "]", |
| | | "?", |
| | | "=", |
| | | "{", |
| | | "}", |
| | | " ", |
| | | chr(9), |
| | | } |
| | | TOKEN = CHAR ^ CTL ^ SEPARATORS |
| | | |
| | | |
| | | class noop: |
| | | def __await__(self) -> Generator[None, None, None]: |
| | | yield |
| | | |
| | | |
| | | class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): |
| | | """Http basic authentication helper.""" |
| | | |
| | | def __new__( |
| | | cls, login: str, password: str = "", encoding: str = "latin1" |
| | | ) -> "BasicAuth": |
| | | if login is None: |
| | | raise ValueError("None is not allowed as login value") |
| | | |
| | | if password is None: |
| | | raise ValueError("None is not allowed as password value") |
| | | |
| | | if ":" in login: |
| | | raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') |
| | | |
| | | return super().__new__(cls, login, password, encoding) |
| | | |
| | | @classmethod |
| | | def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": |
| | | """Create a BasicAuth object from an Authorization HTTP header.""" |
| | | try: |
| | | auth_type, encoded_credentials = auth_header.split(" ", 1) |
| | | except ValueError: |
| | | raise ValueError("Could not parse authorization header.") |
| | | |
| | | if auth_type.lower() != "basic": |
| | | raise ValueError("Unknown authorization method %s" % auth_type) |
| | | |
| | | try: |
| | | decoded = base64.b64decode( |
| | | encoded_credentials.encode("ascii"), validate=True |
| | | ).decode(encoding) |
| | | except binascii.Error: |
| | | raise ValueError("Invalid base64 encoding.") |
| | | |
| | | try: |
| | | # RFC 2617 HTTP Authentication |
| | | # https://www.ietf.org/rfc/rfc2617.txt |
| | | # the colon must be present, but the username and password may be |
| | | # otherwise blank. |
| | | username, password = decoded.split(":", 1) |
| | | except ValueError: |
| | | raise ValueError("Invalid credentials.") |
| | | |
| | | return cls(username, password, encoding=encoding) |
| | | |
| | | @classmethod |
| | | def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: |
| | | """Create BasicAuth from url.""" |
| | | if not isinstance(url, URL): |
| | | raise TypeError("url should be yarl.URL instance") |
| | | # Check raw_user and raw_password first as yarl is likely |
| | | # to already have these values parsed from the netloc in the cache. |
| | | if url.raw_user is None and url.raw_password is None: |
| | | return None |
| | | return cls(url.user or "", url.password or "", encoding=encoding) |
| | | |
| | | def encode(self) -> str: |
| | | """Encode credentials.""" |
| | | creds = (f"{self.login}:{self.password}").encode(self.encoding) |
| | | return "Basic %s" % base64.b64encode(creds).decode(self.encoding) |
| | | |
| | | |
| | | def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: |
| | | """Remove user and password from URL if present and return BasicAuth object.""" |
| | | # Check raw_user and raw_password first as yarl is likely |
| | | # to already have these values parsed from the netloc in the cache. |
| | | if url.raw_user is None and url.raw_password is None: |
| | | return url, None |
| | | return url.with_user(None), BasicAuth(url.user or "", url.password or "") |
| | | |
| | | |
| | | def netrc_from_env() -> Optional[netrc.netrc]: |
| | | """Load netrc from file. |
| | | |
| | | Attempt to load it from the path specified by the env-var |
| | | NETRC or in the default location in the user's home directory. |
| | | |
| | | Returns None if it couldn't be found or fails to parse. |
| | | """ |
| | | netrc_env = os.environ.get("NETRC") |
| | | |
| | | if netrc_env is not None: |
| | | netrc_path = Path(netrc_env) |
| | | else: |
| | | try: |
| | | home_dir = Path.home() |
| | | except RuntimeError as e: # pragma: no cover |
| | | # if pathlib can't resolve home, it may raise a RuntimeError |
| | | client_logger.debug( |
| | | "Could not resolve home directory when " |
| | | "trying to look for .netrc file: %s", |
| | | e, |
| | | ) |
| | | return None |
| | | |
| | | netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc") |
| | | |
| | | try: |
| | | return netrc.netrc(str(netrc_path)) |
| | | except netrc.NetrcParseError as e: |
| | | client_logger.warning("Could not parse .netrc file: %s", e) |
| | | except OSError as e: |
| | | netrc_exists = False |
| | | with contextlib.suppress(OSError): |
| | | netrc_exists = netrc_path.is_file() |
| | | # we couldn't read the file (doesn't exist, permissions, etc.) |
| | | if netrc_env or netrc_exists: |
| | | # only warn if the environment wanted us to load it, |
| | | # or it appears like the default file does actually exist |
| | | client_logger.warning("Could not read .netrc file: %s", e) |
| | | |
| | | return None |
| | | |
| | | |
| | | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | | class ProxyInfo: |
| | | proxy: URL |
| | | proxy_auth: Optional[BasicAuth] |
| | | |
| | | |
| | | def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: |
| | | """ |
| | | Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. |
| | | |
| | | :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no |
| | | entry is found for the ``host``. |
| | | """ |
| | | if netrc_obj is None: |
| | | raise LookupError("No .netrc file found") |
| | | auth_from_netrc = netrc_obj.authenticators(host) |
| | | |
| | | if auth_from_netrc is None: |
| | | raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") |
| | | login, account, password = auth_from_netrc |
| | | |
| | | # TODO(PY311): username = login or account |
| | | # Up to python 3.10, account could be None if not specified, |
| | | # and login will be empty string if not specified. From 3.11, |
| | | # login and account will be empty string if not specified. |
| | | username = login if (login or account is None) else account |
| | | |
| | | # TODO(PY311): Remove this, as password will be empty string |
| | | # if not specified |
| | | if password is None: |
| | | password = "" |
| | | |
| | | return BasicAuth(username, password) |
| | | |
| | | |
| | | def proxies_from_env() -> Dict[str, ProxyInfo]: |
| | | proxy_urls = { |
| | | k: URL(v) |
| | | for k, v in getproxies().items() |
| | | if k in ("http", "https", "ws", "wss") |
| | | } |
| | | netrc_obj = netrc_from_env() |
| | | stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} |
| | | ret = {} |
| | | for proto, val in stripped.items(): |
| | | proxy, auth = val |
| | | if proxy.scheme in ("https", "wss"): |
| | | client_logger.warning( |
| | | "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy |
| | | ) |
| | | continue |
| | | if netrc_obj and auth is None: |
| | | if proxy.host is not None: |
| | | try: |
| | | auth = basicauth_from_netrc(netrc_obj, proxy.host) |
| | | except LookupError: |
| | | auth = None |
| | | ret[proto] = ProxyInfo(proxy, auth) |
| | | return ret |
| | | |
| | | |
| | | def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: |
| | | """Get a permitted proxy for the given URL from the env.""" |
| | | if url.host is not None and proxy_bypass(url.host): |
| | | raise LookupError(f"Proxying is disallowed for `{url.host!r}`") |
| | | |
| | | proxies_in_env = proxies_from_env() |
| | | try: |
| | | proxy_info = proxies_in_env[url.scheme] |
| | | except KeyError: |
| | | raise LookupError(f"No proxies found for `{url!s}` in the env") |
| | | else: |
| | | return proxy_info.proxy, proxy_info.proxy_auth |
| | | |
| | | |
| | | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | | class MimeType: |
| | | type: str |
| | | subtype: str |
| | | suffix: str |
| | | parameters: "MultiDictProxy[str]" |
| | | |
| | | |
| | | @functools.lru_cache(maxsize=56) |
| | | def parse_mimetype(mimetype: str) -> MimeType: |
| | | """Parses a MIME type into its components. |
| | | |
| | | mimetype is a MIME type string. |
| | | |
| | | Returns a MimeType object. |
| | | |
| | | Example: |
| | | |
| | | >>> parse_mimetype('text/html; charset=utf-8') |
| | | MimeType(type='text', subtype='html', suffix='', |
| | | parameters={'charset': 'utf-8'}) |
| | | |
| | | """ |
| | | if not mimetype: |
| | | return MimeType( |
| | | type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) |
| | | ) |
| | | |
| | | parts = mimetype.split(";") |
| | | params: MultiDict[str] = MultiDict() |
| | | for item in parts[1:]: |
| | | if not item: |
| | | continue |
| | | key, _, value = item.partition("=") |
| | | params.add(key.lower().strip(), value.strip(' "')) |
| | | |
| | | fulltype = parts[0].strip().lower() |
| | | if fulltype == "*": |
| | | fulltype = "*/*" |
| | | |
| | | mtype, _, stype = fulltype.partition("/") |
| | | stype, _, suffix = stype.partition("+") |
| | | |
| | | return MimeType( |
| | | type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) |
| | | ) |
| | | |
| | | |
| | | class EnsureOctetStream(EmailMessage): |
| | | def __init__(self) -> None: |
| | | super().__init__() |
| | | # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 |
| | | self.set_default_type("application/octet-stream") |
| | | |
| | | def get_content_type(self) -> str: |
| | | """Re-implementation from Message |
| | | |
| | | Returns application/octet-stream in place of plain/text when |
| | | value is wrong. |
| | | |
| | | The way this class is used guarantees that content-type will |
| | | be present so simplify the checks wrt to the base implementation. |
| | | """ |
| | | value = self.get("content-type", "").lower() |
| | | |
| | | # Based on the implementation of _splitparam in the standard library |
| | | ctype, _, _ = value.partition(";") |
| | | ctype = ctype.strip() |
| | | if ctype.count("/") != 1: |
| | | return self.get_default_type() |
| | | return ctype |
| | | |
| | | |
| | | @functools.lru_cache(maxsize=56) |
| | | def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]: |
| | | """Parse Content-Type header. |
| | | |
| | | Returns a tuple of the parsed content type and a |
| | | MappingProxyType of parameters. The default returned value |
| | | is `application/octet-stream` |
| | | """ |
| | | msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}") |
| | | content_type = msg.get_content_type() |
| | | params = msg.get_params(()) |
| | | content_dict = dict(params[1:]) # First element is content type again |
| | | return content_type, MappingProxyType(content_dict) |
| | | |
| | | |
| | | def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: |
| | | name = getattr(obj, "name", None) |
| | | if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": |
| | | return Path(name).name |
| | | return default |
| | | |
| | | |
| | | not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") |
| | | QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} |
| | | |
| | | |
| | | def quoted_string(content: str) -> str: |
| | | """Return 7-bit content as quoted-string. |
| | | |
| | | Format content into a quoted-string as defined in RFC5322 for |
| | | Internet Message Format. Notice that this is not the 8-bit HTTP |
| | | format, but the 7-bit email format. Content must be in usascii or |
| | | a ValueError is raised. |
| | | """ |
| | | if not (QCONTENT > set(content)): |
| | | raise ValueError(f"bad content for quoted-string {content!r}") |
| | | return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) |
| | | |
| | | |
| | | def content_disposition_header( |
| | | disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str |
| | | ) -> str: |
| | | """Sets ``Content-Disposition`` header for MIME. |
| | | |
| | | This is the MIME payload Content-Disposition header from RFC 2183 |
| | | and RFC 7579 section 4.2, not the HTTP Content-Disposition from |
| | | RFC 6266. |
| | | |
| | | disptype is a disposition type: inline, attachment, form-data. |
| | | Should be valid extension token (see RFC 2183) |
| | | |
| | | quote_fields performs value quoting to 7-bit MIME headers |
| | | according to RFC 7578. Set to quote_fields to False if recipient |
| | | can take 8-bit file names and field values. |
| | | |
| | | _charset specifies the charset to use when quote_fields is True. |
| | | |
| | | params is a dict with disposition params. |
| | | """ |
| | | if not disptype or not (TOKEN > set(disptype)): |
| | | raise ValueError(f"bad content disposition type {disptype!r}") |
| | | |
| | | value = disptype |
| | | if params: |
| | | lparams = [] |
| | | for key, val in params.items(): |
| | | if not key or not (TOKEN > set(key)): |
| | | raise ValueError(f"bad content disposition parameter {key!r}={val!r}") |
| | | if quote_fields: |
| | | if key.lower() == "filename": |
| | | qval = quote(val, "", encoding=_charset) |
| | | lparams.append((key, '"%s"' % qval)) |
| | | else: |
| | | try: |
| | | qval = quoted_string(val) |
| | | except ValueError: |
| | | qval = "".join( |
| | | (_charset, "''", quote(val, "", encoding=_charset)) |
| | | ) |
| | | lparams.append((key + "*", qval)) |
| | | else: |
| | | lparams.append((key, '"%s"' % qval)) |
| | | else: |
| | | qval = val.replace("\\", "\\\\").replace('"', '\\"') |
| | | lparams.append((key, '"%s"' % qval)) |
| | | sparams = "; ".join("=".join(pair) for pair in lparams) |
| | | value = "; ".join((value, sparams)) |
| | | return value |
| | | |
| | | |
| | | def is_ip_address(host: Optional[str]) -> bool: |
| | | """Check if host looks like an IP Address. |
| | | |
| | | This check is only meant as a heuristic to ensure that |
| | | a host is not a domain name. |
| | | """ |
| | | if not host: |
| | | return False |
| | | # For a host to be an ipv4 address, it must be all numeric. |
| | | # The host must contain a colon to be an IPv6 address. |
| | | return ":" in host or host.replace(".", "").isdigit() |
| | | |
| | | |
| | | _cached_current_datetime: Optional[int] = None |
| | | _cached_formatted_datetime = "" |
| | | |
| | | |
| | | def rfc822_formatted_time() -> str: |
| | | global _cached_current_datetime |
| | | global _cached_formatted_datetime |
| | | |
| | | now = int(time.time()) |
| | | if now != _cached_current_datetime: |
| | | # Weekday and month names for HTTP date/time formatting; |
| | | # always English! |
| | | # Tuples are constants stored in codeobject! |
| | | _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") |
| | | _monthname = ( |
| | | "", # Dummy so we can use 1-based month numbers |
| | | "Jan", |
| | | "Feb", |
| | | "Mar", |
| | | "Apr", |
| | | "May", |
| | | "Jun", |
| | | "Jul", |
| | | "Aug", |
| | | "Sep", |
| | | "Oct", |
| | | "Nov", |
| | | "Dec", |
| | | ) |
| | | |
| | | year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) |
| | | _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( |
| | | _weekdayname[wd], |
| | | day, |
| | | _monthname[month], |
| | | year, |
| | | hh, |
| | | mm, |
| | | ss, |
| | | ) |
| | | _cached_current_datetime = now |
| | | return _cached_formatted_datetime |
| | | |
| | | |
| | | def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: |
| | | ref, name = info |
| | | ob = ref() |
| | | if ob is not None: |
| | | with suppress(Exception): |
| | | getattr(ob, name)() |
| | | |
| | | |
| | | def weakref_handle( |
| | | ob: object, |
| | | name: str, |
| | | timeout: float, |
| | | loop: asyncio.AbstractEventLoop, |
| | | timeout_ceil_threshold: float = 5, |
| | | ) -> Optional[asyncio.TimerHandle]: |
| | | if timeout is not None and timeout > 0: |
| | | when = loop.time() + timeout |
| | | if timeout >= timeout_ceil_threshold: |
| | | when = ceil(when) |
| | | |
| | | return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) |
| | | return None |
| | | |
| | | |
| | | def call_later( |
| | | cb: Callable[[], Any], |
| | | timeout: float, |
| | | loop: asyncio.AbstractEventLoop, |
| | | timeout_ceil_threshold: float = 5, |
| | | ) -> Optional[asyncio.TimerHandle]: |
| | | if timeout is None or timeout <= 0: |
| | | return None |
| | | now = loop.time() |
| | | when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) |
| | | return loop.call_at(when, cb) |
| | | |
| | | |
| | | def calculate_timeout_when( |
| | | loop_time: float, |
| | | timeout: float, |
| | | timeout_ceiling_threshold: float, |
| | | ) -> float: |
| | | """Calculate when to execute a timeout.""" |
| | | when = loop_time + timeout |
| | | if timeout > timeout_ceiling_threshold: |
| | | return ceil(when) |
| | | return when |
| | | |
| | | |
| | | class TimeoutHandle: |
| | | """Timeout handle""" |
| | | |
| | | __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") |
| | | |
| | | def __init__( |
| | | self, |
| | | loop: asyncio.AbstractEventLoop, |
| | | timeout: Optional[float], |
| | | ceil_threshold: float = 5, |
| | | ) -> None: |
| | | self._timeout = timeout |
| | | self._loop = loop |
| | | self._ceil_threshold = ceil_threshold |
| | | self._callbacks: List[ |
| | | Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] |
| | | ] = [] |
| | | |
| | | def register( |
| | | self, callback: Callable[..., None], *args: Any, **kwargs: Any |
| | | ) -> None: |
| | | self._callbacks.append((callback, args, kwargs)) |
| | | |
| | | def close(self) -> None: |
| | | self._callbacks.clear() |
| | | |
| | | def start(self) -> Optional[asyncio.TimerHandle]: |
| | | timeout = self._timeout |
| | | if timeout is not None and timeout > 0: |
| | | when = self._loop.time() + timeout |
| | | if timeout >= self._ceil_threshold: |
| | | when = ceil(when) |
| | | return self._loop.call_at(when, self.__call__) |
| | | else: |
| | | return None |
| | | |
| | | def timer(self) -> "BaseTimerContext": |
| | | if self._timeout is not None and self._timeout > 0: |
| | | timer = TimerContext(self._loop) |
| | | self.register(timer.timeout) |
| | | return timer |
| | | else: |
| | | return TimerNoop() |
| | | |
| | | def __call__(self) -> None: |
| | | for cb, args, kwargs in self._callbacks: |
| | | with suppress(Exception): |
| | | cb(*args, **kwargs) |
| | | |
| | | self._callbacks.clear() |
| | | |
| | | |
| | | class BaseTimerContext(ContextManager["BaseTimerContext"]): |
| | | |
| | | __slots__ = () |
| | | |
| | | def assert_timeout(self) -> None: |
| | | """Raise TimeoutError if timeout has been exceeded.""" |
| | | |
| | | |
| | | class TimerNoop(BaseTimerContext): |
| | | |
| | | __slots__ = () |
| | | |
| | | def __enter__(self) -> BaseTimerContext: |
| | | return self |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | return |
| | | |
| | | |
| | | class TimerContext(BaseTimerContext): |
| | | """Low resolution timeout context manager""" |
| | | |
| | | __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") |
| | | |
| | | def __init__(self, loop: asyncio.AbstractEventLoop) -> None: |
| | | self._loop = loop |
| | | self._tasks: List[asyncio.Task[Any]] = [] |
| | | self._cancelled = False |
| | | self._cancelling = 0 |
| | | |
| | | def assert_timeout(self) -> None: |
| | | """Raise TimeoutError if timer has already been cancelled.""" |
| | | if self._cancelled: |
| | | raise asyncio.TimeoutError from None |
| | | |
| | | def __enter__(self) -> BaseTimerContext: |
| | | task = asyncio.current_task(loop=self._loop) |
| | | if task is None: |
| | | raise RuntimeError("Timeout context manager should be used inside a task") |
| | | |
| | | if sys.version_info >= (3, 11): |
| | | # Remember if the task was already cancelling |
| | | # so when we __exit__ we can decide if we should |
| | | # raise asyncio.TimeoutError or let the cancellation propagate |
| | | self._cancelling = task.cancelling() |
| | | |
| | | if self._cancelled: |
| | | raise asyncio.TimeoutError from None |
| | | |
| | | self._tasks.append(task) |
| | | return self |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> Optional[bool]: |
| | | enter_task: Optional[asyncio.Task[Any]] = None |
| | | if self._tasks: |
| | | enter_task = self._tasks.pop() |
| | | |
| | | if exc_type is asyncio.CancelledError and self._cancelled: |
| | | assert enter_task is not None |
| | | # The timeout was hit, and the task was cancelled |
| | | # so we need to uncancel the last task that entered the context manager |
| | | # since the cancellation should not leak out of the context manager |
| | | if sys.version_info >= (3, 11): |
| | | # If the task was already cancelling don't raise |
| | | # asyncio.TimeoutError and instead return None |
| | | # to allow the cancellation to propagate |
| | | if enter_task.uncancel() > self._cancelling: |
| | | return None |
| | | raise asyncio.TimeoutError from exc_val |
| | | return None |
| | | |
| | | def timeout(self) -> None: |
| | | if not self._cancelled: |
| | | for task in set(self._tasks): |
| | | task.cancel() |
| | | |
| | | self._cancelled = True |
| | | |
| | | |
| | | def ceil_timeout( |
| | | delay: Optional[float], ceil_threshold: float = 5 |
| | | ) -> async_timeout.Timeout: |
| | | if delay is None or delay <= 0: |
| | | return async_timeout.timeout(None) |
| | | |
| | | loop = asyncio.get_running_loop() |
| | | now = loop.time() |
| | | when = now + delay |
| | | if delay > ceil_threshold: |
| | | when = ceil(when) |
| | | return async_timeout.timeout_at(when) |
| | | |
| | | |
| | | class HeadersMixin: |
| | | """Mixin for handling headers.""" |
| | | |
| | | ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) |
| | | |
| | | _headers: MultiMapping[str] |
| | | _content_type: Optional[str] = None |
| | | _content_dict: Optional[Dict[str, str]] = None |
| | | _stored_content_type: Union[str, None, _SENTINEL] = sentinel |
| | | |
| | | def _parse_content_type(self, raw: Optional[str]) -> None: |
| | | self._stored_content_type = raw |
| | | if raw is None: |
| | | # default value according to RFC 2616 |
| | | self._content_type = "application/octet-stream" |
| | | self._content_dict = {} |
| | | else: |
| | | content_type, content_mapping_proxy = parse_content_type(raw) |
| | | self._content_type = content_type |
| | | # _content_dict needs to be mutable so we can update it |
| | | self._content_dict = content_mapping_proxy.copy() |
| | | |
| | | @property |
| | | def content_type(self) -> str: |
| | | """The value of content part for Content-Type HTTP header.""" |
| | | raw = self._headers.get(hdrs.CONTENT_TYPE) |
| | | if self._stored_content_type != raw: |
| | | self._parse_content_type(raw) |
| | | assert self._content_type is not None |
| | | return self._content_type |
| | | |
| | | @property |
| | | def charset(self) -> Optional[str]: |
| | | """The value of charset part for Content-Type HTTP header.""" |
| | | raw = self._headers.get(hdrs.CONTENT_TYPE) |
| | | if self._stored_content_type != raw: |
| | | self._parse_content_type(raw) |
| | | assert self._content_dict is not None |
| | | return self._content_dict.get("charset") |
| | | |
| | | @property |
| | | def content_length(self) -> Optional[int]: |
| | | """The value of Content-Length HTTP header.""" |
| | | content_length = self._headers.get(hdrs.CONTENT_LENGTH) |
| | | return None if content_length is None else int(content_length) |
| | | |
| | | |
| | | def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: |
| | | if not fut.done(): |
| | | fut.set_result(result) |
| | | |
| | | |
| | | _EXC_SENTINEL = BaseException() |
| | | |
| | | |
| | | class ErrorableProtocol(Protocol): |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = ..., |
| | | ) -> None: ... # pragma: no cover |
| | | |
| | | |
| | | def set_exception( |
| | | fut: "asyncio.Future[_T] | ErrorableProtocol", |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | """Set future exception. |
| | | |
| | | If the future is marked as complete, this function is a no-op. |
| | | |
| | | :param exc_cause: An exception that is a direct cause of ``exc``. |
| | | Only set if provided. |
| | | """ |
| | | if asyncio.isfuture(fut) and fut.done(): |
| | | return |
| | | |
| | | exc_is_sentinel = exc_cause is _EXC_SENTINEL |
| | | exc_causes_itself = exc is exc_cause |
| | | if not exc_is_sentinel and not exc_causes_itself: |
| | | exc.__cause__ = exc_cause |
| | | |
| | | fut.set_exception(exc) |
| | | |
| | | |
| | | @functools.total_ordering |
| | | class AppKey(Generic[_T]): |
| | | """Keys for static typing support in Application.""" |
| | | |
| | | __slots__ = ("_name", "_t", "__orig_class__") |
| | | |
| | | # This may be set by Python when instantiating with a generic type. We need to |
| | | # support this, in order to support types that are not concrete classes, |
| | | # like Iterable, which can't be passed as the second parameter to __init__. |
| | | __orig_class__: Type[object] |
| | | |
| | | def __init__(self, name: str, t: Optional[Type[_T]] = None): |
| | | # Prefix with module name to help deduplicate key names. |
| | | frame = inspect.currentframe() |
| | | while frame: |
| | | if frame.f_code.co_name == "<module>": |
| | | module: str = frame.f_globals["__name__"] |
| | | break |
| | | frame = frame.f_back |
| | | |
| | | self._name = module + "." + name |
| | | self._t = t |
| | | |
| | | def __lt__(self, other: object) -> bool: |
| | | if isinstance(other, AppKey): |
| | | return self._name < other._name |
| | | return True # Order AppKey above other types. |
| | | |
| | | def __repr__(self) -> str: |
| | | t = self._t |
| | | if t is None: |
| | | with suppress(AttributeError): |
| | | # Set to type arg. |
| | | t = get_args(self.__orig_class__)[0] |
| | | |
| | | if t is None: |
| | | t_repr = "<<Unknown>>" |
| | | elif isinstance(t, type): |
| | | if t.__module__ == "builtins": |
| | | t_repr = t.__qualname__ |
| | | else: |
| | | t_repr = f"{t.__module__}.{t.__qualname__}" |
| | | else: |
| | | t_repr = repr(t) |
| | | return f"<AppKey({self._name}, type={t_repr})>" |
| | | |
| | | |
| | | class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): |
| | | __slots__ = ("_maps",) |
| | | |
| | | def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: |
| | | self._maps = tuple(maps) |
| | | |
| | | def __init_subclass__(cls) -> None: |
| | | raise TypeError( |
| | | "Inheritance class {} from ChainMapProxy " |
| | | "is forbidden".format(cls.__name__) |
| | | ) |
| | | |
| | | @overload # type: ignore[override] |
| | | def __getitem__(self, key: AppKey[_T]) -> _T: ... |
| | | |
| | | @overload |
| | | def __getitem__(self, key: str) -> Any: ... |
| | | |
| | | def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: |
| | | for mapping in self._maps: |
| | | try: |
| | | return mapping[key] |
| | | except KeyError: |
| | | pass |
| | | raise KeyError(key) |
| | | |
| | | @overload # type: ignore[override] |
| | | def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... |
| | | |
| | | @overload |
| | | def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... |
| | | |
| | | @overload |
| | | def get(self, key: str, default: Any = ...) -> Any: ... |
| | | |
| | | def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: |
| | | try: |
| | | return self[key] |
| | | except KeyError: |
| | | return default |
| | | |
| | | def __len__(self) -> int: |
| | | # reuses stored hash values if possible |
| | | return len(set().union(*self._maps)) |
| | | |
| | | def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: |
| | | d: Dict[Union[str, AppKey[Any]], Any] = {} |
| | | for mapping in reversed(self._maps): |
| | | # reuses stored hash values if possible |
| | | d.update(mapping) |
| | | return iter(d) |
| | | |
| | | def __contains__(self, key: object) -> bool: |
| | | return any(key in m for m in self._maps) |
| | | |
| | | def __bool__(self) -> bool: |
| | | return any(self._maps) |
| | | |
| | | def __repr__(self) -> str: |
| | | content = ", ".join(map(repr, self._maps)) |
| | | return f"ChainMapProxy({content})" |
| | | |
| | | |
| | | # https://tools.ietf.org/html/rfc7232#section-2.3 |
| | | _ETAGC = r"[!\x23-\x7E\x80-\xff]+" |
| | | _ETAGC_RE = re.compile(_ETAGC) |
| | | _QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' |
| | | QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) |
| | | LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") |
| | | |
| | | ETAG_ANY = "*" |
| | | |
| | | |
| | | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | | class ETag: |
| | | value: str |
| | | is_weak: bool = False |
| | | |
| | | |
| | | def validate_etag_value(value: str) -> None: |
| | | if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): |
| | | raise ValueError( |
| | | f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" |
| | | ) |
| | | |
| | | |
| | | def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: |
| | | """Process a date string, return a datetime object""" |
| | | if date_str is not None: |
| | | timetuple = parsedate(date_str) |
| | | if timetuple is not None: |
| | | with suppress(ValueError): |
| | | return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) |
| | | return None |
| | | |
| | | |
| | | @functools.lru_cache |
| | | def must_be_empty_body(method: str, code: int) -> bool: |
| | | """Check if a request must return an empty body.""" |
| | | return ( |
| | | code in EMPTY_BODY_STATUS_CODES |
| | | or method in EMPTY_BODY_METHODS |
| | | or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) |
| | | ) |
| | | |
| | | |
| | | def should_remove_content_length(method: str, code: int) -> bool: |
| | | """Check if a Content-Length header should be removed. |
| | | |
| | | This should always be a subset of must_be_empty_body |
| | | """ |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 |
| | | return code in EMPTY_BODY_STATUS_CODES or ( |
| | | 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL |
| | | ) |
| New file |
| | |
| | | import sys |
| | | from http import HTTPStatus |
| | | from typing import Mapping, Tuple |
| | | |
| | | from . import __version__ |
| | | from .http_exceptions import HttpProcessingError as HttpProcessingError |
| | | from .http_parser import ( |
| | | HeadersParser as HeadersParser, |
| | | HttpParser as HttpParser, |
| | | HttpRequestParser as HttpRequestParser, |
| | | HttpResponseParser as HttpResponseParser, |
| | | RawRequestMessage as RawRequestMessage, |
| | | RawResponseMessage as RawResponseMessage, |
| | | ) |
| | | from .http_websocket import ( |
| | | WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE, |
| | | WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE, |
| | | WS_KEY as WS_KEY, |
| | | WebSocketError as WebSocketError, |
| | | WebSocketReader as WebSocketReader, |
| | | WebSocketWriter as WebSocketWriter, |
| | | WSCloseCode as WSCloseCode, |
| | | WSMessage as WSMessage, |
| | | WSMsgType as WSMsgType, |
| | | ws_ext_gen as ws_ext_gen, |
| | | ws_ext_parse as ws_ext_parse, |
| | | ) |
| | | from .http_writer import ( |
| | | HttpVersion as HttpVersion, |
| | | HttpVersion10 as HttpVersion10, |
| | | HttpVersion11 as HttpVersion11, |
| | | StreamWriter as StreamWriter, |
| | | ) |
| | | |
| | | __all__ = ( |
| | | "HttpProcessingError", |
| | | "RESPONSES", |
| | | "SERVER_SOFTWARE", |
| | | # .http_writer |
| | | "StreamWriter", |
| | | "HttpVersion", |
| | | "HttpVersion10", |
| | | "HttpVersion11", |
| | | # .http_parser |
| | | "HeadersParser", |
| | | "HttpParser", |
| | | "HttpRequestParser", |
| | | "HttpResponseParser", |
| | | "RawRequestMessage", |
| | | "RawResponseMessage", |
| | | # .http_websocket |
| | | "WS_CLOSED_MESSAGE", |
| | | "WS_CLOSING_MESSAGE", |
| | | "WS_KEY", |
| | | "WebSocketReader", |
| | | "WebSocketWriter", |
| | | "ws_ext_gen", |
| | | "ws_ext_parse", |
| | | "WSMessage", |
| | | "WebSocketError", |
| | | "WSMsgType", |
| | | "WSCloseCode", |
| | | ) |
| | | |
| | | |
| | | SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format( |
| | | sys.version_info, __version__ |
| | | ) |
| | | |
| | | RESPONSES: Mapping[int, Tuple[str, str]] = { |
| | | v: (v.phrase, v.description) for v in HTTPStatus.__members__.values() |
| | | } |
| New file |
| | |
| | | """Low-level http related exceptions.""" |
| | | |
| | | from textwrap import indent |
| | | from typing import Optional, Union |
| | | |
| | | from .typedefs import _CIMultiDict |
| | | |
| | | __all__ = ("HttpProcessingError",) |
| | | |
| | | |
| | | class HttpProcessingError(Exception): |
| | | """HTTP error. |
| | | |
| | | Shortcut for raising HTTP errors with custom code, message and headers. |
| | | |
| | | code: HTTP Error code. |
| | | message: (optional) Error message. |
| | | headers: (optional) Headers to be sent in response, a list of pairs |
| | | """ |
| | | |
| | | code = 0 |
| | | message = "" |
| | | headers = None |
| | | |
| | | def __init__( |
| | | self, |
| | | *, |
| | | code: Optional[int] = None, |
| | | message: str = "", |
| | | headers: Optional[_CIMultiDict] = None, |
| | | ) -> None: |
| | | if code is not None: |
| | | self.code = code |
| | | self.headers = headers |
| | | self.message = message |
| | | |
| | | def __str__(self) -> str: |
| | | msg = indent(self.message, " ") |
| | | return f"{self.code}, message:\n{msg}" |
| | | |
| | | def __repr__(self) -> str: |
| | | return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>" |
| | | |
| | | |
| | | class BadHttpMessage(HttpProcessingError): |
| | | |
| | | code = 400 |
| | | message = "Bad Request" |
| | | |
| | | def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None: |
| | | super().__init__(message=message, headers=headers) |
| | | self.args = (message,) |
| | | |
| | | |
| | | class HttpBadRequest(BadHttpMessage): |
| | | |
| | | code = 400 |
| | | message = "Bad Request" |
| | | |
| | | |
| | | class PayloadEncodingError(BadHttpMessage): |
| | | """Base class for payload errors""" |
| | | |
| | | |
| | | class ContentEncodingError(PayloadEncodingError): |
| | | """Content encoding error.""" |
| | | |
| | | |
| | | class TransferEncodingError(PayloadEncodingError): |
| | | """transfer encoding error.""" |
| | | |
| | | |
| | | class ContentLengthError(PayloadEncodingError): |
| | | """Not enough data to satisfy content length header.""" |
| | | |
| | | |
| | | class DecompressSizeError(PayloadEncodingError): |
| | | """Decompressed size exceeds the configured limit.""" |
| | | |
| | | |
| | | class LineTooLong(BadHttpMessage): |
| | | def __init__( |
| | | self, line: str, limit: str = "Unknown", actual_size: str = "Unknown" |
| | | ) -> None: |
| | | super().__init__( |
| | | f"Got more than {limit} bytes ({actual_size}) when reading {line}." |
| | | ) |
| | | self.args = (line, limit, actual_size) |
| | | |
| | | |
| | | class InvalidHeader(BadHttpMessage): |
| | | def __init__(self, hdr: Union[bytes, str]) -> None: |
| | | hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr |
| | | super().__init__(f"Invalid HTTP header: {hdr!r}") |
| | | self.hdr = hdr_s |
| | | self.args = (hdr,) |
| | | |
| | | |
| | | class BadStatusLine(BadHttpMessage): |
| | | def __init__(self, line: str = "", error: Optional[str] = None) -> None: |
| | | if not isinstance(line, str): |
| | | line = repr(line) |
| | | super().__init__(error or f"Bad status line {line!r}") |
| | | self.args = (line,) |
| | | self.line = line |
| | | |
| | | |
| | | class BadHttpMethod(BadStatusLine): |
| | | """Invalid HTTP method in status line.""" |
| | | |
| | | def __init__(self, line: str = "", error: Optional[str] = None) -> None: |
| | | super().__init__(line, error or f"Bad HTTP method in status line {line!r}") |
| | | |
| | | |
| | | class InvalidURLError(BadHttpMessage): |
| | | pass |
| New file |
| | |
| | | import abc |
| | | import asyncio |
| | | import re |
| | | import string |
| | | from contextlib import suppress |
| | | from enum import IntEnum |
| | | from typing import ( |
| | | Any, |
| | | ClassVar, |
| | | Final, |
| | | Generic, |
| | | List, |
| | | Literal, |
| | | NamedTuple, |
| | | Optional, |
| | | Pattern, |
| | | Set, |
| | | Tuple, |
| | | Type, |
| | | TypeVar, |
| | | Union, |
| | | ) |
| | | |
| | | from multidict import CIMultiDict, CIMultiDictProxy, istr |
| | | from yarl import URL |
| | | |
| | | from . import hdrs |
| | | from .base_protocol import BaseProtocol |
| | | from .compression_utils import ( |
| | | DEFAULT_MAX_DECOMPRESS_SIZE, |
| | | HAS_BROTLI, |
| | | HAS_ZSTD, |
| | | BrotliDecompressor, |
| | | ZLibDecompressor, |
| | | ZSTDDecompressor, |
| | | ) |
| | | from .helpers import ( |
| | | _EXC_SENTINEL, |
| | | DEBUG, |
| | | EMPTY_BODY_METHODS, |
| | | EMPTY_BODY_STATUS_CODES, |
| | | NO_EXTENSIONS, |
| | | BaseTimerContext, |
| | | set_exception, |
| | | ) |
| | | from .http_exceptions import ( |
| | | BadHttpMessage, |
| | | BadHttpMethod, |
| | | BadStatusLine, |
| | | ContentEncodingError, |
| | | ContentLengthError, |
| | | DecompressSizeError, |
| | | InvalidHeader, |
| | | InvalidURLError, |
| | | LineTooLong, |
| | | TransferEncodingError, |
| | | ) |
| | | from .http_writer import HttpVersion, HttpVersion10 |
| | | from .streams import EMPTY_PAYLOAD, StreamReader |
| | | from .typedefs import RawHeaders |
| | | |
| | | __all__ = ( |
| | | "HeadersParser", |
| | | "HttpParser", |
| | | "HttpRequestParser", |
| | | "HttpResponseParser", |
| | | "RawRequestMessage", |
| | | "RawResponseMessage", |
| | | ) |
| | | |
| | | _SEP = Literal[b"\r\n", b"\n"] |
| | | |
| | | ASCIISET: Final[Set[str]] = set(string.printable) |
| | | |
| | | # See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview |
| | | # and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens |
| | | # |
| | | # method = token |
| | | # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / |
| | | # "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA |
| | | # token = 1*tchar |
| | | _TCHAR_SPECIALS: Final[str] = re.escape("!#$%&'*+-.^_`|~") |
| | | TOKENRE: Final[Pattern[str]] = re.compile(f"[0-9A-Za-z{_TCHAR_SPECIALS}]+") |
| | | VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d)\.(\d)", re.ASCII) |
| | | DIGITS: Final[Pattern[str]] = re.compile(r"\d+", re.ASCII) |
| | | HEXDIGITS: Final[Pattern[bytes]] = re.compile(rb"[0-9a-fA-F]+") |
| | | |
| | | |
| | | class RawRequestMessage(NamedTuple): |
| | | method: str |
| | | path: str |
| | | version: HttpVersion |
| | | headers: "CIMultiDictProxy[str]" |
| | | raw_headers: RawHeaders |
| | | should_close: bool |
| | | compression: Optional[str] |
| | | upgrade: bool |
| | | chunked: bool |
| | | url: URL |
| | | |
| | | |
| | | class RawResponseMessage(NamedTuple): |
| | | version: HttpVersion |
| | | code: int |
| | | reason: str |
| | | headers: CIMultiDictProxy[str] |
| | | raw_headers: RawHeaders |
| | | should_close: bool |
| | | compression: Optional[str] |
| | | upgrade: bool |
| | | chunked: bool |
| | | |
| | | |
| | | _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) |
| | | |
| | | |
| | | class ParseState(IntEnum): |
| | | |
| | | PARSE_NONE = 0 |
| | | PARSE_LENGTH = 1 |
| | | PARSE_CHUNKED = 2 |
| | | PARSE_UNTIL_EOF = 3 |
| | | |
| | | |
| | | class ChunkState(IntEnum): |
| | | PARSE_CHUNKED_SIZE = 0 |
| | | PARSE_CHUNKED_CHUNK = 1 |
| | | PARSE_CHUNKED_CHUNK_EOF = 2 |
| | | PARSE_MAYBE_TRAILERS = 3 |
| | | PARSE_TRAILERS = 4 |
| | | |
| | | |
| | | class HeadersParser: |
| | | def __init__( |
| | | self, |
| | | max_line_size: int = 8190, |
| | | max_headers: int = 32768, |
| | | max_field_size: int = 8190, |
| | | lax: bool = False, |
| | | ) -> None: |
| | | self.max_line_size = max_line_size |
| | | self.max_headers = max_headers |
| | | self.max_field_size = max_field_size |
| | | self._lax = lax |
| | | |
| | | def parse_headers( |
| | | self, lines: List[bytes] |
| | | ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]: |
| | | headers: CIMultiDict[str] = CIMultiDict() |
| | | # note: "raw" does not mean inclusion of OWS before/after the field value |
| | | raw_headers = [] |
| | | |
| | | lines_idx = 0 |
| | | line = lines[lines_idx] |
| | | line_count = len(lines) |
| | | |
| | | while line: |
| | | # Parse initial header name : value pair. |
| | | try: |
| | | bname, bvalue = line.split(b":", 1) |
| | | except ValueError: |
| | | raise InvalidHeader(line) from None |
| | | |
| | | if len(bname) == 0: |
| | | raise InvalidHeader(bname) |
| | | |
| | | # https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2 |
| | | if {bname[0], bname[-1]} & {32, 9}: # {" ", "\t"} |
| | | raise InvalidHeader(line) |
| | | |
| | | bvalue = bvalue.lstrip(b" \t") |
| | | if len(bname) > self.max_field_size: |
| | | raise LineTooLong( |
| | | "request header name {}".format( |
| | | bname.decode("utf8", "backslashreplace") |
| | | ), |
| | | str(self.max_field_size), |
| | | str(len(bname)), |
| | | ) |
| | | name = bname.decode("utf-8", "surrogateescape") |
| | | if not TOKENRE.fullmatch(name): |
| | | raise InvalidHeader(bname) |
| | | |
| | | header_length = len(bvalue) |
| | | |
| | | # next line |
| | | lines_idx += 1 |
| | | line = lines[lines_idx] |
| | | |
| | | # consume continuation lines |
| | | continuation = self._lax and line and line[0] in (32, 9) # (' ', '\t') |
| | | |
| | | # Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding |
| | | if continuation: |
| | | bvalue_lst = [bvalue] |
| | | while continuation: |
| | | header_length += len(line) |
| | | if header_length > self.max_field_size: |
| | | raise LineTooLong( |
| | | "request header field {}".format( |
| | | bname.decode("utf8", "backslashreplace") |
| | | ), |
| | | str(self.max_field_size), |
| | | str(header_length), |
| | | ) |
| | | bvalue_lst.append(line) |
| | | |
| | | # next line |
| | | lines_idx += 1 |
| | | if lines_idx < line_count: |
| | | line = lines[lines_idx] |
| | | if line: |
| | | continuation = line[0] in (32, 9) # (' ', '\t') |
| | | else: |
| | | line = b"" |
| | | break |
| | | bvalue = b"".join(bvalue_lst) |
| | | else: |
| | | if header_length > self.max_field_size: |
| | | raise LineTooLong( |
| | | "request header field {}".format( |
| | | bname.decode("utf8", "backslashreplace") |
| | | ), |
| | | str(self.max_field_size), |
| | | str(header_length), |
| | | ) |
| | | |
| | | bvalue = bvalue.strip(b" \t") |
| | | value = bvalue.decode("utf-8", "surrogateescape") |
| | | |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5 |
| | | if "\n" in value or "\r" in value or "\x00" in value: |
| | | raise InvalidHeader(bvalue) |
| | | |
| | | headers.add(name, value) |
| | | raw_headers.append((bname, bvalue)) |
| | | |
| | | return (CIMultiDictProxy(headers), tuple(raw_headers)) |
| | | |
| | | |
| | | def _is_supported_upgrade(headers: CIMultiDictProxy[str]) -> bool: |
| | | """Check if the upgrade header is supported.""" |
| | | u = headers.get(hdrs.UPGRADE, "") |
| | | # .lower() can transform non-ascii characters. |
| | | return u.isascii() and u.lower() in {"tcp", "websocket"} |
| | | |
| | | |
| | | class HttpParser(abc.ABC, Generic[_MsgT]): |
| | | lax: ClassVar[bool] = False |
| | | |
| | | def __init__( |
| | | self, |
| | | protocol: Optional[BaseProtocol] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | limit: int = 2**16, |
| | | max_line_size: int = 8190, |
| | | max_headers: int = 32768, |
| | | max_field_size: int = 8190, |
| | | timer: Optional[BaseTimerContext] = None, |
| | | code: Optional[int] = None, |
| | | method: Optional[str] = None, |
| | | payload_exception: Optional[Type[BaseException]] = None, |
| | | response_with_body: bool = True, |
| | | read_until_eof: bool = False, |
| | | auto_decompress: bool = True, |
| | | ) -> None: |
| | | self.protocol = protocol |
| | | self.loop = loop |
| | | self.max_line_size = max_line_size |
| | | self.max_headers = max_headers |
| | | self.max_field_size = max_field_size |
| | | self.timer = timer |
| | | self.code = code |
| | | self.method = method |
| | | self.payload_exception = payload_exception |
| | | self.response_with_body = response_with_body |
| | | self.read_until_eof = read_until_eof |
| | | |
| | | self._lines: List[bytes] = [] |
| | | self._tail = b"" |
| | | self._upgraded = False |
| | | self._payload = None |
| | | self._payload_parser: Optional[HttpPayloadParser] = None |
| | | self._auto_decompress = auto_decompress |
| | | self._limit = limit |
| | | self._headers_parser = HeadersParser( |
| | | max_line_size, max_headers, max_field_size, self.lax |
| | | ) |
| | | |
| | | @abc.abstractmethod |
| | | def parse_message(self, lines: List[bytes]) -> _MsgT: ... |
| | | |
| | | @abc.abstractmethod |
| | | def _is_chunked_te(self, te: str) -> bool: ... |
| | | |
| | | def feed_eof(self) -> Optional[_MsgT]: |
| | | if self._payload_parser is not None: |
| | | self._payload_parser.feed_eof() |
| | | self._payload_parser = None |
| | | else: |
| | | # try to extract partial message |
| | | if self._tail: |
| | | self._lines.append(self._tail) |
| | | |
| | | if self._lines: |
| | | if self._lines[-1] != "\r\n": |
| | | self._lines.append(b"") |
| | | with suppress(Exception): |
| | | return self.parse_message(self._lines) |
| | | return None |
| | | |
| | | def feed_data( |
| | | self, |
| | | data: bytes, |
| | | SEP: _SEP = b"\r\n", |
| | | EMPTY: bytes = b"", |
| | | CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, |
| | | METH_CONNECT: str = hdrs.METH_CONNECT, |
| | | SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, |
| | | ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: |
| | | |
| | | messages = [] |
| | | |
| | | if self._tail: |
| | | data, self._tail = self._tail + data, b"" |
| | | |
| | | data_len = len(data) |
| | | start_pos = 0 |
| | | loop = self.loop |
| | | |
| | | should_close = False |
| | | while start_pos < data_len: |
| | | |
| | | # read HTTP message (request/response line + headers), \r\n\r\n |
| | | # and split by lines |
| | | if self._payload_parser is None and not self._upgraded: |
| | | pos = data.find(SEP, start_pos) |
| | | # consume \r\n |
| | | if pos == start_pos and not self._lines: |
| | | start_pos = pos + len(SEP) |
| | | continue |
| | | |
| | | if pos >= start_pos: |
| | | if should_close: |
| | | raise BadHttpMessage("Data after `Connection: close`") |
| | | |
| | | # line found |
| | | line = data[start_pos:pos] |
| | | if SEP == b"\n": # For lax response parsing |
| | | line = line.rstrip(b"\r") |
| | | self._lines.append(line) |
| | | start_pos = pos + len(SEP) |
| | | |
| | | # \r\n\r\n found |
| | | if self._lines[-1] == EMPTY: |
| | | try: |
| | | msg: _MsgT = self.parse_message(self._lines) |
| | | finally: |
| | | self._lines.clear() |
| | | |
| | | def get_content_length() -> Optional[int]: |
| | | # payload length |
| | | length_hdr = msg.headers.get(CONTENT_LENGTH) |
| | | if length_hdr is None: |
| | | return None |
| | | |
| | | # Shouldn't allow +/- or other number formats. |
| | | # https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2 |
| | | # msg.headers is already stripped of leading/trailing wsp |
| | | if not DIGITS.fullmatch(length_hdr): |
| | | raise InvalidHeader(CONTENT_LENGTH) |
| | | |
| | | return int(length_hdr) |
| | | |
| | | length = get_content_length() |
| | | # do not support old websocket spec |
| | | if SEC_WEBSOCKET_KEY1 in msg.headers: |
| | | raise InvalidHeader(SEC_WEBSOCKET_KEY1) |
| | | |
| | | self._upgraded = msg.upgrade and _is_supported_upgrade( |
| | | msg.headers |
| | | ) |
| | | |
| | | method = getattr(msg, "method", self.method) |
| | | # code is only present on responses |
| | | code = getattr(msg, "code", 0) |
| | | |
| | | assert self.protocol is not None |
| | | # calculate payload |
| | | empty_body = code in EMPTY_BODY_STATUS_CODES or bool( |
| | | method and method in EMPTY_BODY_METHODS |
| | | ) |
| | | if not empty_body and ( |
| | | ((length is not None and length > 0) or msg.chunked) |
| | | and not self._upgraded |
| | | ): |
| | | payload = StreamReader( |
| | | self.protocol, |
| | | timer=self.timer, |
| | | loop=loop, |
| | | limit=self._limit, |
| | | ) |
| | | payload_parser = HttpPayloadParser( |
| | | payload, |
| | | length=length, |
| | | chunked=msg.chunked, |
| | | method=method, |
| | | compression=msg.compression, |
| | | code=self.code, |
| | | response_with_body=self.response_with_body, |
| | | auto_decompress=self._auto_decompress, |
| | | lax=self.lax, |
| | | headers_parser=self._headers_parser, |
| | | ) |
| | | if not payload_parser.done: |
| | | self._payload_parser = payload_parser |
| | | elif method == METH_CONNECT: |
| | | assert isinstance(msg, RawRequestMessage) |
| | | payload = StreamReader( |
| | | self.protocol, |
| | | timer=self.timer, |
| | | loop=loop, |
| | | limit=self._limit, |
| | | ) |
| | | self._upgraded = True |
| | | self._payload_parser = HttpPayloadParser( |
| | | payload, |
| | | method=msg.method, |
| | | compression=msg.compression, |
| | | auto_decompress=self._auto_decompress, |
| | | lax=self.lax, |
| | | headers_parser=self._headers_parser, |
| | | ) |
| | | elif not empty_body and length is None and self.read_until_eof: |
| | | payload = StreamReader( |
| | | self.protocol, |
| | | timer=self.timer, |
| | | loop=loop, |
| | | limit=self._limit, |
| | | ) |
| | | payload_parser = HttpPayloadParser( |
| | | payload, |
| | | length=length, |
| | | chunked=msg.chunked, |
| | | method=method, |
| | | compression=msg.compression, |
| | | code=self.code, |
| | | response_with_body=self.response_with_body, |
| | | auto_decompress=self._auto_decompress, |
| | | lax=self.lax, |
| | | headers_parser=self._headers_parser, |
| | | ) |
| | | if not payload_parser.done: |
| | | self._payload_parser = payload_parser |
| | | else: |
| | | payload = EMPTY_PAYLOAD |
| | | |
| | | messages.append((msg, payload)) |
| | | should_close = msg.should_close |
| | | else: |
| | | self._tail = data[start_pos:] |
| | | data = EMPTY |
| | | break |
| | | |
| | | # no parser, just store |
| | | elif self._payload_parser is None and self._upgraded: |
| | | assert not self._lines |
| | | break |
| | | |
| | | # feed payload |
| | | elif data and start_pos < data_len: |
| | | assert not self._lines |
| | | assert self._payload_parser is not None |
| | | try: |
| | | eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) |
| | | except BaseException as underlying_exc: |
| | | reraised_exc = underlying_exc |
| | | if self.payload_exception is not None: |
| | | reraised_exc = self.payload_exception(str(underlying_exc)) |
| | | |
| | | set_exception( |
| | | self._payload_parser.payload, |
| | | reraised_exc, |
| | | underlying_exc, |
| | | ) |
| | | |
| | | eof = True |
| | | data = b"" |
| | | if isinstance( |
| | | underlying_exc, (InvalidHeader, TransferEncodingError) |
| | | ): |
| | | raise |
| | | |
| | | if eof: |
| | | start_pos = 0 |
| | | data_len = len(data) |
| | | self._payload_parser = None |
| | | continue |
| | | else: |
| | | break |
| | | |
| | | if data and start_pos < data_len: |
| | | data = data[start_pos:] |
| | | else: |
| | | data = EMPTY |
| | | |
| | | return messages, self._upgraded, data |
| | | |
| | | def parse_headers( |
| | | self, lines: List[bytes] |
| | | ) -> Tuple[ |
| | | "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool |
| | | ]: |
| | | """Parses RFC 5322 headers from a stream. |
| | | |
| | | Line continuations are supported. Returns list of header name |
| | | and value pairs. Header name is in upper case. |
| | | """ |
| | | headers, raw_headers = self._headers_parser.parse_headers(lines) |
| | | close_conn = None |
| | | encoding = None |
| | | upgrade = False |
| | | chunked = False |
| | | |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6 |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf |
| | | singletons = ( |
| | | hdrs.CONTENT_LENGTH, |
| | | hdrs.CONTENT_LOCATION, |
| | | hdrs.CONTENT_RANGE, |
| | | hdrs.CONTENT_TYPE, |
| | | hdrs.ETAG, |
| | | hdrs.HOST, |
| | | hdrs.MAX_FORWARDS, |
| | | hdrs.SERVER, |
| | | hdrs.TRANSFER_ENCODING, |
| | | hdrs.USER_AGENT, |
| | | ) |
| | | bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None) |
| | | if bad_hdr is not None: |
| | | raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.") |
| | | |
| | | # keep-alive |
| | | conn = headers.get(hdrs.CONNECTION) |
| | | if conn: |
| | | v = conn.lower() |
| | | if v == "close": |
| | | close_conn = True |
| | | elif v == "keep-alive": |
| | | close_conn = False |
| | | # https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols |
| | | elif v == "upgrade" and headers.get(hdrs.UPGRADE): |
| | | upgrade = True |
| | | |
| | | # encoding |
| | | enc = headers.get(hdrs.CONTENT_ENCODING, "") |
| | | if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}: |
| | | encoding = enc |
| | | |
| | | # chunking |
| | | te = headers.get(hdrs.TRANSFER_ENCODING) |
| | | if te is not None: |
| | | if self._is_chunked_te(te): |
| | | chunked = True |
| | | |
| | | if hdrs.CONTENT_LENGTH in headers: |
| | | raise BadHttpMessage( |
| | | "Transfer-Encoding can't be present with Content-Length", |
| | | ) |
| | | |
| | | return (headers, raw_headers, close_conn, encoding, upgrade, chunked) |
| | | |
| | | def set_upgraded(self, val: bool) -> None: |
| | | """Set connection upgraded (to websocket) mode. |
| | | |
| | | :param bool val: new state. |
| | | """ |
| | | self._upgraded = val |
| | | |
| | | |
| | | class HttpRequestParser(HttpParser[RawRequestMessage]): |
| | | """Read request status line. |
| | | |
| | | Exception .http_exceptions.BadStatusLine |
| | | could be raised in case of any errors in status line. |
| | | Returns RawRequestMessage. |
| | | """ |
| | | |
| | | def parse_message(self, lines: List[bytes]) -> RawRequestMessage: |
| | | # request line |
| | | line = lines[0].decode("utf-8", "surrogateescape") |
| | | try: |
| | | method, path, version = line.split(" ", maxsplit=2) |
| | | except ValueError: |
| | | raise BadHttpMethod(line) from None |
| | | |
| | | if len(path) > self.max_line_size: |
| | | raise LineTooLong( |
| | | "Status line is too long", str(self.max_line_size), str(len(path)) |
| | | ) |
| | | |
| | | # method |
| | | if not TOKENRE.fullmatch(method): |
| | | raise BadHttpMethod(method) |
| | | |
| | | # version |
| | | match = VERSRE.fullmatch(version) |
| | | if match is None: |
| | | raise BadStatusLine(line) |
| | | version_o = HttpVersion(int(match.group(1)), int(match.group(2))) |
| | | |
| | | if method == "CONNECT": |
| | | # authority-form, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 |
| | | url = URL.build(authority=path, encoded=True) |
| | | elif path.startswith("/"): |
| | | # origin-form, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 |
| | | path_part, _hash_separator, url_fragment = path.partition("#") |
| | | path_part, _question_mark_separator, qs_part = path_part.partition("?") |
| | | |
| | | # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based |
| | | # NOTE: parser does, otherwise it results into the same |
| | | # NOTE: HTTP Request-Line input producing different |
| | | # NOTE: `yarl.URL()` objects |
| | | url = URL.build( |
| | | path=path_part, |
| | | query_string=qs_part, |
| | | fragment=url_fragment, |
| | | encoded=True, |
| | | ) |
| | | elif path == "*" and method == "OPTIONS": |
| | | # asterisk-form, |
| | | url = URL(path, encoded=True) |
| | | else: |
| | | # absolute-form for proxy maybe, |
| | | # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 |
| | | url = URL(path, encoded=True) |
| | | if url.scheme == "": |
| | | # not absolute-form |
| | | raise InvalidURLError( |
| | | path.encode(errors="surrogateescape").decode("latin1") |
| | | ) |
| | | |
| | | # read headers |
| | | ( |
| | | headers, |
| | | raw_headers, |
| | | close, |
| | | compression, |
| | | upgrade, |
| | | chunked, |
| | | ) = self.parse_headers(lines[1:]) |
| | | |
| | | if close is None: # then the headers weren't set in the request |
| | | if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close |
| | | close = True |
| | | else: # HTTP 1.1 must ask to close. |
| | | close = False |
| | | |
| | | return RawRequestMessage( |
| | | method, |
| | | path, |
| | | version_o, |
| | | headers, |
| | | raw_headers, |
| | | close, |
| | | compression, |
| | | upgrade, |
| | | chunked, |
| | | url, |
| | | ) |
| | | |
| | | def _is_chunked_te(self, te: str) -> bool: |
| | | te = te.rsplit(",", maxsplit=1)[-1].strip(" \t") |
| | | # .lower() transforms some non-ascii chars, so must check first. |
| | | if te.isascii() and te.lower() == "chunked": |
| | | return True |
| | | # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 |
| | | raise BadHttpMessage("Request has invalid `Transfer-Encoding`") |
| | | |
| | | |
| | | class HttpResponseParser(HttpParser[RawResponseMessage]): |
| | | """Read response status line and headers. |
| | | |
| | | BadStatusLine could be raised in case of any errors in status line. |
| | | Returns RawResponseMessage. |
| | | """ |
| | | |
| | | # Lax mode should only be enabled on response parser. |
| | | lax = not DEBUG |
| | | |
| | | def feed_data( |
| | | self, |
| | | data: bytes, |
| | | SEP: Optional[_SEP] = None, |
| | | *args: Any, |
| | | **kwargs: Any, |
| | | ) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]: |
| | | if SEP is None: |
| | | SEP = b"\r\n" if DEBUG else b"\n" |
| | | return super().feed_data(data, SEP, *args, **kwargs) |
| | | |
| | | def parse_message(self, lines: List[bytes]) -> RawResponseMessage: |
| | | line = lines[0].decode("utf-8", "surrogateescape") |
| | | try: |
| | | version, status = line.split(maxsplit=1) |
| | | except ValueError: |
| | | raise BadStatusLine(line) from None |
| | | |
| | | try: |
| | | status, reason = status.split(maxsplit=1) |
| | | except ValueError: |
| | | status = status.strip() |
| | | reason = "" |
| | | |
| | | if len(reason) > self.max_line_size: |
| | | raise LineTooLong( |
| | | "Status line is too long", str(self.max_line_size), str(len(reason)) |
| | | ) |
| | | |
| | | # version |
| | | match = VERSRE.fullmatch(version) |
| | | if match is None: |
| | | raise BadStatusLine(line) |
| | | version_o = HttpVersion(int(match.group(1)), int(match.group(2))) |
| | | |
| | | # The status code is a three-digit ASCII number, no padding |
| | | if len(status) != 3 or not DIGITS.fullmatch(status): |
| | | raise BadStatusLine(line) |
| | | status_i = int(status) |
| | | |
| | | # read headers |
| | | ( |
| | | headers, |
| | | raw_headers, |
| | | close, |
| | | compression, |
| | | upgrade, |
| | | chunked, |
| | | ) = self.parse_headers(lines[1:]) |
| | | |
| | | if close is None: |
| | | if version_o <= HttpVersion10: |
| | | close = True |
| | | # https://www.rfc-editor.org/rfc/rfc9112.html#name-message-body-length |
| | | elif 100 <= status_i < 200 or status_i in {204, 304}: |
| | | close = False |
| | | elif hdrs.CONTENT_LENGTH in headers or hdrs.TRANSFER_ENCODING in headers: |
| | | close = False |
| | | else: |
| | | # https://www.rfc-editor.org/rfc/rfc9112.html#section-6.3-2.8 |
| | | close = True |
| | | |
| | | return RawResponseMessage( |
| | | version_o, |
| | | status_i, |
| | | reason.strip(), |
| | | headers, |
| | | raw_headers, |
| | | close, |
| | | compression, |
| | | upgrade, |
| | | chunked, |
| | | ) |
| | | |
| | | def _is_chunked_te(self, te: str) -> bool: |
| | | # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2 |
| | | return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked" |
| | | |
| | | |
| | | class HttpPayloadParser: |
| | | def __init__( |
| | | self, |
| | | payload: StreamReader, |
| | | length: Optional[int] = None, |
| | | chunked: bool = False, |
| | | compression: Optional[str] = None, |
| | | code: Optional[int] = None, |
| | | method: Optional[str] = None, |
| | | response_with_body: bool = True, |
| | | auto_decompress: bool = True, |
| | | lax: bool = False, |
| | | *, |
| | | headers_parser: HeadersParser, |
| | | ) -> None: |
| | | self._length = 0 |
| | | self._type = ParseState.PARSE_UNTIL_EOF |
| | | self._chunk = ChunkState.PARSE_CHUNKED_SIZE |
| | | self._chunk_size = 0 |
| | | self._chunk_tail = b"" |
| | | self._auto_decompress = auto_decompress |
| | | self._lax = lax |
| | | self._headers_parser = headers_parser |
| | | self._trailer_lines: list[bytes] = [] |
| | | self.done = False |
| | | |
| | | # payload decompression wrapper |
| | | if response_with_body and compression and self._auto_decompress: |
| | | real_payload: Union[StreamReader, DeflateBuffer] = DeflateBuffer( |
| | | payload, compression |
| | | ) |
| | | else: |
| | | real_payload = payload |
| | | |
| | | # payload parser |
| | | if not response_with_body: |
| | | # don't parse payload if it's not expected to be received |
| | | self._type = ParseState.PARSE_NONE |
| | | real_payload.feed_eof() |
| | | self.done = True |
| | | elif chunked: |
| | | self._type = ParseState.PARSE_CHUNKED |
| | | elif length is not None: |
| | | self._type = ParseState.PARSE_LENGTH |
| | | self._length = length |
| | | if self._length == 0: |
| | | real_payload.feed_eof() |
| | | self.done = True |
| | | |
| | | self.payload = real_payload |
| | | |
| | | def feed_eof(self) -> None: |
| | | if self._type == ParseState.PARSE_UNTIL_EOF: |
| | | self.payload.feed_eof() |
| | | elif self._type == ParseState.PARSE_LENGTH: |
| | | raise ContentLengthError( |
| | | "Not enough data to satisfy content length header." |
| | | ) |
| | | elif self._type == ParseState.PARSE_CHUNKED: |
| | | raise TransferEncodingError( |
| | | "Not enough data to satisfy transfer length header." |
| | | ) |
| | | |
| | | def feed_data( |
| | | self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" |
| | | ) -> Tuple[bool, bytes]: |
| | | # Read specified amount of bytes |
| | | if self._type == ParseState.PARSE_LENGTH: |
| | | required = self._length |
| | | chunk_len = len(chunk) |
| | | |
| | | if required >= chunk_len: |
| | | self._length = required - chunk_len |
| | | self.payload.feed_data(chunk, chunk_len) |
| | | if self._length == 0: |
| | | self.payload.feed_eof() |
| | | return True, b"" |
| | | else: |
| | | self._length = 0 |
| | | self.payload.feed_data(chunk[:required], required) |
| | | self.payload.feed_eof() |
| | | return True, chunk[required:] |
| | | |
| | | # Chunked transfer encoding parser |
| | | elif self._type == ParseState.PARSE_CHUNKED: |
| | | if self._chunk_tail: |
| | | chunk = self._chunk_tail + chunk |
| | | self._chunk_tail = b"" |
| | | |
| | | while chunk: |
| | | |
| | | # read next chunk size |
| | | if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: |
| | | pos = chunk.find(SEP) |
| | | if pos >= 0: |
| | | i = chunk.find(CHUNK_EXT, 0, pos) |
| | | if i >= 0: |
| | | size_b = chunk[:i] # strip chunk-extensions |
| | | # Verify no LF in the chunk-extension |
| | | if b"\n" in (ext := chunk[i:pos]): |
| | | exc = TransferEncodingError( |
| | | f"Unexpected LF in chunk-extension: {ext!r}" |
| | | ) |
| | | set_exception(self.payload, exc) |
| | | raise exc |
| | | else: |
| | | size_b = chunk[:pos] |
| | | |
| | | if self._lax: # Allow whitespace in lax mode. |
| | | size_b = size_b.strip() |
| | | |
| | | if not re.fullmatch(HEXDIGITS, size_b): |
| | | exc = TransferEncodingError( |
| | | chunk[:pos].decode("ascii", "surrogateescape") |
| | | ) |
| | | set_exception(self.payload, exc) |
| | | raise exc |
| | | size = int(bytes(size_b), 16) |
| | | |
| | | chunk = chunk[pos + len(SEP) :] |
| | | if size == 0: # eof marker |
| | | self._chunk = ChunkState.PARSE_TRAILERS |
| | | if self._lax and chunk.startswith(b"\r"): |
| | | chunk = chunk[1:] |
| | | else: |
| | | self._chunk = ChunkState.PARSE_CHUNKED_CHUNK |
| | | self._chunk_size = size |
| | | self.payload.begin_http_chunk_receiving() |
| | | else: |
| | | self._chunk_tail = chunk |
| | | return False, b"" |
| | | |
| | | # read chunk and feed buffer |
| | | if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: |
| | | required = self._chunk_size |
| | | chunk_len = len(chunk) |
| | | |
| | | if required > chunk_len: |
| | | self._chunk_size = required - chunk_len |
| | | self.payload.feed_data(chunk, chunk_len) |
| | | return False, b"" |
| | | else: |
| | | self._chunk_size = 0 |
| | | self.payload.feed_data(chunk[:required], required) |
| | | chunk = chunk[required:] |
| | | self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF |
| | | self.payload.end_http_chunk_receiving() |
| | | |
| | | # toss the CRLF at the end of the chunk |
| | | if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: |
| | | if self._lax and chunk.startswith(b"\r"): |
| | | chunk = chunk[1:] |
| | | if chunk[: len(SEP)] == SEP: |
| | | chunk = chunk[len(SEP) :] |
| | | self._chunk = ChunkState.PARSE_CHUNKED_SIZE |
| | | else: |
| | | self._chunk_tail = chunk |
| | | return False, b"" |
| | | |
| | | if self._chunk == ChunkState.PARSE_TRAILERS: |
| | | pos = chunk.find(SEP) |
| | | if pos < 0: # No line found |
| | | self._chunk_tail = chunk |
| | | return False, b"" |
| | | |
| | | line = chunk[:pos] |
| | | chunk = chunk[pos + len(SEP) :] |
| | | if SEP == b"\n": # For lax response parsing |
| | | line = line.rstrip(b"\r") |
| | | self._trailer_lines.append(line) |
| | | |
| | | # \r\n\r\n found, end of stream |
| | | if self._trailer_lines[-1] == b"": |
| | | # Headers and trailers are defined the same way, |
| | | # so we reuse the HeadersParser here. |
| | | try: |
| | | trailers, raw_trailers = self._headers_parser.parse_headers( |
| | | self._trailer_lines |
| | | ) |
| | | finally: |
| | | self._trailer_lines.clear() |
| | | self.payload.feed_eof() |
| | | return True, chunk |
| | | |
| | | # Read all bytes until eof |
| | | elif self._type == ParseState.PARSE_UNTIL_EOF: |
| | | self.payload.feed_data(chunk, len(chunk)) |
| | | |
| | | return False, b"" |
| | | |
| | | |
| | | class DeflateBuffer: |
| | | """DeflateStream decompress stream and feed data into specified stream.""" |
| | | |
| | | decompressor: Any |
| | | |
| | | def __init__( |
| | | self, |
| | | out: StreamReader, |
| | | encoding: Optional[str], |
| | | max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, |
| | | ) -> None: |
| | | self.out = out |
| | | self.size = 0 |
| | | out.total_compressed_bytes = self.size |
| | | self.encoding = encoding |
| | | self._started_decoding = False |
| | | |
| | | self.decompressor: Union[BrotliDecompressor, ZLibDecompressor, ZSTDDecompressor] |
| | | if encoding == "br": |
| | | if not HAS_BROTLI: # pragma: no cover |
| | | raise ContentEncodingError( |
| | | "Can not decode content-encoding: brotli (br). " |
| | | "Please install `Brotli`" |
| | | ) |
| | | self.decompressor = BrotliDecompressor() |
| | | elif encoding == "zstd": |
| | | if not HAS_ZSTD: |
| | | raise ContentEncodingError( |
| | | "Can not decode content-encoding: zstandard (zstd). " |
| | | "Please install `backports.zstd`" |
| | | ) |
| | | self.decompressor = ZSTDDecompressor() |
| | | else: |
| | | self.decompressor = ZLibDecompressor(encoding=encoding) |
| | | |
| | | self._max_decompress_size = max_decompress_size |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | set_exception(self.out, exc, exc_cause) |
| | | |
| | | def feed_data(self, chunk: bytes, size: int) -> None: |
| | | if not size: |
| | | return |
| | | |
| | | self.size += size |
| | | self.out.total_compressed_bytes = self.size |
| | | |
| | | # RFC1950 |
| | | # bits 0..3 = CM = 0b1000 = 8 = "deflate" |
| | | # bits 4..7 = CINFO = 1..7 = windows size. |
| | | if ( |
| | | not self._started_decoding |
| | | and self.encoding == "deflate" |
| | | and chunk[0] & 0xF != 8 |
| | | ): |
| | | # Change the decoder to decompress incorrectly compressed data |
| | | # Actually we should issue a warning about non-RFC-compliant data. |
| | | self.decompressor = ZLibDecompressor( |
| | | encoding=self.encoding, suppress_deflate_header=True |
| | | ) |
| | | |
| | | try: |
| | | # Decompress with limit + 1 so we can detect if output exceeds limit |
| | | chunk = self.decompressor.decompress_sync( |
| | | chunk, max_length=self._max_decompress_size + 1 |
| | | ) |
| | | except Exception: |
| | | raise ContentEncodingError( |
| | | "Can not decode content-encoding: %s" % self.encoding |
| | | ) |
| | | |
| | | self._started_decoding = True |
| | | |
| | | # Check if decompression limit was exceeded |
| | | if len(chunk) > self._max_decompress_size: |
| | | raise DecompressSizeError( |
| | | "Decompressed data exceeds the configured limit of %d bytes" |
| | | % self._max_decompress_size |
| | | ) |
| | | |
| | | if chunk: |
| | | self.out.feed_data(chunk, len(chunk)) |
| | | |
| | | def feed_eof(self) -> None: |
| | | chunk = self.decompressor.flush() |
| | | |
| | | if chunk or self.size > 0: |
| | | self.out.feed_data(chunk, len(chunk)) |
| | | if self.encoding == "deflate" and not self.decompressor.eof: |
| | | raise ContentEncodingError("deflate") |
| | | |
| | | self.out.feed_eof() |
| | | |
| | | def begin_http_chunk_receiving(self) -> None: |
| | | self.out.begin_http_chunk_receiving() |
| | | |
| | | def end_http_chunk_receiving(self) -> None: |
| | | self.out.end_http_chunk_receiving() |
| | | |
| | | |
| | | HttpRequestParserPy = HttpRequestParser |
| | | HttpResponseParserPy = HttpResponseParser |
| | | RawRequestMessagePy = RawRequestMessage |
| | | RawResponseMessagePy = RawResponseMessage |
| | | |
| | | try: |
| | | if not NO_EXTENSIONS: |
| | | from ._http_parser import ( # type: ignore[import-not-found,no-redef] |
| | | HttpRequestParser, |
| | | HttpResponseParser, |
| | | RawRequestMessage, |
| | | RawResponseMessage, |
| | | ) |
| | | |
| | | HttpRequestParserC = HttpRequestParser |
| | | HttpResponseParserC = HttpResponseParser |
| | | RawRequestMessageC = RawRequestMessage |
| | | RawResponseMessageC = RawResponseMessage |
| | | except ImportError: # pragma: no cover |
| | | pass |
| New file |
| | |
| | | """WebSocket protocol versions 13 and 8.""" |
| | | |
| | | from ._websocket.helpers import WS_KEY, ws_ext_gen, ws_ext_parse |
| | | from ._websocket.models import ( |
| | | WS_CLOSED_MESSAGE, |
| | | WS_CLOSING_MESSAGE, |
| | | WebSocketError, |
| | | WSCloseCode, |
| | | WSHandshakeError, |
| | | WSMessage, |
| | | WSMsgType, |
| | | ) |
| | | from ._websocket.reader import WebSocketReader |
| | | from ._websocket.writer import WebSocketWriter |
| | | |
| | | # Messages that the WebSocketResponse.receive needs to handle internally |
| | | _INTERNAL_RECEIVE_TYPES = frozenset( |
| | | (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.PING, WSMsgType.PONG) |
| | | ) |
| | | |
| | | |
| | | __all__ = ( |
| | | "WS_CLOSED_MESSAGE", |
| | | "WS_CLOSING_MESSAGE", |
| | | "WS_KEY", |
| | | "WebSocketReader", |
| | | "WebSocketWriter", |
| | | "WSMessage", |
| | | "WebSocketError", |
| | | "WSMsgType", |
| | | "WSCloseCode", |
| | | "ws_ext_gen", |
| | | "ws_ext_parse", |
| | | "WSHandshakeError", |
| | | "WSMessage", |
| | | ) |
| New file |
| | |
| | | """Http related parsers and protocol.""" |
| | | |
| | | import asyncio |
| | | import sys |
| | | from typing import ( # noqa |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | Iterable, |
| | | List, |
| | | NamedTuple, |
| | | Optional, |
| | | Union, |
| | | ) |
| | | |
| | | from multidict import CIMultiDict |
| | | |
| | | from .abc import AbstractStreamWriter |
| | | from .base_protocol import BaseProtocol |
| | | from .client_exceptions import ClientConnectionResetError |
| | | from .compression_utils import ZLibCompressor |
| | | from .helpers import NO_EXTENSIONS |
| | | |
| | | __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") |
| | | |
| | | |
| | | MIN_PAYLOAD_FOR_WRITELINES = 2048 |
| | | IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2) |
| | | IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9) |
| | | SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9 |
| | | # writelines is not safe for use |
| | | # on Python 3.12+ until 3.12.9 |
| | | # on Python 3.13+ until 3.13.2 |
| | | # and on older versions it not any faster than write |
| | | # CVE-2024-12254: https://github.com/python/cpython/pull/127656 |
| | | |
| | | |
| | | class HttpVersion(NamedTuple): |
| | | major: int |
| | | minor: int |
| | | |
| | | |
| | | HttpVersion10 = HttpVersion(1, 0) |
| | | HttpVersion11 = HttpVersion(1, 1) |
| | | |
| | | |
| | | _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] |
| | | _T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] |
| | | |
| | | |
| | | class StreamWriter(AbstractStreamWriter): |
| | | |
| | | length: Optional[int] = None |
| | | chunked: bool = False |
| | | _eof: bool = False |
| | | _compress: Optional[ZLibCompressor] = None |
| | | |
| | | def __init__( |
| | | self, |
| | | protocol: BaseProtocol, |
| | | loop: asyncio.AbstractEventLoop, |
| | | on_chunk_sent: _T_OnChunkSent = None, |
| | | on_headers_sent: _T_OnHeadersSent = None, |
| | | ) -> None: |
| | | self._protocol = protocol |
| | | self.loop = loop |
| | | self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent |
| | | self._on_headers_sent: _T_OnHeadersSent = on_headers_sent |
| | | self._headers_buf: Optional[bytes] = None |
| | | self._headers_written: bool = False |
| | | |
| | | @property |
| | | def transport(self) -> Optional[asyncio.Transport]: |
| | | return self._protocol.transport |
| | | |
| | | @property |
| | | def protocol(self) -> BaseProtocol: |
| | | return self._protocol |
| | | |
| | | def enable_chunking(self) -> None: |
| | | self.chunked = True |
| | | |
| | | def enable_compression( |
| | | self, encoding: str = "deflate", strategy: Optional[int] = None |
| | | ) -> None: |
| | | self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) |
| | | |
| | | def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None: |
| | | size = len(chunk) |
| | | self.buffer_size += size |
| | | self.output_size += size |
| | | transport = self._protocol.transport |
| | | if transport is None or transport.is_closing(): |
| | | raise ClientConnectionResetError("Cannot write to closing transport") |
| | | transport.write(chunk) |
| | | |
| | | def _writelines(self, chunks: Iterable[bytes]) -> None: |
| | | size = 0 |
| | | for chunk in chunks: |
| | | size += len(chunk) |
| | | self.buffer_size += size |
| | | self.output_size += size |
| | | transport = self._protocol.transport |
| | | if transport is None or transport.is_closing(): |
| | | raise ClientConnectionResetError("Cannot write to closing transport") |
| | | if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES: |
| | | transport.write(b"".join(chunks)) |
| | | else: |
| | | transport.writelines(chunks) |
| | | |
| | | def _write_chunked_payload( |
| | | self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] |
| | | ) -> None: |
| | | """Write a chunk with proper chunked encoding.""" |
| | | chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") |
| | | self._writelines((chunk_len_pre, chunk, b"\r\n")) |
| | | |
| | | def _send_headers_with_payload( |
| | | self, |
| | | chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"], |
| | | is_eof: bool, |
| | | ) -> None: |
| | | """Send buffered headers with payload, coalescing into single write.""" |
| | | # Mark headers as written |
| | | self._headers_written = True |
| | | headers_buf = self._headers_buf |
| | | self._headers_buf = None |
| | | |
| | | if TYPE_CHECKING: |
| | | # Safe because callers (write() and write_eof()) only invoke this method |
| | | # after checking that self._headers_buf is truthy |
| | | assert headers_buf is not None |
| | | |
| | | if not self.chunked: |
| | | # Non-chunked: coalesce headers with body |
| | | if chunk: |
| | | self._writelines((headers_buf, chunk)) |
| | | else: |
| | | self._write(headers_buf) |
| | | return |
| | | |
| | | # Coalesce headers with chunked data |
| | | if chunk: |
| | | chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") |
| | | if is_eof: |
| | | self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n")) |
| | | else: |
| | | self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n")) |
| | | elif is_eof: |
| | | self._writelines((headers_buf, b"0\r\n\r\n")) |
| | | else: |
| | | self._write(headers_buf) |
| | | |
| | | async def write( |
| | | self, |
| | | chunk: Union[bytes, bytearray, memoryview], |
| | | *, |
| | | drain: bool = True, |
| | | LIMIT: int = 0x10000, |
| | | ) -> None: |
| | | """ |
| | | Writes chunk of data to a stream. |
| | | |
| | | write_eof() indicates end of stream. |
| | | writer can't be used after write_eof() method being called. |
| | | write() return drain future. |
| | | """ |
| | | if self._on_chunk_sent is not None: |
| | | await self._on_chunk_sent(chunk) |
| | | |
| | | if isinstance(chunk, memoryview): |
| | | if chunk.nbytes != len(chunk): |
| | | # just reshape it |
| | | chunk = chunk.cast("c") |
| | | |
| | | if self._compress is not None: |
| | | chunk = await self._compress.compress(chunk) |
| | | if not chunk: |
| | | return |
| | | |
| | | if self.length is not None: |
| | | chunk_len = len(chunk) |
| | | if self.length >= chunk_len: |
| | | self.length = self.length - chunk_len |
| | | else: |
| | | chunk = chunk[: self.length] |
| | | self.length = 0 |
| | | if not chunk: |
| | | return |
| | | |
| | | # Handle buffered headers for small payload optimization |
| | | if self._headers_buf and not self._headers_written: |
| | | self._send_headers_with_payload(chunk, False) |
| | | if drain and self.buffer_size > LIMIT: |
| | | self.buffer_size = 0 |
| | | await self.drain() |
| | | return |
| | | |
| | | if chunk: |
| | | if self.chunked: |
| | | self._write_chunked_payload(chunk) |
| | | else: |
| | | self._write(chunk) |
| | | |
| | | if drain and self.buffer_size > LIMIT: |
| | | self.buffer_size = 0 |
| | | await self.drain() |
| | | |
| | | async def write_headers( |
| | | self, status_line: str, headers: "CIMultiDict[str]" |
| | | ) -> None: |
| | | """Write headers to the stream.""" |
| | | if self._on_headers_sent is not None: |
| | | await self._on_headers_sent(headers) |
| | | # status + headers |
| | | buf = _serialize_headers(status_line, headers) |
| | | self._headers_written = False |
| | | self._headers_buf = buf |
| | | |
| | | def send_headers(self) -> None: |
| | | """Force sending buffered headers if not already sent.""" |
| | | if not self._headers_buf or self._headers_written: |
| | | return |
| | | |
| | | self._headers_written = True |
| | | headers_buf = self._headers_buf |
| | | self._headers_buf = None |
| | | |
| | | if TYPE_CHECKING: |
| | | # Safe because we only enter this block when self._headers_buf is truthy |
| | | assert headers_buf is not None |
| | | |
| | | self._write(headers_buf) |
| | | |
| | | def set_eof(self) -> None: |
| | | """Indicate that the message is complete.""" |
| | | if self._eof: |
| | | return |
| | | |
| | | # If headers haven't been sent yet, send them now |
| | | # This handles the case where there's no body at all |
| | | if self._headers_buf and not self._headers_written: |
| | | self._headers_written = True |
| | | headers_buf = self._headers_buf |
| | | self._headers_buf = None |
| | | |
| | | if TYPE_CHECKING: |
| | | # Safe because we only enter this block when self._headers_buf is truthy |
| | | assert headers_buf is not None |
| | | |
| | | # Combine headers and chunked EOF marker in a single write |
| | | if self.chunked: |
| | | self._writelines((headers_buf, b"0\r\n\r\n")) |
| | | else: |
| | | self._write(headers_buf) |
| | | elif self.chunked and self._headers_written: |
| | | # Headers already sent, just send the final chunk marker |
| | | self._write(b"0\r\n\r\n") |
| | | |
| | | self._eof = True |
| | | |
| | | async def write_eof(self, chunk: bytes = b"") -> None: |
| | | if self._eof: |
| | | return |
| | | |
| | | if chunk and self._on_chunk_sent is not None: |
| | | await self._on_chunk_sent(chunk) |
| | | |
| | | # Handle body/compression |
| | | if self._compress: |
| | | chunks: List[bytes] = [] |
| | | chunks_len = 0 |
| | | if chunk and (compressed_chunk := await self._compress.compress(chunk)): |
| | | chunks_len = len(compressed_chunk) |
| | | chunks.append(compressed_chunk) |
| | | |
| | | flush_chunk = self._compress.flush() |
| | | chunks_len += len(flush_chunk) |
| | | chunks.append(flush_chunk) |
| | | assert chunks_len |
| | | |
| | | # Send buffered headers with compressed data if not yet sent |
| | | if self._headers_buf and not self._headers_written: |
| | | self._headers_written = True |
| | | headers_buf = self._headers_buf |
| | | self._headers_buf = None |
| | | |
| | | if self.chunked: |
| | | # Coalesce headers with compressed chunked data |
| | | chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") |
| | | self._writelines( |
| | | (headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n") |
| | | ) |
| | | else: |
| | | # Coalesce headers with compressed data |
| | | self._writelines((headers_buf, *chunks)) |
| | | await self.drain() |
| | | self._eof = True |
| | | return |
| | | |
| | | # Headers already sent, just write compressed data |
| | | if self.chunked: |
| | | chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") |
| | | self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")) |
| | | elif len(chunks) > 1: |
| | | self._writelines(chunks) |
| | | else: |
| | | self._write(chunks[0]) |
| | | await self.drain() |
| | | self._eof = True |
| | | return |
| | | |
| | | # No compression - send buffered headers if not yet sent |
| | | if self._headers_buf and not self._headers_written: |
| | | # Use helper to send headers with payload |
| | | self._send_headers_with_payload(chunk, True) |
| | | await self.drain() |
| | | self._eof = True |
| | | return |
| | | |
| | | # Handle remaining body |
| | | if self.chunked: |
| | | if chunk: |
| | | # Write final chunk with EOF marker |
| | | self._writelines( |
| | | (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n") |
| | | ) |
| | | else: |
| | | self._write(b"0\r\n\r\n") |
| | | await self.drain() |
| | | self._eof = True |
| | | return |
| | | |
| | | if chunk: |
| | | self._write(chunk) |
| | | await self.drain() |
| | | |
| | | self._eof = True |
| | | |
| | | async def drain(self) -> None: |
| | | """Flush the write buffer. |
| | | |
| | | The intended use is to write |
| | | |
| | | await w.write(data) |
| | | await w.drain() |
| | | """ |
| | | protocol = self._protocol |
| | | if protocol.transport is not None and protocol._paused: |
| | | await protocol._drain_helper() |
| | | |
| | | |
| | | def _safe_header(string: str) -> str: |
| | | if "\r" in string or "\n" in string: |
| | | raise ValueError( |
| | | "Newline or carriage return detected in headers. " |
| | | "Potential header injection attack." |
| | | ) |
| | | return string |
| | | |
| | | |
| | | def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: |
| | | headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items()) |
| | | line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n" |
| | | return line.encode("utf-8") |
| | | |
| | | |
| | | _serialize_headers = _py_serialize_headers |
| | | |
| | | try: |
| | | import aiohttp._http_writer as _http_writer # type: ignore[import-not-found] |
| | | |
| | | _c_serialize_headers = _http_writer._serialize_headers |
| | | if not NO_EXTENSIONS: |
| | | _serialize_headers = _c_serialize_headers |
| | | except ImportError: |
| | | pass |
| New file |
| | |
| | | import logging |
| | | |
| | | access_logger = logging.getLogger("aiohttp.access") |
| | | client_logger = logging.getLogger("aiohttp.client") |
| | | internal_logger = logging.getLogger("aiohttp.internal") |
| | | server_logger = logging.getLogger("aiohttp.server") |
| | | web_logger = logging.getLogger("aiohttp.web") |
| | | ws_logger = logging.getLogger("aiohttp.websocket") |
| New file |
| | |
| | | import base64 |
| | | import binascii |
| | | import json |
| | | import re |
| | | import sys |
| | | import uuid |
| | | import warnings |
| | | from collections import deque |
| | | from collections.abc import Mapping, Sequence |
| | | from types import TracebackType |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Deque, |
| | | Dict, |
| | | Iterator, |
| | | List, |
| | | Optional, |
| | | Tuple, |
| | | Type, |
| | | Union, |
| | | cast, |
| | | ) |
| | | from urllib.parse import parse_qsl, unquote, urlencode |
| | | |
| | | from multidict import CIMultiDict, CIMultiDictProxy |
| | | |
| | | from .abc import AbstractStreamWriter |
| | | from .compression_utils import ( |
| | | DEFAULT_MAX_DECOMPRESS_SIZE, |
| | | ZLibCompressor, |
| | | ZLibDecompressor, |
| | | ) |
| | | from .hdrs import ( |
| | | CONTENT_DISPOSITION, |
| | | CONTENT_ENCODING, |
| | | CONTENT_LENGTH, |
| | | CONTENT_TRANSFER_ENCODING, |
| | | CONTENT_TYPE, |
| | | ) |
| | | from .helpers import CHAR, TOKEN, parse_mimetype, reify |
| | | from .http import HeadersParser |
| | | from .log import internal_logger |
| | | from .payload import ( |
| | | JsonPayload, |
| | | LookupError, |
| | | Order, |
| | | Payload, |
| | | StringPayload, |
| | | get_payload, |
| | | payload_type, |
| | | ) |
| | | from .streams import StreamReader |
| | | |
| | | if sys.version_info >= (3, 11): |
| | | from typing import Self |
| | | else: |
| | | from typing import TypeVar |
| | | |
| | | Self = TypeVar("Self", bound="BodyPartReader") |
| | | |
| | | __all__ = ( |
| | | "MultipartReader", |
| | | "MultipartWriter", |
| | | "BodyPartReader", |
| | | "BadContentDispositionHeader", |
| | | "BadContentDispositionParam", |
| | | "parse_content_disposition", |
| | | "content_disposition_filename", |
| | | ) |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from .client_reqrep import ClientResponse |
| | | |
| | | |
| | | class BadContentDispositionHeader(RuntimeWarning): |
| | | pass |
| | | |
| | | |
| | | class BadContentDispositionParam(RuntimeWarning): |
| | | pass |
| | | |
| | | |
| | | def parse_content_disposition( |
| | | header: Optional[str], |
| | | ) -> Tuple[Optional[str], Dict[str, str]]: |
| | | def is_token(string: str) -> bool: |
| | | return bool(string) and TOKEN >= set(string) |
| | | |
| | | def is_quoted(string: str) -> bool: |
| | | return string[0] == string[-1] == '"' |
| | | |
| | | def is_rfc5987(string: str) -> bool: |
| | | return is_token(string) and string.count("'") == 2 |
| | | |
| | | def is_extended_param(string: str) -> bool: |
| | | return string.endswith("*") |
| | | |
| | | def is_continuous_param(string: str) -> bool: |
| | | pos = string.find("*") + 1 |
| | | if not pos: |
| | | return False |
| | | substring = string[pos:-1] if string.endswith("*") else string[pos:] |
| | | return substring.isdigit() |
| | | |
| | | def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: |
| | | return re.sub(f"\\\\([{chars}])", "\\1", text) |
| | | |
| | | if not header: |
| | | return None, {} |
| | | |
| | | disptype, *parts = header.split(";") |
| | | if not is_token(disptype): |
| | | warnings.warn(BadContentDispositionHeader(header)) |
| | | return None, {} |
| | | |
| | | params: Dict[str, str] = {} |
| | | while parts: |
| | | item = parts.pop(0) |
| | | |
| | | if not item: # To handle trailing semicolons |
| | | warnings.warn(BadContentDispositionHeader(header)) |
| | | continue |
| | | |
| | | if "=" not in item: |
| | | warnings.warn(BadContentDispositionHeader(header)) |
| | | return None, {} |
| | | |
| | | key, value = item.split("=", 1) |
| | | key = key.lower().strip() |
| | | value = value.lstrip() |
| | | |
| | | if key in params: |
| | | warnings.warn(BadContentDispositionHeader(header)) |
| | | return None, {} |
| | | |
| | | if not is_token(key): |
| | | warnings.warn(BadContentDispositionParam(item)) |
| | | continue |
| | | |
| | | elif is_continuous_param(key): |
| | | if is_quoted(value): |
| | | value = unescape(value[1:-1]) |
| | | elif not is_token(value): |
| | | warnings.warn(BadContentDispositionParam(item)) |
| | | continue |
| | | |
| | | elif is_extended_param(key): |
| | | if is_rfc5987(value): |
| | | encoding, _, value = value.split("'", 2) |
| | | encoding = encoding or "utf-8" |
| | | else: |
| | | warnings.warn(BadContentDispositionParam(item)) |
| | | continue |
| | | |
| | | try: |
| | | value = unquote(value, encoding, "strict") |
| | | except UnicodeDecodeError: # pragma: nocover |
| | | warnings.warn(BadContentDispositionParam(item)) |
| | | continue |
| | | |
| | | else: |
| | | failed = True |
| | | if is_quoted(value): |
| | | failed = False |
| | | value = unescape(value[1:-1].lstrip("\\/")) |
| | | elif is_token(value): |
| | | failed = False |
| | | elif parts: |
| | | # maybe just ; in filename, in any case this is just |
| | | # one case fix, for proper fix we need to redesign parser |
| | | _value = f"{value};{parts[0]}" |
| | | if is_quoted(_value): |
| | | parts.pop(0) |
| | | value = unescape(_value[1:-1].lstrip("\\/")) |
| | | failed = False |
| | | |
| | | if failed: |
| | | warnings.warn(BadContentDispositionHeader(header)) |
| | | return None, {} |
| | | |
| | | params[key] = value |
| | | |
| | | return disptype.lower(), params |
| | | |
| | | |
| | | def content_disposition_filename( |
| | | params: Mapping[str, str], name: str = "filename" |
| | | ) -> Optional[str]: |
| | | name_suf = "%s*" % name |
| | | if not params: |
| | | return None |
| | | elif name_suf in params: |
| | | return params[name_suf] |
| | | elif name in params: |
| | | return params[name] |
| | | else: |
| | | parts = [] |
| | | fnparams = sorted( |
| | | (key, value) for key, value in params.items() if key.startswith(name_suf) |
| | | ) |
| | | for num, (key, value) in enumerate(fnparams): |
| | | _, tail = key.split("*", 1) |
| | | if tail.endswith("*"): |
| | | tail = tail[:-1] |
| | | if tail == str(num): |
| | | parts.append(value) |
| | | else: |
| | | break |
| | | if not parts: |
| | | return None |
| | | value = "".join(parts) |
| | | if "'" in value: |
| | | encoding, _, value = value.split("'", 2) |
| | | encoding = encoding or "utf-8" |
| | | return unquote(value, encoding, "strict") |
| | | return value |
| | | |
| | | |
| | | class MultipartResponseWrapper: |
| | | """Wrapper around the MultipartReader. |
| | | |
| | | It takes care about |
| | | underlying connection and close it when it needs in. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | resp: "ClientResponse", |
| | | stream: "MultipartReader", |
| | | ) -> None: |
| | | self.resp = resp |
| | | self.stream = stream |
| | | |
| | | def __aiter__(self) -> "MultipartResponseWrapper": |
| | | return self |
| | | |
| | | async def __anext__( |
| | | self, |
| | | ) -> Union["MultipartReader", "BodyPartReader"]: |
| | | part = await self.next() |
| | | if part is None: |
| | | raise StopAsyncIteration |
| | | return part |
| | | |
| | | def at_eof(self) -> bool: |
| | | """Returns True when all response data had been read.""" |
| | | return self.resp.content.at_eof() |
| | | |
| | | async def next( |
| | | self, |
| | | ) -> Optional[Union["MultipartReader", "BodyPartReader"]]: |
| | | """Emits next multipart reader object.""" |
| | | item = await self.stream.next() |
| | | if self.stream.at_eof(): |
| | | await self.release() |
| | | return item |
| | | |
| | | async def release(self) -> None: |
| | | """Release the connection gracefully. |
| | | |
| | | All remaining content is read to the void. |
| | | """ |
| | | await self.resp.release() |
| | | |
| | | |
| | | class BodyPartReader: |
| | | """Multipart reader for single body part.""" |
| | | |
| | | chunk_size = 8192 |
| | | |
| | | def __init__( |
| | | self, |
| | | boundary: bytes, |
| | | headers: "CIMultiDictProxy[str]", |
| | | content: StreamReader, |
| | | *, |
| | | subtype: str = "mixed", |
| | | default_charset: Optional[str] = None, |
| | | max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, |
| | | ) -> None: |
| | | self.headers = headers |
| | | self._boundary = boundary |
| | | self._boundary_len = len(boundary) + 2 # Boundary + \r\n |
| | | self._content = content |
| | | self._default_charset = default_charset |
| | | self._at_eof = False |
| | | self._is_form_data = subtype == "form-data" |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 |
| | | length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None) |
| | | self._length = int(length) if length is not None else None |
| | | self._read_bytes = 0 |
| | | self._unread: Deque[bytes] = deque() |
| | | self._prev_chunk: Optional[bytes] = None |
| | | self._content_eof = 0 |
| | | self._cache: Dict[str, Any] = {} |
| | | self._max_decompress_size = max_decompress_size |
| | | |
| | | def __aiter__(self: Self) -> Self: |
| | | return self |
| | | |
| | | async def __anext__(self) -> bytes: |
| | | part = await self.next() |
| | | if part is None: |
| | | raise StopAsyncIteration |
| | | return part |
| | | |
| | | async def next(self) -> Optional[bytes]: |
| | | item = await self.read() |
| | | if not item: |
| | | return None |
| | | return item |
| | | |
| | | async def read(self, *, decode: bool = False) -> bytes: |
| | | """Reads body part data. |
| | | |
| | | decode: Decodes data following by encoding |
| | | method from Content-Encoding header. If it missed |
| | | data remains untouched |
| | | """ |
| | | if self._at_eof: |
| | | return b"" |
| | | data = bytearray() |
| | | while not self._at_eof: |
| | | data.extend(await self.read_chunk(self.chunk_size)) |
| | | if decode: |
| | | return await self.decode(data) |
| | | return data |
| | | |
| | | async def read_chunk(self, size: int = chunk_size) -> bytes: |
| | | """Reads body part content chunk of the specified size. |
| | | |
| | | size: chunk size |
| | | """ |
| | | if self._at_eof: |
| | | return b"" |
| | | if self._length: |
| | | chunk = await self._read_chunk_from_length(size) |
| | | else: |
| | | chunk = await self._read_chunk_from_stream(size) |
| | | |
| | | # For the case of base64 data, we must read a fragment of size with a |
| | | # remainder of 0 by dividing by 4 for string without symbols \n or \r |
| | | encoding = self.headers.get(CONTENT_TRANSFER_ENCODING) |
| | | if encoding and encoding.lower() == "base64": |
| | | stripped_chunk = b"".join(chunk.split()) |
| | | remainder = len(stripped_chunk) % 4 |
| | | |
| | | while remainder != 0 and not self.at_eof(): |
| | | over_chunk_size = 4 - remainder |
| | | over_chunk = b"" |
| | | |
| | | if self._prev_chunk: |
| | | over_chunk = self._prev_chunk[:over_chunk_size] |
| | | self._prev_chunk = self._prev_chunk[len(over_chunk) :] |
| | | |
| | | if len(over_chunk) != over_chunk_size: |
| | | over_chunk += await self._content.read(4 - len(over_chunk)) |
| | | |
| | | if not over_chunk: |
| | | self._at_eof = True |
| | | |
| | | stripped_chunk += b"".join(over_chunk.split()) |
| | | chunk += over_chunk |
| | | remainder = len(stripped_chunk) % 4 |
| | | |
| | | self._read_bytes += len(chunk) |
| | | if self._read_bytes == self._length: |
| | | self._at_eof = True |
| | | if self._at_eof and await self._content.readline() != b"\r\n": |
| | | raise ValueError("Reader did not read all the data or it is malformed") |
| | | return chunk |
| | | |
| | | async def _read_chunk_from_length(self, size: int) -> bytes: |
| | | # Reads body part content chunk of the specified size. |
| | | # The body part must has Content-Length header with proper value. |
| | | assert self._length is not None, "Content-Length required for chunked read" |
| | | chunk_size = min(size, self._length - self._read_bytes) |
| | | chunk = await self._content.read(chunk_size) |
| | | if self._content.at_eof(): |
| | | self._at_eof = True |
| | | return chunk |
| | | |
| | | async def _read_chunk_from_stream(self, size: int) -> bytes: |
| | | # Reads content chunk of body part with unknown length. |
| | | # The Content-Length header for body part is not necessary. |
| | | assert ( |
| | | size >= self._boundary_len |
| | | ), "Chunk size must be greater or equal than boundary length + 2" |
| | | first_chunk = self._prev_chunk is None |
| | | if first_chunk: |
| | | # We need to re-add the CRLF that got removed from headers parsing. |
| | | self._prev_chunk = b"\r\n" + await self._content.read(size) |
| | | |
| | | chunk = b"" |
| | | # content.read() may return less than size, so we need to loop to ensure |
| | | # we have enough data to detect the boundary. |
| | | while len(chunk) < self._boundary_len: |
| | | chunk += await self._content.read(size) |
| | | self._content_eof += int(self._content.at_eof()) |
| | | if self._content_eof > 2: |
| | | raise ValueError("Reading after EOF") |
| | | if self._content_eof: |
| | | break |
| | | if len(chunk) > size: |
| | | self._content.unread_data(chunk[size:]) |
| | | chunk = chunk[:size] |
| | | |
| | | assert self._prev_chunk is not None |
| | | window = self._prev_chunk + chunk |
| | | sub = b"\r\n" + self._boundary |
| | | if first_chunk: |
| | | idx = window.find(sub) |
| | | else: |
| | | idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) |
| | | if idx >= 0: |
| | | # pushing boundary back to content |
| | | with warnings.catch_warnings(): |
| | | warnings.filterwarnings("ignore", category=DeprecationWarning) |
| | | self._content.unread_data(window[idx:]) |
| | | self._prev_chunk = self._prev_chunk[:idx] |
| | | chunk = window[len(self._prev_chunk) : idx] |
| | | if not chunk: |
| | | self._at_eof = True |
| | | result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF |
| | | self._prev_chunk = chunk |
| | | return result |
| | | |
| | | async def readline(self) -> bytes: |
| | | """Reads body part by line by line.""" |
| | | if self._at_eof: |
| | | return b"" |
| | | |
| | | if self._unread: |
| | | line = self._unread.popleft() |
| | | else: |
| | | line = await self._content.readline() |
| | | |
| | | if line.startswith(self._boundary): |
| | | # the very last boundary may not come with \r\n, |
| | | # so set single rules for everyone |
| | | sline = line.rstrip(b"\r\n") |
| | | boundary = self._boundary |
| | | last_boundary = self._boundary + b"--" |
| | | # ensure that we read exactly the boundary, not something alike |
| | | if sline == boundary or sline == last_boundary: |
| | | self._at_eof = True |
| | | self._unread.append(line) |
| | | return b"" |
| | | else: |
| | | next_line = await self._content.readline() |
| | | if next_line.startswith(self._boundary): |
| | | line = line[:-2] # strip CRLF but only once |
| | | self._unread.append(next_line) |
| | | |
| | | return line |
| | | |
| | | async def release(self) -> None: |
| | | """Like read(), but reads all the data to the void.""" |
| | | if self._at_eof: |
| | | return |
| | | while not self._at_eof: |
| | | await self.read_chunk(self.chunk_size) |
| | | |
| | | async def text(self, *, encoding: Optional[str] = None) -> str: |
| | | """Like read(), but assumes that body part contains text data.""" |
| | | data = await self.read(decode=True) |
| | | # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm |
| | | # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send |
| | | encoding = encoding or self.get_charset(default="utf-8") |
| | | return data.decode(encoding) |
| | | |
| | | async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]: |
| | | """Like read(), but assumes that body parts contains JSON data.""" |
| | | data = await self.read(decode=True) |
| | | if not data: |
| | | return None |
| | | encoding = encoding or self.get_charset(default="utf-8") |
| | | return cast(Dict[str, Any], json.loads(data.decode(encoding))) |
| | | |
| | | async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]: |
| | | """Like read(), but assumes that body parts contain form urlencoded data.""" |
| | | data = await self.read(decode=True) |
| | | if not data: |
| | | return [] |
| | | if encoding is not None: |
| | | real_encoding = encoding |
| | | else: |
| | | real_encoding = self.get_charset(default="utf-8") |
| | | try: |
| | | decoded_data = data.rstrip().decode(real_encoding) |
| | | except UnicodeDecodeError: |
| | | raise ValueError("data cannot be decoded with %s encoding" % real_encoding) |
| | | |
| | | return parse_qsl( |
| | | decoded_data, |
| | | keep_blank_values=True, |
| | | encoding=real_encoding, |
| | | ) |
| | | |
| | | def at_eof(self) -> bool: |
| | | """Returns True if the boundary was reached or False otherwise.""" |
| | | return self._at_eof |
| | | |
| | | async def decode(self, data: bytes) -> bytes: |
| | | """Decodes data. |
| | | |
| | | Decoding is done according the specified Content-Encoding |
| | | or Content-Transfer-Encoding headers value. |
| | | """ |
| | | if CONTENT_TRANSFER_ENCODING in self.headers: |
| | | data = self._decode_content_transfer(data) |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 |
| | | if not self._is_form_data and CONTENT_ENCODING in self.headers: |
| | | return await self._decode_content(data) |
| | | return data |
| | | |
| | | async def _decode_content(self, data: bytes) -> bytes: |
| | | encoding = self.headers.get(CONTENT_ENCODING, "").lower() |
| | | if encoding == "identity": |
| | | return data |
| | | if encoding in {"deflate", "gzip"}: |
| | | return await ZLibDecompressor( |
| | | encoding=encoding, |
| | | suppress_deflate_header=True, |
| | | ).decompress(data, max_length=self._max_decompress_size) |
| | | |
| | | raise RuntimeError(f"unknown content encoding: {encoding}") |
| | | |
| | | def _decode_content_transfer(self, data: bytes) -> bytes: |
| | | encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() |
| | | |
| | | if encoding == "base64": |
| | | return base64.b64decode(data) |
| | | elif encoding == "quoted-printable": |
| | | return binascii.a2b_qp(data) |
| | | elif encoding in ("binary", "8bit", "7bit"): |
| | | return data |
| | | else: |
| | | raise RuntimeError(f"unknown content transfer encoding: {encoding}") |
| | | |
| | | def get_charset(self, default: str) -> str: |
| | | """Returns charset parameter from Content-Type header or default.""" |
| | | ctype = self.headers.get(CONTENT_TYPE, "") |
| | | mimetype = parse_mimetype(ctype) |
| | | return mimetype.parameters.get("charset", self._default_charset or default) |
| | | |
| | | @reify |
| | | def name(self) -> Optional[str]: |
| | | """Returns name specified in Content-Disposition header. |
| | | |
| | | If the header is missing or malformed, returns None. |
| | | """ |
| | | _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) |
| | | return content_disposition_filename(params, "name") |
| | | |
| | | @reify |
| | | def filename(self) -> Optional[str]: |
| | | """Returns filename specified in Content-Disposition header. |
| | | |
| | | Returns None if the header is missing or malformed. |
| | | """ |
| | | _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) |
| | | return content_disposition_filename(params, "filename") |
| | | |
| | | |
| | | @payload_type(BodyPartReader, order=Order.try_first) |
| | | class BodyPartReaderPayload(Payload): |
| | | _value: BodyPartReader |
| | | # _autoclose = False (inherited) - Streaming reader that may have resources |
| | | |
| | | def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: |
| | | super().__init__(value, *args, **kwargs) |
| | | |
| | | params: Dict[str, str] = {} |
| | | if value.name is not None: |
| | | params["name"] = value.name |
| | | if value.filename is not None: |
| | | params["filename"] = value.filename |
| | | |
| | | if params: |
| | | self.set_content_disposition("attachment", True, **params) |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | raise TypeError("Unable to decode.") |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """Raises TypeError as body parts should be consumed via write(). |
| | | |
| | | This is intentional: BodyPartReader payloads are designed for streaming |
| | | large data (potentially gigabytes) and must be consumed only once via |
| | | the write() method to avoid memory exhaustion. They cannot be buffered |
| | | in memory for reuse. |
| | | """ |
| | | raise TypeError("Unable to read body part as bytes. Use write() to consume.") |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | field = self._value |
| | | chunk = await field.read_chunk(size=2**16) |
| | | while chunk: |
| | | await writer.write(await field.decode(chunk)) |
| | | chunk = await field.read_chunk(size=2**16) |
| | | |
| | | |
| | | class MultipartReader: |
| | | """Multipart body reader.""" |
| | | |
| | | #: Response wrapper, used when multipart readers constructs from response. |
| | | response_wrapper_cls = MultipartResponseWrapper |
| | | #: Multipart reader class, used to handle multipart/* body parts. |
| | | #: None points to type(self) |
| | | multipart_reader_cls: Optional[Type["MultipartReader"]] = None |
| | | #: Body part reader class for non multipart/* content types. |
| | | part_reader_cls = BodyPartReader |
| | | |
| | | def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: |
| | | self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) |
| | | assert self._mimetype.type == "multipart", "multipart/* content type expected" |
| | | if "boundary" not in self._mimetype.parameters: |
| | | raise ValueError( |
| | | "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE] |
| | | ) |
| | | |
| | | self.headers = headers |
| | | self._boundary = ("--" + self._get_boundary()).encode() |
| | | self._content = content |
| | | self._default_charset: Optional[str] = None |
| | | self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None |
| | | self._at_eof = False |
| | | self._at_bof = True |
| | | self._unread: List[bytes] = [] |
| | | |
| | | def __aiter__(self: Self) -> Self: |
| | | return self |
| | | |
| | | async def __anext__( |
| | | self, |
| | | ) -> Optional[Union["MultipartReader", BodyPartReader]]: |
| | | part = await self.next() |
| | | if part is None: |
| | | raise StopAsyncIteration |
| | | return part |
| | | |
| | | @classmethod |
| | | def from_response( |
| | | cls, |
| | | response: "ClientResponse", |
| | | ) -> MultipartResponseWrapper: |
| | | """Constructs reader instance from HTTP response. |
| | | |
| | | :param response: :class:`~aiohttp.client.ClientResponse` instance |
| | | """ |
| | | obj = cls.response_wrapper_cls( |
| | | response, cls(response.headers, response.content) |
| | | ) |
| | | return obj |
| | | |
| | | def at_eof(self) -> bool: |
| | | """Returns True if the final boundary was reached, false otherwise.""" |
| | | return self._at_eof |
| | | |
| | | async def next( |
| | | self, |
| | | ) -> Optional[Union["MultipartReader", BodyPartReader]]: |
| | | """Emits the next multipart body part.""" |
| | | # So, if we're at BOF, we need to skip till the boundary. |
| | | if self._at_eof: |
| | | return None |
| | | await self._maybe_release_last_part() |
| | | if self._at_bof: |
| | | await self._read_until_first_boundary() |
| | | self._at_bof = False |
| | | else: |
| | | await self._read_boundary() |
| | | if self._at_eof: # we just read the last boundary, nothing to do there |
| | | return None |
| | | |
| | | part = await self.fetch_next_part() |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6 |
| | | if ( |
| | | self._last_part is None |
| | | and self._mimetype.subtype == "form-data" |
| | | and isinstance(part, BodyPartReader) |
| | | ): |
| | | _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION)) |
| | | if params.get("name") == "_charset_": |
| | | # Longest encoding in https://encoding.spec.whatwg.org/encodings.json |
| | | # is 19 characters, so 32 should be more than enough for any valid encoding. |
| | | charset = await part.read_chunk(32) |
| | | if len(charset) > 31: |
| | | raise RuntimeError("Invalid default charset") |
| | | self._default_charset = charset.strip().decode() |
| | | part = await self.fetch_next_part() |
| | | self._last_part = part |
| | | return self._last_part |
| | | |
| | | async def release(self) -> None: |
| | | """Reads all the body parts to the void till the final boundary.""" |
| | | while not self._at_eof: |
| | | item = await self.next() |
| | | if item is None: |
| | | break |
| | | await item.release() |
| | | |
| | | async def fetch_next_part( |
| | | self, |
| | | ) -> Union["MultipartReader", BodyPartReader]: |
| | | """Returns the next body part reader.""" |
| | | headers = await self._read_headers() |
| | | return self._get_part_reader(headers) |
| | | |
| | | def _get_part_reader( |
| | | self, |
| | | headers: "CIMultiDictProxy[str]", |
| | | ) -> Union["MultipartReader", BodyPartReader]: |
| | | """Dispatches the response by the `Content-Type` header. |
| | | |
| | | Returns a suitable reader instance. |
| | | |
| | | :param dict headers: Response headers |
| | | """ |
| | | ctype = headers.get(CONTENT_TYPE, "") |
| | | mimetype = parse_mimetype(ctype) |
| | | |
| | | if mimetype.type == "multipart": |
| | | if self.multipart_reader_cls is None: |
| | | return type(self)(headers, self._content) |
| | | return self.multipart_reader_cls(headers, self._content) |
| | | else: |
| | | return self.part_reader_cls( |
| | | self._boundary, |
| | | headers, |
| | | self._content, |
| | | subtype=self._mimetype.subtype, |
| | | default_charset=self._default_charset, |
| | | ) |
| | | |
| | | def _get_boundary(self) -> str: |
| | | boundary = self._mimetype.parameters["boundary"] |
| | | if len(boundary) > 70: |
| | | raise ValueError("boundary %r is too long (70 chars max)" % boundary) |
| | | |
| | | return boundary |
| | | |
| | | async def _readline(self) -> bytes: |
| | | if self._unread: |
| | | return self._unread.pop() |
| | | return await self._content.readline() |
| | | |
| | | async def _read_until_first_boundary(self) -> None: |
| | | while True: |
| | | chunk = await self._readline() |
| | | if chunk == b"": |
| | | raise ValueError( |
| | | "Could not find starting boundary %r" % (self._boundary) |
| | | ) |
| | | chunk = chunk.rstrip() |
| | | if chunk == self._boundary: |
| | | return |
| | | elif chunk == self._boundary + b"--": |
| | | self._at_eof = True |
| | | return |
| | | |
| | | async def _read_boundary(self) -> None: |
| | | chunk = (await self._readline()).rstrip() |
| | | if chunk == self._boundary: |
| | | pass |
| | | elif chunk == self._boundary + b"--": |
| | | self._at_eof = True |
| | | epilogue = await self._readline() |
| | | next_line = await self._readline() |
| | | |
| | | # the epilogue is expected and then either the end of input or the |
| | | # parent multipart boundary, if the parent boundary is found then |
| | | # it should be marked as unread and handed to the parent for |
| | | # processing |
| | | if next_line[:2] == b"--": |
| | | self._unread.append(next_line) |
| | | # otherwise the request is likely missing an epilogue and both |
| | | # lines should be passed to the parent for processing |
| | | # (this handles the old behavior gracefully) |
| | | else: |
| | | self._unread.extend([next_line, epilogue]) |
| | | else: |
| | | raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") |
| | | |
| | | async def _read_headers(self) -> "CIMultiDictProxy[str]": |
| | | lines = [] |
| | | while True: |
| | | chunk = await self._content.readline() |
| | | chunk = chunk.rstrip(b"\r\n") |
| | | lines.append(chunk) |
| | | if not chunk: |
| | | break |
| | | parser = HeadersParser() |
| | | headers, raw_headers = parser.parse_headers(lines) |
| | | return headers |
| | | |
| | | async def _maybe_release_last_part(self) -> None: |
| | | """Ensures that the last read body part is read completely.""" |
| | | if self._last_part is not None: |
| | | if not self._last_part.at_eof(): |
| | | await self._last_part.release() |
| | | self._unread.extend(self._last_part._unread) |
| | | self._last_part = None |
| | | |
| | | |
| | | _Part = Tuple[Payload, str, str] |
| | | |
| | | |
| | | class MultipartWriter(Payload): |
| | | """Multipart body writer.""" |
| | | |
| | | _value: None |
| | | # _consumed = False (inherited) - Can be encoded multiple times |
| | | _autoclose = True # No file handles, just collects parts in memory |
| | | |
| | | def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: |
| | | boundary = boundary if boundary is not None else uuid.uuid4().hex |
| | | # The underlying Payload API demands a str (utf-8), not bytes, |
| | | # so we need to ensure we don't lose anything during conversion. |
| | | # As a result, require the boundary to be ASCII only. |
| | | # In both situations. |
| | | |
| | | try: |
| | | self._boundary = boundary.encode("ascii") |
| | | except UnicodeEncodeError: |
| | | raise ValueError("boundary should contain ASCII only chars") from None |
| | | ctype = f"multipart/{subtype}; boundary={self._boundary_value}" |
| | | |
| | | super().__init__(None, content_type=ctype) |
| | | |
| | | self._parts: List[_Part] = [] |
| | | self._is_form_data = subtype == "form-data" |
| | | |
| | | def __enter__(self) -> "MultipartWriter": |
| | | return self |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_val: Optional[BaseException], |
| | | exc_tb: Optional[TracebackType], |
| | | ) -> None: |
| | | pass |
| | | |
| | | def __iter__(self) -> Iterator[_Part]: |
| | | return iter(self._parts) |
| | | |
| | | def __len__(self) -> int: |
| | | return len(self._parts) |
| | | |
| | | def __bool__(self) -> bool: |
| | | return True |
| | | |
| | | _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z") |
| | | _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]") |
| | | |
| | | @property |
| | | def _boundary_value(self) -> str: |
| | | """Wrap boundary parameter value in quotes, if necessary. |
| | | |
| | | Reads self.boundary and returns a unicode string. |
| | | """ |
| | | # Refer to RFCs 7231, 7230, 5234. |
| | | # |
| | | # parameter = token "=" ( token / quoted-string ) |
| | | # token = 1*tchar |
| | | # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE |
| | | # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text |
| | | # obs-text = %x80-FF |
| | | # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) |
| | | # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" |
| | | # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" |
| | | # / DIGIT / ALPHA |
| | | # ; any VCHAR, except delimiters |
| | | # VCHAR = %x21-7E |
| | | value = self._boundary |
| | | if re.match(self._valid_tchar_regex, value): |
| | | return value.decode("ascii") # cannot fail |
| | | |
| | | if re.search(self._invalid_qdtext_char_regex, value): |
| | | raise ValueError("boundary value contains invalid characters") |
| | | |
| | | # escape %x5C and %x22 |
| | | quoted_value_content = value.replace(b"\\", b"\\\\") |
| | | quoted_value_content = quoted_value_content.replace(b'"', b'\\"') |
| | | |
| | | return '"' + quoted_value_content.decode("ascii") + '"' |
| | | |
| | | @property |
| | | def boundary(self) -> str: |
| | | return self._boundary.decode("ascii") |
| | | |
| | | def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload: |
| | | if headers is None: |
| | | headers = CIMultiDict() |
| | | |
| | | if isinstance(obj, Payload): |
| | | obj.headers.update(headers) |
| | | return self.append_payload(obj) |
| | | else: |
| | | try: |
| | | payload = get_payload(obj, headers=headers) |
| | | except LookupError: |
| | | raise TypeError("Cannot create payload from %r" % obj) |
| | | else: |
| | | return self.append_payload(payload) |
| | | |
| | | def append_payload(self, payload: Payload) -> Payload: |
| | | """Adds a new body part to multipart writer.""" |
| | | encoding: Optional[str] = None |
| | | te_encoding: Optional[str] = None |
| | | if self._is_form_data: |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7 |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 |
| | | assert ( |
| | | not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING} |
| | | & payload.headers.keys() |
| | | ) |
| | | # Set default Content-Disposition in case user doesn't create one |
| | | if CONTENT_DISPOSITION not in payload.headers: |
| | | name = f"section-{len(self._parts)}" |
| | | payload.set_content_disposition("form-data", name=name) |
| | | else: |
| | | # compression |
| | | encoding = payload.headers.get(CONTENT_ENCODING, "").lower() |
| | | if encoding and encoding not in ("deflate", "gzip", "identity"): |
| | | raise RuntimeError(f"unknown content encoding: {encoding}") |
| | | if encoding == "identity": |
| | | encoding = None |
| | | |
| | | # te encoding |
| | | te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() |
| | | if te_encoding not in ("", "base64", "quoted-printable", "binary"): |
| | | raise RuntimeError(f"unknown content transfer encoding: {te_encoding}") |
| | | if te_encoding == "binary": |
| | | te_encoding = None |
| | | |
| | | # size |
| | | size = payload.size |
| | | if size is not None and not (encoding or te_encoding): |
| | | payload.headers[CONTENT_LENGTH] = str(size) |
| | | |
| | | self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type] |
| | | return payload |
| | | |
| | | def append_json( |
| | | self, obj: Any, headers: Optional[Mapping[str, str]] = None |
| | | ) -> Payload: |
| | | """Helper to append JSON part.""" |
| | | if headers is None: |
| | | headers = CIMultiDict() |
| | | |
| | | return self.append_payload(JsonPayload(obj, headers=headers)) |
| | | |
| | | def append_form( |
| | | self, |
| | | obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]], |
| | | headers: Optional[Mapping[str, str]] = None, |
| | | ) -> Payload: |
| | | """Helper to append form urlencoded part.""" |
| | | assert isinstance(obj, (Sequence, Mapping)) |
| | | |
| | | if headers is None: |
| | | headers = CIMultiDict() |
| | | |
| | | if isinstance(obj, Mapping): |
| | | obj = list(obj.items()) |
| | | data = urlencode(obj, doseq=True) |
| | | |
| | | return self.append_payload( |
| | | StringPayload( |
| | | data, headers=headers, content_type="application/x-www-form-urlencoded" |
| | | ) |
| | | ) |
| | | |
| | | @property |
| | | def size(self) -> Optional[int]: |
| | | """Size of the payload.""" |
| | | total = 0 |
| | | for part, encoding, te_encoding in self._parts: |
| | | part_size = part.size |
| | | if encoding or te_encoding or part_size is None: |
| | | return None |
| | | |
| | | total += int( |
| | | 2 |
| | | + len(self._boundary) |
| | | + 2 |
| | | + part_size # b'--'+self._boundary+b'\r\n' |
| | | + len(part._binary_headers) |
| | | + 2 # b'\r\n' |
| | | ) |
| | | |
| | | total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' |
| | | return total |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | """Return string representation of the multipart data. |
| | | |
| | | WARNING: This method may do blocking I/O if parts contain file payloads. |
| | | It should not be called in the event loop. Use as_bytes().decode() instead. |
| | | """ |
| | | return "".join( |
| | | "--" |
| | | + self.boundary |
| | | + "\r\n" |
| | | + part._binary_headers.decode(encoding, errors) |
| | | + part.decode() |
| | | for part, _e, _te in self._parts |
| | | ) |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """Return bytes representation of the multipart data. |
| | | |
| | | This method is async-safe and calls as_bytes on underlying payloads. |
| | | """ |
| | | parts: List[bytes] = [] |
| | | |
| | | # Process each part |
| | | for part, _e, _te in self._parts: |
| | | # Add boundary |
| | | parts.append(b"--" + self._boundary + b"\r\n") |
| | | |
| | | # Add headers |
| | | parts.append(part._binary_headers) |
| | | |
| | | # Add payload content using as_bytes for async safety |
| | | part_bytes = await part.as_bytes(encoding, errors) |
| | | parts.append(part_bytes) |
| | | |
| | | # Add trailing CRLF |
| | | parts.append(b"\r\n") |
| | | |
| | | # Add closing boundary |
| | | parts.append(b"--" + self._boundary + b"--\r\n") |
| | | |
| | | return b"".join(parts) |
| | | |
| | | async def write( |
| | | self, writer: AbstractStreamWriter, close_boundary: bool = True |
| | | ) -> None: |
| | | """Write body.""" |
| | | for part, encoding, te_encoding in self._parts: |
| | | if self._is_form_data: |
| | | # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2 |
| | | assert CONTENT_DISPOSITION in part.headers |
| | | assert "name=" in part.headers[CONTENT_DISPOSITION] |
| | | |
| | | await writer.write(b"--" + self._boundary + b"\r\n") |
| | | await writer.write(part._binary_headers) |
| | | |
| | | if encoding or te_encoding: |
| | | w = MultipartPayloadWriter(writer) |
| | | if encoding: |
| | | w.enable_compression(encoding) |
| | | if te_encoding: |
| | | w.enable_encoding(te_encoding) |
| | | await part.write(w) # type: ignore[arg-type] |
| | | await w.write_eof() |
| | | else: |
| | | await part.write(writer) |
| | | |
| | | await writer.write(b"\r\n") |
| | | |
| | | if close_boundary: |
| | | await writer.write(b"--" + self._boundary + b"--\r\n") |
| | | |
| | | async def close(self) -> None: |
| | | """ |
| | | Close all part payloads that need explicit closing. |
| | | |
| | | IMPORTANT: This method must not await anything that might not finish |
| | | immediately, as it may be called during cleanup/cancellation. Schedule |
| | | any long-running operations without awaiting them. |
| | | """ |
| | | if self._consumed: |
| | | return |
| | | self._consumed = True |
| | | |
| | | # Close all parts that need explicit closing |
| | | # We catch and log exceptions to ensure all parts get a chance to close |
| | | # we do not use asyncio.gather() here because we are not allowed |
| | | # to suspend given we may be called during cleanup |
| | | for idx, (part, _, _) in enumerate(self._parts): |
| | | if not part.autoclose and not part.consumed: |
| | | try: |
| | | await part.close() |
| | | except Exception as exc: |
| | | internal_logger.error( |
| | | "Failed to close multipart part %d: %s", idx, exc, exc_info=True |
| | | ) |
| | | |
| | | |
| | | class MultipartPayloadWriter: |
| | | def __init__(self, writer: AbstractStreamWriter) -> None: |
| | | self._writer = writer |
| | | self._encoding: Optional[str] = None |
| | | self._compress: Optional[ZLibCompressor] = None |
| | | self._encoding_buffer: Optional[bytearray] = None |
| | | |
| | | def enable_encoding(self, encoding: str) -> None: |
| | | if encoding == "base64": |
| | | self._encoding = encoding |
| | | self._encoding_buffer = bytearray() |
| | | elif encoding == "quoted-printable": |
| | | self._encoding = "quoted-printable" |
| | | |
| | | def enable_compression( |
| | | self, encoding: str = "deflate", strategy: Optional[int] = None |
| | | ) -> None: |
| | | self._compress = ZLibCompressor( |
| | | encoding=encoding, |
| | | suppress_deflate_header=True, |
| | | strategy=strategy, |
| | | ) |
| | | |
| | | async def write_eof(self) -> None: |
| | | if self._compress is not None: |
| | | chunk = self._compress.flush() |
| | | if chunk: |
| | | self._compress = None |
| | | await self.write(chunk) |
| | | |
| | | if self._encoding == "base64": |
| | | if self._encoding_buffer: |
| | | await self._writer.write(base64.b64encode(self._encoding_buffer)) |
| | | |
| | | async def write(self, chunk: bytes) -> None: |
| | | if self._compress is not None: |
| | | if chunk: |
| | | chunk = await self._compress.compress(chunk) |
| | | if not chunk: |
| | | return |
| | | |
| | | if self._encoding == "base64": |
| | | buf = self._encoding_buffer |
| | | assert buf is not None |
| | | buf.extend(chunk) |
| | | |
| | | if buf: |
| | | div, mod = divmod(len(buf), 3) |
| | | enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :]) |
| | | if enc_chunk: |
| | | b64chunk = base64.b64encode(enc_chunk) |
| | | await self._writer.write(b64chunk) |
| | | elif self._encoding == "quoted-printable": |
| | | await self._writer.write(binascii.b2a_qp(chunk)) |
| | | else: |
| | | await self._writer.write(chunk) |
| New file |
| | |
| | | import asyncio |
| | | import enum |
| | | import io |
| | | import json |
| | | import mimetypes |
| | | import os |
| | | import sys |
| | | import warnings |
| | | from abc import ABC, abstractmethod |
| | | from collections.abc import Iterable |
| | | from itertools import chain |
| | | from typing import ( |
| | | IO, |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Dict, |
| | | Final, |
| | | List, |
| | | Optional, |
| | | Set, |
| | | TextIO, |
| | | Tuple, |
| | | Type, |
| | | Union, |
| | | ) |
| | | |
| | | from multidict import CIMultiDict |
| | | |
| | | from . import hdrs |
| | | from .abc import AbstractStreamWriter |
| | | from .helpers import ( |
| | | _SENTINEL, |
| | | content_disposition_header, |
| | | guess_filename, |
| | | parse_mimetype, |
| | | sentinel, |
| | | ) |
| | | from .streams import StreamReader |
| | | from .typedefs import JSONEncoder, _CIMultiDict |
| | | |
| | | __all__ = ( |
| | | "PAYLOAD_REGISTRY", |
| | | "get_payload", |
| | | "payload_type", |
| | | "Payload", |
| | | "BytesPayload", |
| | | "StringPayload", |
| | | "IOBasePayload", |
| | | "BytesIOPayload", |
| | | "BufferedReaderPayload", |
| | | "TextIOPayload", |
| | | "StringIOPayload", |
| | | "JsonPayload", |
| | | "AsyncIterablePayload", |
| | | ) |
| | | |
| | | TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB |
| | | READ_SIZE: Final[int] = 2**16 # 64 KB |
| | | _CLOSE_FUTURES: Set[asyncio.Future[None]] = set() |
| | | |
| | | |
| | | class LookupError(Exception): |
| | | """Raised when no payload factory is found for the given data type.""" |
| | | |
| | | |
| | | class Order(str, enum.Enum): |
| | | normal = "normal" |
| | | try_first = "try_first" |
| | | try_last = "try_last" |
| | | |
| | | |
| | | def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": |
| | | return PAYLOAD_REGISTRY.get(data, *args, **kwargs) |
| | | |
| | | |
| | | def register_payload( |
| | | factory: Type["Payload"], type: Any, *, order: Order = Order.normal |
| | | ) -> None: |
| | | PAYLOAD_REGISTRY.register(factory, type, order=order) |
| | | |
| | | |
| | | class payload_type: |
| | | def __init__(self, type: Any, *, order: Order = Order.normal) -> None: |
| | | self.type = type |
| | | self.order = order |
| | | |
| | | def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: |
| | | register_payload(factory, self.type, order=self.order) |
| | | return factory |
| | | |
| | | |
| | | PayloadType = Type["Payload"] |
| | | _PayloadRegistryItem = Tuple[PayloadType, Any] |
| | | |
| | | |
| | | class PayloadRegistry: |
| | | """Payload registry. |
| | | |
| | | note: we need zope.interface for more efficient adapter search |
| | | """ |
| | | |
| | | __slots__ = ("_first", "_normal", "_last", "_normal_lookup") |
| | | |
| | | def __init__(self) -> None: |
| | | self._first: List[_PayloadRegistryItem] = [] |
| | | self._normal: List[_PayloadRegistryItem] = [] |
| | | self._last: List[_PayloadRegistryItem] = [] |
| | | self._normal_lookup: Dict[Any, PayloadType] = {} |
| | | |
| | | def get( |
| | | self, |
| | | data: Any, |
| | | *args: Any, |
| | | _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, |
| | | **kwargs: Any, |
| | | ) -> "Payload": |
| | | if self._first: |
| | | for factory, type_ in self._first: |
| | | if isinstance(data, type_): |
| | | return factory(data, *args, **kwargs) |
| | | # Try the fast lookup first |
| | | if lookup_factory := self._normal_lookup.get(type(data)): |
| | | return lookup_factory(data, *args, **kwargs) |
| | | # Bail early if its already a Payload |
| | | if isinstance(data, Payload): |
| | | return data |
| | | # Fallback to the slower linear search |
| | | for factory, type_ in _CHAIN(self._normal, self._last): |
| | | if isinstance(data, type_): |
| | | return factory(data, *args, **kwargs) |
| | | raise LookupError() |
| | | |
| | | def register( |
| | | self, factory: PayloadType, type: Any, *, order: Order = Order.normal |
| | | ) -> None: |
| | | if order is Order.try_first: |
| | | self._first.append((factory, type)) |
| | | elif order is Order.normal: |
| | | self._normal.append((factory, type)) |
| | | if isinstance(type, Iterable): |
| | | for t in type: |
| | | self._normal_lookup[t] = factory |
| | | else: |
| | | self._normal_lookup[type] = factory |
| | | elif order is Order.try_last: |
| | | self._last.append((factory, type)) |
| | | else: |
| | | raise ValueError(f"Unsupported order {order!r}") |
| | | |
| | | |
| | | class Payload(ABC): |
| | | |
| | | _default_content_type: str = "application/octet-stream" |
| | | _size: Optional[int] = None |
| | | _consumed: bool = False # Default: payload has not been consumed yet |
| | | _autoclose: bool = False # Default: assume resource needs explicit closing |
| | | |
| | | def __init__( |
| | | self, |
| | | value: Any, |
| | | headers: Optional[ |
| | | Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] |
| | | ] = None, |
| | | content_type: Union[str, None, _SENTINEL] = sentinel, |
| | | filename: Optional[str] = None, |
| | | encoding: Optional[str] = None, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | self._encoding = encoding |
| | | self._filename = filename |
| | | self._headers: _CIMultiDict = CIMultiDict() |
| | | self._value = value |
| | | if content_type is not sentinel and content_type is not None: |
| | | self._headers[hdrs.CONTENT_TYPE] = content_type |
| | | elif self._filename is not None: |
| | | if sys.version_info >= (3, 13): |
| | | guesser = mimetypes.guess_file_type |
| | | else: |
| | | guesser = mimetypes.guess_type |
| | | content_type = guesser(self._filename)[0] |
| | | if content_type is None: |
| | | content_type = self._default_content_type |
| | | self._headers[hdrs.CONTENT_TYPE] = content_type |
| | | else: |
| | | self._headers[hdrs.CONTENT_TYPE] = self._default_content_type |
| | | if headers: |
| | | self._headers.update(headers) |
| | | |
| | | @property |
| | | def size(self) -> Optional[int]: |
| | | """Size of the payload in bytes. |
| | | |
| | | Returns the number of bytes that will be transmitted when the payload |
| | | is written. For string payloads, this is the size after encoding to bytes, |
| | | not the length of the string. |
| | | """ |
| | | return self._size |
| | | |
| | | @property |
| | | def filename(self) -> Optional[str]: |
| | | """Filename of the payload.""" |
| | | return self._filename |
| | | |
| | | @property |
| | | def headers(self) -> _CIMultiDict: |
| | | """Custom item headers""" |
| | | return self._headers |
| | | |
| | | @property |
| | | def _binary_headers(self) -> bytes: |
| | | return ( |
| | | "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( |
| | | "utf-8" |
| | | ) |
| | | + b"\r\n" |
| | | ) |
| | | |
| | | @property |
| | | def encoding(self) -> Optional[str]: |
| | | """Payload encoding""" |
| | | return self._encoding |
| | | |
| | | @property |
| | | def content_type(self) -> str: |
| | | """Content type""" |
| | | return self._headers[hdrs.CONTENT_TYPE] |
| | | |
| | | @property |
| | | def consumed(self) -> bool: |
| | | """Whether the payload has been consumed and cannot be reused.""" |
| | | return self._consumed |
| | | |
| | | @property |
| | | def autoclose(self) -> bool: |
| | | """ |
| | | Whether the payload can close itself automatically. |
| | | |
| | | Returns True if the payload has no file handles or resources that need |
| | | explicit closing. If False, callers must await close() to release resources. |
| | | """ |
| | | return self._autoclose |
| | | |
| | | def set_content_disposition( |
| | | self, |
| | | disptype: str, |
| | | quote_fields: bool = True, |
| | | _charset: str = "utf-8", |
| | | **params: Any, |
| | | ) -> None: |
| | | """Sets ``Content-Disposition`` header.""" |
| | | self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( |
| | | disptype, quote_fields=quote_fields, _charset=_charset, **params |
| | | ) |
| | | |
| | | @abstractmethod |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | """ |
| | | Return string representation of the value. |
| | | |
| | | This is named decode() to allow compatibility with bytes objects. |
| | | """ |
| | | |
| | | @abstractmethod |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | """ |
| | | Write payload to the writer stream. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | |
| | | This is a legacy method that writes the entire payload without length constraints. |
| | | |
| | | Important: |
| | | For new implementations, use write_with_length() instead of this method. |
| | | This method is maintained for backwards compatibility and will eventually |
| | | delegate to write_with_length(writer, None) in all implementations. |
| | | |
| | | All payload subclasses must override this method for backwards compatibility, |
| | | but new code should use write_with_length for more flexibility and control. |
| | | |
| | | """ |
| | | |
| | | # write_with_length is new in aiohttp 3.12 |
| | | # it should be overridden by subclasses |
| | | async def write_with_length( |
| | | self, writer: AbstractStreamWriter, content_length: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Write payload with a specific content length constraint. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | content_length: Maximum number of bytes to write (None for unlimited) |
| | | |
| | | This method allows writing payload content with a specific length constraint, |
| | | which is particularly useful for HTTP responses with Content-Length header. |
| | | |
| | | Note: |
| | | This is the base implementation that provides backwards compatibility |
| | | for subclasses that don't override this method. Specific payload types |
| | | should override this method to implement proper length-constrained writing. |
| | | |
| | | """ |
| | | # Backwards compatibility for subclasses that don't override this method |
| | | # and for the default implementation |
| | | await self.write(writer) |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This is a convenience method that calls decode() and encodes the result |
| | | to bytes using the specified encoding. |
| | | """ |
| | | # Use instance encoding if available, otherwise use parameter |
| | | actual_encoding = self._encoding or encoding |
| | | return self.decode(actual_encoding, errors).encode(actual_encoding) |
| | | |
| | | def _close(self) -> None: |
| | | """ |
| | | Async safe synchronous close operations for backwards compatibility. |
| | | |
| | | This method exists only for backwards compatibility with code that |
| | | needs to clean up payloads synchronously. In the future, we will |
| | | drop this method and only support the async close() method. |
| | | |
| | | WARNING: This method must be safe to call from within the event loop |
| | | without blocking. Subclasses should not perform any blocking I/O here. |
| | | |
| | | WARNING: This method must be called from within an event loop for |
| | | certain payload types (e.g., IOBasePayload). Calling it outside an |
| | | event loop may raise RuntimeError. |
| | | """ |
| | | # This is a no-op by default, but subclasses can override it |
| | | # for non-blocking cleanup operations. |
| | | |
| | | async def close(self) -> None: |
| | | """ |
| | | Close the payload if it holds any resources. |
| | | |
| | | IMPORTANT: This method must not await anything that might not finish |
| | | immediately, as it may be called during cleanup/cancellation. Schedule |
| | | any long-running operations without awaiting them. |
| | | |
| | | In the future, this will be the only close method supported. |
| | | """ |
| | | self._close() |
| | | |
| | | |
| | | class BytesPayload(Payload): |
| | | _value: bytes |
| | | # _consumed = False (inherited) - Bytes are immutable and can be reused |
| | | _autoclose = True # No file handle, just bytes in memory |
| | | |
| | | def __init__( |
| | | self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any |
| | | ) -> None: |
| | | if "content_type" not in kwargs: |
| | | kwargs["content_type"] = "application/octet-stream" |
| | | |
| | | super().__init__(value, *args, **kwargs) |
| | | |
| | | if isinstance(value, memoryview): |
| | | self._size = value.nbytes |
| | | elif isinstance(value, (bytes, bytearray)): |
| | | self._size = len(value) |
| | | else: |
| | | raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") |
| | | |
| | | if self._size > TOO_LARGE_BYTES_BODY: |
| | | kwargs = {"source": self} |
| | | warnings.warn( |
| | | "Sending a large body directly with raw bytes might" |
| | | " lock the event loop. You should probably pass an " |
| | | "io.BytesIO object instead", |
| | | ResourceWarning, |
| | | **kwargs, |
| | | ) |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | return self._value.decode(encoding, errors) |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This method returns the raw bytes content of the payload. |
| | | It is equivalent to accessing the _value attribute directly. |
| | | """ |
| | | return self._value |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | """ |
| | | Write the entire bytes payload to the writer stream. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | |
| | | This method writes the entire bytes content without any length constraint. |
| | | |
| | | Note: |
| | | For new implementations that need length control, use write_with_length(). |
| | | This method is maintained for backwards compatibility and is equivalent |
| | | to write_with_length(writer, None). |
| | | |
| | | """ |
| | | await writer.write(self._value) |
| | | |
| | | async def write_with_length( |
| | | self, writer: AbstractStreamWriter, content_length: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Write bytes payload with a specific content length constraint. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | content_length: Maximum number of bytes to write (None for unlimited) |
| | | |
| | | This method writes either the entire byte sequence or a slice of it |
| | | up to the specified content_length. For BytesPayload, this operation |
| | | is performed efficiently using array slicing. |
| | | |
| | | """ |
| | | if content_length is not None: |
| | | await writer.write(self._value[:content_length]) |
| | | else: |
| | | await writer.write(self._value) |
| | | |
| | | |
| | | class StringPayload(BytesPayload): |
| | | def __init__( |
| | | self, |
| | | value: str, |
| | | *args: Any, |
| | | encoding: Optional[str] = None, |
| | | content_type: Optional[str] = None, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | |
| | | if encoding is None: |
| | | if content_type is None: |
| | | real_encoding = "utf-8" |
| | | content_type = "text/plain; charset=utf-8" |
| | | else: |
| | | mimetype = parse_mimetype(content_type) |
| | | real_encoding = mimetype.parameters.get("charset", "utf-8") |
| | | else: |
| | | if content_type is None: |
| | | content_type = "text/plain; charset=%s" % encoding |
| | | real_encoding = encoding |
| | | |
| | | super().__init__( |
| | | value.encode(real_encoding), |
| | | encoding=real_encoding, |
| | | content_type=content_type, |
| | | *args, |
| | | **kwargs, |
| | | ) |
| | | |
| | | |
| | | class StringIOPayload(StringPayload): |
| | | def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: |
| | | super().__init__(value.read(), *args, **kwargs) |
| | | |
| | | |
| | | class IOBasePayload(Payload): |
| | | _value: io.IOBase |
| | | # _consumed = False (inherited) - File can be re-read from the same position |
| | | _start_position: Optional[int] = None |
| | | # _autoclose = False (inherited) - Has file handle that needs explicit closing |
| | | |
| | | def __init__( |
| | | self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any |
| | | ) -> None: |
| | | if "filename" not in kwargs: |
| | | kwargs["filename"] = guess_filename(value) |
| | | |
| | | super().__init__(value, *args, **kwargs) |
| | | |
| | | if self._filename is not None and disposition is not None: |
| | | if hdrs.CONTENT_DISPOSITION not in self.headers: |
| | | self.set_content_disposition(disposition, filename=self._filename) |
| | | |
| | | def _set_or_restore_start_position(self) -> None: |
| | | """Set or restore the start position of the file-like object.""" |
| | | if self._start_position is None: |
| | | try: |
| | | self._start_position = self._value.tell() |
| | | except (OSError, AttributeError): |
| | | self._consumed = True # Cannot seek, mark as consumed |
| | | return |
| | | try: |
| | | self._value.seek(self._start_position) |
| | | except (OSError, AttributeError): |
| | | # Failed to seek back - mark as consumed since we've already read |
| | | self._consumed = True |
| | | |
| | | def _read_and_available_len( |
| | | self, remaining_content_len: Optional[int] |
| | | ) -> Tuple[Optional[int], bytes]: |
| | | """ |
| | | Read the file-like object and return both its total size and the first chunk. |
| | | |
| | | Args: |
| | | remaining_content_len: Optional limit on how many bytes to read in this operation. |
| | | If None, READ_SIZE will be used as the default chunk size. |
| | | |
| | | Returns: |
| | | A tuple containing: |
| | | - The total size of the remaining unread content (None if size cannot be determined) |
| | | - The first chunk of bytes read from the file object |
| | | |
| | | This method is optimized to perform both size calculation and initial read |
| | | in a single operation, which is executed in a single executor job to minimize |
| | | context switches and file operations when streaming content. |
| | | |
| | | """ |
| | | self._set_or_restore_start_position() |
| | | size = self.size # Call size only once since it does I/O |
| | | return size, self._value.read( |
| | | min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) |
| | | ) |
| | | |
| | | def _read(self, remaining_content_len: Optional[int]) -> bytes: |
| | | """ |
| | | Read a chunk of data from the file-like object. |
| | | |
| | | Args: |
| | | remaining_content_len: Optional maximum number of bytes to read. |
| | | If None, READ_SIZE will be used as the default chunk size. |
| | | |
| | | Returns: |
| | | A chunk of bytes read from the file object, respecting the |
| | | remaining_content_len limit if specified. |
| | | |
| | | This method is used for subsequent reads during streaming after |
| | | the initial _read_and_available_len call has been made. |
| | | |
| | | """ |
| | | return self._value.read(remaining_content_len or READ_SIZE) # type: ignore[no-any-return] |
| | | |
| | | @property |
| | | def size(self) -> Optional[int]: |
| | | """ |
| | | Size of the payload in bytes. |
| | | |
| | | Returns the total size of the payload content from the initial position. |
| | | This ensures consistent Content-Length for requests, including 307/308 redirects |
| | | where the same payload instance is reused. |
| | | |
| | | Returns None if the size cannot be determined (e.g., for unseekable streams). |
| | | """ |
| | | try: |
| | | # Store the start position on first access. |
| | | # This is critical when the same payload instance is reused (e.g., 307/308 |
| | | # redirects). Without storing the initial position, after the payload is |
| | | # read once, the file position would be at EOF, which would cause the |
| | | # size calculation to return 0 (file_size - EOF position). |
| | | # By storing the start position, we ensure the size calculation always |
| | | # returns the correct total size for any subsequent use. |
| | | if self._start_position is None: |
| | | self._start_position = self._value.tell() |
| | | |
| | | # Return the total size from the start position |
| | | # This ensures Content-Length is correct even after reading |
| | | return os.fstat(self._value.fileno()).st_size - self._start_position |
| | | except (AttributeError, OSError): |
| | | return None |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | """ |
| | | Write the entire file-like payload to the writer stream. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | |
| | | This method writes the entire file content without any length constraint. |
| | | It delegates to write_with_length() with no length limit for implementation |
| | | consistency. |
| | | |
| | | Note: |
| | | For new implementations that need length control, use write_with_length() directly. |
| | | This method is maintained for backwards compatibility with existing code. |
| | | |
| | | """ |
| | | await self.write_with_length(writer, None) |
| | | |
| | | async def write_with_length( |
| | | self, writer: AbstractStreamWriter, content_length: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Write file-like payload with a specific content length constraint. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | content_length: Maximum number of bytes to write (None for unlimited) |
| | | |
| | | This method implements optimized streaming of file content with length constraints: |
| | | |
| | | 1. File reading is performed in a thread pool to avoid blocking the event loop |
| | | 2. Content is read and written in chunks to maintain memory efficiency |
| | | 3. Writing stops when either: |
| | | - All available file content has been written (when size is known) |
| | | - The specified content_length has been reached |
| | | 4. File resources are properly closed even if the operation is cancelled |
| | | |
| | | The implementation carefully handles both known-size and unknown-size payloads, |
| | | as well as constrained and unconstrained content lengths. |
| | | |
| | | """ |
| | | loop = asyncio.get_running_loop() |
| | | total_written_len = 0 |
| | | remaining_content_len = content_length |
| | | |
| | | # Get initial data and available length |
| | | available_len, chunk = await loop.run_in_executor( |
| | | None, self._read_and_available_len, remaining_content_len |
| | | ) |
| | | # Process data chunks until done |
| | | while chunk: |
| | | chunk_len = len(chunk) |
| | | |
| | | # Write data with or without length constraint |
| | | if remaining_content_len is None: |
| | | await writer.write(chunk) |
| | | else: |
| | | await writer.write(chunk[:remaining_content_len]) |
| | | remaining_content_len -= chunk_len |
| | | |
| | | total_written_len += chunk_len |
| | | |
| | | # Check if we're done writing |
| | | if self._should_stop_writing( |
| | | available_len, total_written_len, remaining_content_len |
| | | ): |
| | | return |
| | | |
| | | # Read next chunk |
| | | chunk = await loop.run_in_executor( |
| | | None, |
| | | self._read, |
| | | ( |
| | | min(READ_SIZE, remaining_content_len) |
| | | if remaining_content_len is not None |
| | | else READ_SIZE |
| | | ), |
| | | ) |
| | | |
| | | def _should_stop_writing( |
| | | self, |
| | | available_len: Optional[int], |
| | | total_written_len: int, |
| | | remaining_content_len: Optional[int], |
| | | ) -> bool: |
| | | """ |
| | | Determine if we should stop writing data. |
| | | |
| | | Args: |
| | | available_len: Known size of the payload if available (None if unknown) |
| | | total_written_len: Number of bytes already written |
| | | remaining_content_len: Remaining bytes to be written for content-length limited responses |
| | | |
| | | Returns: |
| | | True if we should stop writing data, based on either: |
| | | - Having written all available data (when size is known) |
| | | - Having written all requested content (when content-length is specified) |
| | | |
| | | """ |
| | | return (available_len is not None and total_written_len >= available_len) or ( |
| | | remaining_content_len is not None and remaining_content_len <= 0 |
| | | ) |
| | | |
| | | def _close(self) -> None: |
| | | """ |
| | | Async safe synchronous close operations for backwards compatibility. |
| | | |
| | | This method exists only for backwards |
| | | compatibility. Use the async close() method instead. |
| | | |
| | | WARNING: This method MUST be called from within an event loop. |
| | | Calling it outside an event loop will raise RuntimeError. |
| | | """ |
| | | # Skip if already consumed |
| | | if self._consumed: |
| | | return |
| | | self._consumed = True # Mark as consumed to prevent further writes |
| | | # Schedule file closing without awaiting to prevent cancellation issues |
| | | loop = asyncio.get_running_loop() |
| | | close_future = loop.run_in_executor(None, self._value.close) |
| | | # Hold a strong reference to the future to prevent it from being |
| | | # garbage collected before it completes. |
| | | _CLOSE_FUTURES.add(close_future) |
| | | close_future.add_done_callback(_CLOSE_FUTURES.remove) |
| | | |
| | | async def close(self) -> None: |
| | | """ |
| | | Close the payload if it holds any resources. |
| | | |
| | | IMPORTANT: This method must not await anything that might not finish |
| | | immediately, as it may be called during cleanup/cancellation. Schedule |
| | | any long-running operations without awaiting them. |
| | | """ |
| | | self._close() |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | """ |
| | | Return string representation of the value. |
| | | |
| | | WARNING: This method does blocking I/O and should not be called in the event loop. |
| | | """ |
| | | return self._read_all().decode(encoding, errors) |
| | | |
| | | def _read_all(self) -> bytes: |
| | | """Read the entire file-like object and return its content as bytes.""" |
| | | self._set_or_restore_start_position() |
| | | # Use readlines() to ensure we get all content |
| | | return b"".join(self._value.readlines()) |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This method reads the entire file content and returns it as bytes. |
| | | It is equivalent to reading the file-like object directly. |
| | | The file reading is performed in an executor to avoid blocking the event loop. |
| | | """ |
| | | loop = asyncio.get_running_loop() |
| | | return await loop.run_in_executor(None, self._read_all) |
| | | |
| | | |
| | | class TextIOPayload(IOBasePayload): |
| | | _value: io.TextIOBase |
| | | # _autoclose = False (inherited) - Has text file handle that needs explicit closing |
| | | |
| | | def __init__( |
| | | self, |
| | | value: TextIO, |
| | | *args: Any, |
| | | encoding: Optional[str] = None, |
| | | content_type: Optional[str] = None, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | |
| | | if encoding is None: |
| | | if content_type is None: |
| | | encoding = "utf-8" |
| | | content_type = "text/plain; charset=utf-8" |
| | | else: |
| | | mimetype = parse_mimetype(content_type) |
| | | encoding = mimetype.parameters.get("charset", "utf-8") |
| | | else: |
| | | if content_type is None: |
| | | content_type = "text/plain; charset=%s" % encoding |
| | | |
| | | super().__init__( |
| | | value, |
| | | content_type=content_type, |
| | | encoding=encoding, |
| | | *args, |
| | | **kwargs, |
| | | ) |
| | | |
| | | def _read_and_available_len( |
| | | self, remaining_content_len: Optional[int] |
| | | ) -> Tuple[Optional[int], bytes]: |
| | | """ |
| | | Read the text file-like object and return both its total size and the first chunk. |
| | | |
| | | Args: |
| | | remaining_content_len: Optional limit on how many bytes to read in this operation. |
| | | If None, READ_SIZE will be used as the default chunk size. |
| | | |
| | | Returns: |
| | | A tuple containing: |
| | | - The total size of the remaining unread content (None if size cannot be determined) |
| | | - The first chunk of bytes read from the file object, encoded using the payload's encoding |
| | | |
| | | This method is optimized to perform both size calculation and initial read |
| | | in a single operation, which is executed in a single executor job to minimize |
| | | context switches and file operations when streaming content. |
| | | |
| | | Note: |
| | | TextIOPayload handles encoding of the text content before writing it |
| | | to the stream. If no encoding is specified, UTF-8 is used as the default. |
| | | |
| | | """ |
| | | self._set_or_restore_start_position() |
| | | size = self.size |
| | | chunk = self._value.read( |
| | | min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) |
| | | ) |
| | | return size, chunk.encode(self._encoding) if self._encoding else chunk.encode() |
| | | |
| | | def _read(self, remaining_content_len: Optional[int]) -> bytes: |
| | | """ |
| | | Read a chunk of data from the text file-like object. |
| | | |
| | | Args: |
| | | remaining_content_len: Optional maximum number of bytes to read. |
| | | If None, READ_SIZE will be used as the default chunk size. |
| | | |
| | | Returns: |
| | | A chunk of bytes read from the file object and encoded using the payload's |
| | | encoding. The data is automatically converted from text to bytes. |
| | | |
| | | This method is used for subsequent reads during streaming after |
| | | the initial _read_and_available_len call has been made. It properly |
| | | handles text encoding, converting the text content to bytes using |
| | | the specified encoding (or UTF-8 if none was provided). |
| | | |
| | | """ |
| | | chunk = self._value.read(remaining_content_len or READ_SIZE) |
| | | return chunk.encode(self._encoding) if self._encoding else chunk.encode() |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | """ |
| | | Return string representation of the value. |
| | | |
| | | WARNING: This method does blocking I/O and should not be called in the event loop. |
| | | """ |
| | | self._set_or_restore_start_position() |
| | | return self._value.read() |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This method reads the entire text file content and returns it as bytes. |
| | | It encodes the text content using the specified encoding. |
| | | The file reading is performed in an executor to avoid blocking the event loop. |
| | | """ |
| | | loop = asyncio.get_running_loop() |
| | | |
| | | # Use instance encoding if available, otherwise use parameter |
| | | actual_encoding = self._encoding or encoding |
| | | |
| | | def _read_and_encode() -> bytes: |
| | | self._set_or_restore_start_position() |
| | | # TextIO read() always returns the full content |
| | | return self._value.read().encode(actual_encoding, errors) |
| | | |
| | | return await loop.run_in_executor(None, _read_and_encode) |
| | | |
| | | |
| | | class BytesIOPayload(IOBasePayload): |
| | | _value: io.BytesIO |
| | | _size: int # Always initialized in __init__ |
| | | _autoclose = True # BytesIO is in-memory, safe to auto-close |
| | | |
| | | def __init__(self, value: io.BytesIO, *args: Any, **kwargs: Any) -> None: |
| | | super().__init__(value, *args, **kwargs) |
| | | # Calculate size once during initialization |
| | | self._size = len(self._value.getbuffer()) - self._value.tell() |
| | | |
| | | @property |
| | | def size(self) -> int: |
| | | """Size of the payload in bytes. |
| | | |
| | | Returns the number of bytes in the BytesIO buffer that will be transmitted. |
| | | This is calculated once during initialization for efficiency. |
| | | """ |
| | | return self._size |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | self._set_or_restore_start_position() |
| | | return self._value.read().decode(encoding, errors) |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | return await self.write_with_length(writer, None) |
| | | |
| | | async def write_with_length( |
| | | self, writer: AbstractStreamWriter, content_length: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Write BytesIO payload with a specific content length constraint. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | content_length: Maximum number of bytes to write (None for unlimited) |
| | | |
| | | This implementation is specifically optimized for BytesIO objects: |
| | | |
| | | 1. Reads content in chunks to maintain memory efficiency |
| | | 2. Yields control back to the event loop periodically to prevent blocking |
| | | when dealing with large BytesIO objects |
| | | 3. Respects content_length constraints when specified |
| | | 4. Properly cleans up by closing the BytesIO object when done or on error |
| | | |
| | | The periodic yielding to the event loop is important for maintaining |
| | | responsiveness when processing large in-memory buffers. |
| | | |
| | | """ |
| | | self._set_or_restore_start_position() |
| | | loop_count = 0 |
| | | remaining_bytes = content_length |
| | | while chunk := self._value.read(READ_SIZE): |
| | | if loop_count > 0: |
| | | # Avoid blocking the event loop |
| | | # if they pass a large BytesIO object |
| | | # and we are not in the first iteration |
| | | # of the loop |
| | | await asyncio.sleep(0) |
| | | if remaining_bytes is None: |
| | | await writer.write(chunk) |
| | | else: |
| | | await writer.write(chunk[:remaining_bytes]) |
| | | remaining_bytes -= len(chunk) |
| | | if remaining_bytes <= 0: |
| | | return |
| | | loop_count += 1 |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This method reads the entire BytesIO content and returns it as bytes. |
| | | It is equivalent to accessing the _value attribute directly. |
| | | """ |
| | | self._set_or_restore_start_position() |
| | | return self._value.read() |
| | | |
| | | async def close(self) -> None: |
| | | """ |
| | | Close the BytesIO payload. |
| | | |
| | | This does nothing since BytesIO is in-memory and does not require explicit closing. |
| | | """ |
| | | |
| | | |
| | | class BufferedReaderPayload(IOBasePayload): |
| | | _value: io.BufferedIOBase |
| | | # _autoclose = False (inherited) - Has buffered file handle that needs explicit closing |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | self._set_or_restore_start_position() |
| | | return self._value.read().decode(encoding, errors) |
| | | |
| | | |
| | | class JsonPayload(BytesPayload): |
| | | def __init__( |
| | | self, |
| | | value: Any, |
| | | encoding: str = "utf-8", |
| | | content_type: str = "application/json", |
| | | dumps: JSONEncoder = json.dumps, |
| | | *args: Any, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | |
| | | super().__init__( |
| | | dumps(value).encode(encoding), |
| | | content_type=content_type, |
| | | encoding=encoding, |
| | | *args, |
| | | **kwargs, |
| | | ) |
| | | |
| | | |
| | | if TYPE_CHECKING: |
| | | from typing import AsyncIterable, AsyncIterator |
| | | |
| | | _AsyncIterator = AsyncIterator[bytes] |
| | | _AsyncIterable = AsyncIterable[bytes] |
| | | else: |
| | | from collections.abc import AsyncIterable, AsyncIterator |
| | | |
| | | _AsyncIterator = AsyncIterator |
| | | _AsyncIterable = AsyncIterable |
| | | |
| | | |
| | | class AsyncIterablePayload(Payload): |
| | | |
| | | _iter: Optional[_AsyncIterator] = None |
| | | _value: _AsyncIterable |
| | | _cached_chunks: Optional[List[bytes]] = None |
| | | # _consumed stays False to allow reuse with cached content |
| | | _autoclose = True # Iterator doesn't need explicit closing |
| | | |
| | | def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: |
| | | if not isinstance(value, AsyncIterable): |
| | | raise TypeError( |
| | | "value argument must support " |
| | | "collections.abc.AsyncIterable interface, " |
| | | "got {!r}".format(type(value)) |
| | | ) |
| | | |
| | | if "content_type" not in kwargs: |
| | | kwargs["content_type"] = "application/octet-stream" |
| | | |
| | | super().__init__(value, *args, **kwargs) |
| | | |
| | | self._iter = value.__aiter__() |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | """ |
| | | Write the entire async iterable payload to the writer stream. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | |
| | | This method iterates through the async iterable and writes each chunk |
| | | to the writer without any length constraint. |
| | | |
| | | Note: |
| | | For new implementations that need length control, use write_with_length() directly. |
| | | This method is maintained for backwards compatibility with existing code. |
| | | |
| | | """ |
| | | await self.write_with_length(writer, None) |
| | | |
| | | async def write_with_length( |
| | | self, writer: AbstractStreamWriter, content_length: Optional[int] |
| | | ) -> None: |
| | | """ |
| | | Write async iterable payload with a specific content length constraint. |
| | | |
| | | Args: |
| | | writer: An AbstractStreamWriter instance that handles the actual writing |
| | | content_length: Maximum number of bytes to write (None for unlimited) |
| | | |
| | | This implementation handles streaming of async iterable content with length constraints: |
| | | |
| | | 1. If cached chunks are available, writes from them |
| | | 2. Otherwise iterates through the async iterable one chunk at a time |
| | | 3. Respects content_length constraints when specified |
| | | 4. Does NOT generate cache - that's done by as_bytes() |
| | | |
| | | """ |
| | | # If we have cached chunks, use them |
| | | if self._cached_chunks is not None: |
| | | remaining_bytes = content_length |
| | | for chunk in self._cached_chunks: |
| | | if remaining_bytes is None: |
| | | await writer.write(chunk) |
| | | elif remaining_bytes > 0: |
| | | await writer.write(chunk[:remaining_bytes]) |
| | | remaining_bytes -= len(chunk) |
| | | else: |
| | | break |
| | | return |
| | | |
| | | # If iterator is exhausted and we don't have cached chunks, nothing to write |
| | | if self._iter is None: |
| | | return |
| | | |
| | | # Stream from the iterator |
| | | remaining_bytes = content_length |
| | | |
| | | try: |
| | | while True: |
| | | if sys.version_info >= (3, 10): |
| | | chunk = await anext(self._iter) |
| | | else: |
| | | chunk = await self._iter.__anext__() |
| | | if remaining_bytes is None: |
| | | await writer.write(chunk) |
| | | # If we have a content length limit |
| | | elif remaining_bytes > 0: |
| | | await writer.write(chunk[:remaining_bytes]) |
| | | remaining_bytes -= len(chunk) |
| | | # We still want to exhaust the iterator even |
| | | # if we have reached the content length limit |
| | | # since the file handle may not get closed by |
| | | # the iterator if we don't do this |
| | | except StopAsyncIteration: |
| | | # Iterator is exhausted |
| | | self._iter = None |
| | | self._consumed = True # Mark as consumed when streamed without caching |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | """Decode the payload content as a string if cached chunks are available.""" |
| | | if self._cached_chunks is not None: |
| | | return b"".join(self._cached_chunks).decode(encoding, errors) |
| | | raise TypeError("Unable to decode - content not cached. Call as_bytes() first.") |
| | | |
| | | async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: |
| | | """ |
| | | Return bytes representation of the value. |
| | | |
| | | This method reads the entire async iterable content and returns it as bytes. |
| | | It generates and caches the chunks for future reuse. |
| | | """ |
| | | # If we have cached chunks, return them joined |
| | | if self._cached_chunks is not None: |
| | | return b"".join(self._cached_chunks) |
| | | |
| | | # If iterator is exhausted and no cache, return empty |
| | | if self._iter is None: |
| | | return b"" |
| | | |
| | | # Read all chunks and cache them |
| | | chunks: List[bytes] = [] |
| | | async for chunk in self._iter: |
| | | chunks.append(chunk) |
| | | |
| | | # Iterator is exhausted, cache the chunks |
| | | self._iter = None |
| | | self._cached_chunks = chunks |
| | | # Keep _consumed as False to allow reuse with cached chunks |
| | | |
| | | return b"".join(chunks) |
| | | |
| | | |
| | | class StreamReaderPayload(AsyncIterablePayload): |
| | | def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: |
| | | super().__init__(value.iter_any(), *args, **kwargs) |
| | | |
| | | |
| | | PAYLOAD_REGISTRY = PayloadRegistry() |
| | | PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) |
| | | PAYLOAD_REGISTRY.register(StringPayload, str) |
| | | PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) |
| | | PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) |
| | | PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) |
| | | PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) |
| | | PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) |
| | | PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) |
| | | # try_last for giving a chance to more specialized async interables like |
| | | # multipart.BodyPartReaderPayload override the default |
| | | PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) |
| New file |
| | |
| | | """ |
| | | Payload implementation for coroutines as data provider. |
| | | |
| | | As a simple case, you can upload data from file:: |
| | | |
| | | @aiohttp.streamer |
| | | async def file_sender(writer, file_name=None): |
| | | with open(file_name, 'rb') as f: |
| | | chunk = f.read(2**16) |
| | | while chunk: |
| | | await writer.write(chunk) |
| | | |
| | | chunk = f.read(2**16) |
| | | |
| | | Then you can use `file_sender` like this: |
| | | |
| | | async with session.post('http://httpbin.org/post', |
| | | data=file_sender(file_name='huge_file')) as resp: |
| | | print(await resp.text()) |
| | | |
| | | ..note:: Coroutine must accept `writer` as first argument |
| | | |
| | | """ |
| | | |
| | | import types |
| | | import warnings |
| | | from typing import Any, Awaitable, Callable, Dict, Tuple |
| | | |
| | | from .abc import AbstractStreamWriter |
| | | from .payload import Payload, payload_type |
| | | |
| | | __all__ = ("streamer",) |
| | | |
| | | |
| | | class _stream_wrapper: |
| | | def __init__( |
| | | self, |
| | | coro: Callable[..., Awaitable[None]], |
| | | args: Tuple[Any, ...], |
| | | kwargs: Dict[str, Any], |
| | | ) -> None: |
| | | self.coro = types.coroutine(coro) |
| | | self.args = args |
| | | self.kwargs = kwargs |
| | | |
| | | async def __call__(self, writer: AbstractStreamWriter) -> None: |
| | | await self.coro(writer, *self.args, **self.kwargs) |
| | | |
| | | |
| | | class streamer: |
| | | def __init__(self, coro: Callable[..., Awaitable[None]]) -> None: |
| | | warnings.warn( |
| | | "@streamer is deprecated, use async generators instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | self.coro = coro |
| | | |
| | | def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper: |
| | | return _stream_wrapper(self.coro, args, kwargs) |
| | | |
| | | |
| | | @payload_type(_stream_wrapper) |
| | | class StreamWrapperPayload(Payload): |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | await self._value(writer) |
| | | |
| | | def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: |
| | | raise TypeError("Unable to decode.") |
| | | |
| | | |
| | | @payload_type(streamer) |
| | | class StreamPayload(StreamWrapperPayload): |
| | | def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None: |
| | | super().__init__(value(), *args, **kwargs) |
| | | |
| | | async def write(self, writer: AbstractStreamWriter) -> None: |
| | | await self._value(writer) |
| New file |
| | |
| | | import asyncio |
| | | import contextlib |
| | | import inspect |
| | | import warnings |
| | | from typing import ( |
| | | Any, |
| | | Awaitable, |
| | | Callable, |
| | | Dict, |
| | | Iterator, |
| | | Optional, |
| | | Protocol, |
| | | Union, |
| | | overload, |
| | | ) |
| | | |
| | | import pytest |
| | | |
| | | from .test_utils import ( |
| | | BaseTestServer, |
| | | RawTestServer, |
| | | TestClient, |
| | | TestServer, |
| | | loop_context, |
| | | setup_test_loop, |
| | | teardown_test_loop, |
| | | unused_port as _unused_port, |
| | | ) |
| | | from .web import Application, BaseRequest, Request |
| | | from .web_protocol import _RequestHandler |
| | | |
| | | try: |
| | | import uvloop |
| | | except ImportError: # pragma: no cover |
| | | uvloop = None # type: ignore[assignment] |
| | | |
| | | |
| | | class AiohttpClient(Protocol): |
| | | @overload |
| | | async def __call__( |
| | | self, |
| | | __param: Application, |
| | | *, |
| | | server_kwargs: Optional[Dict[str, Any]] = None, |
| | | **kwargs: Any, |
| | | ) -> TestClient[Request, Application]: ... |
| | | @overload |
| | | async def __call__( |
| | | self, |
| | | __param: BaseTestServer, |
| | | *, |
| | | server_kwargs: Optional[Dict[str, Any]] = None, |
| | | **kwargs: Any, |
| | | ) -> TestClient[BaseRequest, None]: ... |
| | | |
| | | |
| | | class AiohttpServer(Protocol): |
| | | def __call__( |
| | | self, app: Application, *, port: Optional[int] = None, **kwargs: Any |
| | | ) -> Awaitable[TestServer]: ... |
| | | |
| | | |
| | | class AiohttpRawServer(Protocol): |
| | | def __call__( |
| | | self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any |
| | | ) -> Awaitable[RawTestServer]: ... |
| | | |
| | | |
| | | def pytest_addoption(parser): # type: ignore[no-untyped-def] |
| | | parser.addoption( |
| | | "--aiohttp-fast", |
| | | action="store_true", |
| | | default=False, |
| | | help="run tests faster by disabling extra checks", |
| | | ) |
| | | parser.addoption( |
| | | "--aiohttp-loop", |
| | | action="store", |
| | | default="pyloop", |
| | | help="run tests with specific loop: pyloop, uvloop or all", |
| | | ) |
| | | parser.addoption( |
| | | "--aiohttp-enable-loop-debug", |
| | | action="store_true", |
| | | default=False, |
| | | help="enable event loop debug mode", |
| | | ) |
| | | |
| | | |
| | | def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] |
| | | """Set up pytest fixture. |
| | | |
| | | Allow fixtures to be coroutines. Run coroutine fixtures in an event loop. |
| | | """ |
| | | func = fixturedef.func |
| | | |
| | | if inspect.isasyncgenfunction(func): |
| | | # async generator fixture |
| | | is_async_gen = True |
| | | elif inspect.iscoroutinefunction(func): |
| | | # regular async fixture |
| | | is_async_gen = False |
| | | else: |
| | | # not an async fixture, nothing to do |
| | | return |
| | | |
| | | strip_request = False |
| | | if "request" not in fixturedef.argnames: |
| | | fixturedef.argnames += ("request",) |
| | | strip_request = True |
| | | |
| | | def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] |
| | | request = kwargs["request"] |
| | | if strip_request: |
| | | del kwargs["request"] |
| | | |
| | | # if neither the fixture nor the test use the 'loop' fixture, |
| | | # 'getfixturevalue' will fail because the test is not parameterized |
| | | # (this can be removed someday if 'loop' is no longer parameterized) |
| | | if "loop" not in request.fixturenames: |
| | | raise Exception( |
| | | "Asynchronous fixtures must depend on the 'loop' fixture or " |
| | | "be used in tests depending from it." |
| | | ) |
| | | |
| | | _loop = request.getfixturevalue("loop") |
| | | |
| | | if is_async_gen: |
| | | # for async generators, we need to advance the generator once, |
| | | # then advance it again in a finalizer |
| | | gen = func(*args, **kwargs) |
| | | |
| | | def finalizer(): # type: ignore[no-untyped-def] |
| | | try: |
| | | return _loop.run_until_complete(gen.__anext__()) |
| | | except StopAsyncIteration: |
| | | pass |
| | | |
| | | request.addfinalizer(finalizer) |
| | | return _loop.run_until_complete(gen.__anext__()) |
| | | else: |
| | | return _loop.run_until_complete(func(*args, **kwargs)) |
| | | |
| | | fixturedef.func = wrapper |
| | | |
| | | |
| | | @pytest.fixture |
| | | def fast(request): # type: ignore[no-untyped-def] |
| | | """--fast config option""" |
| | | return request.config.getoption("--aiohttp-fast") |
| | | |
| | | |
| | | @pytest.fixture |
| | | def loop_debug(request): # type: ignore[no-untyped-def] |
| | | """--enable-loop-debug config option""" |
| | | return request.config.getoption("--aiohttp-enable-loop-debug") |
| | | |
| | | |
| | | @contextlib.contextmanager |
| | | def _runtime_warning_context(): # type: ignore[no-untyped-def] |
| | | """Context manager which checks for RuntimeWarnings. |
| | | |
| | | This exists specifically to |
| | | avoid "coroutine 'X' was never awaited" warnings being missed. |
| | | |
| | | If RuntimeWarnings occur in the context a RuntimeError is raised. |
| | | """ |
| | | with warnings.catch_warnings(record=True) as _warnings: |
| | | yield |
| | | rw = [ |
| | | "{w.filename}:{w.lineno}:{w.message}".format(w=w) |
| | | for w in _warnings |
| | | if w.category == RuntimeWarning |
| | | ] |
| | | if rw: |
| | | raise RuntimeError( |
| | | "{} Runtime Warning{},\n{}".format( |
| | | len(rw), "" if len(rw) == 1 else "s", "\n".join(rw) |
| | | ) |
| | | ) |
| | | |
| | | |
| | | @contextlib.contextmanager |
| | | def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def] |
| | | """Passthrough loop context. |
| | | |
| | | Sets up and tears down a loop unless one is passed in via the loop |
| | | argument when it's passed straight through. |
| | | """ |
| | | if loop: |
| | | # loop already exists, pass it straight through |
| | | yield loop |
| | | else: |
| | | # this shadows loop_context's standard behavior |
| | | loop = setup_test_loop() |
| | | yield loop |
| | | teardown_test_loop(loop, fast=fast) |
| | | |
| | | |
| | | def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def] |
| | | """Fix pytest collecting for coroutines.""" |
| | | if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj): |
| | | return list(collector._genfunctions(name, obj)) |
| | | |
| | | |
| | | def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def] |
| | | """Run coroutines in an event loop instead of a normal function call.""" |
| | | fast = pyfuncitem.config.getoption("--aiohttp-fast") |
| | | if inspect.iscoroutinefunction(pyfuncitem.function): |
| | | existing_loop = ( |
| | | pyfuncitem.funcargs.get("proactor_loop") |
| | | or pyfuncitem.funcargs.get("selector_loop") |
| | | or pyfuncitem.funcargs.get("uvloop_loop") |
| | | or pyfuncitem.funcargs.get("loop", None) |
| | | ) |
| | | |
| | | with _runtime_warning_context(): |
| | | with _passthrough_loop_context(existing_loop, fast=fast) as _loop: |
| | | testargs = { |
| | | arg: pyfuncitem.funcargs[arg] |
| | | for arg in pyfuncitem._fixtureinfo.argnames |
| | | } |
| | | _loop.run_until_complete(pyfuncitem.obj(**testargs)) |
| | | |
| | | return True |
| | | |
| | | |
| | | def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def] |
| | | if "loop_factory" not in metafunc.fixturenames: |
| | | return |
| | | |
| | | loops = metafunc.config.option.aiohttp_loop |
| | | avail_factories: dict[str, Callable[[], asyncio.AbstractEventLoop]] |
| | | avail_factories = {"pyloop": asyncio.new_event_loop} |
| | | |
| | | if uvloop is not None: # pragma: no cover |
| | | avail_factories["uvloop"] = uvloop.new_event_loop |
| | | |
| | | if loops == "all": |
| | | loops = "pyloop,uvloop?" |
| | | |
| | | factories = {} # type: ignore[var-annotated] |
| | | for name in loops.split(","): |
| | | required = not name.endswith("?") |
| | | name = name.strip(" ?") |
| | | if name not in avail_factories: # pragma: no cover |
| | | if required: |
| | | raise ValueError( |
| | | "Unknown loop '%s', available loops: %s" |
| | | % (name, list(factories.keys())) |
| | | ) |
| | | else: |
| | | continue |
| | | factories[name] = avail_factories[name] |
| | | metafunc.parametrize( |
| | | "loop_factory", list(factories.values()), ids=list(factories.keys()) |
| | | ) |
| | | |
| | | |
| | | @pytest.fixture |
| | | def loop( |
| | | loop_factory: Callable[[], asyncio.AbstractEventLoop], |
| | | fast: bool, |
| | | loop_debug: bool, |
| | | ) -> Iterator[asyncio.AbstractEventLoop]: |
| | | """Return an instance of the event loop.""" |
| | | with loop_context(loop_factory, fast=fast) as _loop: |
| | | if loop_debug: |
| | | _loop.set_debug(True) # pragma: no cover |
| | | asyncio.set_event_loop(_loop) |
| | | yield _loop |
| | | |
| | | |
| | | @pytest.fixture |
| | | def proactor_loop() -> Iterator[asyncio.AbstractEventLoop]: |
| | | factory = asyncio.ProactorEventLoop # type: ignore[attr-defined] |
| | | |
| | | with loop_context(factory) as _loop: |
| | | asyncio.set_event_loop(_loop) |
| | | yield _loop |
| | | |
| | | |
| | | @pytest.fixture |
| | | def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]: |
| | | warnings.warn( |
| | | "Deprecated, use aiohttp_unused_port fixture instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return aiohttp_unused_port |
| | | |
| | | |
| | | @pytest.fixture |
| | | def aiohttp_unused_port() -> Callable[[], int]: |
| | | """Return a port that is unused on the current host.""" |
| | | return _unused_port |
| | | |
| | | |
| | | @pytest.fixture |
| | | def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: |
| | | """Factory to create a TestServer instance, given an app. |
| | | |
| | | aiohttp_server(app, **kwargs) |
| | | """ |
| | | servers = [] |
| | | |
| | | async def go( |
| | | app: Application, |
| | | *, |
| | | host: str = "127.0.0.1", |
| | | port: Optional[int] = None, |
| | | **kwargs: Any, |
| | | ) -> TestServer: |
| | | server = TestServer(app, host=host, port=port) |
| | | await server.start_server(loop=loop, **kwargs) |
| | | servers.append(server) |
| | | return server |
| | | |
| | | yield go |
| | | |
| | | async def finalize() -> None: |
| | | while servers: |
| | | await servers.pop().close() |
| | | |
| | | loop.run_until_complete(finalize()) |
| | | |
| | | |
| | | @pytest.fixture |
| | | def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover |
| | | warnings.warn( |
| | | "Deprecated, use aiohttp_server fixture instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return aiohttp_server |
| | | |
| | | |
| | | @pytest.fixture |
| | | def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]: |
| | | """Factory to create a RawTestServer instance, given a web handler. |
| | | |
| | | aiohttp_raw_server(handler, **kwargs) |
| | | """ |
| | | servers = [] |
| | | |
| | | async def go( |
| | | handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any |
| | | ) -> RawTestServer: |
| | | server = RawTestServer(handler, port=port) |
| | | await server.start_server(loop=loop, **kwargs) |
| | | servers.append(server) |
| | | return server |
| | | |
| | | yield go |
| | | |
| | | async def finalize() -> None: |
| | | while servers: |
| | | await servers.pop().close() |
| | | |
| | | loop.run_until_complete(finalize()) |
| | | |
| | | |
| | | @pytest.fixture |
| | | def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover |
| | | aiohttp_raw_server, |
| | | ): |
| | | warnings.warn( |
| | | "Deprecated, use aiohttp_raw_server fixture instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return aiohttp_raw_server |
| | | |
| | | |
| | | @pytest.fixture |
| | | def aiohttp_client(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpClient]: |
| | | """Factory to create a TestClient instance. |
| | | |
| | | aiohttp_client(app, **kwargs) |
| | | aiohttp_client(server, **kwargs) |
| | | aiohttp_client(raw_server, **kwargs) |
| | | """ |
| | | clients = [] |
| | | |
| | | @overload |
| | | async def go( |
| | | __param: Application, |
| | | *, |
| | | server_kwargs: Optional[Dict[str, Any]] = None, |
| | | **kwargs: Any, |
| | | ) -> TestClient[Request, Application]: ... |
| | | |
| | | @overload |
| | | async def go( |
| | | __param: BaseTestServer, |
| | | *, |
| | | server_kwargs: Optional[Dict[str, Any]] = None, |
| | | **kwargs: Any, |
| | | ) -> TestClient[BaseRequest, None]: ... |
| | | |
| | | async def go( |
| | | __param: Union[Application, BaseTestServer], |
| | | *args: Any, |
| | | server_kwargs: Optional[Dict[str, Any]] = None, |
| | | **kwargs: Any, |
| | | ) -> TestClient[Any, Any]: |
| | | if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type] |
| | | __param, (Application, BaseTestServer) |
| | | ): |
| | | __param = __param(loop, *args, **kwargs) |
| | | kwargs = {} |
| | | else: |
| | | assert not args, "args should be empty" |
| | | |
| | | if isinstance(__param, Application): |
| | | server_kwargs = server_kwargs or {} |
| | | server = TestServer(__param, loop=loop, **server_kwargs) |
| | | client = TestClient(server, loop=loop, **kwargs) |
| | | elif isinstance(__param, BaseTestServer): |
| | | client = TestClient(__param, loop=loop, **kwargs) |
| | | else: |
| | | raise ValueError("Unknown argument type: %r" % type(__param)) |
| | | |
| | | await client.start_server() |
| | | clients.append(client) |
| | | return client |
| | | |
| | | yield go |
| | | |
| | | async def finalize() -> None: |
| | | while clients: |
| | | await clients.pop().close() |
| | | |
| | | loop.run_until_complete(finalize()) |
| | | |
| | | |
| | | @pytest.fixture |
| | | def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover |
| | | warnings.warn( |
| | | "Deprecated, use aiohttp_client fixture instead", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return aiohttp_client |
| New file |
| | |
| | | import asyncio |
| | | import socket |
| | | import weakref |
| | | from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union |
| | | |
| | | from .abc import AbstractResolver, ResolveResult |
| | | |
| | | __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") |
| | | |
| | | |
| | | try: |
| | | import aiodns |
| | | |
| | | aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") |
| | | except ImportError: # pragma: no cover |
| | | aiodns = None # type: ignore[assignment] |
| | | aiodns_default = False |
| | | |
| | | |
| | | _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV |
| | | _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV |
| | | _AI_ADDRCONFIG = socket.AI_ADDRCONFIG |
| | | if hasattr(socket, "AI_MASK"): |
| | | _AI_ADDRCONFIG &= socket.AI_MASK |
| | | |
| | | |
| | | class ThreadedResolver(AbstractResolver): |
| | | """Threaded resolver. |
| | | |
| | | Uses an Executor for synchronous getaddrinfo() calls. |
| | | concurrent.futures.ThreadPoolExecutor is used by default. |
| | | """ |
| | | |
| | | def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
| | | self._loop = loop or asyncio.get_running_loop() |
| | | |
| | | async def resolve( |
| | | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| | | ) -> List[ResolveResult]: |
| | | infos = await self._loop.getaddrinfo( |
| | | host, |
| | | port, |
| | | type=socket.SOCK_STREAM, |
| | | family=family, |
| | | flags=_AI_ADDRCONFIG, |
| | | ) |
| | | |
| | | hosts: List[ResolveResult] = [] |
| | | for family, _, proto, _, address in infos: |
| | | if family == socket.AF_INET6: |
| | | if len(address) < 3: |
| | | # IPv6 is not supported by Python build, |
| | | # or IPv6 is not enabled in the host |
| | | continue |
| | | if address[3]: |
| | | # This is essential for link-local IPv6 addresses. |
| | | # LL IPv6 is a VERY rare case. Strictly speaking, we should use |
| | | # getnameinfo() unconditionally, but performance makes sense. |
| | | resolved_host, _port = await self._loop.getnameinfo( |
| | | address, _NAME_SOCKET_FLAGS |
| | | ) |
| | | port = int(_port) |
| | | else: |
| | | resolved_host, port = address[:2] |
| | | else: # IPv4 |
| | | assert family == socket.AF_INET |
| | | resolved_host, port = address # type: ignore[misc] |
| | | hosts.append( |
| | | ResolveResult( |
| | | hostname=host, |
| | | host=resolved_host, |
| | | port=port, |
| | | family=family, |
| | | proto=proto, |
| | | flags=_NUMERIC_SOCKET_FLAGS, |
| | | ) |
| | | ) |
| | | |
| | | return hosts |
| | | |
| | | async def close(self) -> None: |
| | | pass |
| | | |
| | | |
| | | class AsyncResolver(AbstractResolver): |
| | | """Use the `aiodns` package to make asynchronous DNS lookups""" |
| | | |
| | | def __init__( |
| | | self, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | *args: Any, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | if aiodns is None: |
| | | raise RuntimeError("Resolver requires aiodns library") |
| | | |
| | | self._loop = loop or asyncio.get_running_loop() |
| | | self._manager: Optional[_DNSResolverManager] = None |
| | | # If custom args are provided, create a dedicated resolver instance |
| | | # This means each AsyncResolver with custom args gets its own |
| | | # aiodns.DNSResolver instance |
| | | if args or kwargs: |
| | | self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| | | return |
| | | # Use the shared resolver from the manager for default arguments |
| | | self._manager = _DNSResolverManager() |
| | | self._resolver = self._manager.get_resolver(self, self._loop) |
| | | |
| | | if not hasattr(self._resolver, "gethostbyname"): |
| | | # aiodns 1.1 is not available, fallback to DNSResolver.query |
| | | self.resolve = self._resolve_with_query # type: ignore |
| | | |
| | | async def resolve( |
| | | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| | | ) -> List[ResolveResult]: |
| | | try: |
| | | resp = await self._resolver.getaddrinfo( |
| | | host, |
| | | port=port, |
| | | type=socket.SOCK_STREAM, |
| | | family=family, |
| | | flags=_AI_ADDRCONFIG, |
| | | ) |
| | | except aiodns.error.DNSError as exc: |
| | | msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| | | raise OSError(None, msg) from exc |
| | | hosts: List[ResolveResult] = [] |
| | | for node in resp.nodes: |
| | | address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr |
| | | family = node.family |
| | | if family == socket.AF_INET6: |
| | | if len(address) > 3 and address[3]: |
| | | # This is essential for link-local IPv6 addresses. |
| | | # LL IPv6 is a VERY rare case. Strictly speaking, we should use |
| | | # getnameinfo() unconditionally, but performance makes sense. |
| | | result = await self._resolver.getnameinfo( |
| | | (address[0].decode("ascii"), *address[1:]), |
| | | _NAME_SOCKET_FLAGS, |
| | | ) |
| | | resolved_host = result.node |
| | | else: |
| | | resolved_host = address[0].decode("ascii") |
| | | port = address[1] |
| | | else: # IPv4 |
| | | assert family == socket.AF_INET |
| | | resolved_host = address[0].decode("ascii") |
| | | port = address[1] |
| | | hosts.append( |
| | | ResolveResult( |
| | | hostname=host, |
| | | host=resolved_host, |
| | | port=port, |
| | | family=family, |
| | | proto=0, |
| | | flags=_NUMERIC_SOCKET_FLAGS, |
| | | ) |
| | | ) |
| | | |
| | | if not hosts: |
| | | raise OSError(None, "DNS lookup failed") |
| | | |
| | | return hosts |
| | | |
| | | async def _resolve_with_query( |
| | | self, host: str, port: int = 0, family: int = socket.AF_INET |
| | | ) -> List[Dict[str, Any]]: |
| | | qtype: Final = "AAAA" if family == socket.AF_INET6 else "A" |
| | | |
| | | try: |
| | | resp = await self._resolver.query(host, qtype) |
| | | except aiodns.error.DNSError as exc: |
| | | msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| | | raise OSError(None, msg) from exc |
| | | |
| | | hosts = [] |
| | | for rr in resp: |
| | | hosts.append( |
| | | { |
| | | "hostname": host, |
| | | "host": rr.host, |
| | | "port": port, |
| | | "family": family, |
| | | "proto": 0, |
| | | "flags": socket.AI_NUMERICHOST, |
| | | } |
| | | ) |
| | | |
| | | if not hosts: |
| | | raise OSError(None, "DNS lookup failed") |
| | | |
| | | return hosts |
| | | |
| | | async def close(self) -> None: |
| | | if self._manager: |
| | | # Release the resolver from the manager if using the shared resolver |
| | | self._manager.release_resolver(self, self._loop) |
| | | self._manager = None # Clear reference to manager |
| | | self._resolver = None # type: ignore[assignment] # Clear reference to resolver |
| | | return |
| | | # Otherwise cancel our dedicated resolver |
| | | if self._resolver is not None: |
| | | self._resolver.cancel() |
| | | self._resolver = None # type: ignore[assignment] # Clear reference |
| | | |
| | | |
| | | class _DNSResolverManager: |
| | | """Manager for aiodns.DNSResolver objects. |
| | | |
| | | This class manages shared aiodns.DNSResolver instances |
| | | with no custom arguments across different event loops. |
| | | """ |
| | | |
| | | _instance: Optional["_DNSResolverManager"] = None |
| | | |
| | | def __new__(cls) -> "_DNSResolverManager": |
| | | if cls._instance is None: |
| | | cls._instance = super().__new__(cls) |
| | | cls._instance._init() |
| | | return cls._instance |
| | | |
| | | def _init(self) -> None: |
| | | # Use WeakKeyDictionary to allow event loops to be garbage collected |
| | | self._loop_data: weakref.WeakKeyDictionary[ |
| | | asyncio.AbstractEventLoop, |
| | | tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
| | | ] = weakref.WeakKeyDictionary() |
| | | |
| | | def get_resolver( |
| | | self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| | | ) -> "aiodns.DNSResolver": |
| | | """Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
| | | |
| | | Args: |
| | | client: The AsyncResolver instance requesting the resolver. |
| | | This is required to track resolver usage. |
| | | loop: The event loop to use for the resolver. |
| | | """ |
| | | # Create a new resolver and client set for this loop if it doesn't exist |
| | | if loop not in self._loop_data: |
| | | resolver = aiodns.DNSResolver(loop=loop) |
| | | client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
| | | self._loop_data[loop] = (resolver, client_set) |
| | | else: |
| | | # Get the existing resolver and client set |
| | | resolver, client_set = self._loop_data[loop] |
| | | |
| | | # Register this client with the loop |
| | | client_set.add(client) |
| | | return resolver |
| | | |
| | | def release_resolver( |
| | | self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| | | ) -> None: |
| | | """Release the resolver for an AsyncResolver client when it's closed. |
| | | |
| | | Args: |
| | | client: The AsyncResolver instance to release. |
| | | loop: The event loop the resolver was using. |
| | | """ |
| | | # Remove client from its loop's tracking |
| | | current_loop_data = self._loop_data.get(loop) |
| | | if current_loop_data is None: |
| | | return |
| | | resolver, client_set = current_loop_data |
| | | client_set.discard(client) |
| | | # If no more clients for this loop, cancel and remove its resolver |
| | | if not client_set: |
| | | if resolver is not None: |
| | | resolver.cancel() |
| | | del self._loop_data[loop] |
| | | |
| | | |
| | | _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] |
| | | DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver |
| New file |
| | |
| | | import asyncio |
| | | import collections |
| | | import warnings |
| | | from typing import ( |
| | | Awaitable, |
| | | Callable, |
| | | Deque, |
| | | Final, |
| | | Generic, |
| | | List, |
| | | Optional, |
| | | Tuple, |
| | | TypeVar, |
| | | ) |
| | | |
| | | from .base_protocol import BaseProtocol |
| | | from .helpers import ( |
| | | _EXC_SENTINEL, |
| | | BaseTimerContext, |
| | | TimerNoop, |
| | | set_exception, |
| | | set_result, |
| | | ) |
| | | from .log import internal_logger |
| | | |
| | | __all__ = ( |
| | | "EMPTY_PAYLOAD", |
| | | "EofStream", |
| | | "StreamReader", |
| | | "DataQueue", |
| | | ) |
| | | |
| | | _T = TypeVar("_T") |
| | | |
| | | |
| | | class EofStream(Exception): |
| | | """eof stream indication.""" |
| | | |
| | | |
| | | class AsyncStreamIterator(Generic[_T]): |
| | | |
| | | __slots__ = ("read_func",) |
| | | |
| | | def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None: |
| | | self.read_func = read_func |
| | | |
| | | def __aiter__(self) -> "AsyncStreamIterator[_T]": |
| | | return self |
| | | |
| | | async def __anext__(self) -> _T: |
| | | try: |
| | | rv = await self.read_func() |
| | | except EofStream: |
| | | raise StopAsyncIteration |
| | | if rv == b"": |
| | | raise StopAsyncIteration |
| | | return rv |
| | | |
| | | |
| | | class ChunkTupleAsyncStreamIterator: |
| | | |
| | | __slots__ = ("_stream",) |
| | | |
| | | def __init__(self, stream: "StreamReader") -> None: |
| | | self._stream = stream |
| | | |
| | | def __aiter__(self) -> "ChunkTupleAsyncStreamIterator": |
| | | return self |
| | | |
| | | async def __anext__(self) -> Tuple[bytes, bool]: |
| | | rv = await self._stream.readchunk() |
| | | if rv == (b"", False): |
| | | raise StopAsyncIteration |
| | | return rv |
| | | |
| | | |
| | | class AsyncStreamReaderMixin: |
| | | |
| | | __slots__ = () |
| | | |
| | | def __aiter__(self) -> AsyncStreamIterator[bytes]: |
| | | return AsyncStreamIterator(self.readline) # type: ignore[attr-defined] |
| | | |
| | | def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]: |
| | | """Returns an asynchronous iterator that yields chunks of size n.""" |
| | | return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined] |
| | | |
| | | def iter_any(self) -> AsyncStreamIterator[bytes]: |
| | | """Yield all available data as soon as it is received.""" |
| | | return AsyncStreamIterator(self.readany) # type: ignore[attr-defined] |
| | | |
| | | def iter_chunks(self) -> ChunkTupleAsyncStreamIterator: |
| | | """Yield chunks of data as they are received by the server. |
| | | |
| | | The yielded objects are tuples |
| | | of (bytes, bool) as returned by the StreamReader.readchunk method. |
| | | """ |
| | | return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type] |
| | | |
| | | |
| | | class StreamReader(AsyncStreamReaderMixin): |
| | | """An enhancement of asyncio.StreamReader. |
| | | |
| | | Supports asynchronous iteration by line, chunk or as available:: |
| | | |
| | | async for line in reader: |
| | | ... |
| | | async for chunk in reader.iter_chunked(1024): |
| | | ... |
| | | async for slice in reader.iter_any(): |
| | | ... |
| | | |
| | | """ |
| | | |
| | | __slots__ = ( |
| | | "_protocol", |
| | | "_low_water", |
| | | "_high_water", |
| | | "_low_water_chunks", |
| | | "_high_water_chunks", |
| | | "_loop", |
| | | "_size", |
| | | "_cursor", |
| | | "_http_chunk_splits", |
| | | "_buffer", |
| | | "_buffer_offset", |
| | | "_eof", |
| | | "_waiter", |
| | | "_eof_waiter", |
| | | "_exception", |
| | | "_timer", |
| | | "_eof_callbacks", |
| | | "_eof_counter", |
| | | "total_bytes", |
| | | "total_compressed_bytes", |
| | | ) |
| | | |
| | | def __init__( |
| | | self, |
| | | protocol: BaseProtocol, |
| | | limit: int, |
| | | *, |
| | | timer: Optional[BaseTimerContext] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | ) -> None: |
| | | self._protocol = protocol |
| | | self._low_water = limit |
| | | self._high_water = limit * 2 |
| | | if loop is None: |
| | | loop = asyncio.get_event_loop() |
| | | # Ensure high_water_chunks >= 3 so it's always > low_water_chunks. |
| | | self._high_water_chunks = max(3, limit // 4) |
| | | # Use max(2, ...) because there's always at least 1 chunk split remaining |
| | | # (the current position), so we need low_water >= 2 to allow resume. |
| | | self._low_water_chunks = max(2, self._high_water_chunks // 2) |
| | | self._loop = loop |
| | | self._size = 0 |
| | | self._cursor = 0 |
| | | self._http_chunk_splits: Optional[Deque[int]] = None |
| | | self._buffer: Deque[bytes] = collections.deque() |
| | | self._buffer_offset = 0 |
| | | self._eof = False |
| | | self._waiter: Optional[asyncio.Future[None]] = None |
| | | self._eof_waiter: Optional[asyncio.Future[None]] = None |
| | | self._exception: Optional[BaseException] = None |
| | | self._timer = TimerNoop() if timer is None else timer |
| | | self._eof_callbacks: List[Callable[[], None]] = [] |
| | | self._eof_counter = 0 |
| | | self.total_bytes = 0 |
| | | self.total_compressed_bytes: Optional[int] = None |
| | | |
| | | def __repr__(self) -> str: |
| | | info = [self.__class__.__name__] |
| | | if self._size: |
| | | info.append("%d bytes" % self._size) |
| | | if self._eof: |
| | | info.append("eof") |
| | | if self._low_water != 2**16: # default limit |
| | | info.append("low=%d high=%d" % (self._low_water, self._high_water)) |
| | | if self._waiter: |
| | | info.append("w=%r" % self._waiter) |
| | | if self._exception: |
| | | info.append("e=%r" % self._exception) |
| | | return "<%s>" % " ".join(info) |
| | | |
| | | def get_read_buffer_limits(self) -> Tuple[int, int]: |
| | | return (self._low_water, self._high_water) |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return self._exception |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | self._exception = exc |
| | | self._eof_callbacks.clear() |
| | | |
| | | waiter = self._waiter |
| | | if waiter is not None: |
| | | self._waiter = None |
| | | set_exception(waiter, exc, exc_cause) |
| | | |
| | | waiter = self._eof_waiter |
| | | if waiter is not None: |
| | | self._eof_waiter = None |
| | | set_exception(waiter, exc, exc_cause) |
| | | |
| | | def on_eof(self, callback: Callable[[], None]) -> None: |
| | | if self._eof: |
| | | try: |
| | | callback() |
| | | except Exception: |
| | | internal_logger.exception("Exception in eof callback") |
| | | else: |
| | | self._eof_callbacks.append(callback) |
| | | |
| | | def feed_eof(self) -> None: |
| | | self._eof = True |
| | | |
| | | waiter = self._waiter |
| | | if waiter is not None: |
| | | self._waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | waiter = self._eof_waiter |
| | | if waiter is not None: |
| | | self._eof_waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | if self._protocol._reading_paused: |
| | | self._protocol.resume_reading() |
| | | |
| | | for cb in self._eof_callbacks: |
| | | try: |
| | | cb() |
| | | except Exception: |
| | | internal_logger.exception("Exception in eof callback") |
| | | |
| | | self._eof_callbacks.clear() |
| | | |
| | | def is_eof(self) -> bool: |
| | | """Return True if 'feed_eof' was called.""" |
| | | return self._eof |
| | | |
| | | def at_eof(self) -> bool: |
| | | """Return True if the buffer is empty and 'feed_eof' was called.""" |
| | | return self._eof and not self._buffer |
| | | |
| | | async def wait_eof(self) -> None: |
| | | if self._eof: |
| | | return |
| | | |
| | | assert self._eof_waiter is None |
| | | self._eof_waiter = self._loop.create_future() |
| | | try: |
| | | await self._eof_waiter |
| | | finally: |
| | | self._eof_waiter = None |
| | | |
| | | @property |
| | | def total_raw_bytes(self) -> int: |
| | | if self.total_compressed_bytes is None: |
| | | return self.total_bytes |
| | | return self.total_compressed_bytes |
| | | |
| | | def unread_data(self, data: bytes) -> None: |
| | | """rollback reading some data from stream, inserting it to buffer head.""" |
| | | warnings.warn( |
| | | "unread_data() is deprecated " |
| | | "and will be removed in future releases (#3260)", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | if not data: |
| | | return |
| | | |
| | | if self._buffer_offset: |
| | | self._buffer[0] = self._buffer[0][self._buffer_offset :] |
| | | self._buffer_offset = 0 |
| | | self._size += len(data) |
| | | self._cursor -= len(data) |
| | | self._buffer.appendleft(data) |
| | | self._eof_counter = 0 |
| | | |
| | | # TODO: size is ignored, remove the param later |
| | | def feed_data(self, data: bytes, size: int = 0) -> None: |
| | | assert not self._eof, "feed_data after feed_eof" |
| | | |
| | | if not data: |
| | | return |
| | | |
| | | data_len = len(data) |
| | | self._size += data_len |
| | | self._buffer.append(data) |
| | | self.total_bytes += data_len |
| | | |
| | | waiter = self._waiter |
| | | if waiter is not None: |
| | | self._waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | if self._size > self._high_water and not self._protocol._reading_paused: |
| | | self._protocol.pause_reading() |
| | | |
| | | def begin_http_chunk_receiving(self) -> None: |
| | | if self._http_chunk_splits is None: |
| | | if self.total_bytes: |
| | | raise RuntimeError( |
| | | "Called begin_http_chunk_receiving when some data was already fed" |
| | | ) |
| | | self._http_chunk_splits = collections.deque() |
| | | |
| | | def end_http_chunk_receiving(self) -> None: |
| | | if self._http_chunk_splits is None: |
| | | raise RuntimeError( |
| | | "Called end_chunk_receiving without calling " |
| | | "begin_chunk_receiving first" |
| | | ) |
| | | |
| | | # self._http_chunk_splits contains logical byte offsets from start of |
| | | # the body transfer. Each offset is the offset of the end of a chunk. |
| | | # "Logical" means bytes, accessible for a user. |
| | | # If no chunks containing logical data were received, current position |
| | | # is difinitely zero. |
| | | pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0 |
| | | |
| | | if self.total_bytes == pos: |
| | | # We should not add empty chunks here. So we check for that. |
| | | # Note, when chunked + gzip is used, we can receive a chunk |
| | | # of compressed data, but that data may not be enough for gzip FSM |
| | | # to yield any uncompressed data. That's why current position may |
| | | # not change after receiving a chunk. |
| | | return |
| | | |
| | | self._http_chunk_splits.append(self.total_bytes) |
| | | |
| | | # If we get too many small chunks before self._high_water is reached, then any |
| | | # .read() call becomes computationally expensive, and could block the event loop |
| | | # for too long, hence an additional self._high_water_chunks here. |
| | | if ( |
| | | len(self._http_chunk_splits) > self._high_water_chunks |
| | | and not self._protocol._reading_paused |
| | | ): |
| | | self._protocol.pause_reading() |
| | | |
| | | # wake up readchunk when end of http chunk received |
| | | waiter = self._waiter |
| | | if waiter is not None: |
| | | self._waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | async def _wait(self, func_name: str) -> None: |
| | | if not self._protocol.connected: |
| | | raise RuntimeError("Connection closed.") |
| | | |
| | | # StreamReader uses a future to link the protocol feed_data() method |
| | | # to a read coroutine. Running two read coroutines at the same time |
| | | # would have an unexpected behaviour. It would not possible to know |
| | | # which coroutine would get the next data. |
| | | if self._waiter is not None: |
| | | raise RuntimeError( |
| | | "%s() called while another coroutine is " |
| | | "already waiting for incoming data" % func_name |
| | | ) |
| | | |
| | | waiter = self._waiter = self._loop.create_future() |
| | | try: |
| | | with self._timer: |
| | | await waiter |
| | | finally: |
| | | self._waiter = None |
| | | |
| | | async def readline(self) -> bytes: |
| | | return await self.readuntil() |
| | | |
| | | async def readuntil(self, separator: bytes = b"\n") -> bytes: |
| | | seplen = len(separator) |
| | | if seplen == 0: |
| | | raise ValueError("Separator should be at least one-byte string") |
| | | |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | chunk = b"" |
| | | chunk_size = 0 |
| | | not_enough = True |
| | | |
| | | while not_enough: |
| | | while self._buffer and not_enough: |
| | | offset = self._buffer_offset |
| | | ichar = self._buffer[0].find(separator, offset) + 1 |
| | | # Read from current offset to found separator or to the end. |
| | | data = self._read_nowait_chunk( |
| | | ichar - offset + seplen - 1 if ichar else -1 |
| | | ) |
| | | chunk += data |
| | | chunk_size += len(data) |
| | | if ichar: |
| | | not_enough = False |
| | | |
| | | if chunk_size > self._high_water: |
| | | raise ValueError("Chunk too big") |
| | | |
| | | if self._eof: |
| | | break |
| | | |
| | | if not_enough: |
| | | await self._wait("readuntil") |
| | | |
| | | return chunk |
| | | |
| | | async def read(self, n: int = -1) -> bytes: |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | # migration problem; with DataQueue you have to catch |
| | | # EofStream exception, so common way is to run payload.read() inside |
| | | # infinite loop. what can cause real infinite loop with StreamReader |
| | | # lets keep this code one major release. |
| | | if __debug__: |
| | | if self._eof and not self._buffer: |
| | | self._eof_counter = getattr(self, "_eof_counter", 0) + 1 |
| | | if self._eof_counter > 5: |
| | | internal_logger.warning( |
| | | "Multiple access to StreamReader in eof state, " |
| | | "might be infinite loop.", |
| | | stack_info=True, |
| | | ) |
| | | |
| | | if not n: |
| | | return b"" |
| | | |
| | | if n < 0: |
| | | # This used to just loop creating a new waiter hoping to |
| | | # collect everything in self._buffer, but that would |
| | | # deadlock if the subprocess sends more than self.limit |
| | | # bytes. So just call self.readany() until EOF. |
| | | blocks = [] |
| | | while True: |
| | | block = await self.readany() |
| | | if not block: |
| | | break |
| | | blocks.append(block) |
| | | return b"".join(blocks) |
| | | |
| | | # TODO: should be `if` instead of `while` |
| | | # because waiter maybe triggered on chunk end, |
| | | # without feeding any data |
| | | while not self._buffer and not self._eof: |
| | | await self._wait("read") |
| | | |
| | | return self._read_nowait(n) |
| | | |
| | | async def readany(self) -> bytes: |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | # TODO: should be `if` instead of `while` |
| | | # because waiter maybe triggered on chunk end, |
| | | # without feeding any data |
| | | while not self._buffer and not self._eof: |
| | | await self._wait("readany") |
| | | |
| | | return self._read_nowait(-1) |
| | | |
| | | async def readchunk(self) -> Tuple[bytes, bool]: |
| | | """Returns a tuple of (data, end_of_http_chunk). |
| | | |
| | | When chunked transfer |
| | | encoding is used, end_of_http_chunk is a boolean indicating if the end |
| | | of the data corresponds to the end of a HTTP chunk , otherwise it is |
| | | always False. |
| | | """ |
| | | while True: |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | while self._http_chunk_splits: |
| | | pos = self._http_chunk_splits.popleft() |
| | | if pos == self._cursor: |
| | | return (b"", True) |
| | | if pos > self._cursor: |
| | | return (self._read_nowait(pos - self._cursor), True) |
| | | internal_logger.warning( |
| | | "Skipping HTTP chunk end due to data " |
| | | "consumption beyond chunk boundary" |
| | | ) |
| | | |
| | | if self._buffer: |
| | | return (self._read_nowait_chunk(-1), False) |
| | | # return (self._read_nowait(-1), False) |
| | | |
| | | if self._eof: |
| | | # Special case for signifying EOF. |
| | | # (b'', True) is not a final return value actually. |
| | | return (b"", False) |
| | | |
| | | await self._wait("readchunk") |
| | | |
| | | async def readexactly(self, n: int) -> bytes: |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | blocks: List[bytes] = [] |
| | | while n > 0: |
| | | block = await self.read(n) |
| | | if not block: |
| | | partial = b"".join(blocks) |
| | | raise asyncio.IncompleteReadError(partial, len(partial) + n) |
| | | blocks.append(block) |
| | | n -= len(block) |
| | | |
| | | return b"".join(blocks) |
| | | |
| | | def read_nowait(self, n: int = -1) -> bytes: |
| | | # default was changed to be consistent with .read(-1) |
| | | # |
| | | # I believe the most users don't know about the method and |
| | | # they are not affected. |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | |
| | | if self._waiter and not self._waiter.done(): |
| | | raise RuntimeError( |
| | | "Called while some coroutine is waiting for incoming data." |
| | | ) |
| | | |
| | | return self._read_nowait(n) |
| | | |
| | | def _read_nowait_chunk(self, n: int) -> bytes: |
| | | first_buffer = self._buffer[0] |
| | | offset = self._buffer_offset |
| | | if n != -1 and len(first_buffer) - offset > n: |
| | | data = first_buffer[offset : offset + n] |
| | | self._buffer_offset += n |
| | | |
| | | elif offset: |
| | | self._buffer.popleft() |
| | | data = first_buffer[offset:] |
| | | self._buffer_offset = 0 |
| | | |
| | | else: |
| | | data = self._buffer.popleft() |
| | | |
| | | data_len = len(data) |
| | | self._size -= data_len |
| | | self._cursor += data_len |
| | | |
| | | chunk_splits = self._http_chunk_splits |
| | | # Prevent memory leak: drop useless chunk splits |
| | | while chunk_splits and chunk_splits[0] < self._cursor: |
| | | chunk_splits.popleft() |
| | | |
| | | if ( |
| | | self._protocol._reading_paused |
| | | and self._size < self._low_water |
| | | and ( |
| | | self._http_chunk_splits is None |
| | | or len(self._http_chunk_splits) < self._low_water_chunks |
| | | ) |
| | | ): |
| | | self._protocol.resume_reading() |
| | | return data |
| | | |
| | | def _read_nowait(self, n: int) -> bytes: |
| | | """Read not more than n bytes, or whole buffer if n == -1""" |
| | | self._timer.assert_timeout() |
| | | |
| | | chunks = [] |
| | | while self._buffer: |
| | | chunk = self._read_nowait_chunk(n) |
| | | chunks.append(chunk) |
| | | if n != -1: |
| | | n -= len(chunk) |
| | | if n == 0: |
| | | break |
| | | |
| | | return b"".join(chunks) if chunks else b"" |
| | | |
| | | |
| | | class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init] |
| | | |
| | | __slots__ = ("_read_eof_chunk",) |
| | | |
| | | def __init__(self) -> None: |
| | | self._read_eof_chunk = False |
| | | self.total_bytes = 0 |
| | | |
| | | def __repr__(self) -> str: |
| | | return "<%s>" % self.__class__.__name__ |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return None |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | pass |
| | | |
| | | def on_eof(self, callback: Callable[[], None]) -> None: |
| | | try: |
| | | callback() |
| | | except Exception: |
| | | internal_logger.exception("Exception in eof callback") |
| | | |
| | | def feed_eof(self) -> None: |
| | | pass |
| | | |
| | | def is_eof(self) -> bool: |
| | | return True |
| | | |
| | | def at_eof(self) -> bool: |
| | | return True |
| | | |
| | | async def wait_eof(self) -> None: |
| | | return |
| | | |
| | | def feed_data(self, data: bytes, n: int = 0) -> None: |
| | | pass |
| | | |
| | | async def readline(self) -> bytes: |
| | | return b"" |
| | | |
| | | async def read(self, n: int = -1) -> bytes: |
| | | return b"" |
| | | |
| | | # TODO add async def readuntil |
| | | |
| | | async def readany(self) -> bytes: |
| | | return b"" |
| | | |
| | | async def readchunk(self) -> Tuple[bytes, bool]: |
| | | if not self._read_eof_chunk: |
| | | self._read_eof_chunk = True |
| | | return (b"", False) |
| | | |
| | | return (b"", True) |
| | | |
| | | async def readexactly(self, n: int) -> bytes: |
| | | raise asyncio.IncompleteReadError(b"", n) |
| | | |
| | | def read_nowait(self, n: int = -1) -> bytes: |
| | | return b"" |
| | | |
| | | |
| | | EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader() |
| | | |
| | | |
| | | class DataQueue(Generic[_T]): |
| | | """DataQueue is a general-purpose blocking queue with one reader.""" |
| | | |
| | | def __init__(self, loop: asyncio.AbstractEventLoop) -> None: |
| | | self._loop = loop |
| | | self._eof = False |
| | | self._waiter: Optional[asyncio.Future[None]] = None |
| | | self._exception: Optional[BaseException] = None |
| | | self._buffer: Deque[Tuple[_T, int]] = collections.deque() |
| | | |
| | | def __len__(self) -> int: |
| | | return len(self._buffer) |
| | | |
| | | def is_eof(self) -> bool: |
| | | return self._eof |
| | | |
| | | def at_eof(self) -> bool: |
| | | return self._eof and not self._buffer |
| | | |
| | | def exception(self) -> Optional[BaseException]: |
| | | return self._exception |
| | | |
| | | def set_exception( |
| | | self, |
| | | exc: BaseException, |
| | | exc_cause: BaseException = _EXC_SENTINEL, |
| | | ) -> None: |
| | | self._eof = True |
| | | self._exception = exc |
| | | if (waiter := self._waiter) is not None: |
| | | self._waiter = None |
| | | set_exception(waiter, exc, exc_cause) |
| | | |
| | | def feed_data(self, data: _T, size: int = 0) -> None: |
| | | self._buffer.append((data, size)) |
| | | if (waiter := self._waiter) is not None: |
| | | self._waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | def feed_eof(self) -> None: |
| | | self._eof = True |
| | | if (waiter := self._waiter) is not None: |
| | | self._waiter = None |
| | | set_result(waiter, None) |
| | | |
| | | async def read(self) -> _T: |
| | | if not self._buffer and not self._eof: |
| | | assert not self._waiter |
| | | self._waiter = self._loop.create_future() |
| | | try: |
| | | await self._waiter |
| | | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | | self._waiter = None |
| | | raise |
| | | if self._buffer: |
| | | data, _ = self._buffer.popleft() |
| | | return data |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | raise EofStream |
| | | |
| | | def __aiter__(self) -> AsyncStreamIterator[_T]: |
| | | return AsyncStreamIterator(self.read) |
| | | |
| | | |
| | | class FlowControlDataQueue(DataQueue[_T]): |
| | | """FlowControlDataQueue resumes and pauses an underlying stream. |
| | | |
| | | It is a destination for parsed data. |
| | | |
| | | This class is deprecated and will be removed in version 4.0. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop |
| | | ) -> None: |
| | | super().__init__(loop=loop) |
| | | self._size = 0 |
| | | self._protocol = protocol |
| | | self._limit = limit * 2 |
| | | |
| | | def feed_data(self, data: _T, size: int = 0) -> None: |
| | | super().feed_data(data, size) |
| | | self._size += size |
| | | |
| | | if self._size > self._limit and not self._protocol._reading_paused: |
| | | self._protocol.pause_reading() |
| | | |
| | | async def read(self) -> _T: |
| | | if not self._buffer and not self._eof: |
| | | assert not self._waiter |
| | | self._waiter = self._loop.create_future() |
| | | try: |
| | | await self._waiter |
| | | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | | self._waiter = None |
| | | raise |
| | | if self._buffer: |
| | | data, size = self._buffer.popleft() |
| | | self._size -= size |
| | | if self._size < self._limit and self._protocol._reading_paused: |
| | | self._protocol.resume_reading() |
| | | return data |
| | | if self._exception is not None: |
| | | raise self._exception |
| | | raise EofStream |
| New file |
| | |
| | | """Helper methods to tune a TCP connection""" |
| | | |
| | | import asyncio |
| | | import socket |
| | | from contextlib import suppress |
| | | from typing import Optional # noqa |
| | | |
| | | __all__ = ("tcp_keepalive", "tcp_nodelay") |
| | | |
| | | |
| | | if hasattr(socket, "SO_KEEPALIVE"): |
| | | |
| | | def tcp_keepalive(transport: asyncio.Transport) -> None: |
| | | sock = transport.get_extra_info("socket") |
| | | if sock is not None: |
| | | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) |
| | | |
| | | else: |
| | | |
| | | def tcp_keepalive(transport: asyncio.Transport) -> None: # pragma: no cover |
| | | pass |
| | | |
| | | |
| | | def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None: |
| | | sock = transport.get_extra_info("socket") |
| | | |
| | | if sock is None: |
| | | return |
| | | |
| | | if sock.family not in (socket.AF_INET, socket.AF_INET6): |
| | | return |
| | | |
| | | value = bool(value) |
| | | |
| | | # socket may be closed already, on windows OSError get raised |
| | | with suppress(OSError): |
| | | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value) |
| New file |
| | |
| | | """Utilities shared by tests.""" |
| | | |
| | | import asyncio |
| | | import contextlib |
| | | import gc |
| | | import inspect |
| | | import ipaddress |
| | | import os |
| | | import socket |
| | | import sys |
| | | import warnings |
| | | from abc import ABC, abstractmethod |
| | | from types import TracebackType |
| | | from typing import ( |
| | | TYPE_CHECKING, |
| | | Any, |
| | | Callable, |
| | | Generic, |
| | | Iterator, |
| | | List, |
| | | Optional, |
| | | Type, |
| | | TypeVar, |
| | | cast, |
| | | overload, |
| | | ) |
| | | from unittest import IsolatedAsyncioTestCase, mock |
| | | |
| | | from aiosignal import Signal |
| | | from multidict import CIMultiDict, CIMultiDictProxy |
| | | from yarl import URL |
| | | |
| | | import aiohttp |
| | | from aiohttp.client import ( |
| | | _RequestContextManager, |
| | | _RequestOptions, |
| | | _WSRequestContextManager, |
| | | ) |
| | | |
| | | from . import ClientSession, hdrs |
| | | from .abc import AbstractCookieJar |
| | | from .client_reqrep import ClientResponse |
| | | from .client_ws import ClientWebSocketResponse |
| | | from .helpers import sentinel |
| | | from .http import HttpVersion, RawRequestMessage |
| | | from .streams import EMPTY_PAYLOAD, StreamReader |
| | | from .typedefs import StrOrURL |
| | | from .web import ( |
| | | Application, |
| | | AppRunner, |
| | | BaseRequest, |
| | | BaseRunner, |
| | | Request, |
| | | Server, |
| | | ServerRunner, |
| | | SockSite, |
| | | UrlMappingMatchInfo, |
| | | ) |
| | | from .web_protocol import _RequestHandler |
| | | |
| | | if TYPE_CHECKING: |
| | | from ssl import SSLContext |
| | | else: |
| | | SSLContext = None |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | from typing import Unpack |
| | | |
| | | if sys.version_info >= (3, 11): |
| | | from typing import Self |
| | | else: |
| | | Self = Any |
| | | |
| | | _ApplicationNone = TypeVar("_ApplicationNone", Application, None) |
| | | _Request = TypeVar("_Request", bound=BaseRequest) |
| | | |
| | | REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" |
| | | |
| | | |
| | | def get_unused_port_socket( |
| | | host: str, family: socket.AddressFamily = socket.AF_INET |
| | | ) -> socket.socket: |
| | | return get_port_socket(host, 0, family) |
| | | |
| | | |
| | | def get_port_socket( |
| | | host: str, port: int, family: socket.AddressFamily |
| | | ) -> socket.socket: |
| | | s = socket.socket(family, socket.SOCK_STREAM) |
| | | if REUSE_ADDRESS: |
| | | # Windows has different semantics for SO_REUSEADDR, |
| | | # so don't set it. Ref: |
| | | # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse |
| | | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| | | s.bind((host, port)) |
| | | return s |
| | | |
| | | |
| | | def unused_port() -> int: |
| | | """Return a port that is unused on the current host.""" |
| | | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| | | s.bind(("127.0.0.1", 0)) |
| | | return cast(int, s.getsockname()[1]) |
| | | |
| | | |
| | | class BaseTestServer(ABC): |
| | | __test__ = False |
| | | |
| | | def __init__( |
| | | self, |
| | | *, |
| | | scheme: str = "", |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | host: str = "127.0.0.1", |
| | | port: Optional[int] = None, |
| | | skip_url_asserts: bool = False, |
| | | socket_factory: Callable[ |
| | | [str, int, socket.AddressFamily], socket.socket |
| | | ] = get_port_socket, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | self._loop = loop |
| | | self.runner: Optional[BaseRunner] = None |
| | | self._root: Optional[URL] = None |
| | | self.host = host |
| | | self.port = port |
| | | self._closed = False |
| | | self.scheme = scheme |
| | | self.skip_url_asserts = skip_url_asserts |
| | | self.socket_factory = socket_factory |
| | | |
| | | async def start_server( |
| | | self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any |
| | | ) -> None: |
| | | if self.runner: |
| | | return |
| | | self._loop = loop |
| | | self._ssl = kwargs.pop("ssl", None) |
| | | self.runner = await self._make_runner(handler_cancellation=True, **kwargs) |
| | | await self.runner.setup() |
| | | if not self.port: |
| | | self.port = 0 |
| | | absolute_host = self.host |
| | | try: |
| | | version = ipaddress.ip_address(self.host).version |
| | | except ValueError: |
| | | version = 4 |
| | | if version == 6: |
| | | absolute_host = f"[{self.host}]" |
| | | family = socket.AF_INET6 if version == 6 else socket.AF_INET |
| | | _sock = self.socket_factory(self.host, self.port, family) |
| | | self.host, self.port = _sock.getsockname()[:2] |
| | | site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) |
| | | await site.start() |
| | | server = site._server |
| | | assert server is not None |
| | | sockets = server.sockets # type: ignore[attr-defined] |
| | | assert sockets is not None |
| | | self.port = sockets[0].getsockname()[1] |
| | | if not self.scheme: |
| | | self.scheme = "https" if self._ssl else "http" |
| | | self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") |
| | | |
| | | @abstractmethod # pragma: no cover |
| | | async def _make_runner(self, **kwargs: Any) -> BaseRunner: |
| | | pass |
| | | |
| | | def make_url(self, path: StrOrURL) -> URL: |
| | | assert self._root is not None |
| | | url = URL(path) |
| | | if not self.skip_url_asserts: |
| | | assert not url.absolute |
| | | return self._root.join(url) |
| | | else: |
| | | return URL(str(self._root) + str(path)) |
| | | |
| | | @property |
| | | def started(self) -> bool: |
| | | return self.runner is not None |
| | | |
| | | @property |
| | | def closed(self) -> bool: |
| | | return self._closed |
| | | |
| | | @property |
| | | def handler(self) -> Server: |
| | | # for backward compatibility |
| | | # web.Server instance |
| | | runner = self.runner |
| | | assert runner is not None |
| | | assert runner.server is not None |
| | | return runner.server |
| | | |
| | | async def close(self) -> None: |
| | | """Close all fixtures created by the test client. |
| | | |
| | | After that point, the TestClient is no longer usable. |
| | | |
| | | This is an idempotent function: running close multiple times |
| | | will not have any additional effects. |
| | | |
| | | close is also run when the object is garbage collected, and on |
| | | exit when used as a context manager. |
| | | |
| | | """ |
| | | if self.started and not self.closed: |
| | | assert self.runner is not None |
| | | await self.runner.cleanup() |
| | | self._root = None |
| | | self.port = None |
| | | self._closed = True |
| | | |
| | | def __enter__(self) -> None: |
| | | raise TypeError("Use async with instead") |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_value: Optional[BaseException], |
| | | traceback: Optional[TracebackType], |
| | | ) -> None: |
| | | # __exit__ should exist in pair with __enter__ but never executed |
| | | pass # pragma: no cover |
| | | |
| | | async def __aenter__(self) -> "BaseTestServer": |
| | | await self.start_server(loop=self._loop) |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc_value: Optional[BaseException], |
| | | traceback: Optional[TracebackType], |
| | | ) -> None: |
| | | await self.close() |
| | | |
| | | |
| | | class TestServer(BaseTestServer): |
| | | def __init__( |
| | | self, |
| | | app: Application, |
| | | *, |
| | | scheme: str = "", |
| | | host: str = "127.0.0.1", |
| | | port: Optional[int] = None, |
| | | **kwargs: Any, |
| | | ): |
| | | self.app = app |
| | | super().__init__(scheme=scheme, host=host, port=port, **kwargs) |
| | | |
| | | async def _make_runner(self, **kwargs: Any) -> BaseRunner: |
| | | return AppRunner(self.app, **kwargs) |
| | | |
| | | |
| | | class RawTestServer(BaseTestServer): |
| | | def __init__( |
| | | self, |
| | | handler: _RequestHandler, |
| | | *, |
| | | scheme: str = "", |
| | | host: str = "127.0.0.1", |
| | | port: Optional[int] = None, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | self._handler = handler |
| | | super().__init__(scheme=scheme, host=host, port=port, **kwargs) |
| | | |
| | | async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner: |
| | | srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs) |
| | | return ServerRunner(srv, debug=debug, **kwargs) |
| | | |
| | | |
| | | class TestClient(Generic[_Request, _ApplicationNone]): |
| | | """ |
| | | A test client implementation. |
| | | |
| | | To write functional tests for aiohttp based servers. |
| | | |
| | | """ |
| | | |
| | | __test__ = False |
| | | |
| | | @overload |
| | | def __init__( |
| | | self: "TestClient[Request, Application]", |
| | | server: TestServer, |
| | | *, |
| | | cookie_jar: Optional[AbstractCookieJar] = None, |
| | | **kwargs: Any, |
| | | ) -> None: ... |
| | | @overload |
| | | def __init__( |
| | | self: "TestClient[_Request, None]", |
| | | server: BaseTestServer, |
| | | *, |
| | | cookie_jar: Optional[AbstractCookieJar] = None, |
| | | **kwargs: Any, |
| | | ) -> None: ... |
| | | def __init__( |
| | | self, |
| | | server: BaseTestServer, |
| | | *, |
| | | cookie_jar: Optional[AbstractCookieJar] = None, |
| | | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | | **kwargs: Any, |
| | | ) -> None: |
| | | if not isinstance(server, BaseTestServer): |
| | | raise TypeError( |
| | | "server must be TestServer instance, found type: %r" % type(server) |
| | | ) |
| | | self._server = server |
| | | self._loop = loop |
| | | if cookie_jar is None: |
| | | cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop) |
| | | self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs) |
| | | self._session._retry_connection = False |
| | | self._closed = False |
| | | self._responses: List[ClientResponse] = [] |
| | | self._websockets: List[ClientWebSocketResponse] = [] |
| | | |
| | | async def start_server(self) -> None: |
| | | await self._server.start_server(loop=self._loop) |
| | | |
| | | @property |
| | | def host(self) -> str: |
| | | return self._server.host |
| | | |
| | | @property |
| | | def port(self) -> Optional[int]: |
| | | return self._server.port |
| | | |
| | | @property |
| | | def server(self) -> BaseTestServer: |
| | | return self._server |
| | | |
| | | @property |
| | | def app(self) -> _ApplicationNone: |
| | | return getattr(self._server, "app", None) # type: ignore[return-value] |
| | | |
| | | @property |
| | | def session(self) -> ClientSession: |
| | | """An internal aiohttp.ClientSession. |
| | | |
| | | Unlike the methods on the TestClient, client session requests |
| | | do not automatically include the host in the url queried, and |
| | | will require an absolute path to the resource. |
| | | |
| | | """ |
| | | return self._session |
| | | |
| | | def make_url(self, path: StrOrURL) -> URL: |
| | | return self._server.make_url(path) |
| | | |
| | | async def _request( |
| | | self, method: str, path: StrOrURL, **kwargs: Any |
| | | ) -> ClientResponse: |
| | | resp = await self._session.request(method, self.make_url(path), **kwargs) |
| | | # save it to close later |
| | | self._responses.append(resp) |
| | | return resp |
| | | |
| | | if sys.version_info >= (3, 11) and TYPE_CHECKING: |
| | | |
| | | def request( |
| | | self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def get( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def options( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def head( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def post( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def put( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def patch( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | def delete( |
| | | self, |
| | | path: StrOrURL, |
| | | **kwargs: Unpack[_RequestOptions], |
| | | ) -> _RequestContextManager: ... |
| | | |
| | | else: |
| | | |
| | | def request( |
| | | self, method: str, path: StrOrURL, **kwargs: Any |
| | | ) -> _RequestContextManager: |
| | | """Routes a request to tested http server. |
| | | |
| | | The interface is identical to aiohttp.ClientSession.request, |
| | | except the loop kwarg is overridden by the instance used by the |
| | | test server. |
| | | |
| | | """ |
| | | return _RequestContextManager(self._request(method, path, **kwargs)) |
| | | |
| | | def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP GET request.""" |
| | | return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) |
| | | |
| | | def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP POST request.""" |
| | | return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) |
| | | |
| | | def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP OPTIONS request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_OPTIONS, path, **kwargs) |
| | | ) |
| | | |
| | | def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP HEAD request.""" |
| | | return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) |
| | | |
| | | def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP PUT request.""" |
| | | return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) |
| | | |
| | | def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP PATCH request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_PATCH, path, **kwargs) |
| | | ) |
| | | |
| | | def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: |
| | | """Perform an HTTP PATCH request.""" |
| | | return _RequestContextManager( |
| | | self._request(hdrs.METH_DELETE, path, **kwargs) |
| | | ) |
| | | |
| | | def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: |
| | | """Initiate websocket connection. |
| | | |
| | | The api corresponds to aiohttp.ClientSession.ws_connect. |
| | | |
| | | """ |
| | | return _WSRequestContextManager(self._ws_connect(path, **kwargs)) |
| | | |
| | | async def _ws_connect( |
| | | self, path: StrOrURL, **kwargs: Any |
| | | ) -> ClientWebSocketResponse: |
| | | ws = await self._session.ws_connect(self.make_url(path), **kwargs) |
| | | self._websockets.append(ws) |
| | | return ws |
| | | |
| | | async def close(self) -> None: |
| | | """Close all fixtures created by the test client. |
| | | |
| | | After that point, the TestClient is no longer usable. |
| | | |
| | | This is an idempotent function: running close multiple times |
| | | will not have any additional effects. |
| | | |
| | | close is also run on exit when used as a(n) (asynchronous) |
| | | context manager. |
| | | |
| | | """ |
| | | if not self._closed: |
| | | for resp in self._responses: |
| | | resp.close() |
| | | for ws in self._websockets: |
| | | await ws.close() |
| | | await self._session.close() |
| | | await self._server.close() |
| | | self._closed = True |
| | | |
| | | def __enter__(self) -> None: |
| | | raise TypeError("Use async with instead") |
| | | |
| | | def __exit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc: Optional[BaseException], |
| | | tb: Optional[TracebackType], |
| | | ) -> None: |
| | | # __exit__ should exist in pair with __enter__ but never executed |
| | | pass # pragma: no cover |
| | | |
| | | async def __aenter__(self) -> Self: |
| | | await self.start_server() |
| | | return self |
| | | |
| | | async def __aexit__( |
| | | self, |
| | | exc_type: Optional[Type[BaseException]], |
| | | exc: Optional[BaseException], |
| | | tb: Optional[TracebackType], |
| | | ) -> None: |
| | | await self.close() |
| | | |
| | | |
| | | class AioHTTPTestCase(IsolatedAsyncioTestCase): |
| | | """A base class to allow for unittest web applications using aiohttp. |
| | | |
| | | Provides the following: |
| | | |
| | | * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. |
| | | * self.loop (asyncio.BaseEventLoop): the event loop in which the |
| | | application and server are running. |
| | | * self.app (aiohttp.web.Application): the application returned by |
| | | self.get_application() |
| | | |
| | | Note that the TestClient's methods are asynchronous: you have to |
| | | execute function on the test client using asynchronous methods. |
| | | """ |
| | | |
| | | async def get_application(self) -> Application: |
| | | """Get application. |
| | | |
| | | This method should be overridden |
| | | to return the aiohttp.web.Application |
| | | object to test. |
| | | """ |
| | | return self.get_app() |
| | | |
| | | def get_app(self) -> Application: |
| | | """Obsolete method used to constructing web application. |
| | | |
| | | Use .get_application() coroutine instead. |
| | | """ |
| | | raise RuntimeError("Did you forget to define get_application()?") |
| | | |
| | | async def asyncSetUp(self) -> None: |
| | | self.loop = asyncio.get_running_loop() |
| | | return await self.setUpAsync() |
| | | |
| | | async def setUpAsync(self) -> None: |
| | | self.app = await self.get_application() |
| | | self.server = await self.get_server(self.app) |
| | | self.client = await self.get_client(self.server) |
| | | |
| | | await self.client.start_server() |
| | | |
| | | async def asyncTearDown(self) -> None: |
| | | return await self.tearDownAsync() |
| | | |
| | | async def tearDownAsync(self) -> None: |
| | | await self.client.close() |
| | | |
| | | async def get_server(self, app: Application) -> TestServer: |
| | | """Return a TestServer instance.""" |
| | | return TestServer(app, loop=self.loop) |
| | | |
| | | async def get_client(self, server: TestServer) -> TestClient[Request, Application]: |
| | | """Return a TestClient instance.""" |
| | | return TestClient(server, loop=self.loop) |
| | | |
| | | |
| | | def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any: |
| | | """ |
| | | A decorator dedicated to use with asynchronous AioHTTPTestCase test methods. |
| | | |
| | | In 3.8+, this does nothing. |
| | | """ |
| | | warnings.warn( |
| | | "Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+", |
| | | DeprecationWarning, |
| | | stacklevel=2, |
| | | ) |
| | | return func |
| | | |
| | | |
| | | _LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] |
| | | |
| | | |
| | | @contextlib.contextmanager |
| | | def loop_context( |
| | | loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False |
| | | ) -> Iterator[asyncio.AbstractEventLoop]: |
| | | """A contextmanager that creates an event_loop, for test purposes. |
| | | |
| | | Handles the creation and cleanup of a test loop. |
| | | """ |
| | | loop = setup_test_loop(loop_factory) |
| | | yield loop |
| | | teardown_test_loop(loop, fast=fast) |
| | | |
| | | |
| | | def setup_test_loop( |
| | | loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, |
| | | ) -> asyncio.AbstractEventLoop: |
| | | """Create and return an asyncio.BaseEventLoop instance. |
| | | |
| | | The caller should also call teardown_test_loop, |
| | | once they are done with the loop. |
| | | """ |
| | | loop = loop_factory() |
| | | asyncio.set_event_loop(loop) |
| | | return loop |
| | | |
| | | |
| | | def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: |
| | | """Teardown and cleanup an event_loop created by setup_test_loop.""" |
| | | closed = loop.is_closed() |
| | | if not closed: |
| | | loop.call_soon(loop.stop) |
| | | loop.run_forever() |
| | | loop.close() |
| | | |
| | | if not fast: |
| | | gc.collect() |
| | | |
| | | asyncio.set_event_loop(None) |
| | | |
| | | |
| | | def _create_app_mock() -> mock.MagicMock: |
| | | def get_dict(app: Any, key: str) -> Any: |
| | | return app.__app_dict[key] |
| | | |
| | | def set_dict(app: Any, key: str, value: Any) -> None: |
| | | app.__app_dict[key] = value |
| | | |
| | | app = mock.MagicMock(spec=Application) |
| | | app.__app_dict = {} |
| | | app.__getitem__ = get_dict |
| | | app.__setitem__ = set_dict |
| | | |
| | | app._debug = False |
| | | app.on_response_prepare = Signal(app) |
| | | app.on_response_prepare.freeze() |
| | | return app |
| | | |
| | | |
| | | def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: |
| | | transport = mock.Mock() |
| | | |
| | | def get_extra_info(key: str) -> Optional[SSLContext]: |
| | | if key == "sslcontext": |
| | | return sslcontext |
| | | else: |
| | | return None |
| | | |
| | | transport.get_extra_info.side_effect = get_extra_info |
| | | return transport |
| | | |
| | | |
| | | def make_mocked_request( |
| | | method: str, |
| | | path: str, |
| | | headers: Any = None, |
| | | *, |
| | | match_info: Any = sentinel, |
| | | version: HttpVersion = HttpVersion(1, 1), |
| | | closing: bool = False, |
| | | app: Any = None, |
| | | writer: Any = sentinel, |
| | | protocol: Any = sentinel, |
| | | transport: Any = sentinel, |
| | | payload: StreamReader = EMPTY_PAYLOAD, |
| | | sslcontext: Optional[SSLContext] = None, |
| | | client_max_size: int = 1024**2, |
| | | loop: Any = ..., |
| | | ) -> Request: |
| | | """Creates mocked web.Request testing purposes. |
| | | |
| | | Useful in unit tests, when spinning full web server is overkill or |
| | | specific conditions and errors are hard to trigger. |
| | | """ |
| | | task = mock.Mock() |
| | | if loop is ...: |
| | | # no loop passed, try to get the current one if |
| | | # its is running as we need a real loop to create |
| | | # executor jobs to be able to do testing |
| | | # with a real executor |
| | | try: |
| | | loop = asyncio.get_running_loop() |
| | | except RuntimeError: |
| | | loop = mock.Mock() |
| | | loop.create_future.return_value = () |
| | | |
| | | if version < HttpVersion(1, 1): |
| | | closing = True |
| | | |
| | | if headers: |
| | | headers = CIMultiDictProxy(CIMultiDict(headers)) |
| | | raw_hdrs = tuple( |
| | | (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() |
| | | ) |
| | | else: |
| | | headers = CIMultiDictProxy(CIMultiDict()) |
| | | raw_hdrs = () |
| | | |
| | | chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() |
| | | |
| | | message = RawRequestMessage( |
| | | method, |
| | | path, |
| | | version, |
| | | headers, |
| | | raw_hdrs, |
| | | closing, |
| | | None, |
| | | False, |
| | | chunked, |
| | | URL(path), |
| | | ) |
| | | if app is None: |
| | | app = _create_app_mock() |
| | | |
| | | if transport is sentinel: |
| | | transport = _create_transport(sslcontext) |
| | | |
| | | if protocol is sentinel: |
| | | protocol = mock.Mock() |
| | | protocol.transport = transport |
| | | type(protocol).peername = mock.PropertyMock( |
| | | return_value=transport.get_extra_info("peername") |
| | | ) |
| | | type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext) |
| | | |
| | | if writer is sentinel: |
| | | writer = mock.Mock() |
| | | writer.write_headers = make_mocked_coro(None) |
| | | writer.write = make_mocked_coro(None) |
| | | writer.write_eof = make_mocked_coro(None) |
| | | writer.drain = make_mocked_coro(None) |
| | | writer.transport = transport |
| | | |
| | | protocol.transport = transport |
| | | protocol.writer = writer |
| | | |
| | | req = Request( |
| | | message, payload, protocol, writer, task, loop, client_max_size=client_max_size |
| | | ) |
| | | |
| | | match_info = UrlMappingMatchInfo( |
| | | {} if match_info is sentinel else match_info, mock.Mock() |
| | | ) |
| | | match_info.add_app(app) |
| | | req._match_info = match_info |
| | | |
| | | return req |
| | | |
| | | |
| | | def make_mocked_coro( |
| | | return_value: Any = sentinel, raise_exception: Any = sentinel |
| | | ) -> Any: |
| | | """Creates a coroutine mock.""" |
| | | |
| | | async def mock_coro(*args: Any, **kwargs: Any) -> Any: |
| | | if raise_exception is not sentinel: |
| | | raise raise_exception |
| | | if not inspect.isawaitable(return_value): |
| | | return return_value |
| | | await return_value |
| | | |
| | | return mock.Mock(wraps=mock_coro) |
| 测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/tracing.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/typedefs.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_app.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_exceptions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_fileresponse.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_log.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_middlewares.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_protocol.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_request.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_response.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_routedef.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_runner.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_server.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_urldispatcher.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/web_ws.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiohttp/worker.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal-1.4.0.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/aiosignal/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/__init__.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_cmp.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_cmp.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_compat.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_config.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_funcs.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_make.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_next_gen.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_typing_compat.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_version_info.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/_version_info.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/converters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/converters.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/exceptions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/exceptions.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/filters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/filters.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/setters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/setters.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/validators.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attr/validators.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs-25.4.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs-25.4.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs-25.4.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs-25.4.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs-25.4.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/__init__.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/converters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/exceptions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/filters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/setters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/attrs/validators.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi-2026.1.4.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi/__main__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi/cacert.pem
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi/core.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/certifi/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/entry_points.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer-3.4.4.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/__main__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/api.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/cd.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/cli/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/cli/__main__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/constant.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/legacy.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/md.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/md.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/md__mypyc.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/models.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/charset_normalizer/version.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama-0.4.6.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama-0.4.6.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/ansi.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/ansitowin32.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/initialise.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/ansi_test.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/ansitowin32_test.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/initialise_test.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/isatty_test.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/tests/winterm_test.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/win32.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/colorama/winterm.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist-1.8.0.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist/__init__.pyi
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist/_frozenlist.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist/_frozenlist.pyx
测试组/脚本/Change_password/venv_build/Lib/site-packages/frozenlist/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna-3.11.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna-3.11.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna-3.11.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna-3.11.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna-3.11.dist-info/licenses/LICENSE.md
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/codec.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/compat.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/core.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/idnadata.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/intranges.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/package_data.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/idna/uts46data.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict-6.7.0.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/_abc.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/_compat.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/_multidict.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/_multidict_py.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/multidict/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/licenses/NOTICE
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache-0.4.1.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/_helpers.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/_helpers_c.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/_helpers_c.pyx
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/_helpers_py.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/api.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/propcache/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/REQUESTED
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql-1.1.2.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/_auth.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/charset.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/connections.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/CLIENT.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/COMMAND.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/CR.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/ER.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/FIELD_TYPE.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/FLAG.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/SERVER_STATUS.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/constants/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/converters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/cursors.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/err.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/optionfile.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/protocol.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/pymysql/times.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/REQUESTED
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests-2.32.5.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/__version__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/_internal_utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/adapters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/api.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/auth.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/certs.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/compat.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/cookies.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/exceptions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/help.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/hooks.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/models.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/packages.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/sessions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/status_codes.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/structures.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/requests/utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/LICENCE
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/REQUESTED
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/entry_points.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm-4.67.1.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/__main__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_dist_ver.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_main.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_monitor.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_tqdm.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_tqdm_gui.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_tqdm_notebook.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_tqdm_pandas.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/_utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/asyncio.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/auto.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/autonotebook.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/cli.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/completion.sh
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/bells.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/concurrent.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/discord.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/itertools.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/logging.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/slack.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/telegram.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/contrib/utils_worker.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/dask.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/gui.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/keras.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/notebook.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/rich.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/std.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/tk.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/tqdm.1
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/utils.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/tqdm/version.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions-4.15.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions-4.15.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions-4.15.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions-4.15.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions-4.15.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/typing_extensions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3-2.6.2.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3-2.6.2.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3-2.6.2.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3-2.6.2.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3-2.6.2.dist-info/licenses/LICENSE.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/_base_connection.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/_collections.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/_request_methods.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/_version.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/connection.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/connectionpool.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/connection.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/fetch.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/request.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/emscripten/response.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/pyopenssl.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/contrib/socks.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/exceptions.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/fields.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/filepost.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/http2/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/http2/connection.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/http2/probe.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/poolmanager.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/py.typed
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/response.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/connection.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/proxy.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/request.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/response.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/retry.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/ssl_.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/ssl_match_hostname.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/ssltransport.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/timeout.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/url.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/util.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/urllib3/util/wait.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/INSTALLER
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/METADATA
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/RECORD
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/WHEEL
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/licenses/LICENSE
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/licenses/NOTICE
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl-1.22.0.dist-info/top_level.txt
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/__init__.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_parse.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_path.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_query.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_quoters.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_quoting.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_quoting_c.cp312-win_amd64.pyd
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_quoting_c.pyx
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_quoting_py.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/_url.py
测试组/脚本/Change_password/venv_build/Lib/site-packages/yarl/py.typed
测试组/脚本/Change_password/venv_build/Scripts/normalizer.exe
测试组/脚本/Change_password/venv_build/Scripts/tqdm.exe
测试组/脚本/造数脚本2/华东师范大学二期/并发入驻笼位.py |