hyb
2026-01-09 4cb426cb3ae31e772a09d4ade5b2f0242aaeefa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import struct
from enum import IntEnum
 
from types import ModuleType
from typing import Optional
 
from .KDF import _HKDF_extract, _HKDF_expand
from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
from Crypto.Util.strxor import strxor
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Hash import SHA256, SHA384, SHA512
from Crypto.Cipher import AES, ChaCha20_Poly1305
 
 
class MODE(IntEnum):
    """HPKE modes"""
    BASE = 0x00
    PSK = 0x01
    AUTH = 0x02
    AUTH_PSK = 0x03
 
 
class AEAD(IntEnum):
    """Authenticated Encryption with Associated Data (AEAD) Functions"""
    AES128_GCM = 0x0001
    AES256_GCM = 0x0002
    CHACHA20_POLY1305 = 0x0003
 
 
class DeserializeError(ValueError):
    pass
 
class MessageLimitReachedError(ValueError):
    pass
 
# CURVE to (KEM ID, KDF ID, HASH)
_Curve_Config = {
  "NIST P-256": (0x0010, 0x0001, SHA256),
  "NIST P-384": (0x0011, 0x0002, SHA384),
  "NIST P-521": (0x0012, 0x0003, SHA512),
  "Curve25519": (0x0020, 0x0001, SHA256),
  "Curve448":   (0x0021, 0x0003, SHA512),
}
 
 
def _labeled_extract(salt: bytes,
                     label: bytes,
                     ikm: bytes,
                     suite_id: bytes,
                     hashmod: ModuleType):
    labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
    return _HKDF_extract(salt, labeled_ikm, hashmod)
 
 
def _labeled_expand(prk: bytes,
                    label: bytes,
                    info: bytes,
                    L: int,
                    suite_id: bytes,
                    hashmod: ModuleType):
    labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
                   label + info
    return _HKDF_expand(prk, labeled_info, L, hashmod)
 
 
def _extract_and_expand(dh: bytes,
                        kem_context: bytes,
                        suite_id: bytes,
                        hashmod: ModuleType):
    Nsecret = hashmod.digest_size
 
    eae_prk = _labeled_extract(b"",
                               b"eae_prk",
                               dh,
                               suite_id,
                               hashmod)
 
    shared_secret = _labeled_expand(eae_prk,
                                    b"shared_secret",
                                    kem_context,
                                    Nsecret,
                                    suite_id,
                                    hashmod)
    return shared_secret
 
 
class HPKE_Cipher:
 
    def __init__(self,
                 receiver_key: EccKey,
                 enc: Optional[bytes],
                 sender_key: Optional[EccKey],
                 psk_pair: tuple[bytes, bytes],
                 info: bytes,
                 aead_id: AEAD,
                 mode: MODE):
 
        self.enc: bytes = b'' if enc is None else enc
        """The encapsulated session key."""
 
        self._verify_psk_inputs(mode, psk_pair)
 
        self._curve = receiver_key.curve
        self._aead_id = aead_id
        self._mode = mode
 
        try:
            self._kem_id, \
             self._kdf_id, \
             self._hashmod = _Curve_Config[self._curve]
        except KeyError as ke:
            raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke
 
        self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
        self._Nn = 12
        self._Nt = 16
        self._Nh = self._hashmod.digest_size
 
        self._encrypt = not receiver_key.has_private()
 
        if self._encrypt:
            # SetupBaseS (encryption)
            if enc is not None:
                raise ValueError("Parameter 'enc' cannot be an input  when sealing")
            shared_secret, self.enc = self._encap(receiver_key,
                                                  self._kem_id,
                                                  self._hashmod,
                                                  sender_key)
        else:
            # SetupBaseR (decryption)
            if enc is None:
                raise ValueError("Parameter 'enc' required when unsealing")
            shared_secret = self._decap(enc,
                                        receiver_key,
                                        self._kem_id,
                                        self._hashmod,
                                        sender_key)
 
        self._sequence = 0
        self._max_sequence = (1 << (8 * self._Nn)) - 1
 
        self._key, \
            self._base_nonce, \
            self._export_secret = self._key_schedule(shared_secret,
                                                     info,
                                                     *psk_pair)
 
    @staticmethod
    def _encap(receiver_key: EccKey,
               kem_id: int,
               hashmod: ModuleType,
               sender_key: Optional[EccKey] = None,
               eph_key: Optional[EccKey] = None):
 
        assert (sender_key is None) or sender_key.has_private()
        assert (eph_key is None) or eph_key.has_private()
 
        if eph_key is None:
            eph_key = ECC.generate(curve=receiver_key.curve)
        enc = eph_key.public_key().export_key(format='raw')
 
        pkRm = receiver_key.public_key().export_key(format='raw')
        kem_context = enc + pkRm
        extra_param = {}
        if sender_key:
            kem_context += sender_key.public_key().export_key(format='raw')
            extra_param = {'static_priv': sender_key}
 
        suite_id = b"KEM" + struct.pack('>H', kem_id)
 
        def kdf(dh,
                kem_context=kem_context,
                suite_id=suite_id,
                hashmod=hashmod):
            return _extract_and_expand(dh, kem_context, suite_id, hashmod)
 
        shared_secret = key_agreement(eph_priv=eph_key,
                                      static_pub=receiver_key,
                                      kdf=kdf,
                                      **extra_param)
        return shared_secret, enc
 
    @staticmethod
    def _decap(enc: bytes,
               receiver_key: EccKey,
               kem_id: int,
               hashmod: ModuleType,
               sender_key: Optional[EccKey] = None):
 
        assert receiver_key.has_private()
 
        try:
            if receiver_key.curve == 'Curve25519':
                pkE = import_x25519_public_key(enc)
            elif receiver_key.curve == 'Curve448':
                pkE = import_x448_public_key(enc)
            else:
                pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
        except ValueError as ve:
            raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve
 
        pkRm = receiver_key.public_key().export_key(format='raw')
        kem_context = enc + pkRm
        extra_param = {}
        if sender_key:
            kem_context += sender_key.public_key().export_key(format='raw')
            extra_param = {'static_pub': sender_key}
 
        suite_id = b"KEM" + struct.pack('>H', kem_id)
 
        def kdf(dh,
                kem_context=kem_context,
                suite_id=suite_id,
                hashmod=hashmod):
            return _extract_and_expand(dh, kem_context, suite_id, hashmod)
 
        shared_secret = key_agreement(eph_pub=pkE,
                                      static_priv=receiver_key,
                                      kdf=kdf,
                                      **extra_param)
        return shared_secret
 
    @staticmethod
    def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
        psk_id, psk = psk_pair
 
        if (psk == b'') ^ (psk_id == b''):
            raise ValueError("Inconsistent PSK inputs")
 
        if (psk == b''):
            if mode in (MODE.PSK, MODE.AUTH_PSK):
                raise ValueError(f"PSK is required with mode {mode.name}")
        else:
            if len(psk) < 32:
                raise ValueError("PSK must be at least 32 byte long")
            if mode in (MODE.BASE, MODE.AUTH):
                raise ValueError("PSK is not compatible with this mode")
 
    def _key_schedule(self,
                      shared_secret: bytes,
                      info: bytes,
                      psk_id: bytes,
                      psk: bytes):
 
        suite_id = b"HPKE" + struct.pack('>HHH',
                                         self._kem_id,
                                         self._kdf_id,
                                         self._aead_id)
 
        psk_id_hash = _labeled_extract(b'',
                                       b'psk_id_hash',
                                       psk_id,
                                       suite_id,
                                       self._hashmod)
 
        info_hash = _labeled_extract(b'',
                                     b'info_hash',
                                     info,
                                     suite_id,
                                     self._hashmod)
 
        key_schedule_context = self._mode.to_bytes(1, 'big') + psk_id_hash + info_hash
 
        secret = _labeled_extract(shared_secret,
                                  b'secret',
                                  psk,
                                  suite_id,
                                  self._hashmod)
 
        key = _labeled_expand(secret,
                              b'key',
                              key_schedule_context,
                              self._Nk,
                              suite_id,
                              self._hashmod)
 
        base_nonce = _labeled_expand(secret,
                                     b'base_nonce',
                                     key_schedule_context,
                                     self._Nn,
                                     suite_id,
                                     self._hashmod)
 
        exporter_secret = _labeled_expand(secret,
                                          b'exp',
                                          key_schedule_context,
                                          self._Nh,
                                          suite_id,
                                          self._hashmod)
 
        return key, base_nonce, exporter_secret
 
    def _new_cipher(self):
        nonce = strxor(self._base_nonce, self._sequence.to_bytes(self._Nn, 'big'))
        if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
            cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
        elif self._aead_id == AEAD.CHACHA20_POLY1305:
            cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
        else:
            raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
        if self._sequence >= self._max_sequence:
            raise MessageLimitReachedError()
        self._sequence += 1
        return cipher
 
    def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
        """Encrypt and authenticate a message.
 
        This method can be invoked multiple times
        to seal an ordered sequence of messages.
 
        Arguments:
          plaintext: bytes
            The message to seal.
          auth_data: bytes
            Optional. Additional Authenticated data (AAD) that is not encrypted
            but that will be also covered by the authentication tag.
 
        Returns:
           The ciphertext concatenated with the authentication tag.
        """
 
        if not self._encrypt:
            raise ValueError("This cipher can only be used to seal")
        cipher = self._new_cipher()
        if auth_data:
            cipher.update(auth_data)
        ct, tag = cipher.encrypt_and_digest(plaintext)
        return ct + tag
 
    def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
        """Decrypt a message and validate its authenticity.
 
        This method can be invoked multiple times
        to unseal an ordered sequence of messages.
 
        Arguments:
          cipertext: bytes
            The message to unseal.
          auth_data: bytes
            Optional. Additional Authenticated data (AAD) that
            was also covered by the authentication tag.
 
        Returns:
           The original plaintext.
 
        Raises: ValueError
           If the ciphertext (in combination with the AAD) is not valid.
 
           But if it is the first time you call ``unseal()`` this
           exception may also mean that any of the parameters or keys
           used to establish the session is wrong or that one is missing.
        """
 
        if self._encrypt:
            raise ValueError("This cipher can only be used to unseal")
        if len(ciphertext) < self._Nt:
            raise ValueError("Ciphertext is too small")
        cipher = self._new_cipher()
        if auth_data:
            cipher.update(auth_data)
 
        try:
            pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
                                           ciphertext[-self._Nt:])
        except ValueError:
            if self._sequence == 1:
                raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
            raise ValueError("Invalid message (wrong MAC tag)")
        return pt
 
 
def new(*, receiver_key: EccKey,
        aead_id: AEAD,
        enc: Optional[bytes] = None,
        sender_key: Optional[EccKey] = None,
        psk: Optional[tuple[bytes, bytes]] = None,
        info: Optional[bytes] = None) -> HPKE_Cipher:
    """Create an HPKE context which can be used:
 
    - by the sender to seal (encrypt) a message or
    - by the receiver to unseal (decrypt) it.
 
    As a minimum, the two parties agree on the receiver's asymmetric key
    (of which the sender will only know the public half).
 
    Additionally, for authentication purposes, they may also agree on:
 
    * the sender's asymmetric key (of which the receiver will only know the public half)
 
    * a shared secret (e.g., a symmetric key derived from a password)
 
    Args:
      receiver_key:
        The ECC key of the receiver.
        It must be on one of the following curves: ``NIST P-256``,
        ``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.
 
        If this is a **public** key, the HPKE context can only be used to
        **seal** (**encrypt**).
 
        If this is a **private** key, the HPKE context can only be used to
        **unseal** (**decrypt**).
 
      aead_id:
        The HPKE identifier of the symmetric cipher.
        The possible values are:
 
        * ``HPKE.AEAD.AES128_GCM``
        * ``HPKE.AEAD.AES256_GCM``
        * ``HPKE.AEAD.CHACHA20_POLY1305``
 
      enc:
        The encapsulated session key (i.e., the KEM shared secret).
 
        The receiver must always specify this parameter.
 
        The sender must always omit this parameter.
 
      sender_key:
        The ECC key of the sender.
        It must be on the same curve as the ``receiver_key``.
        If the ``receiver_key`` is a public key, ``sender_key`` must be a
        private key, and vice versa.
 
      psk:
        A Pre-Shared Key (PSK) as a 2-tuple of non-empty
        byte strings: the identifier and the actual secret value.
        Sender and receiver must use the same PSK (or none).
 
        The secret value must be at least 32 bytes long,
        but it  must not be a low-entropy password
        (use a KDF like PBKDF2 or scrypt to derive a secret
        from a password).
 
      info:
        A non-secret parameter that contributes
        to the generation of all session keys.
        Sender and receive must use the same **info** parameter (or none).
 
    Returns:
        An object that can be used for
        sealing (if ``receiver_key`` is a public key) or
        unsealing (if ``receiver_key`` is a private key).
        In the latter case,
        correctness of all the keys and parameters will only
        be assessed with the first call to ``unseal()``.
    """
 
    if aead_id not in AEAD:
        raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")
 
    curve = receiver_key.curve
    if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
                     'Curve25519', 'Curve448'):
        raise ValueError(f"Unsupported curve {curve}")
 
    if sender_key:
        count_private_keys = int(receiver_key.has_private()) + \
                             int(sender_key.has_private())
        if count_private_keys != 1:
            raise ValueError("Exactly 1 private key required")
        if sender_key.curve != curve:
            raise ValueError("Sender key uses {} but recipient key {}".
                             format(sender_key.curve, curve))
        mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
    else:
        mode = MODE.BASE if psk is None else MODE.PSK
 
    if psk is None:
        psk = b'', b''
 
    if info is None:
        info = b''
 
    return HPKE_Cipher(receiver_key,
                       enc,
                       sender_key,
                       psk,
                       info,
                       aead_id,
                       mode)