From cd531b08b6588aa94b7979ae9a448f10b32247d8 Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Fri, 1 Jul 2022 15:17:18 +0000 Subject: [PATCH 1/6] Started adding fieldsets Bugfix, concatenated code --- .gitignore | 3 +- requirements-test.txt | 3 +- src/sshkey_tools/cert.py | 6 +- src/sshkey_tools/fields.py | 811 +++++++++++++++++-------------------- src/sshkey_tools/utils.py | 51 ++- tests/test_certificates.py | 12 +- tests/test_utils.py | 197 +++++++++ 7 files changed, 622 insertions(+), 461 deletions(-) diff --git a/.gitignore b/.gitignore index 5c054b9..a1e902f 100644 --- a/.gitignore +++ b/.gitignore @@ -145,4 +145,5 @@ temptest.py oldsrc/* tempfolder report.html -test_certificate \ No newline at end of file +test_certificate +testing.py \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index b4811ed..36efcbe 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,4 +6,5 @@ paramiko coverage black pytest-cov -faker \ No newline at end of file +faker +cprint \ No newline at end of file diff --git a/src/sshkey_tools/cert.py b/src/sshkey_tools/cert.py index f427a06..d66e17e 100644 --- a/src/sshkey_tools/cert.py +++ b/src/sshkey_tools/cert.py @@ -25,10 +25,10 @@ CERTIFICATE_FIELDS = { "serial": _FIELD.SerialField, "cert_type": _FIELD.CertificateTypeField, - "key_id": _FIELD.KeyIDField, + "key_id": _FIELD.KeyIdField, "principals": _FIELD.PrincipalsField, - "valid_after": _FIELD.ValidityStartField, - "valid_before": _FIELD.ValidityEndField, + "valid_after": _FIELD.ValidAfterField, + "valid_before": _FIELD.ValidBeforeField, "critical_options": _FIELD.CriticalOptionsField, "extensions": _FIELD.ExtensionsField, } diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index 886ea9e..3a6d85f 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -2,9 +2,12 @@ Field types for SSH Certificates """ # pylint: disable=invalid-name,too-many-lines,arguments-differ +import re from enum import Enum +from types import NoneType, MethodType from typing import Union, Tuple -from datetime import datetime +from dataclasses import dataclass +from datetime import datetime, timedelta from struct import pack, unpack from base64 import b64encode from cryptography.hazmat.primitives.asymmetric.utils import ( @@ -29,14 +32,18 @@ from .utils import ( long_to_bytes, bytes_to_long, - generate_secure_nonce, - ensure_string, + generate_secure_nonce, + random_keyid, + random_serial, + ensure_string, + ensure_bytestring, concat_to_string ) MAX_INT32 = 2**32 MAX_INT64 = 2**64 +NEWLINE = '\n' ECDSA_CURVE_MAP = { "secp256r1": "nistp256", @@ -65,7 +72,6 @@ b"ed25519": "ED25519SignatureField", } - class CERT_TYPE(Enum): """ Certificate types, User certificate/Host certificate @@ -74,45 +80,92 @@ class CERT_TYPE(Enum): USER = 1 HOST = 2 - class CertificateField: """ The base class for certificate fields """ + IS_SET = None + DEFAULT = None + REQUIRED = False + DATA_TYPE = NoneType - is_set = None - - def __init__(self, value, name=None): - self.name = name + def __init__(self, value): self.value = value self.exception = None - self.is_set = True + self.IS_SET = True + self.name = self.get_name() def __str__(self): return f"{self.name}: {self.value}" - @staticmethod - def encode(value) -> bytes: + def __bytes__(self) -> bytes: + return self.encode(self.value) + + @classmethod + def get_name(cls) -> str: + return "_".join( + re.findall('[A-Z][^A-Z]*', cls.__class__.__name__)[:-1] + ).lower() + + @classmethod + def __validate_type(cls, value, do_raise: bool = False) -> Union[bool, Exception]: """ - Returns the encoded value of the field + Validate the data type of the value against the class data type """ + if not isinstance(value, cls.DATA_TYPE): + ex = _EX.InvalidDataException( + f"Invalid data type for {cls.get_name()} (expected {cls.DATA_TYPE}, got {type(value)})" + ) + + if do_raise: + raise ex + + return ex - @staticmethod - def decode(data: bytes) -> tuple: + return True + + def __validate_required(self) -> Union[bool, Exception]: """ - Returns the decoded value of the field + Validates if the field is set when required """ + if self.DEFAULT == self.value == None: + return _EX.InvalidFieldDataException( + f"{self.get_name()} is a required field" + ) + return True - def __bytes__(self) -> bytes: - return self.encode(self.value) + def __validate_value(self) -> Union[bool, Exception]: + """ + Validates the contents of the field + Meant to be overridden by child classes + """ + return True # pylint: disable=no-self-use - def validate(self) -> Union[bool, Exception]: + def validate(self) -> bool: """ Validates the field """ - return True + self.exception = ( + self.__validate_type(self.value), + self.__validate_required(), + self.__validate_value() + ) + + return self.exception == [True, True, True] + + @staticmethod + def decode(cls, data: bytes) -> tuple: + """ + Returns the decoded value of the field + """ + @classmethod + def encode(cls, value) -> bytes: + """ + Returns the encoded value of the field + """ + @classmethod def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: """ @@ -123,15 +176,30 @@ def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: """ value, data = cls.decode(data) return cls(value), data + + @classmethod + def factory(cls) -> 'CertificateField': + """ + Factory to create field with default value if set, otherwise empty + Returns: + CertificateField: A new CertificateField subclass instance + """ + if callable(cls.DEFAULT): + return cls(cls.DEFAULT()) + if cls.DEFAULT is None: + return cls + + return cls(cls.DEFAULT) class BooleanField(CertificateField): """ - Field representing a boolean value (True/False) + Field representing a boolean value (True/False) or (1/0) """ + DATA_TYPE = (bool, int) - @staticmethod - def encode(value: bool) -> bytes: + @classmethod + def encode(cls, value: Union[int, bool]) -> bytes: """ Encodes a boolean value to a byte string @@ -141,11 +209,7 @@ def encode(value: bool) -> bytes: Returns: bytes: Packed byte representing the boolean """ - if not isinstance(value, bool): - raise _EX.InvalidFieldDataException( - f"Expected bool, got {value.__class__.__name__}" - ) - + cls.__validate_type(value, True) return pack("B", 1 if value else 0) @staticmethod @@ -157,26 +221,24 @@ def decode(data: bytes) -> Tuple[bool, bytes]: data (bytes): The byte string starting with an encoded boolean """ return bool(unpack("B", data[:1])[0]), data[1:] - - def validate(self) -> Union[bool, Exception]: + + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if self.value not in [True, False]: - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a boolean" - ) - - return True - + return True if self.value in (True, False, 1, 0) else _EX.InvalidFieldDataException( + f"{self.get_name()} must be a boolean (True/1 or False/0)" + ) class BytestringField(CertificateField): """ Field representing a bytestring value """ + DATA_TYPE = (bytes, str) + DEFAULT = b"" - @staticmethod - def encode(value: bytes) -> bytes: + @classmethod + def encode(cls, value: bytes) -> bytes: """ Encodes a string or bytestring into a packed byte string @@ -187,12 +249,8 @@ def encode(value: bytes) -> bytes: Returns: bytes: Packed byte string containing the source data """ - if not isinstance(value, bytes): - raise _EX.InvalidFieldDataException( - f"Expected bytes, got {value.__class__.__name__}" - ) - - return pack(">I", len(value)) + value + cls.__validate_type(value, True) + return pack(">I", len(value)) + ensure_bytestring(value) @staticmethod def decode(data: bytes) -> Tuple[bytes, bytes]: @@ -207,27 +265,17 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: string and remainder of the data """ length = unpack(">I", data[:4])[0] + 4 - return data[4:length], data[length:] - - def validate(self) -> Union[bool, Exception]: - """ - Validate the field data - """ - if not isinstance(self.value, bytes): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a bytestring" - ) - - return True - + return ensure_bytestring(data[4:length]), data[length:] class StringField(BytestringField): """ Field representing a string value """ + DATA_TYPE = (str, bytes) + DEFAULT = "" - @staticmethod - def encode(value: str, encoding: str = "utf-8"): + @classmethod + def encode(cls, value: str, encoding: str = "utf-8"): """ Encodes a string or bytestring into a packed byte string @@ -238,11 +286,8 @@ def encode(value: str, encoding: str = "utf-8"): Returns: bytes: Packed byte string containing the source data """ - if not isinstance(value, str): - raise _EX.InvalidFieldDataException( - f"Expected str, got {value.__class__.__name__}" - ) - return BytestringField.encode(value.encode(encoding)) + cls.__validate_type(value, True) + return BytestringField.encode(ensure_bytestring(encoding)) @staticmethod def decode(data: bytes, encoding: str = "utf-8") -> Tuple[str, bytes]: @@ -258,27 +303,17 @@ def decode(data: bytes, encoding: str = "utf-8") -> Tuple[str, bytes]: """ value, data = BytestringField.decode(data) - return value.decode(encoding), data - - def validate(self) -> Union[bool, Exception]: - """ - Validate the field data - """ - if not isinstance(self.value, str): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a string" - ) - - return True - + return value.decode(encoding), data class Integer32Field(CertificateField): """ Certificate field representing a 32-bit integer """ + DATA_TYPE = int + DEFAULT = 0 - @staticmethod - def encode(value: int) -> bytes: + @classmethod + def encode(cls, value: int) -> bytes: """Encodes a 32-bit integer value to a packed byte string Args: @@ -287,11 +322,7 @@ def encode(value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - if not isinstance(value, int): - raise _EX.InvalidFieldDataException( - f"Expected int, got {value.__class__.__name__}" - ) - + cls.__validate_type(value, True) return pack(">I", value) @staticmethod @@ -306,30 +337,23 @@ def decode(data: bytes) -> Tuple[int, bytes]: """ return int(unpack(">I", data[:4])[0]), data[4:] - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if not isinstance(self.value, int): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not an integer" - ) - - if self.value > MAX_INT32: - return _EX.IntegerOverflowException( - f"Passed value {self.value} is too large for a 32-bit integer" - ) - - return True - + return True if self.value < MAX_INT32 else _EX.InvalidFieldDataException( + f"{self.get_name()} must be a 32-bit integer" + ) class Integer64Field(CertificateField): """ Certificate field representing a 64-bit integer """ + DATA_TYPE = int + DEFAULT = 0 - @staticmethod - def encode(value: int) -> bytes: + @classmethod + def encode(cls, value: int) -> bytes: """Encodes a 64-bit integer value to a packed byte string Args: @@ -338,11 +362,7 @@ def encode(value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - if not isinstance(value, int): - raise _EX.InvalidFieldDataException( - f"Expected int, got {value.__class__.__name__}" - ) - + cls.__validate_type(value, True) return pack(">Q", value) @staticmethod @@ -357,21 +377,13 @@ def decode(data: bytes) -> Tuple[int, bytes]: """ return int(unpack(">Q", data[:8])[0]), data[8:] - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if not isinstance(self.value, int): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not an integer" - ) - - if self.value > MAX_INT64: - return _EX.IntegerOverflowException( - f"Passed value {self.value} is too large for a 64-bit integer" - ) - - return True + return True if self.value < MAX_INT32 else _EX.InvalidFieldDataException( + f"{self.get_name()} must be a 32-bit integer" + ) class DateTimeField(Integer64Field): @@ -379,9 +391,11 @@ class DateTimeField(Integer64Field): Certificate field representing a datetime value. The value is saved as a 64-bit integer (unix timestamp) """ + DATA_TYPE = (datetime, int) + DEFAULT = datetime.now - @staticmethod - def encode(value: Union[datetime, int]) -> bytes: + @classmethod + def encode(cls, value: Union[datetime, int]) -> bytes: """Encodes a datetime object to a byte string Args: @@ -390,10 +404,7 @@ def encode(value: Union[datetime, int]) -> bytes: Returns: bytes: Packed byte string containing datetime timestamp """ - if not isinstance(value, (datetime, int)): - raise _EX.InvalidFieldDataException( - f"Expected datetime, got {value.__class__.__name__}" - ) + cls.__validate_type(value, True) if isinstance(value, datetime): value = int(value.timestamp()) @@ -411,30 +422,18 @@ def decode(data: bytes) -> datetime: tuple: Tuple with datetime and remainder of data """ timestamp, data = Integer64Field.decode(data) - return datetime.fromtimestamp(timestamp), data - def validate(self) -> Union[bool, Exception]: - """ - Validate the field data - """ - if not isinstance(self.value, (datetime, int)): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a datetime object" - ) - - return True - - class MpIntegerField(BytestringField): """ Certificate field representing a multiple precision integer, an integer too large to fit in 64 bits. """ + DATA_TYPE = int + DEFAULT = 0 - @staticmethod - # pylint: disable=arguments-differ - def encode(value: int) -> bytes: + @classmethod + def encode(cls, value: int) -> bytes: """ Encodes a multiprecision integer (integer larger than 64bit) into a packed byte string @@ -445,11 +444,7 @@ def encode(value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - if not isinstance(value, int): - raise _EX.InvalidFieldDataException( - f"Expected int, got {value.__class__.__name__}" - ) - + cls.__validate_type(value, True) return BytestringField.encode(long_to_bytes(value)) @staticmethod @@ -465,25 +460,15 @@ def decode(data: bytes) -> Tuple[int, bytes]: mpint, data = BytestringField.decode(data) return bytes_to_long(mpint), data - def validate(self) -> Union[bool, Exception]: - """ - Validate the field data - """ - if not isinstance(self.value, int): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not an integer" - ) - - return True - - class ListField(CertificateField): """ Certificate field representing a list or tuple of strings """ + DATA_TYPE = (list, set, tuple) + DEFAULT = [] - @staticmethod - def encode(value: Union[list, tuple, set]) -> bytes: + @classmethod + def encode(cls, value: Union[list, tuple, set]) -> bytes: """Encodes a list or tuple to a byte string Args: @@ -493,10 +478,7 @@ def encode(value: Union[list, tuple, set]) -> bytes: Returns: bytes: Packed byte string containing the source data """ - if not isinstance(value, (list, tuple, set)): - raise _EX.InvalidFieldDataException( - f"Expected (list, tuple, set), got {value.__class__.__name__}" - ) + cls.__validate_type(value, True) try: if sum([not isinstance(item, (str, bytes)) for item in value]) > 0: @@ -526,31 +508,27 @@ def decode(data: bytes) -> Tuple[list, bytes]: return ensure_string(decoded), data - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if not isinstance(self.value, (list, tuple)): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a list/tuple" - ) - - if sum([not isinstance(item, (str, bytes)) for item in self.value]) > 0: + if not all((isinstance(val, (str, bytes)) for val in self.value)): return _EX.InvalidFieldDataException( "Expected list or tuple containing strings or bytes" ) return True - class KeyValueField(CertificateField): """ Certificate field representing a list or integer in python, separated in byte-form by null-bytes. """ + DATA_TYPE = (list, tuple, set, dict) + DEFAULT = {} - @staticmethod - def encode(value: Union[list, tuple, dict, set]) -> bytes: + @classmethod + def encode(cls, value: Union[list, tuple, dict, set]) -> bytes: """ Encodes a dict, set, list or tuple into a key-value byte string. If a set, list or tuple is provided, the items are considered keys @@ -562,10 +540,7 @@ def encode(value: Union[list, tuple, dict, set]) -> bytes: Returns: bytes: Packed byte string containing the source data """ - if not isinstance(value, (list, tuple, dict, set)): - raise _EX.InvalidFieldDataException( - f"Expected (list, tuple, dict, set), got {value.__class__.__name__}" - ) + cls.__validate_type(value, True) if not isinstance(value, dict): value = {item: "" for item in value} @@ -576,7 +551,7 @@ def encode(value: Union[list, tuple, dict, set]) -> bytes: list_data += StringField.encode(key) item = ( - StringField.encode(item) + StringField.encode("") if item in ["", b""] else ListField.encode( [item] if isinstance(item, (str, bytes)) else item @@ -608,94 +583,78 @@ def decode(data: bytes) -> Tuple[dict, bytes]: decoded[key] = value + decoded = ensure_string(decoded) + if "".join(decoded.values()) == "": return list(decoded.keys()), data - return ensure_string(decoded), data + return decoded, data - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if not isinstance(self.value, (list, tuple, dict, set)): + testvals = ( + self.value if not isinstance(self.value, dict) + else list(self.value.keys()) + list(self.value.values()) + ) + + if not all((isinstance(val, (str, bytes)) for val in testvals)): return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a list/tuple/dict/set" - ) - - if isinstance(self.value, (dict)): - if ( - sum([not isinstance(item, (str, bytes)) for item in self.value.keys()]) - > 0 - ): - return _EX.InvalidFieldDataException( - "Expected a dict with string or byte keys" + "Expected dict, list, tuple, set with string or byte keys and values" ) - if ( - isinstance(self.value, (list, tuple, set)) - and sum([not isinstance(item, (str, bytes)) for item in self.value]) > 0 - ): - return _EX.InvalidFieldDataException( - "Expected list, tuple or set containing strings or bytes" - ) - return True - class PubkeyTypeField(StringField): """ Contains the certificate type, which is based on the public key type the certificate is created for, e.g. 'ssh-ed25519-cert-v01@openssh.com' for an ED25519 key """ - - def __init__(self, value: str): - super().__init__( - value=value, - name="pubkey_type", - ) - - def validate(self) -> Union[bool, Exception]: + DEFAULT = None + DATA_TYPE = (str, bytes) + ALLOWED_VALUES = ( + "ssh-rsa-cert-v01@openssh.com", + "rsa-sha2-256-cert-v01@openssh.com", + "rsa-sha2-512-cert-v01@openssh.com", + "ssh-dss-cert-v01@openssh.com", + "ecdsa-sha2-nistp256-cert-v01@openssh.com", + "ecdsa-sha2-nistp384-cert-v01@openssh.com", + "ecdsa-sha2-nistp521-cert-v01@openssh.com", + "ssh-ed25519-cert-v01@openssh.com", + ) + + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ - if self.value not in ( - "ssh-rsa-cert-v01@openssh.com", - "rsa-sha2-256-cert-v01@openssh.com", - "rsa-sha2-512-cert-v01@openssh.com", - "ssh-dss-cert-v01@openssh.com", - "ecdsa-sha2-nistp256-cert-v01@openssh.com", - "ecdsa-sha2-nistp384-cert-v01@openssh.com", - "ecdsa-sha2-nistp521-cert-v01@openssh.com", - "ssh-ed25519-cert-v01@openssh.com", - ): - return _EX.InvalidDataException(f"Invalid pubkey type: {self.value}") - + if ensure_string(self.value) not in self.ALLOWED_VALUES: + return _EX.InvalidFieldDataException( + "Expected one of the following values: {}".format( + NEWLINE.join(self.ALLOWED_VALUES) + ) + ) + return True - class NonceField(StringField): """ Contains the nonce for the certificate, randomly generated this protects the integrity of the private key, especially for ecdsa. """ + DEFAULT = generate_secure_nonce + DATA_TYPE = (str, bytes) - def __init__(self, value: str = None): - super().__init__( - value=value if value is not None else generate_secure_nonce(), name="nonce" - ) - - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate the field data + Validates the contents of the field """ if len(self.value) < 32: - self.exception = _EX.InsecureNonceException( - "Nonce should be at least 32 bytes long to be secure. " - + "This is especially important for ECDSA" + return _EX.InvalidFieldDataException( + "Expected a nonce of at least 32 bytes" ) - return False return True @@ -705,9 +664,8 @@ class PublicKeyField(CertificateField): Contains the subject (User or Host) public key for whom/which the certificate is created. """ - - def __init__(self, value: PublicKey): - super().__init__(value=value, name="public_key") + DEFAULT = None + DATA_TYPE = PublicKey def __str__(self) -> str: return " ".join( @@ -717,8 +675,8 @@ def __str__(self) -> str: ] ) - @staticmethod - def encode(value: PublicKey) -> bytes: + @classmethod + def encode(cls, value: PublicKey) -> bytes: """ Encode the certificate field to a byte string @@ -728,11 +686,7 @@ def encode(value: PublicKey) -> bytes: Returns: bytes: A byte string with the encoded public key """ - if not isinstance(value, PublicKey): - raise _EX.InvalidFieldDataException( - f"Expected PublicKey, got {value.__class__.__name__}" - ) - + cls.__validate_type(value, True) return BytestringField.decode(value.raw_bytes())[1] @staticmethod @@ -762,6 +716,8 @@ class RSAPubkeyField(PublicKeyField): """ Holds the RSA Public Key for RSA Certificates """ + DEFAULT = None + DATA_TYPE = RSAPublicKey @staticmethod def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: @@ -780,22 +736,12 @@ def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: return RSAPublicKey.from_numbers(e=e, n=n), data - def validate(self) -> Union[bool, Exception]: - """ - Validates that the field data is a valid RSA Public Key - """ - if not isinstance(self.value, RSAPublicKey): - return _EX.InvalidFieldDataException( - "This public key class is not valid for use in a certificate" - ) - - return True - - class DSAPubkeyField(PublicKeyField): """ Holds the DSA Public Key for DSA Certificates """ + DEFAULT = None + DATA_TYPE = DSAPublicKey @staticmethod def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: @@ -816,22 +762,12 @@ def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data - def validate(self) -> Union[bool, Exception]: - """ - Validates that the field data is a valid DSA Public Key - """ - if not isinstance(self.value, DSAPublicKey): - return _EX.InvalidFieldDataException( - "This public key class is not valid for use in a certificate" - ) - - return True - - class ECDSAPubkeyField(PublicKeyField): """ Holds the ECDSA Public Key for ECDSA Certificates """ + DEFAULT = None + DATA_TYPE = ECDSAPublicKey @staticmethod def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: @@ -863,22 +799,12 @@ def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: data, ) - def validate(self) -> Union[bool, Exception]: - """ - Validates that the field data is a valid ECDSA Public Key - """ - if not isinstance(self.value, ECDSAPublicKey): - return _EX.InvalidFieldDataException( - "This public key class is not valid for use in a certificate" - ) - - return True - - class ED25519PubkeyField(PublicKeyField): """ Holds the ED25519 Public Key for ED25519 Certificates """ + DEFAULT = None + DATA_TYPE = ED25519PublicKey @staticmethod def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: @@ -896,27 +822,14 @@ def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: return ED25519PublicKey.from_raw_bytes(pubkey), data - def validate(self) -> Union[bool, Exception]: - """ - Validates that the field data is a valid ED25519 Public Key - """ - if not isinstance(self.value, ED25519PublicKey): - return _EX.InvalidFieldDataException( - "This public key class is not valid for use in a certificate" - ) - - return True - class SerialField(Integer64Field): """ Contains the numeric serial number of the certificate, maximum is (2**64)-1 """ - - def __init__(self, value: int): - super().__init__(value=value, name="serial") - + DEFAULT = random_serial + DATA_TYPE = int class CertificateTypeField(Integer32Field): """ @@ -924,14 +837,17 @@ class CertificateTypeField(Integer32Field): User certificate: CERT_TYPE.USER/1 Host certificate: CERT_TYPE.HOST/2 """ + DEFAULT = CERT_TYPE.USER + DATA_TYPE = Union[CERT_TYPE, int] + ALLOWED_VALUES = ( + CERT_TYPE.USER, + CERT_TYPE.HOST, + 1, + 2 + ) - def __init__(self, value: Union[CERT_TYPE, int]): - super().__init__( - value=value.value if isinstance(value, CERT_TYPE) else value, name="type" - ) - - @staticmethod - def encode(value: Union[CERT_TYPE, int]) -> bytes: + @classmethod + def encode(cls, value: Union[CERT_TYPE, int]) -> bytes: """ Encode the certificate type field to a byte string @@ -941,81 +857,69 @@ def encode(value: Union[CERT_TYPE, int]) -> bytes: Returns: bytes: A byte string with the encoded public key """ - if not isinstance(value, (CERT_TYPE, int)): - raise _EX.InvalidFieldDataException( - f"Expected (CERT_TYPE, int), got {value.__class__.__name__}" - ) + cls.__validate_type(value, True) if isinstance(value, CERT_TYPE): value = value.value return Integer32Field.encode(value) - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validates that the field contains a valid type + Validates the contents of the field """ - if not isinstance(self.value, (CERT_TYPE, int)): - return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not an integer" + if self.value not in self.ALLOWED_VALUES: + return _EX.InvalidCertificateFieldException( + f"The certificate type is invalid (expected {','.join(self.ALLOWED_VALUES)})" ) - - if not isinstance(self.value, CERT_TYPE) and (self.value > 2 or self.value < 1): - return _EX.InvalidDataException( - "The certificate type is invalid (1: User, 2: Host)" - ) - + return True - -class KeyIDField(StringField): +class KeyIdField(StringField): """ Contains the key identifier (subject) of the certificate, alphanumeric string """ - - def __init__(self, value: str): - super().__init__(value=value, name="key_id") - - def validate(self) -> Union[bool, Exception]: - """ - Validates that the field is set and not empty - """ - if self.value in [None, False, "", " "]: - return _EX.InvalidDataException("You need to provide a Key ID") - - return super().validate() - + DEFAULT = random_keyid + DATA_TYPE = Union[str, bytes] class PrincipalsField(ListField): """ Contains a list of principals for the certificate, - e.g. SERVERHOSTNAME01 or all-web-servers + e.g. SERVERHOSTNAME01 or all-web-servers. + If no principals are added, the certificate is valid + only for servers that have no allowed principals specified """ + DEFAFULT = [] + DATA_TYPE = Union[list, set, tuple] - def __init__(self, value: Union[list, tuple]): - super().__init__(value=list(value), name="principals") - - -class ValidityStartField(DateTimeField): +class ValidAfterField(DateTimeField): """ Contains the start of the validity period for the certificate, represented by a datetime object """ + DEFAULT = datetime.now() + DATA_TYPE = Union[datetime, int] - def __init__(self, value: datetime): - super().__init__(value=value, name="valid_after") - - -class ValidityEndField(DateTimeField): +class ValidBeforeField(DateTimeField): """ Contains the end of the validity period for the certificate, represented by a datetime object """ - - def __init__(self, value: datetime): - super().__init__(value=value, name="valid_before") - + DEFAULT = datetime.now() + timedelta(minutes=10) + DATA_TYPE = Union[datetime, int] + + def __validate_value(self) -> Union[bool, Exception]: + """ + Validates the contents of the field + """ + val = val if isinstance(val, datetime) else datetime.fromtimestamp(val) + if val < datetime.now(): + return _EX.InvalidCertificateFieldException( + f'The certificate validity period is invalid (expected a future datetime object)' + ) + + return True class CriticalOptionsField(KeyValueField): """ @@ -1033,34 +937,27 @@ class CriticalOptionsField(KeyValueField): If set to true, the user must verify their identity if using a hardware token """ + DEFAULT = [] + DATA_TYPE = Union[list, set, tuple, dict] + ALLOWED_VALUES = ( + "force-command", + "source-address", + "verify-required" + ) - def __init__(self, value: Union[list, tuple, dict]): - super().__init__(value=value, name="critical_options") - - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate that the field contains a valid list of options + Validates the contents of the field """ - valid_opts = ("force-command", "source-address", "verify-required") - - if not isinstance(self.value, (list, tuple, dict, set)): - return _EX.InvalidFieldDataException( - "You need to provide a list, tuple or set of strings or a dict" - ) - - if not all( - elem in valid_opts - for elem in ( - self.value.keys() if isinstance(self.value, dict) else self.value - ) - ): - return _EX.InvalidFieldDataException( - "You have provided invalid data to the critical options field" - ) + for elem in self.value if not isinstance(self.value, dict) else list(self.value.keys()): + if elem not in self.ALLOWED_VALUES: + return _EX.InvalidCertificateFieldException( + f"Critical option not recognized ({elem}){NEWLINE}" + + f"Valid options are {', '.join(self.ALLOWED_VALUES)}" + ) return True - class ExtensionsField(KeyValueField): """ Contains a list of extensions for the certificate, @@ -1088,26 +985,27 @@ class ExtensionsField(KeyValueField): Permits the user to use the user rc file """ + DEFAULT = [] + DATA_TYPE = Union[list, set, tuple, dict] + ALLOWED_VALUES = ( + "no-touch-required", + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-port-forwarding", + "permit-pty", + "permit-user-rc" + ) - def __init__(self, value: Union[list, tuple]): - super().__init__(value=value, name="extensions") - - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validates that the options provided are valid + Validates the contents of the field """ - valid_opts = ( - "no-touch-required", - "permit-X11-forwarding", - "permit-agent-forwarding", - "permit-port-forwarding", - "permit-pty", - "permit-user-rc", - ) - for item in self.value: - if item not in valid_opts: - return _EX.InvalidDataException(f"The extension '{item}' is invalid") + if item not in self.ALLOWED_VALUES: + return _EX.InvalidDataException( + f"Invalid extension '{item}'{NEWLINE}" + + f"Allowed values are: {NEWLINE.join(self.ALLOWED_VALUES)}" + ) return True @@ -1117,30 +1015,24 @@ class ReservedField(StringField): This field is reserved for future use, and doesn't contain any actual data, just an empty string. """ + DEFAULT = "" + DATA_TYPE = str - def __init__(self, value: str = ""): - super().__init__(value=value, name="reserved") - - def validate(self) -> Union[bool, Exception]: + def __validate_value(self) -> Union[bool, Exception]: """ - Validate that the field only contains an empty string + Validates the contents of the field """ - if self.value == "": - return True - - return _EX.InvalidDataException( - "The reserved field needs to be an empty string" + return True if self.value == "" else _EX.InvalidDataException( + f"The reserved field is not empty" ) - class CAPublicKeyField(BytestringField): """ Contains the public key of the certificate authority that is used to sign the certificate. """ - - def __init__(self, value: PublicKey): - super().__init__(value=value, name="ca_public_key") + DEFAULT = None + DATA_TYPE = Union[str, bytes] def __str__(self) -> str: return " ".join( @@ -1203,15 +1095,16 @@ class SignatureField(CertificateField): """ Creates and contains the signature of the certificate """ + DEFAULT = None + DATA_TYPE = bytes # pylint: disable=super-init-not-called def __init__(self, private_key: PrivateKey = None, signature: bytes = None): - self.name = "signature" self.private_key = private_key self.is_signed = False self.value = signature - if signature is not None and ensure_string(signature) not in ("", " "): + if signature is not None and ensure_bytestring(signature) not in ("", " "): self.is_signed = True @staticmethod @@ -1219,7 +1112,6 @@ def from_object(private_key: PrivateKey): """ Load a private key from a PrivateKey object - Args: private_key (PrivateKey): Private key to use for signing @@ -1281,6 +1173,8 @@ class RSASignatureField(SignatureField): """ Creates and contains the RSA signature from an RSA Private Key """ + DEFAULT = None + DATA_TYPE = bytes def __init__( self, @@ -1291,11 +1185,11 @@ def __init__( super().__init__(private_key, signature) self.hash_alg = hash_alg - @staticmethod + @classmethod # pylint: disable=arguments-renamed - def encode(signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: + def encode(cls, value: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: """ - Encodes the signature to a byte string + Encodes the value to a byte string Args: signature (bytes): The signature bytes to encode @@ -1305,13 +1199,10 @@ def encode(signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: Returns: bytes: The encoded byte string """ - if not isinstance(signature, bytes): - raise _EX.InvalidFieldDataException( - f"Expected bytes, got {signature.__class__.__name__}" - ) + cls.__validate_type(value, True) return BytestringField.encode( - StringField.encode(hash_alg.value[0]) + BytestringField.encode(signature) + StringField.encode(hash_alg.value[0]) + BytestringField.encode(value) ) @staticmethod @@ -1379,15 +1270,16 @@ class DSASignatureField(SignatureField): """ Creates and contains the DSA signature from an DSA Private Key """ + DEFAULT = None + DATA_TYPE = bytes def __init__( self, private_key: DSAPrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) - @staticmethod - # pylint: disable=arguments-renamed - def encode(signature: bytes): + @classmethod + def encode(cls, value: bytes): """ Encodes the signature to a byte string @@ -1397,12 +1289,9 @@ def encode(signature: bytes): Returns: bytes: The encoded byte string """ - if not isinstance(signature, bytes): - raise _EX.InvalidFieldDataException( - f"Expected bytes, got {signature.__class__.__name__}" - ) + cls.__validate_type(value, True) - r, s = decode_dss_signature(signature) + r, s = decode_dss_signature(value) return BytestringField.encode( StringField.encode("ssh-dss") @@ -1460,6 +1349,8 @@ class ECDSASignatureField(SignatureField): """ Creates and contains the ECDSA signature from an ECDSA Private Key """ + DEFAULT = None + DATA_TYPE = bytes def __init__( self, @@ -1475,9 +1366,8 @@ def __init__( self.curve = curve_name - @staticmethod - # pylint: disable=arguments-renamed - def encode(signature: bytes, curve_name: str = None) -> bytes: + @classmethod + def encode(cls, value: bytes, curve_name: str = None) -> bytes: """ Encodes the signature to a byte string @@ -1489,12 +1379,9 @@ def encode(signature: bytes, curve_name: str = None) -> bytes: Returns: bytes: The encoded byte string """ - if not isinstance(signature, bytes): - raise _EX.InvalidFieldDataException( - f"Expected bytes, got {signature.__class__.__name__}" - ) + cls.__validate_type(value, True) - r, s = decode_dss_signature(signature) + r, s = decode_dss_signature(value) return BytestringField.encode( StringField.encode(curve_name) @@ -1562,15 +1449,16 @@ class ED25519SignatureField(SignatureField): """ Creates and contains the ED25519 signature from an ED25519 Private Key """ + DEFAULT = None + DATA_TYPE = bytes def __init__( self, private_key: ED25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) - @staticmethod - # pylint: disable=arguments-renamed - def encode(signature: bytes) -> None: + @classmethod + def value(cls, value: bytes) -> None: """ Encodes the signature to a byte string @@ -1580,13 +1468,10 @@ def encode(signature: bytes) -> None: Returns: bytes: The encoded byte string """ - if not isinstance(signature, bytes): - raise _EX.InvalidFieldDataException( - f"Expected bytes, got {signature.__class__.__name__}" - ) + cls.__validate_type(value, True) return BytestringField.encode( - StringField.encode("ssh-ed25519") + BytestringField.encode(signature) + StringField.encode("ssh-ed25519") + BytestringField.encode(value) ) @staticmethod @@ -1632,3 +1517,55 @@ def sign(self, data: bytes, **kwargs) -> None: """ self.value = self.private_key.sign(data) self.is_signed = True + +# ADD: Default class for each field +# ADD: Typing class for each field + +@dataclass +class Fieldset: + def __setattr__(self, name, value): + field = getattr(self, name, None) + if callable(field) and not isinstance(field, CertificateField): + if field.__name__ == "factory": + super().__setattr__(name, field()) + self.__setattr__(name, value) + return + + if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': + super().__setattr__(name, field(value)) + return + + if getattr(value, '__name__', '') != 'factory': + field.value = value + super().__setattr__(name, field) + + def get(self, name: str): + field = getattr(self, name, None) + if field: + if isinstance(field, type): + return field.DEFAULT + return field.value + raise _EX.InvalidCertificateFieldException(f"Unknown field {name}") + +@dataclass +class CertificateHeaders(Fieldset): + public_key: PublicKeyField = PublicKeyField.factory + pubkey_type: PubkeyTypeField = PubkeyTypeField.factory + nonce: NonceField = NonceField.factory + +@dataclass +class CertificateFooter(Fieldset): + reserved: ReservedField = ReservedField.factory + ca_pubkey: CAPublicKeyField = CAPublicKeyField.factory + signature: SignatureField = SignatureField.factory + +@dataclass +class CertificateFields(Fieldset): + serial: SerialField = SerialField.factory + cert_type: CertificateTypeField = CertificateTypeField.factory + key_id: KeyIdField = KeyIdField.factory + principals: PrincipalsField = PrincipalsField.factory + valid_after: ValidAfterField = ValidAfterField.factory + valid_before: ValidBeforeField = ValidBeforeField.factory + critical_options: CriticalOptionsField = CriticalOptionsField.factory + extensions: ExtensionsField = ExtensionsField.factory \ No newline at end of file diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index 3da0a68..2ced2d1 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -2,15 +2,19 @@ Utilities for handling keys and certificates """ import sys +from types import NoneType from typing import Union, List, Dict from secrets import randbits +from random import randint +from uuid import uuid4 from base64 import b64encode import hashlib as hl def ensure_string( - obj: Union[str, bytes, list, tuple, set, dict], - encoding: str = 'utf-8' -) -> Union[str, List[str], Dict[str, str]]: + obj: Union[str, bytes, list, tuple, set, dict, NoneType], + encoding: str = 'utf-8', + required: bool = False +) -> Union[str, List[str], Dict[str, str], NoneType]: """Ensure the provided value is or contains a string/strings Args: @@ -20,8 +24,10 @@ def ensure_string( Returns: Union[str, List[str], Dict[str, str]]: Returns a string, list of strings or dictionary with strings """ - if isinstance(obj, (str, bytes)): - return obj.decode(encoding) if isinstance(obj, bytes) else obj + if (obj is None and not required) or isinstance(obj, str): + return obj + elif isinstance(obj, bytes): + return obj.decode(encoding) elif isinstance(obj, (list, tuple, set)): return [ensure_string(o, encoding) for o in obj] elif isinstance(obj, dict): @@ -30,9 +36,10 @@ def ensure_string( raise TypeError(f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}.") def ensure_bytestring( - obj: Union[str, bytes, list, tuple, set, dict], - encoding: str = 'utf-8' -) -> Union[str, List[str], Dict[str, str]]: + obj: Union[str, bytes, list, tuple, set, dict, NoneType], + encoding: str = 'utf-8', + required: bool = None +) -> Union[str, List[str], Dict[str, str], NoneType]: """Ensure the provided value is or contains a bytestring/bytestrings Args: @@ -42,8 +49,10 @@ def ensure_bytestring( Returns: Union[str, List[str], Dict[str, str]]: Returns a bytestring, list of bytestrings or dictionary with bytestrings """ - if isinstance(obj, (str, bytes)): - return obj.encode(encoding) if isinstance(obj, str) else obj + if (obj is None and not required) or isinstance(obj, bytes): + return obj + elif isinstance(obj, str): + return obj.encode(encoding) elif isinstance(obj, (list, tuple, set)): return [ensure_bytestring(o, encoding) for o in obj] elif isinstance(obj, dict): @@ -61,9 +70,9 @@ def concat_to_string(*strs, encoding: str = 'utf-8') -> str: Returns: str: Concatenated string """ - return ''.join(ensure_string(st, encoding) for st in strs) + return ''.join(st if st is not None else "" for st in ensure_string(strs, encoding)) -def concat_to_bytes(*strs, encoding: str = 'utf-8') -> bytes: +def concat_to_bytestring(*strs, encoding: str = 'utf-8') -> bytes: """Concatenates a list of strings or bytestrings to a single bytestring. Args: @@ -73,7 +82,23 @@ def concat_to_bytes(*strs, encoding: str = 'utf-8') -> bytes: Returns: bytes: Concatenated bytestring """ - return ''.join(ensure_bytestring(st, encoding) for st in strs) + return b"".join(st if st is not None else b"" for st in ensure_bytestring(strs, encoding=encoding)) + +def random_keyid() -> str: + """Generates a random Key ID + + Returns: + str: Random keyid + """ + return uuid4() + +def random_serial() -> str: + """ Generates a random serial number + + Returns: + int: Random serial + """ + return randint(0, 2**64-1) def long_to_bytes( source_int: int, force_length: int = None, byteorder: str = "big" diff --git a/tests/test_certificates.py b/tests/test_certificates.py index 5ad7205..16a661c 100644 --- a/tests/test_certificates.py +++ b/tests/test_certificates.py @@ -432,12 +432,12 @@ def test_invalid_certificate_field(self): def test_key_id_field(self): self.assertRandomResponse( - _FIELD.KeyIDField, + _FIELD.KeyIdField, random_function=lambda : self.faker.pystr(8, 128) ) def test_invalid_key_id_field(self): - field = _FIELD.KeyIDField('') + field = _FIELD.KeyIdField('') self.assertIsInstance( field.validate(), @@ -474,12 +474,12 @@ def test_invalid_principals_field(self): def test_validity_start_field(self): self.assertRandomResponse( - _FIELD.ValidityStartField, + _FIELD.ValidAfterField, random_function=lambda : self.faker.date_time() ) def test_invalid_validity_start_field(self): - field = _FIELD.ValidityStartField(ValueError) + field = _FIELD.ValidAfterField(ValueError) self.assertIsInstance( field.validate(), @@ -491,12 +491,12 @@ def test_invalid_validity_start_field(self): def test_validity_end_field(self): self.assertRandomResponse( - _FIELD.ValidityEndField, + _FIELD.ValidBeforeField, random_function=lambda : self.faker.date_time() ) def test_invalid_validity_end_field(self): - field = _FIELD.ValidityEndField(ValueError) + field = _FIELD.ValidBeforeField(ValueError) self.assertIsInstance( field.validate(), diff --git a/tests/test_utils.py b/tests/test_utils.py index 6421ef2..d45bedb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import unittest +import faker from random import randint import src.sshkey_tools.utils as utils @@ -29,6 +30,202 @@ ( b'XKHCjGvvtJgPdDkGSGDyPkZDBif^BZ', '7e:e4:29:df:b4:77:2a:6f:5d:eb:9f:25:8e:bd:45:b6', 'BmF8Pt4E/z0M8/rMy6mXkUpVlDG9Zje5+KA3dBVIR7c', 'I9lVFWPx6YkwnsZRMf21TPFvquV59+ng11F3EFFhDLKHrK/l6cdGQu1K0idWQQfRK44d77z00TK/aKmPmgZ2kg' ), ] +class TestStringBytestringConversion(unittest.TestCase): + def setUp(self): + self.faker = faker.Faker() + + def test_ensure_string(self): + self.assertEqual( + utils.ensure_string(""), + utils.ensure_string(b"") + ) + + self.assertEqual( + None, + utils.ensure_string(None) + ) + + for _ in range(100): + val = self.faker.pystr() + + self.assertEqual( + utils.ensure_string(val), + utils.ensure_string(val.encode('utf-8')) + ) + + + + lst = list([self.faker.pystr() for _ in range(10)]) + lst_byt = list([x.encode('utf-8') if randint(0, 1) == 1 else x for x in lst]) + + self.assertIsInstance( + utils.ensure_string(lst), + list + ) + self.assertEqual( + lst, + utils.ensure_string(lst), + utils.ensure_string(lst_byt) + ) + + tpl = tuple(self.faker.pystr() for _ in range(10)) + tpl_byt = tuple(x.encode('utf-8') if randint(0, 1) == 1 else x for x in tpl) + + self.assertIsInstance( + utils.ensure_string(tpl), + list + ) + + self.assertEqual( + list(tpl), + utils.ensure_string(tpl), + utils.ensure_string(tpl_byt) + ) + + st = set(self.faker.pystr() for _ in range(10)) + st_byt = set(x.encode('utf-8') if randint(0, 1) == 1 else x for x in st) + + self.assertIsInstance( + utils.ensure_string(st), + list + ) + self.assertEqual( + list(st), + utils.ensure_string(st), + utils.ensure_string(st_byt) + ) + + dct = dict({self.faker.pystr(): self.faker.pystr() for _ in range(10)}) + dct_byt = dict({x.encode('utf-8') if randint(0, 1) == 1 else x: y.encode('utf-8') if randint(0, 1) == 1 else y for x, y in dct.items()}) + + self.assertIsInstance( + utils.ensure_string(dct), + dict + ) + self.assertEqual( + dct, + utils.ensure_string(dct), + utils.ensure_string(dct_byt) + ) + + def test_ensure_bytestring(self): + self.assertEqual( + utils.ensure_bytestring(""), + utils.ensure_bytestring(b"") + ) + + self.assertEqual( + None, + utils.ensure_bytestring(None) + ) + + for _ in range(100): + val = self.faker.pystr().encode('utf-8') + + self.assertEqual( + utils.ensure_bytestring(val), + utils.ensure_bytestring(val.decode('utf-8')) + ) + + + + lst = list([self.faker.pystr().encode('utf-8') for _ in range(10)]) + lst_byt = list([x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst]) + + self.assertIsInstance( + utils.ensure_bytestring(lst), + list + ) + self.assertEqual( + lst, + utils.ensure_bytestring(lst), + utils.ensure_bytestring(lst_byt) + ) + + tpl = tuple(self.faker.pystr().encode('utf-8') for _ in range(10)) + tpl_byt = tuple(x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst) + + self.assertIsInstance( + utils.ensure_bytestring(tpl), + list + ) + + self.assertEqual( + list(tpl), + utils.ensure_bytestring(tpl), + utils.ensure_bytestring(tpl_byt) + ) + + st = set(self.faker.pystr().encode('utf-8') for _ in range(10)) + st_byt = set(x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst) + + self.assertIsInstance( + utils.ensure_bytestring(st), + list + ) + self.assertEqual( + list(st), + utils.ensure_bytestring(st), + utils.ensure_bytestring(st_byt) + ) + + dct = dict({self.faker.pystr().encode('utf-8'): self.faker.pystr().encode('utf-8') for _ in range(10)}) + dct_byt = dict({x.decode('utf-8') if randint(0, 1) == 1 else x: y.decode('utf-8') if randint(0, 1) == 1 else y for x, y in dct.items()}) + + self.assertIsInstance( + utils.ensure_bytestring(dct), + dict + ) + self.assertEqual( + dct, + utils.ensure_bytestring(dct), + utils.ensure_bytestring(dct_byt) + ) + + def test_concat_to_string(self): + self.assertEqual( + utils.concat_to_string(""), + utils.concat_to_string(b"") + ) + self.assertEqual( + utils.concat_to_string( + None + ), + "" + ) + + for _ in range(100): + strs = [self.faker.pystr() for _ in range(randint(10, 100))] + strs_byt = [x.encode('utf-8') if randint(0, 1) == 1 else x for x in strs] + + self.assertEqual( + utils.concat_to_string(*strs), + utils.concat_to_string(*strs_byt), + "".join(strs) + ) + + def test_concat_to_bytestring(self): + self.assertEqual( + utils.concat_to_bytestring(b""), + utils.concat_to_bytestring("") + ) + self.assertEqual( + utils.concat_to_bytestring( + None + ), + b"" + ) + + for _ in range(100): + byts = [self.faker.pystr().encode('utf-8') for _ in range(randint(10, 100))] + byts_str = [x.decode('utf-8') if randint(0, 1) == 1 else x for x in byts] + + self.assertEqual( + utils.concat_to_bytestring(*byts), + utils.concat_to_bytestring(*byts_str), + b"".join(byts) + ) + class TestLongConversion(unittest.TestCase): def test_expected_deflation(self): """ From b9cf45f245ae6a49af45f6919d7a0591bfa12b60 Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Fri, 8 Jul 2022 15:25:58 +0000 Subject: [PATCH 2/6] Certificate fieldsets insteas of dicts for header, fields, footer --- .gitignore | 4 +- README.md | 46 +-- docs/cert.html | 168 ++++---- docs/fields.html | 168 ++++---- docs/index.html | 46 +-- docs/keys.html | 450 ++++++++++----------- src/sshkey_tools/cert.py | 777 +++++++++++------------------------- src/sshkey_tools/certold.py | 588 +++++++++++++++++++++++++++ src/sshkey_tools/fields.py | 370 +++++++++-------- src/sshkey_tools/keys.py | 84 ++-- src/sshkey_tools/utils.py | 5 +- testcert | 1 + tests/test_certificates.py | 327 ++++++++------- tests/test_keypairs.py | 208 +++++----- 14 files changed, 1811 insertions(+), 1431 deletions(-) create mode 100644 src/sshkey_tools/certold.py create mode 100644 testcert diff --git a/.gitignore b/.gitignore index a1e902f..ef8570c 100644 --- a/.gitignore +++ b/.gitignore @@ -146,4 +146,6 @@ oldsrc/* tempfolder report.html test_certificate -testing.py \ No newline at end of file +testing.py +.idea +core \ No newline at end of file diff --git a/README.md b/README.md index 9be5041..57ba685 100644 --- a/README.md +++ b/README.md @@ -23,34 +23,34 @@ pip3 install ./ ### Generate keys ```python from sshkey_tools.keys import ( - RSAPrivateKey, - DSAPrivateKey, - ECDSAPrivateKey, - ED25519PrivateKey, + RsaPrivateKey, + DsaPrivateKey, + EcdsaPrivateKey, + Ed25519PrivateKey, EcdsaCurves ) # RSA # By default, RSA is generated with a 4096-bit keysize -rsa_private = RSAPrivateKey.generate() +rsa_private = RsaPrivateKey.generate() # You can also specify the key size -rsa_private = RSAPrivateKey.generate(bits) +rsa_private = RsaPrivateKey.generate(bits) # DSA # Since OpenSSH only supports 1024-bit keys, this is the default -dsa_private = DSAPrivateKey.generate() +dsa_private = DsaPrivateKey.generate() # ECDSA # The default curve is P521 -ecdsa_private = ECDSAPrivateKey.generate() +ecdsa_private = EcdsaPrivateKey.generate() # You can also manually specify a curve -ecdsa_private = ECDSAPrivateKey.generate(EcdsaCurves.P256) +ecdsa_private = EcdsaPrivateKey.generate(EcdsaCurves.P256) # ED25519 # The ED25519 keys are always a fixed size -ed25519_private = ED25519PrivateKey.generate() +ed25519_private = Ed25519PrivateKey.generate() # Public keys # The public key for any given private key is in the public_key parameter @@ -58,29 +58,29 @@ rsa_pub = rsa_private.public_key ``` ### Load keys -You can load keys either directly with the specific key classes (RSAPrivateKey, DSAPrivateKey, etc.) or the general PrivateKey class +You can load keys either directly with the specific key classes (RsaPrivateKey, DsaPrivateKey, etc.) or the general PrivateKey class ```python from sshkey_tools.keys import ( PrivateKey, PublicKey, - RSAPrivateKey, - RSAPublicKey + RsaPrivateKey, + RsaPublicKey ) # Load a private key with a specific class -rsa_private = RSAPrivateKey.from_file('path/to/rsa_key') +rsa_private = RsaPrivateKey.from_file('path/to/rsa_key') # Load a private key with the general class rsa_private = PrivateKey.from_file('path/to/rsa_key') print(type(rsa_private)) -"" +"" # Public keys can be loaded in the same way -rsa_pub = RSAPublicKey.from_file('path/to/rsa_key.pub') +rsa_pub = RsaPublicKey.from_file('path/to/rsa_key.pub') rsa_pub = PublicKey.from_file('path/to/rsa_key.pub') print(type(rsa_private)) -"" +"" # Public key objects are automatically created for any given private key # negating the need to load them separately @@ -98,12 +98,12 @@ with open('path/to/rsa_key', 'rb') as file: rsa_private = PrivateKey.from_bytes(file.read()) # RSA, DSA and ECDSA keys can be loaded from the public/private numbers and/or parameters -rsa_public = RSAPublicKey.from_numbers( +rsa_public = RsaPublicKey.from_numbers( e=65537, n=12.........811 ) -rsa_private = RSAPrivateKey.from_numbers( +rsa_private = RsaPrivateKey.from_numbers( e=65537, n=12......811, d=17......122 @@ -132,7 +132,7 @@ from cryptography.hazmat.primitives import ( hashes as crypto_hashes ) from cryptography.hazmat.primitives.asymmetric import padding as crypto_padding -from sshkey_tools.keys import PublicKey, RSAPrivateKey, RsaAlgs +from sshkey_tools.keys import PublicKey, RsaPrivateKey, RsaAlgs from sshkey_tools.cert import SSHCertificate from sshkey_tools.exceptions import SignatureNotPossibleException @@ -240,7 +240,7 @@ ca_pubkey.verify( RsaAlgs.SHA256.value[1] ) -# pyca/cryptography RSAPrivateKey +# pyca/cryptography RsaPrivateKey with open('path/to/ca_pubkey', 'rb') as file: crypto_ca_pubkey = crypto_serialization.load_ssh_public_key(file.read()) @@ -264,13 +264,13 @@ or the base64-decoded byte data of the certificate ```python from sshkey_tools.keys import PublicKey, PrivateKey -from sshkey_tools.cert import SSHCertificate, RSACertificate +from sshkey_tools.cert import SSHCertificate, RsaCertificate # Load an existing certificate certificate = SSHCertificate.from_file('path/to/user_key-cert.pub') # or -certificate = RSACertificate.from_file('path/to/user_key-cert.pub') +certificate = RsaCertificate.from_file('path/to/user_key-cert.pub') # Verify the certificate with a CA public key ca_pubkey = PublicKey.from_file('path/to/ca_key.pub') diff --git a/docs/cert.html b/docs/cert.html index 0df50ff..ce04276 100644 --- a/docs/cert.html +++ b/docs/cert.html @@ -57,10 +57,10 @@

Raises

from .keys import ( PublicKey, PrivateKey, - RSAPublicKey, - DSAPublicKey, - ECDSAPublicKey, - ED25519PublicKey, + RsaPublicKey, + DsaPublicKey, + EcdsaPublicKey, + Ed25519PublicKey, ) from . import fields as _FIELD from . import exceptions as _EX @@ -79,24 +79,24 @@

Raises

} CERT_TYPES = { - "ssh-rsa-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "rsa-sha2-256-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "rsa-sha2-512-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "ssh-dss-cert-v01@openssh.com": ("DSACertificate", "_FIELD.DSAPubkeyField"), + "ssh-rsa-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RSAPubkeyField"), + "rsa-sha2-256-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RSAPubkeyField"), + "rsa-sha2-512-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RSAPubkeyField"), + "ssh-dss-cert-v01@openssh.com": ("DsaCertificate", "_FIELD.DSAPubkeyField"), "ecdsa-sha2-nistp256-cert-v01@openssh.com": ( - "ECDSACertificate", + "EcdsaCertificate", "_FIELD.ECDSAPubkeyField", ), "ecdsa-sha2-nistp384-cert-v01@openssh.com": ( - "ECDSACertificate", + "EcdsaCertificate", "_FIELD.ECDSAPubkeyField", ), "ecdsa-sha2-nistp521-cert-v01@openssh.com": ( - "ECDSACertificate", + "EcdsaCertificate", "_FIELD.ECDSAPubkeyField", ), "ssh-ed25519-cert-v01@openssh.com": ( - "ED25519Certificate", + "Ed25519Certificate", "_FIELD.ED25519PubkeyField", ), } @@ -540,14 +540,14 @@

Raises

file.write(self.to_string(comment, encoding)) -class RSACertificate(SSHCertificate): +class RsaCertificate(SSHCertificate): """ Specific class for RSA Certificates. Inherits from SSHCertificate """ def __init__( self, - subject_pubkey: RSAPublicKey, + subject_pubkey: RsaPublicKey, ca_privkey: PrivateKey = None, rsa_alg: RsaAlgs = RsaAlgs.SHA512, **kwargs, @@ -567,25 +567,25 @@

Raises

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - RSACertificate: The decoded certificate + RsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.RSAPubkeyField) -class DSACertificate(SSHCertificate): +class DsaCertificate(SSHCertificate): """ Specific class for DSA/DSS Certificates. Inherits from SSHCertificate """ def __init__( - self, subject_pubkey: DSAPublicKey, ca_privkey: PrivateKey = None, **kwargs + self, subject_pubkey: DsaPublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) self.set_type("ssh-dss-cert-v01@openssh.com") @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "DSACertificate": + def decode(cls, cert_bytes: bytes) -> "DsaCertificate": """ Decode an existing DSA Certificate @@ -593,18 +593,18 @@

Raises

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - DSACertificate: The decoded certificate + DsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.DSAPubkeyField) -class ECDSACertificate(SSHCertificate): +class EcdsaCertificate(SSHCertificate): """ Specific class for ECDSA Certificates. Inherits from SSHCertificate """ def __init__( - self, subject_pubkey: ECDSAPublicKey, ca_privkey: PrivateKey = None, **kwargs + self, subject_pubkey: EcdsaPublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) self.set_type( @@ -613,7 +613,7 @@

Raises

@classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "ECDSACertificate": + def decode(cls, cert_bytes: bytes) -> "EcdsaCertificate": """ Decode an existing ECDSA Certificate @@ -621,25 +621,25 @@

Raises

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ECDSACertificate: The decoded certificate + EcdsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ECDSAPubkeyField) -class ED25519Certificate(SSHCertificate): +class Ed25519Certificate(SSHCertificate): """ Specific class for ED25519 Certificates. Inherits from SSHCertificate """ def __init__( - self, subject_pubkey: ED25519PublicKey, ca_privkey: PrivateKey = None, **kwargs + self, subject_pubkey: Ed25519PublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) self.set_type("ssh-ed25519-cert-v01@openssh.com") @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "ED25519Certificate": + def decode(cls, cert_bytes: bytes) -> "Ed25519Certificate": """ Decode an existing ED25519 Certificate @@ -647,7 +647,7 @@

Raises

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ED25519Certificate: The decoded certificate + Ed25519Certificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ED25519PubkeyField) @@ -661,9 +661,9 @@

Raises

Classes

-
-class DSACertificate -(subject_pubkey: DSAPublicKey, ca_privkey: PrivateKey = None, **kwargs) +
+class DsaCertificate +(subject_pubkey: DsaPublicKey, ca_privkey: PrivateKey = None, **kwargs)

Specific class for DSA/DSS Certificates. Inherits from SSHCertificate

@@ -671,20 +671,20 @@

Classes

Expand source code -
class DSACertificate(SSHCertificate):
+
class DsaCertificate(SSHCertificate):
     """
     Specific class for DSA/DSS Certificates. Inherits from SSHCertificate
     """
 
     def __init__(
-        self, subject_pubkey: DSAPublicKey, ca_privkey: PrivateKey = None, **kwargs
+        self, subject_pubkey: DsaPublicKey, ca_privkey: PrivateKey = None, **kwargs
     ):
         super().__init__(subject_pubkey, ca_privkey, **kwargs)
         self.set_type("ssh-dss-cert-v01@openssh.com")
 
     @classmethod
     # pylint: disable=arguments-differ
-    def decode(cls, cert_bytes: bytes) -> "DSACertificate":
+    def decode(cls, cert_bytes: bytes) -> "DsaCertificate":
         """
         Decode an existing DSA Certificate
 
@@ -692,7 +692,7 @@ 

Classes

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - DSACertificate: The decoded certificate + DsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.DSAPubkeyField)
@@ -702,8 +702,8 @@

Ancestors

Static methods

-
-def decode(cert_bytes: bytes) ‑> DSACertificate +
+def decode(cert_bytes: bytes) ‑> DsaCertificate

Decode an existing DSA Certificate

@@ -714,7 +714,7 @@

Args

Returns

-
DSACertificate
+
DsaCertificate
The decoded certificate
@@ -723,7 +723,7 @@

Returns

@classmethod
 # pylint: disable=arguments-differ
-def decode(cls, cert_bytes: bytes) -> "DSACertificate":
+def decode(cls, cert_bytes: bytes) -> "DsaCertificate":
     """
     Decode an existing DSA Certificate
 
@@ -731,7 +731,7 @@ 

Returns

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - DSACertificate: The decoded certificate + DsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.DSAPubkeyField)
@@ -760,9 +760,9 @@

Inherited members

-
-class ECDSACertificate -(subject_pubkey: ECDSAPublicKey, ca_privkey: PrivateKey = None, **kwargs) +
+class EcdsaCertificate +(subject_pubkey: EcdsaPublicKey, ca_privkey: PrivateKey = None, **kwargs)

Specific class for ECDSA Certificates. Inherits from SSHCertificate

@@ -770,13 +770,13 @@

Inherited members

Expand source code -
class ECDSACertificate(SSHCertificate):
+
class EcdsaCertificate(SSHCertificate):
     """
     Specific class for ECDSA Certificates. Inherits from SSHCertificate
     """
 
     def __init__(
-        self, subject_pubkey: ECDSAPublicKey, ca_privkey: PrivateKey = None, **kwargs
+        self, subject_pubkey: EcdsaPublicKey, ca_privkey: PrivateKey = None, **kwargs
     ):
         super().__init__(subject_pubkey, ca_privkey, **kwargs)
         self.set_type(
@@ -785,7 +785,7 @@ 

Inherited members

@classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "ECDSACertificate": + def decode(cls, cert_bytes: bytes) -> "EcdsaCertificate": """ Decode an existing ECDSA Certificate @@ -793,7 +793,7 @@

Inherited members

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ECDSACertificate: The decoded certificate + EcdsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ECDSAPubkeyField)
@@ -803,8 +803,8 @@

Ancestors

Static methods

-
-def decode(cert_bytes: bytes) ‑> ECDSACertificate +
+def decode(cert_bytes: bytes) ‑> EcdsaCertificate

Decode an existing ECDSA Certificate

@@ -815,7 +815,7 @@

Args

Returns

-
ECDSACertificate
+
EcdsaCertificate
The decoded certificate
@@ -824,7 +824,7 @@

Returns

@classmethod
 # pylint: disable=arguments-differ
-def decode(cls, cert_bytes: bytes) -> "ECDSACertificate":
+def decode(cls, cert_bytes: bytes) -> "EcdsaCertificate":
     """
     Decode an existing ECDSA Certificate
 
@@ -832,7 +832,7 @@ 

Returns

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ECDSACertificate: The decoded certificate + EcdsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ECDSAPubkeyField)
@@ -861,9 +861,9 @@

Inherited members

-
-class ED25519Certificate -(subject_pubkey: ED25519PublicKey, ca_privkey: PrivateKey = None, **kwargs) +
+class Ed25519Certificate +(subject_pubkey: Ed25519PublicKey, ca_privkey: PrivateKey = None, **kwargs)

Specific class for ED25519 Certificates. Inherits from SSHCertificate

@@ -871,20 +871,20 @@

Inherited members

Expand source code -
class ED25519Certificate(SSHCertificate):
+
class Ed25519Certificate(SSHCertificate):
     """
     Specific class for ED25519 Certificates. Inherits from SSHCertificate
     """
 
     def __init__(
-        self, subject_pubkey: ED25519PublicKey, ca_privkey: PrivateKey = None, **kwargs
+        self, subject_pubkey: Ed25519PublicKey, ca_privkey: PrivateKey = None, **kwargs
     ):
         super().__init__(subject_pubkey, ca_privkey, **kwargs)
         self.set_type("ssh-ed25519-cert-v01@openssh.com")
 
     @classmethod
     # pylint: disable=arguments-differ
-    def decode(cls, cert_bytes: bytes) -> "ED25519Certificate":
+    def decode(cls, cert_bytes: bytes) -> "Ed25519Certificate":
         """
         Decode an existing ED25519 Certificate
 
@@ -892,7 +892,7 @@ 

Inherited members

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ED25519Certificate: The decoded certificate + Ed25519Certificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ED25519PubkeyField)
@@ -902,8 +902,8 @@

Ancestors

Static methods

-
-def decode(cert_bytes: bytes) ‑> ED25519Certificate +
+def decode(cert_bytes: bytes) ‑> Ed25519Certificate

Decode an existing ED25519 Certificate

@@ -914,7 +914,7 @@

Args

Returns

-
ED25519Certificate
+
Ed25519Certificate
The decoded certificate
@@ -923,7 +923,7 @@

Returns

@classmethod
 # pylint: disable=arguments-differ
-def decode(cls, cert_bytes: bytes) -> "ED25519Certificate":
+def decode(cls, cert_bytes: bytes) -> "Ed25519Certificate":
     """
     Decode an existing ED25519 Certificate
 
@@ -931,7 +931,7 @@ 

Returns

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - ED25519Certificate: The decoded certificate + Ed25519Certificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.ED25519PubkeyField)
@@ -960,9 +960,9 @@

Inherited members

-
-class RSACertificate -(subject_pubkey: RSAPublicKey, ca_privkey: PrivateKey = None, rsa_alg: RsaAlgs = RsaAlgs.SHA512, **kwargs) +
+class RsaCertificate +(subject_pubkey: RsaPublicKey, ca_privkey: PrivateKey = None, rsa_alg: RsaAlgs = RsaAlgs.SHA512, **kwargs)

Specific class for RSA Certificates. Inherits from SSHCertificate

@@ -970,14 +970,14 @@

Inherited members

Expand source code -
class RSACertificate(SSHCertificate):
+
class RsaCertificate(SSHCertificate):
     """
     Specific class for RSA Certificates. Inherits from SSHCertificate
     """
 
     def __init__(
         self,
-        subject_pubkey: RSAPublicKey,
+        subject_pubkey: RsaPublicKey,
         ca_privkey: PrivateKey = None,
         rsa_alg: RsaAlgs = RsaAlgs.SHA512,
         **kwargs,
@@ -997,7 +997,7 @@ 

Inherited members

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - RSACertificate: The decoded certificate + RsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.RSAPubkeyField)
@@ -1007,7 +1007,7 @@

Ancestors

Static methods

-
+
def decode(cert_bytes: bytes) ‑> SSHCertificate
@@ -1019,7 +1019,7 @@

Args

Returns

-
RSACertificate
+
RsaCertificate
The decoded certificate
@@ -1036,7 +1036,7 @@

Returns

cert_bytes (bytes): The base64-decoded bytes for the certificate Returns: - RSACertificate: The decoded certificate + RsaCertificate: The decoded certificate """ return super().decode(cert_bytes, _FIELD.RSAPubkeyField)
@@ -1516,10 +1516,10 @@

Inherited members

Subclasses

Static methods

@@ -2190,27 +2190,27 @@

Index

  • Classes

    • -

      DSACertificate

      +

      DsaCertificate

    • -

      ECDSACertificate

      +

      EcdsaCertificate

    • -

      ED25519Certificate

      +

      Ed25519Certificate

    • -

      RSACertificate

      +

      RsaCertificate

    • diff --git a/docs/fields.html b/docs/fields.html index 6a2afd8..ef52535 100644 --- a/docs/fields.html +++ b/docs/fields.html @@ -45,14 +45,14 @@

      Module sshkey_tools.fields

      RsaAlgs, PrivateKey, PublicKey, - RSAPublicKey, - RSAPrivateKey, - DSAPublicKey, - DSAPrivateKey, - ECDSAPublicKey, - ECDSAPrivateKey, - ED25519PublicKey, - ED25519PrivateKey, + RsaPublicKey, + RsaPrivateKey, + DsaPublicKey, + DsaPrivateKey, + EcdsaPublicKey, + EcdsaPrivateKey, + Ed25519PublicKey, + Ed25519PrivateKey, ) from .utils import long_to_bytes, bytes_to_long, generate_secure_nonce @@ -68,17 +68,17 @@

      Module sshkey_tools.fields

      } SUBJECT_PUBKEY_MAP = { - RSAPublicKey: "RSAPubkeyField", - DSAPublicKey: "DSAPubkeyField", - ECDSAPublicKey: "ECDSAPubkeyField", - ED25519PublicKey: "ED25519PubkeyField", + RsaPublicKey: "RSAPubkeyField", + DsaPublicKey: "DSAPubkeyField", + EcdsaPublicKey: "ECDSAPubkeyField", + Ed25519PublicKey: "ED25519PubkeyField", } CA_SIGNATURE_MAP = { - RSAPrivateKey: "RSASignatureField", - DSAPrivateKey: "DSASignatureField", - ECDSAPrivateKey: "ECDSASignatureField", - ED25519PrivateKey: "ED25519SignatureField", + RsaPrivateKey: "RSASignatureField", + DsaPrivateKey: "DSASignatureField", + EcdsaPrivateKey: "ECDSASignatureField", + Ed25519PrivateKey: "ED25519SignatureField", } SIGNATURE_TYPE_MAP = { @@ -746,7 +746,7 @@

      Module sshkey_tools.fields

      Encode the certificate field to a byte string Args: - value (RSAPublicKey): The public key to encode + value (RsaPublicKey): The public key to encode Returns: bytes: A byte string with the encoded public key @@ -787,7 +787,7 @@

      Module sshkey_tools.fields

      """ @staticmethod - def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[RsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -796,18 +796,18 @@

      Module sshkey_tools.fields

      data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ e, data = MpIntegerField.decode(data) n, data = MpIntegerField.decode(data) - return RSAPublicKey.from_numbers(e=e, n=n), data + return RsaPublicKey.from_numbers(e=e, n=n), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid RSA Public Key """ - if not isinstance(self.value, RSAPublicKey): + if not isinstance(self.value, RsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -821,7 +821,7 @@

      Module sshkey_tools.fields

      """ @staticmethod - def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[DsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -830,20 +830,20 @@

      Module sshkey_tools.fields

      data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ p, data = MpIntegerField.decode(data) q, data = MpIntegerField.decode(data) g, data = MpIntegerField.decode(data) y, data = MpIntegerField.decode(data) - return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data + return DsaPublicKey.from_numbers(p=p, q=q, g=g, y=y), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid DSA Public Key """ - if not isinstance(self.value, DSAPublicKey): + if not isinstance(self.value, DsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -857,7 +857,7 @@

      Module sshkey_tools.fields

      """ @staticmethod - def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[EcdsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -874,7 +874,7 @@

      Module sshkey_tools.fields

      key_type = "ecdsa-sha2-" + curve return ( - ECDSAPublicKey.from_string( + EcdsaPublicKey.from_string( key_type + " " + b64encode( @@ -890,7 +890,7 @@

      Module sshkey_tools.fields

      """ Validates that the field data is a valid ECDSA Public Key """ - if not isinstance(self.value, ECDSAPublicKey): + if not isinstance(self.value, EcdsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -904,7 +904,7 @@

      Module sshkey_tools.fields

      """ @staticmethod - def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: + def decode(data: bytes) -> Tuple[Ed25519PublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -913,17 +913,17 @@

      Module sshkey_tools.fields

      data (bytes): The byte string starting with the encoded key Returns: - Tuple[ED25519PublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[Ed25519PublicKey, bytes]: The PublicKey field and remainder of the data """ pubkey, data = BytestringField.decode(data) - return ED25519PublicKey.from_raw_bytes(pubkey), data + return Ed25519PublicKey.from_raw_bytes(pubkey), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid ED25519 Public Key """ - if not isinstance(self.value, ED25519PublicKey): + if not isinstance(self.value, Ed25519PublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -1304,7 +1304,7 @@

      Module sshkey_tools.fields

      def __init__( self, - private_key: RSAPrivateKey = None, + private_key: RsaPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, signature: bytes = None, ): @@ -1401,7 +1401,7 @@

      Module sshkey_tools.fields

      """ def __init__( - self, private_key: DSAPrivateKey = None, signature: bytes = None + self, private_key: DsaPrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -1483,7 +1483,7 @@

      Module sshkey_tools.fields

      def __init__( self, - private_key: ECDSAPrivateKey = None, + private_key: EcdsaPrivateKey = None, signature: bytes = None, curve_name: str = None, ) -> None: @@ -1584,7 +1584,7 @@

      Module sshkey_tools.fields

      """ def __init__( - self, private_key: ED25519PrivateKey = None, signature: bytes = None + self, private_key: Ed25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -2668,7 +2668,7 @@

      Inherited members

      """ @staticmethod - def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[DsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -2677,20 +2677,20 @@

      Inherited members

      data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ p, data = MpIntegerField.decode(data) q, data = MpIntegerField.decode(data) g, data = MpIntegerField.decode(data) y, data = MpIntegerField.decode(data) - return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data + return DsaPublicKey.from_numbers(p=p, q=q, g=g, y=y), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid DSA Public Key """ - if not isinstance(self.value, DSAPublicKey): + if not isinstance(self.value, DsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -2705,7 +2705,7 @@

      Ancestors

      Static methods

      -def decode(data: bytes) ‑> Tuple[DSAPublicKey, bytes] +def decode(data: bytes) ‑> Tuple[DsaPublicKey, bytes]

      Decode the certificate field from a byte string @@ -2717,7 +2717,7 @@

      Args

      Returns

      -
      Tuple[RSAPublicKey, bytes]
      +
      Tuple[RsaPublicKey, bytes]
      The PublicKey field and remainder of the data
      @@ -2725,7 +2725,7 @@

      Returns

      Expand source code
      @staticmethod
      -def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]:
      +def decode(data: bytes) -> Tuple[DsaPublicKey, bytes]:
           """
           Decode the certificate field from a byte string
           starting with the encoded public key
      @@ -2734,14 +2734,14 @@ 

      Returns

      data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ p, data = MpIntegerField.decode(data) q, data = MpIntegerField.decode(data) g, data = MpIntegerField.decode(data) y, data = MpIntegerField.decode(data) - return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data
      + return DsaPublicKey.from_numbers(p=p, q=q, g=g, y=y), data
  • @@ -2760,7 +2760,7 @@

    Methods

    """ Validates that the field data is a valid DSA Public Key """ - if not isinstance(self.value, DSAPublicKey): + if not isinstance(self.value, DsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -2782,7 +2782,7 @@

    Inherited members

    class DSASignatureField -(private_key: DSAPrivateKey = None, signature: bytes = None) +(private_key: DsaPrivateKey = None, signature: bytes = None)

    Creates and contains the DSA signature from an DSA Private Key

    @@ -2796,7 +2796,7 @@

    Inherited members

    """ def __init__( - self, private_key: DSAPrivateKey = None, signature: bytes = None + self, private_key: DsaPrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -3212,7 +3212,7 @@

    Inherited members

    """ @staticmethod - def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[EcdsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -3229,7 +3229,7 @@

    Inherited members

    key_type = "ecdsa-sha2-" + curve return ( - ECDSAPublicKey.from_string( + EcdsaPublicKey.from_string( key_type + " " + b64encode( @@ -3245,7 +3245,7 @@

    Inherited members

    """ Validates that the field data is a valid ECDSA Public Key """ - if not isinstance(self.value, ECDSAPublicKey): + if not isinstance(self.value, EcdsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -3260,7 +3260,7 @@

    Ancestors

    Static methods

    -def decode(data: bytes) ‑> Tuple[ECDSAPublicKey, bytes] +def decode(data: bytes) ‑> Tuple[EcdsaPublicKey, bytes]

    Decode the certificate field from a byte string @@ -3280,7 +3280,7 @@

    Returns

    Expand source code
    @staticmethod
    -def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]:
    +def decode(data: bytes) -> Tuple[EcdsaPublicKey, bytes]:
         """
         Decode the certificate field from a byte string
         starting with the encoded public key
    @@ -3297,7 +3297,7 @@ 

    Returns

    key_type = "ecdsa-sha2-" + curve return ( - ECDSAPublicKey.from_string( + EcdsaPublicKey.from_string( key_type + " " + b64encode( @@ -3326,7 +3326,7 @@

    Methods

    """ Validates that the field data is a valid ECDSA Public Key """ - if not isinstance(self.value, ECDSAPublicKey): + if not isinstance(self.value, EcdsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -3348,7 +3348,7 @@

    Inherited members

    class ECDSASignatureField -(private_key: ECDSAPrivateKey = None, signature: bytes = None, curve_name: str = None) +(private_key: EcdsaPrivateKey = None, signature: bytes = None, curve_name: str = None)

    Creates and contains the ECDSA signature from an ECDSA Private Key

    @@ -3363,7 +3363,7 @@

    Inherited members

    def __init__( self, - private_key: ECDSAPrivateKey = None, + private_key: EcdsaPrivateKey = None, signature: bytes = None, curve_name: str = None, ) -> None: @@ -3652,7 +3652,7 @@

    Inherited members

    """ @staticmethod - def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: + def decode(data: bytes) -> Tuple[Ed25519PublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -3661,17 +3661,17 @@

    Inherited members

    data (bytes): The byte string starting with the encoded key Returns: - Tuple[ED25519PublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[Ed25519PublicKey, bytes]: The PublicKey field and remainder of the data """ pubkey, data = BytestringField.decode(data) - return ED25519PublicKey.from_raw_bytes(pubkey), data + return Ed25519PublicKey.from_raw_bytes(pubkey), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid ED25519 Public Key """ - if not isinstance(self.value, ED25519PublicKey): + if not isinstance(self.value, Ed25519PublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -3686,7 +3686,7 @@

    Ancestors

    Static methods

    -def decode(data: bytes) ‑> Tuple[ED25519PublicKey, bytes] +def decode(data: bytes) ‑> Tuple[Ed25519PublicKey, bytes]

    Decode the certificate field from a byte string @@ -3698,7 +3698,7 @@

    Args

    Returns

    -
    Tuple[ED25519PublicKey, bytes]
    +
    Tuple[Ed25519PublicKey, bytes]
    The PublicKey field and remainder of the data
    @@ -3706,7 +3706,7 @@

    Returns

    Expand source code
    @staticmethod
    -def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]:
    +def decode(data: bytes) -> Tuple[Ed25519PublicKey, bytes]:
         """
         Decode the certificate field from a byte string
         starting with the encoded public key
    @@ -3715,11 +3715,11 @@ 

    Returns

    data (bytes): The byte string starting with the encoded key Returns: - Tuple[ED25519PublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[Ed25519PublicKey, bytes]: The PublicKey field and remainder of the data """ pubkey, data = BytestringField.decode(data) - return ED25519PublicKey.from_raw_bytes(pubkey), data
    + return Ed25519PublicKey.from_raw_bytes(pubkey), data
    @@ -3738,7 +3738,7 @@

    Methods

    """ Validates that the field data is a valid ED25519 Public Key """ - if not isinstance(self.value, ED25519PublicKey): + if not isinstance(self.value, Ed25519PublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -3760,7 +3760,7 @@

    Inherited members

    class ED25519SignatureField -(private_key: ED25519PrivateKey = None, signature: bytes = None) +(private_key: Ed25519PrivateKey = None, signature: bytes = None)

    Creates and contains the ED25519 signature from an ED25519 Private Key

    @@ -3774,7 +3774,7 @@

    Inherited members

    """ def __init__( - self, private_key: ED25519PrivateKey = None, signature: bytes = None + self, private_key: Ed25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -5389,7 +5389,7 @@

    Inherited members

    Encode the certificate field to a byte string Args: - value (RSAPublicKey): The public key to encode + value (RsaPublicKey): The public key to encode Returns: bytes: A byte string with the encoded public key @@ -5443,7 +5443,7 @@

    Static methods

    Encode the certificate field to a byte string

    Args

    -
    value : RSAPublicKey
    +
    value : RsaPublicKey
    The public key to encode

    Returns

    @@ -5461,7 +5461,7 @@

    Returns

    Encode the certificate field to a byte string Args: - value (RSAPublicKey): The public key to encode + value (RsaPublicKey): The public key to encode Returns: bytes: A byte string with the encoded public key @@ -5552,7 +5552,7 @@

    Inherited members

    """ @staticmethod - def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[RsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -5561,18 +5561,18 @@

    Inherited members

    data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ e, data = MpIntegerField.decode(data) n, data = MpIntegerField.decode(data) - return RSAPublicKey.from_numbers(e=e, n=n), data + return RsaPublicKey.from_numbers(e=e, n=n), data def validate(self) -> Union[bool, Exception]: """ Validates that the field data is a valid RSA Public Key """ - if not isinstance(self.value, RSAPublicKey): + if not isinstance(self.value, RsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -5587,7 +5587,7 @@

    Ancestors

    Static methods

    -def decode(data: bytes) ‑> Tuple[RSAPublicKey, bytes] +def decode(data: bytes) ‑> Tuple[RsaPublicKey, bytes]

    Decode the certificate field from a byte string @@ -5599,7 +5599,7 @@

    Args

    Returns

    -
    Tuple[RSAPublicKey, bytes]
    +
    Tuple[RsaPublicKey, bytes]
    The PublicKey field and remainder of the data
    @@ -5607,7 +5607,7 @@

    Returns

    Expand source code
    @staticmethod
    -def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]:
    +def decode(data: bytes) -> Tuple[RsaPublicKey, bytes]:
         """
         Decode the certificate field from a byte string
         starting with the encoded public key
    @@ -5616,12 +5616,12 @@ 

    Returns

    data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ e, data = MpIntegerField.decode(data) n, data = MpIntegerField.decode(data) - return RSAPublicKey.from_numbers(e=e, n=n), data
    + return RsaPublicKey.from_numbers(e=e, n=n), data
    @@ -5640,7 +5640,7 @@

    Methods

    """ Validates that the field data is a valid RSA Public Key """ - if not isinstance(self.value, RSAPublicKey): + if not isinstance(self.value, RsaPublicKey): return _EX.InvalidFieldDataException( "This public key class is not valid for use in a certificate" ) @@ -5662,7 +5662,7 @@

    Inherited members

    class RSASignatureField -(private_key: RSAPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, signature: bytes = None) +(private_key: RsaPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, signature: bytes = None)

    Creates and contains the RSA signature from an RSA Private Key

    @@ -5677,7 +5677,7 @@

    Inherited members

    def __init__( self, - private_key: RSAPrivateKey = None, + private_key: RsaPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, signature: bytes = None, ): diff --git a/docs/index.html b/docs/index.html index 7089526..ac0c918 100644 --- a/docs/index.html +++ b/docs/index.html @@ -42,62 +42,62 @@

    Basic usage

    SSH Keypairs

    Generate keys

    from sshkey_tools.keys import (
    -    RSAPrivateKey,
    -    DSAPrivateKey,
    -    ECDSAPrivateKey,
    -    ED25519PrivateKey,
    +    RsaPrivateKey,
    +    DsaPrivateKey,
    +    EcdsaPrivateKey,
    +    Ed25519PrivateKey,
         EcdsaCurves
     )
     
     # RSA
     # By default, RSA is generated with a 4096-bit keysize
    -rsa_private = RSAPrivateKey.generate()
    +rsa_private = RsaPrivateKey.generate()
     
     # You can also specify the key size
    -rsa_private = RSAPrivateKey.generate(bits)
    +rsa_private = RsaPrivateKey.generate(bits)
     
     # DSA
     # Since OpenSSH only supports 1024-bit keys, this is the default
    -dsa_private = DSAPrivateKey.generate()
    +dsa_private = DsaPrivateKey.generate()
     
     # ECDSA
     # The default curve is P521
    -ecdsa_private = ECDSAPrivateKey.generate()
    +ecdsa_private = EcdsaPrivateKey.generate()
     
     # You can also manually specify a curve
    -ecdsa_private = ECDSAPrivateKey.generate(EcdsaCurves.P256)
    +ecdsa_private = EcdsaPrivateKey.generate(EcdsaCurves.P256)
     
     # ED25519
     # The ED25519 keys are always a fixed size
    -ed25519_private = ED25519PrivateKey.generate()
    +ed25519_private = Ed25519PrivateKey.generate()
     
     # Public keys
     # The public key for any given private key is in the public_key parameter
     rsa_pub = rsa_private.public_key
     

    Load keys

    -

    You can load keys either directly with the specific key classes (RSAPrivateKey, DSAPrivateKey, etc.) or the general PrivateKey class

    +

    You can load keys either directly with the specific key classes (RsaPrivateKey, DsaPrivateKey, etc.) or the general PrivateKey class

    from sshkey_tools.keys import (
         PrivateKey,
         PublicKey,
    -    RSAPrivateKey,
    -    RSAPublicKey
    +    RsaPrivateKey,
    +    RsaPublicKey
     )
     
     # Load a private key with a specific class
    -rsa_private = RSAPrivateKey.from_file('path/to/rsa_key')
    +rsa_private = RsaPrivateKey.from_file('path/to/rsa_key')
     
     # Load a private key with the general class
     rsa_private = PrivateKey.from_file('path/to/rsa_key')
     print(type(rsa_private))
    -"<class 'sshkey_tools.keys.RSAPrivateKey'>"
    +"<class 'sshkey_tools.keys.RsaPrivateKey'>"
     
     # Public keys can be loaded in the same way
    -rsa_pub = RSAPublicKey.from_file('path/to/rsa_key.pub')
    +rsa_pub = RsaPublicKey.from_file('path/to/rsa_key.pub')
     rsa_pub = PublicKey.from_file('path/to/rsa_key.pub')
     
     print(type(rsa_private))
    -"<class 'sshkey_tools.keys.RSAPrivateKey'>"
    +"<class 'sshkey_tools.keys.RsaPrivateKey'>"
     
     # Public key objects are automatically created for any given private key
     # negating the need to load them separately
    @@ -115,12 +115,12 @@ 

    Load keys

    rsa_private = PrivateKey.from_bytes(file.read()) # RSA, DSA and ECDSA keys can be loaded from the public/private numbers and/or parameters -rsa_public = RSAPublicKey.from_numbers( +rsa_public = RsaPublicKey.from_numbers( e=65537, n=12.........811 ) -rsa_private = RSAPrivateKey.from_numbers( +rsa_private = RsaPrivateKey.from_numbers( e=65537, n=12......811, d=17......122 @@ -146,7 +146,7 @@

    Certificate creation

    hashes as crypto_hashes ) from cryptography.hazmat.primitives.asymmetric import padding as crypto_padding -from sshkey_tools.keys import PublicKey, RSAPrivateKey, RsaAlgs +from sshkey_tools.keys import PublicKey, RsaPrivateKey, RsaAlgs from sshkey_tools.cert import SSHCertificate from sshkey_tools.exceptions import SignatureNotPossibleException @@ -254,7 +254,7 @@

    Certificate creation

    RsaAlgs.SHA256.value[1] ) -# pyca/cryptography RSAPrivateKey +# pyca/cryptography RsaPrivateKey with open('path/to/ca_pubkey', 'rb') as file: crypto_ca_pubkey = crypto_serialization.load_ssh_public_key(file.read()) @@ -275,13 +275,13 @@

    Load an existing certificate

    Certificates can be loaded from file, a string/bytestring with file contents or the base64-decoded byte data of the certificate

    from sshkey_tools.keys import PublicKey, PrivateKey
    -from sshkey_tools.cert import SSHCertificate, RSACertificate
    +from sshkey_tools.cert import SSHCertificate, RsaCertificate
     
     # Load an existing certificate
     certificate = SSHCertificate.from_file('path/to/user_key-cert.pub')
     
     # or
    -certificate = RSACertificate.from_file('path/to/user_key-cert.pub')
    +certificate = RsaCertificate.from_file('path/to/user_key-cert.pub')
     
     # Verify the certificate with a CA public key
     ca_pubkey = PublicKey.from_file('path/to/ca_key.pub')
    diff --git a/docs/keys.html b/docs/keys.html
    index 24c64ab..125dbba 100644
    --- a/docs/keys.html
    +++ b/docs/keys.html
    @@ -34,8 +34,8 @@ 

    Module sshkey_tools.keys

    from enum import Enum from base64 import b64decode from struct import unpack -from cryptography.hazmat.backends.openssl.rsa import _RSAPublicKey, _RSAPrivateKey -from cryptography.hazmat.backends.openssl.dsa import _DSAPublicKey, _DSAPrivateKey +from cryptography.hazmat.backends.openssl.rsa import _RsaPublicKey, _RsaPrivateKey +from cryptography.hazmat.backends.openssl.dsa import _DsaPublicKey, _DsaPrivateKey from cryptography.hazmat.backends.openssl.ed25519 import ( _Ed25519PublicKey, _Ed25519PrivateKey, @@ -66,17 +66,17 @@

    Module sshkey_tools.keys

    ) PUBKEY_MAP = { - _RSAPublicKey: "RSAPublicKey", - _DSAPublicKey: "DSAPublicKey", - _EllipticCurvePublicKey: "ECDSAPublicKey", - _Ed25519PublicKey: "ED25519PublicKey", + _RsaPublicKey: "RsaPublicKey", + _DsaPublicKey: "DsaPublicKey", + _EllipticCurvePublicKey: "EcdsaPublicKey", + _Ed25519PublicKey: "Ed25519PublicKey", } PRIVKEY_MAP = { - _RSAPrivateKey: "RSAPrivateKey", - _DSAPrivateKey: "DSAPrivateKey", - _EllipticCurvePrivateKey: "ECDSAPrivateKey", - _Ed25519PrivateKey: "ED25519PrivateKey", + _RsaPrivateKey: "RsaPrivateKey", + _DsaPrivateKey: "DsaPrivateKey", + _EllipticCurvePrivateKey: "EcdsaPrivateKey", + _Ed25519PrivateKey: "Ed25519PrivateKey", } ECDSA_HASHES = { @@ -86,15 +86,15 @@

    Module sshkey_tools.keys

    } PubkeyClasses = Union[ - _RSA.RSAPublicKey, - _DSA.DSAPublicKey, + _RSA.RsaPublicKey, + _DSA.DsaPublicKey, _ECDSA.EllipticCurvePublicKey, _ED25519.Ed25519PublicKey, ] PrivkeyClasses = Union[ - _RSA.RSAPrivateKey, - _DSA.DSAPrivateKey, + _RSA.RsaPrivateKey, + _DSA.DsaPrivateKey, _ECDSA.EllipticCurvePrivateKey, _ED25519.Ed25519PrivateKey, ] @@ -446,14 +446,14 @@

    Module sshkey_tools.keys

    key_file.write(self.to_string(password, encoding)) -class RSAPublicKey(PublicKey): +class RsaPublicKey(PublicKey): """ Class for holding RSA public keys """ def __init__( self, - key: _RSA.RSAPublicKey, + key: _RSA.RsaPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None, @@ -468,7 +468,7 @@

    Module sshkey_tools.keys

    @classmethod # pylint: disable=invalid-name - def from_numbers(cls, e: int, n: int) -> "RSAPublicKey": + def from_numbers(cls, e: int, n: int) -> "RsaPublicKey": """ Loads an RSA Public Key from the public numbers e and n @@ -477,7 +477,7 @@

    Module sshkey_tools.keys

    n (int): n-value Returns: - RSAPublicKey: _description_ + RsaPublicKey: _description_ """ return cls(key=_RSA.RSAPublicNumbers(e, n).public_key()) @@ -505,15 +505,15 @@

    Module sshkey_tools.keys

    ) from InvalidSignature -class RSAPrivateKey(PrivateKey): +class RsaPrivateKey(PrivateKey): """ Class for holding RSA private keys """ - def __init__(self, key: _RSA.RSAPrivateKey): + def __init__(self, key: _RSA.RsaPrivateKey): super().__init__( key=key, - public_key=RSAPublicKey(key.public_key()), + public_key=RsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @@ -529,7 +529,7 @@

    Module sshkey_tools.keys

    dmp1: int = None, dmq1: int = None, iqmp: int = None, - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Load an RSA private key from numbers @@ -552,7 +552,7 @@

    Module sshkey_tools.keys

    Automatically generates if not provided Returns: - RSAPrivateKey: An instance of RSAPrivateKey + RsaPrivateKey: An instance of RsaPrivateKey """ if None in (p, q): p, q = _RSA.rsa_recover_prime_factors(n, e, d) @@ -576,7 +576,7 @@

    Module sshkey_tools.keys

    @classmethod def generate( cls, key_size: int = 4096, public_exponent: int = 65537 - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Generates a new RSA private key @@ -585,7 +585,7 @@

    Module sshkey_tools.keys

    public_exponent (int, optional): The public exponent to use. Defaults to 65537. Returns: - RSAPrivateKey: Instance of RSAPrivateKey + RsaPrivateKey: Instance of RsaPrivateKey """ return cls.from_class( _RSA.generate_private_key( @@ -608,14 +608,14 @@

    Module sshkey_tools.keys

    return self.key.sign(data, _PADDING.PKCS1v15(), hash_alg.value[1]()) -class DSAPublicKey(PublicKey): +class DsaPublicKey(PublicKey): """ Class for holding DSA public keys """ def __init__( self, - key: _DSA.DSAPublicKey, + key: _DSA.DsaPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None, @@ -631,7 +631,7 @@

    Module sshkey_tools.keys

    @classmethod # pylint: disable=invalid-name - def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey": + def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DsaPublicKey": """ Create a DSA public key from public numbers and parameters @@ -642,7 +642,7 @@

    Module sshkey_tools.keys

    y (int): The public number Y Returns: - DSAPublicKey: An instance of DSAPublicKey + DsaPublicKey: An instance of DsaPublicKey """ return cls( key=_DSA.DSAPublicNumbers( @@ -669,23 +669,23 @@

    Module sshkey_tools.keys

    ) from InvalidSignature -class DSAPrivateKey(PrivateKey): +class DsaPrivateKey(PrivateKey): """ Class for holding DSA private keys """ - def __init__(self, key: _DSA.DSAPrivateKey): + def __init__(self, key: _DSA.DsaPrivateKey): super().__init__( key=key, - public_key=DSAPublicKey(key.public_key()), + public_key=DsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @classmethod # pylint: disable=invalid-name,too-many-arguments - def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey": + def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DsaPrivateKey": """ - Creates a new DSAPrivateKey object from parameters and public/private numbers + Creates a new DsaPrivateKey object from parameters and public/private numbers Args: p (int): P parameter, the prime modulus @@ -707,13 +707,13 @@

    Module sshkey_tools.keys

    ) @classmethod - def generate(cls) -> "DSAPrivateKey": + def generate(cls) -> "DsaPrivateKey": """ Generate a new DSA private key Key size is fixed since OpenSSH only supports 1024-bit DSA keys Returns: - DSAPrivateKey: An instance of DSAPrivateKey + DsaPrivateKey: An instance of DsaPrivateKey """ return cls.from_class(_DSA.generate_private_key(key_size=1024)) @@ -730,7 +730,7 @@

    Module sshkey_tools.keys

    return self.key.sign(data, _HASHES.SHA1()) -class ECDSAPublicKey(PublicKey): +class EcdsaPublicKey(PublicKey): """ Class for holding ECDSA public keys """ @@ -754,7 +754,7 @@

    Module sshkey_tools.keys

    # pylint: disable=invalid-name def from_numbers( cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int - ) -> "ECDSAPublicKey": + ) -> "EcdsaPublicKey": """ Create an ECDSA public key from public numbers and parameters @@ -764,7 +764,7 @@

    Module sshkey_tools.keys

    y (int): The affine Y component of the public point Returns: - ECDSAPublicKey: An instance of ECDSAPublicKey + EcdsaPublicKey: An instance of EcdsaPublicKey """ if not isinstance(curve, _ECDSA.EllipticCurve) and curve not in ECDSA_HASHES: raise _EX.InvalidCurveException( @@ -802,7 +802,7 @@

    Module sshkey_tools.keys

    ) from InvalidSignature -class ECDSAPrivateKey(PrivateKey): +class EcdsaPrivateKey(PrivateKey): """ Class for holding ECDSA private keys """ @@ -810,7 +810,7 @@

    Module sshkey_tools.keys

    def __init__(self, key: _ECDSA.EllipticCurvePrivateKey): super().__init__( key=key, - public_key=ECDSAPublicKey(key.public_key()), + public_key=EcdsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @@ -820,7 +820,7 @@

    Module sshkey_tools.keys

    cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int, private_value: int ): """ - Creates a new ECDSAPrivateKey object from parameters and public/private numbers + Creates a new EcdsaPrivateKey object from parameters and public/private numbers Args: curve Union[str, _ECDSA.EllipticCurve]: Curve used by the key @@ -859,7 +859,7 @@

    Module sshkey_tools.keys

    curve (EcdsaCurves): Which curve to use. Default secp521r1 Returns: - ECDSAPrivateKey: An instance of ECDSAPrivateKey + EcdsaPrivateKey: An instance of EcdsaPrivateKey """ return cls.from_class(_ECDSA.generate_private_key(curve=curve.value)) @@ -877,7 +877,7 @@

    Module sshkey_tools.keys

    return self.key.sign(data, _ECDSA.ECDSA(curve_hash)) -class ED25519PublicKey(PublicKey): +class Ed25519PublicKey(PublicKey): """ Class for holding ED25519 public keys """ @@ -894,7 +894,7 @@

    Module sshkey_tools.keys

    ) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey": + def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PublicKey": """ Load an ED25519 public key from raw bytes @@ -902,7 +902,7 @@

    Module sshkey_tools.keys

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PublicKey: Instance of ED25519PublicKey + Ed25519PublicKey: Instance of Ed25519PublicKey """ if b"ssh-ed25519" in raw_bytes: id_length = unpack(">I", raw_bytes[:4])[0] + 8 @@ -931,16 +931,16 @@

    Module sshkey_tools.keys

    ) from InvalidSignature -class ED25519PrivateKey(PrivateKey): +class Ed25519PrivateKey(PrivateKey): """ Class for holding ED25519 private keys """ def __init__(self, key: _ED25519.Ed25519PrivateKey): - super().__init__(key=key, public_key=ED25519PublicKey(key.public_key())) + super().__init__(key=key, public_key=Ed25519PublicKey(key.public_key())) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey": + def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PrivateKey": """ Load an ED25519 private key from raw bytes @@ -948,19 +948,19 @@

    Module sshkey_tools.keys

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class( _ED25519.Ed25519PrivateKey.from_private_bytes(data=raw_bytes) ) @classmethod - def generate(cls) -> "ED25519PrivateKey": + def generate(cls) -> "Ed25519PrivateKey": """ Generates a new ED25519 Private Key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class(_ED25519.Ed25519PrivateKey.generate()) @@ -999,9 +999,9 @@

    Module sshkey_tools.keys

    Classes

    -
    -class DSAPrivateKey -(key: cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey) +
    +class DsaPrivateKey +(key: cryptography.hazmat.primitives.asymmetric.dsa.DsaPrivateKey)

    Class for holding DSA private keys

    @@ -1009,23 +1009,23 @@

    Classes

    Expand source code -
    class DSAPrivateKey(PrivateKey):
    +
    class DsaPrivateKey(PrivateKey):
         """
         Class for holding DSA private keys
         """
     
    -    def __init__(self, key: _DSA.DSAPrivateKey):
    +    def __init__(self, key: _DSA.DsaPrivateKey):
             super().__init__(
                 key=key,
    -            public_key=DSAPublicKey(key.public_key()),
    +            public_key=DsaPublicKey(key.public_key()),
                 private_numbers=key.private_numbers(),
             )
     
         @classmethod
         # pylint: disable=invalid-name,too-many-arguments
    -    def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey":
    +    def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DsaPrivateKey":
             """
    -        Creates a new DSAPrivateKey object from parameters and public/private numbers
    +        Creates a new DsaPrivateKey object from parameters and public/private numbers
     
             Args:
                 p (int): P parameter, the prime modulus
    @@ -1047,13 +1047,13 @@ 

    Classes

    ) @classmethod - def generate(cls) -> "DSAPrivateKey": + def generate(cls) -> "DsaPrivateKey": """ Generate a new DSA private key Key size is fixed since OpenSSH only supports 1024-bit DSA keys Returns: - DSAPrivateKey: An instance of DSAPrivateKey + DsaPrivateKey: An instance of DsaPrivateKey """ return cls.from_class(_DSA.generate_private_key(key_size=1024)) @@ -1075,11 +1075,11 @@

    Ancestors

    Static methods

    -
    -def from_numbers(p: int, q: int, g: int, y: int, x: int) ‑> DSAPrivateKey +
    +def from_numbers(p: int, q: int, g: int, y: int, x: int) ‑> DsaPrivateKey
    -

    Creates a new DSAPrivateKey object from parameters and public/private numbers

    +

    Creates a new DsaPrivateKey object from parameters and public/private numbers

    Args

    p : int
    @@ -1104,9 +1104,9 @@

    Returns

    @classmethod
     # pylint: disable=invalid-name,too-many-arguments
    -def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey":
    +def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DsaPrivateKey":
         """
    -    Creates a new DSAPrivateKey object from parameters and public/private numbers
    +    Creates a new DsaPrivateKey object from parameters and public/private numbers
     
         Args:
             p (int): P parameter, the prime modulus
    @@ -1128,29 +1128,29 @@ 

    Returns

    )
    -
    -def generate() ‑> DSAPrivateKey +
    +def generate() ‑> DsaPrivateKey

    Generate a new DSA private key Key size is fixed since OpenSSH only supports 1024-bit DSA keys

    Returns

    -
    DSAPrivateKey
    -
    An instance of DSAPrivateKey
    +
    DsaPrivateKey
    +
    An instance of DsaPrivateKey
    Expand source code
    @classmethod
    -def generate(cls) -> "DSAPrivateKey":
    +def generate(cls) -> "DsaPrivateKey":
         """
         Generate a new DSA private key
         Key size is fixed since OpenSSH only supports 1024-bit DSA keys
     
         Returns:
    -        DSAPrivateKey: An instance of DSAPrivateKey
    +        DsaPrivateKey: An instance of DsaPrivateKey
         """
         return cls.from_class(_DSA.generate_private_key(key_size=1024))
    @@ -1158,7 +1158,7 @@

    Returns

    Methods

    -
    +
    def sign(self, data: bytes)
    @@ -1206,9 +1206,9 @@

    Inherited members

    -
    -class DSAPublicKey -(key: cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None) +
    +class DsaPublicKey +(key: cryptography.hazmat.primitives.asymmetric.dsa.DsaPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None)

    Class for holding DSA public keys

    @@ -1216,14 +1216,14 @@

    Inherited members

    Expand source code -
    class DSAPublicKey(PublicKey):
    +
    class DsaPublicKey(PublicKey):
         """
         Class for holding DSA public keys
         """
     
         def __init__(
             self,
    -        key: _DSA.DSAPublicKey,
    +        key: _DSA.DsaPublicKey,
             comment: Union[str, bytes] = None,
             key_type: Union[str, bytes] = None,
             serialized: bytes = None,
    @@ -1239,7 +1239,7 @@ 

    Inherited members

    @classmethod # pylint: disable=invalid-name - def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey": + def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DsaPublicKey": """ Create a DSA public key from public numbers and parameters @@ -1250,7 +1250,7 @@

    Inherited members

    y (int): The public number Y Returns: - DSAPublicKey: An instance of DSAPublicKey + DsaPublicKey: An instance of DsaPublicKey """ return cls( key=_DSA.DSAPublicNumbers( @@ -1282,8 +1282,8 @@

    Ancestors

    Static methods

    -
    -def from_numbers(p: int, q: int, g: int, y: int) ‑> DSAPublicKey +
    +def from_numbers(p: int, q: int, g: int, y: int) ‑> DsaPublicKey

    Create a DSA public key from public numbers and parameters

    @@ -1300,8 +1300,8 @@

    Args

    Returns

    -
    DSAPublicKey
    -
    An instance of DSAPublicKey
    +
    DsaPublicKey
    +
    An instance of DsaPublicKey
    @@ -1309,7 +1309,7 @@

    Returns

    @classmethod
     # pylint: disable=invalid-name
    -def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey":
    +def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DsaPublicKey":
         """
         Create a DSA public key from public numbers and parameters
     
    @@ -1320,7 +1320,7 @@ 

    Returns

    y (int): The public number Y Returns: - DSAPublicKey: An instance of DSAPublicKey + DsaPublicKey: An instance of DsaPublicKey """ return cls( key=_DSA.DSAPublicNumbers( @@ -1332,7 +1332,7 @@

    Returns

    Methods

    -
    +
    def verify(self, data: bytes, signature: bytes) ‑> None
    @@ -1386,8 +1386,8 @@

    Inherited members

    -
    -class ECDSAPrivateKey +
    +class EcdsaPrivateKey (key: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey)
    @@ -1396,7 +1396,7 @@

    Inherited members

    Expand source code -
    class ECDSAPrivateKey(PrivateKey):
    +
    class EcdsaPrivateKey(PrivateKey):
         """
         Class for holding ECDSA private keys
         """
    @@ -1404,7 +1404,7 @@ 

    Inherited members

    def __init__(self, key: _ECDSA.EllipticCurvePrivateKey): super().__init__( key=key, - public_key=ECDSAPublicKey(key.public_key()), + public_key=EcdsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @@ -1414,7 +1414,7 @@

    Inherited members

    cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int, private_value: int ): """ - Creates a new ECDSAPrivateKey object from parameters and public/private numbers + Creates a new EcdsaPrivateKey object from parameters and public/private numbers Args: curve Union[str, _ECDSA.EllipticCurve]: Curve used by the key @@ -1453,7 +1453,7 @@

    Inherited members

    curve (EcdsaCurves): Which curve to use. Default secp521r1 Returns: - ECDSAPrivateKey: An instance of ECDSAPrivateKey + EcdsaPrivateKey: An instance of EcdsaPrivateKey """ return cls.from_class(_ECDSA.generate_private_key(curve=curve.value)) @@ -1476,11 +1476,11 @@

    Ancestors

    Static methods

    -
    +
    def from_numbers(curve: Union[str, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve], x: int, y: int, private_value: int)
    -

    Creates a new ECDSAPrivateKey object from parameters and public/private numbers

    +

    Creates a new EcdsaPrivateKey object from parameters and public/private numbers

    Args

    curve Union[str, _ECDSA.EllipticCurve]: Curve used by the key
    @@ -1506,7 +1506,7 @@

    Returns

    cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int, private_value: int ): """ - Creates a new ECDSAPrivateKey object from parameters and public/private numbers + Creates a new EcdsaPrivateKey object from parameters and public/private numbers Args: curve Union[str, _ECDSA.EllipticCurve]: Curve used by the key @@ -1537,7 +1537,7 @@

    Returns

    )
    -
    +
    def generate(curve: EcdsaCurves = EcdsaCurves.P521)
    @@ -1549,8 +1549,8 @@

    Args

    Returns

    -
    ECDSAPrivateKey
    -
    An instance of ECDSAPrivateKey
    +
    EcdsaPrivateKey
    +
    An instance of EcdsaPrivateKey
    @@ -1565,7 +1565,7 @@

    Returns

    curve (EcdsaCurves): Which curve to use. Default secp521r1 Returns: - ECDSAPrivateKey: An instance of ECDSAPrivateKey + EcdsaPrivateKey: An instance of EcdsaPrivateKey """ return cls.from_class(_ECDSA.generate_private_key(curve=curve.value))
    @@ -1573,7 +1573,7 @@

    Returns

    Methods

    -
    +
    def sign(self, data: bytes)
    @@ -1622,8 +1622,8 @@

    Inherited members

    -
    -class ECDSAPublicKey +
    +class EcdsaPublicKey (key: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None)
    @@ -1632,7 +1632,7 @@

    Inherited members

    Expand source code -
    class ECDSAPublicKey(PublicKey):
    +
    class EcdsaPublicKey(PublicKey):
         """
         Class for holding ECDSA public keys
         """
    @@ -1656,7 +1656,7 @@ 

    Inherited members

    # pylint: disable=invalid-name def from_numbers( cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int - ) -> "ECDSAPublicKey": + ) -> "EcdsaPublicKey": """ Create an ECDSA public key from public numbers and parameters @@ -1666,7 +1666,7 @@

    Inherited members

    y (int): The affine Y component of the public point Returns: - ECDSAPublicKey: An instance of ECDSAPublicKey + EcdsaPublicKey: An instance of EcdsaPublicKey """ if not isinstance(curve, _ECDSA.EllipticCurve) and curve not in ECDSA_HASHES: raise _EX.InvalidCurveException( @@ -1709,8 +1709,8 @@

    Ancestors

    Static methods

    -
    -def from_numbers(curve: Union[str, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve], x: int, y: int) ‑> ECDSAPublicKey +
    +def from_numbers(curve: Union[str, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve], x: int, y: int) ‑> EcdsaPublicKey

    Create an ECDSA public key from public numbers and parameters

    @@ -1724,8 +1724,8 @@

    Args

    Returns

    -
    ECDSAPublicKey
    -
    An instance of ECDSAPublicKey
    +
    EcdsaPublicKey
    +
    An instance of EcdsaPublicKey
    @@ -1735,7 +1735,7 @@

    Returns

    # pylint: disable=invalid-name def from_numbers( cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int -) -> "ECDSAPublicKey": +) -> "EcdsaPublicKey": """ Create an ECDSA public key from public numbers and parameters @@ -1745,7 +1745,7 @@

    Returns

    y (int): The affine Y component of the public point Returns: - ECDSAPublicKey: An instance of ECDSAPublicKey + EcdsaPublicKey: An instance of EcdsaPublicKey """ if not isinstance(curve, _ECDSA.EllipticCurve) and curve not in ECDSA_HASHES: raise _EX.InvalidCurveException( @@ -1767,7 +1767,7 @@

    Returns

    Methods

    -
    +
    def verify(self, data: bytes, signature: bytes) ‑> None
    @@ -1822,8 +1822,8 @@

    Inherited members

    -
    -class ED25519PrivateKey +
    +class Ed25519PrivateKey (key: cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey)
    @@ -1832,16 +1832,16 @@

    Inherited members

    Expand source code -
    class ED25519PrivateKey(PrivateKey):
    +
    class Ed25519PrivateKey(PrivateKey):
         """
         Class for holding ED25519 private keys
         """
     
         def __init__(self, key: _ED25519.Ed25519PrivateKey):
    -        super().__init__(key=key, public_key=ED25519PublicKey(key.public_key()))
    +        super().__init__(key=key, public_key=Ed25519PublicKey(key.public_key()))
     
         @classmethod
    -    def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey":
    +    def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PrivateKey":
             """
             Load an ED25519 private key from raw bytes
     
    @@ -1849,19 +1849,19 @@ 

    Inherited members

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class( _ED25519.Ed25519PrivateKey.from_private_bytes(data=raw_bytes) ) @classmethod - def generate(cls) -> "ED25519PrivateKey": + def generate(cls) -> "Ed25519PrivateKey": """ Generates a new ED25519 Private Key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class(_ED25519.Ed25519PrivateKey.generate()) @@ -1896,8 +1896,8 @@

    Ancestors

    Static methods

    -
    -def from_raw_bytes(raw_bytes: bytes) ‑> ED25519PrivateKey +
    +def from_raw_bytes(raw_bytes: bytes) ‑> Ed25519PrivateKey

    Load an ED25519 private key from raw bytes

    @@ -1908,15 +1908,15 @@

    Args

    Returns

    -
    ED25519PrivateKey
    -
    Instance of ED25519PrivateKey
    +
    Ed25519PrivateKey
    +
    Instance of Ed25519PrivateKey
    Expand source code
    @classmethod
    -def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey":
    +def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PrivateKey":
         """
         Load an ED25519 private key from raw bytes
     
    @@ -1924,34 +1924,34 @@ 

    Returns

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class( _ED25519.Ed25519PrivateKey.from_private_bytes(data=raw_bytes) )
    -
    -def generate() ‑> ED25519PrivateKey +
    +def generate() ‑> Ed25519PrivateKey

    Generates a new ED25519 Private Key

    Returns

    -
    ED25519PrivateKey
    -
    Instance of ED25519PrivateKey
    +
    Ed25519PrivateKey
    +
    Instance of Ed25519PrivateKey
    Expand source code
    @classmethod
    -def generate(cls) -> "ED25519PrivateKey":
    +def generate(cls) -> "Ed25519PrivateKey":
         """
         Generates a new ED25519 Private Key
     
         Returns:
    -        ED25519PrivateKey: Instance of ED25519PrivateKey
    +        Ed25519PrivateKey: Instance of Ed25519PrivateKey
         """
         return cls.from_class(_ED25519.Ed25519PrivateKey.generate())
    @@ -1959,7 +1959,7 @@

    Returns

    Methods

    -
    +
    def raw_bytes(self) ‑> bytes
    @@ -1987,7 +1987,7 @@

    Returns

    )
    -
    +
    def sign(self, data: bytes)
    @@ -2035,8 +2035,8 @@

    Inherited members

    -
    -class ED25519PublicKey +
    +class Ed25519PublicKey (key: cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None)
    @@ -2045,7 +2045,7 @@

    Inherited members

    Expand source code -
    class ED25519PublicKey(PublicKey):
    +
    class Ed25519PublicKey(PublicKey):
         """
         Class for holding ED25519 public keys
         """
    @@ -2062,7 +2062,7 @@ 

    Inherited members

    ) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey": + def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PublicKey": """ Load an ED25519 public key from raw bytes @@ -2070,7 +2070,7 @@

    Inherited members

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PublicKey: Instance of ED25519PublicKey + Ed25519PublicKey: Instance of Ed25519PublicKey """ if b"ssh-ed25519" in raw_bytes: id_length = unpack(">I", raw_bytes[:4])[0] + 8 @@ -2104,8 +2104,8 @@

    Ancestors

    Static methods

    -
    -def from_raw_bytes(raw_bytes: bytes) ‑> ED25519PublicKey +
    +def from_raw_bytes(raw_bytes: bytes) ‑> Ed25519PublicKey

    Load an ED25519 public key from raw bytes

    @@ -2116,15 +2116,15 @@

    Args

    Returns

    -
    ED25519PublicKey
    -
    Instance of ED25519PublicKey
    +
    Ed25519PublicKey
    +
    Instance of Ed25519PublicKey
    Expand source code
    @classmethod
    -def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey":
    +def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PublicKey":
         """
         Load an ED25519 public key from raw bytes
     
    @@ -2132,7 +2132,7 @@ 

    Returns

    raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PublicKey: Instance of ED25519PublicKey + Ed25519PublicKey: Instance of Ed25519PublicKey """ if b"ssh-ed25519" in raw_bytes: id_length = unpack(">I", raw_bytes[:4])[0] + 8 @@ -2146,7 +2146,7 @@

    Returns

    Methods

    -
    +
    def verify(self, data: bytes, signature: bytes) ‑> None
    @@ -2399,7 +2399,7 @@

    Returns

    class PrivateKey -(key: Union[cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey], public_key: PublicKey, **kwargs) +(key: Union[cryptography.hazmat.primitives.asymmetric.rsa.RsaPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DsaPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey], public_key: PublicKey, **kwargs)

    Class for handling SSH Private keys

    @@ -2560,15 +2560,15 @@

    Returns

    Subclasses

    Static methods

    -def from_class(key_class: Union[cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey]) ‑> PrivateKey +def from_class(key_class: Union[cryptography.hazmat.primitives.asymmetric.rsa.RsaPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DsaPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey]) ‑> PrivateKey

    Import an SSH Private key from a cryptography key class

    @@ -2865,7 +2865,7 @@

    Returns

    class PublicKey -(key: Union[cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey] = None, comment: Union[str, bytes] = None, **kwargs) +(key: Union[cryptography.hazmat.primitives.asymmetric.rsa.RsaPrivateKey, cryptography.hazmat.primitives.asymmetric.dsa.DsaPrivateKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey] = None, comment: Union[str, bytes] = None, **kwargs)

    Class for handling SSH public keys

    @@ -3024,15 +3024,15 @@

    Returns

    Subclasses

    Static methods

    -def from_class(key_class: Union[cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey], comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None) ‑> PublicKey +def from_class(key_class: Union[cryptography.hazmat.primitives.asymmetric.rsa.RsaPublicKey, cryptography.hazmat.primitives.asymmetric.dsa.DsaPublicKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey], comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None) ‑> PublicKey

    Creates a new SSH Public key from a cryptography class

    @@ -3318,9 +3318,9 @@

    Returns

    -
    -class RSAPrivateKey -(key: cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey) +
    +class RsaPrivateKey +(key: cryptography.hazmat.primitives.asymmetric.rsa.RsaPrivateKey)

    Class for holding RSA private keys

    @@ -3328,15 +3328,15 @@

    Returns

    Expand source code -
    class RSAPrivateKey(PrivateKey):
    +
    class RsaPrivateKey(PrivateKey):
         """
         Class for holding RSA private keys
         """
     
    -    def __init__(self, key: _RSA.RSAPrivateKey):
    +    def __init__(self, key: _RSA.RsaPrivateKey):
             super().__init__(
                 key=key,
    -            public_key=RSAPublicKey(key.public_key()),
    +            public_key=RsaPublicKey(key.public_key()),
                 private_numbers=key.private_numbers(),
             )
     
    @@ -3352,7 +3352,7 @@ 

    Returns

    dmp1: int = None, dmq1: int = None, iqmp: int = None, - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Load an RSA private key from numbers @@ -3375,7 +3375,7 @@

    Returns

    Automatically generates if not provided Returns: - RSAPrivateKey: An instance of RSAPrivateKey + RsaPrivateKey: An instance of RsaPrivateKey """ if None in (p, q): p, q = _RSA.rsa_recover_prime_factors(n, e, d) @@ -3399,7 +3399,7 @@

    Returns

    @classmethod def generate( cls, key_size: int = 4096, public_exponent: int = 65537 - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Generates a new RSA private key @@ -3408,7 +3408,7 @@

    Returns

    public_exponent (int, optional): The public exponent to use. Defaults to 65537. Returns: - RSAPrivateKey: Instance of RSAPrivateKey + RsaPrivateKey: Instance of RsaPrivateKey """ return cls.from_class( _RSA.generate_private_key( @@ -3436,8 +3436,8 @@

    Ancestors

    Static methods

    -
    -def from_numbers(e: int, n: int, d: int, p: int = None, q: int = None, dmp1: int = None, dmq1: int = None, iqmp: int = None) ‑> RSAPrivateKey +
    +def from_numbers(e: int, n: int, d: int, p: int = None, q: int = None, dmp1: int = None, dmq1: int = None, iqmp: int = None) ‑> RsaPrivateKey

    Load an RSA private key from numbers

    @@ -3470,8 +3470,8 @@

    Args

    Returns

    -
    RSAPrivateKey
    -
    An instance of RSAPrivateKey
    +
    RsaPrivateKey
    +
    An instance of RsaPrivateKey
    @@ -3489,7 +3489,7 @@

    Returns

    dmp1: int = None, dmq1: int = None, iqmp: int = None, -) -> "RSAPrivateKey": +) -> "RsaPrivateKey": """ Load an RSA private key from numbers @@ -3512,7 +3512,7 @@

    Returns

    Automatically generates if not provided Returns: - RSAPrivateKey: An instance of RSAPrivateKey + RsaPrivateKey: An instance of RsaPrivateKey """ if None in (p, q): p, q = _RSA.rsa_recover_prime_factors(n, e, d) @@ -3534,8 +3534,8 @@

    Returns

    )
    -
    -def generate(key_size: int = 4096, public_exponent: int = 65537) ‑> RSAPrivateKey +
    +def generate(key_size: int = 4096, public_exponent: int = 65537) ‑> RsaPrivateKey

    Generates a new RSA private key

    @@ -3548,8 +3548,8 @@

    Args

    Returns

    -
    RSAPrivateKey
    -
    Instance of RSAPrivateKey
    +
    RsaPrivateKey
    +
    Instance of RsaPrivateKey
    @@ -3558,7 +3558,7 @@

    Returns

    @classmethod
     def generate(
         cls, key_size: int = 4096, public_exponent: int = 65537
    -) -> "RSAPrivateKey":
    +) -> "RsaPrivateKey":
         """
         Generates a new RSA private key
     
    @@ -3567,7 +3567,7 @@ 

    Returns

    public_exponent (int, optional): The public exponent to use. Defaults to 65537. Returns: - RSAPrivateKey: Instance of RSAPrivateKey + RsaPrivateKey: Instance of RsaPrivateKey """ return cls.from_class( _RSA.generate_private_key( @@ -3579,7 +3579,7 @@

    Returns

    Methods

    -
    +
    def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512) ‑> bytes
    @@ -3632,9 +3632,9 @@

    Inherited members

    -
    -class RSAPublicKey -(key: cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None) +
    +class RsaPublicKey +(key: cryptography.hazmat.primitives.asymmetric.rsa.RsaPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, serialized: bytes = None)

    Class for holding RSA public keys

    @@ -3642,14 +3642,14 @@

    Inherited members

    Expand source code -
    class RSAPublicKey(PublicKey):
    +
    class RsaPublicKey(PublicKey):
         """
         Class for holding RSA public keys
         """
     
         def __init__(
             self,
    -        key: _RSA.RSAPublicKey,
    +        key: _RSA.RsaPublicKey,
             comment: Union[str, bytes] = None,
             key_type: Union[str, bytes] = None,
             serialized: bytes = None,
    @@ -3664,7 +3664,7 @@ 

    Inherited members

    @classmethod # pylint: disable=invalid-name - def from_numbers(cls, e: int, n: int) -> "RSAPublicKey": + def from_numbers(cls, e: int, n: int) -> "RsaPublicKey": """ Loads an RSA Public Key from the public numbers e and n @@ -3673,7 +3673,7 @@

    Inherited members

    n (int): n-value Returns: - RSAPublicKey: _description_ + RsaPublicKey: _description_ """ return cls(key=_RSA.RSAPublicNumbers(e, n).public_key()) @@ -3706,8 +3706,8 @@

    Ancestors

    Static methods

    -
    -def from_numbers(e: int, n: int) ‑> RSAPublicKey +
    +def from_numbers(e: int, n: int) ‑> RsaPublicKey

    Loads an RSA Public Key from the public numbers e and n

    @@ -3720,7 +3720,7 @@

    Args

    Returns

    -
    RSAPublicKey
    +
    RsaPublicKey
    description
    @@ -3729,7 +3729,7 @@

    Returns

    @classmethod
     # pylint: disable=invalid-name
    -def from_numbers(cls, e: int, n: int) -> "RSAPublicKey":
    +def from_numbers(cls, e: int, n: int) -> "RsaPublicKey":
         """
         Loads an RSA Public Key from the public numbers e and n
     
    @@ -3738,7 +3738,7 @@ 

    Returns

    n (int): n-value Returns: - RSAPublicKey: _description_ + RsaPublicKey: _description_ """ return cls(key=_RSA.RSAPublicNumbers(e, n).public_key())
    @@ -3746,7 +3746,7 @@

    Returns

    Methods

    -
    +
    def verify(self, data: bytes, signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512) ‑> None
    @@ -3872,49 +3872,49 @@

    Index

  • Classes

    • -

      DSAPrivateKey

      +

      DsaPrivateKey

    • -

      DSAPublicKey

      +

      DsaPublicKey

    • -

      ECDSAPrivateKey

      +

      EcdsaPrivateKey

    • -

      ECDSAPublicKey

      +

      EcdsaPublicKey

    • -

      ED25519PrivateKey

      +

      Ed25519PrivateKey

    • -

      ED25519PublicKey

      +

      Ed25519PublicKey

    • @@ -3959,18 +3959,18 @@

      RSAPrivateKey

      +

      RsaPrivateKey

    • -

      RSAPublicKey

      +

      RsaPublicKey

    • diff --git a/src/sshkey_tools/cert.py b/src/sshkey_tools/cert.py index d66e17e..2f775f3 100644 --- a/src/sshkey_tools/cert.py +++ b/src/sshkey_tools/cert.py @@ -8,603 +8,312 @@ """ from base64 import b64encode, b64decode +from dataclasses import dataclass from typing import Union +from enum import Enum from .keys import ( PublicKey, PrivateKey, - RSAPublicKey, - DSAPublicKey, - ECDSAPublicKey, - ED25519PublicKey, + RsaPublicKey, + DsaPublicKey, + EcdsaPublicKey, + Ed25519PublicKey, ) from . import fields as _FIELD from . import exceptions as _EX from .keys import RsaAlgs -from .utils import join_dicts, concat_to_string, ensure_string, ensure_bytestring - -CERTIFICATE_FIELDS = { - "serial": _FIELD.SerialField, - "cert_type": _FIELD.CertificateTypeField, - "key_id": _FIELD.KeyIdField, - "principals": _FIELD.PrincipalsField, - "valid_after": _FIELD.ValidAfterField, - "valid_before": _FIELD.ValidBeforeField, - "critical_options": _FIELD.CriticalOptionsField, - "extensions": _FIELD.ExtensionsField, -} +from .utils import join_dicts, concat_to_string, concat_to_bytestring, ensure_string, ensure_bytestring CERT_TYPES = { - "ssh-rsa-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "rsa-sha2-256-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "rsa-sha2-512-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), - "ssh-dss-cert-v01@openssh.com": ("DSACertificate", "_FIELD.DSAPubkeyField"), + "ssh-rsa-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), + "rsa-sha2-256-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), + "rsa-sha2-512-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), + "ssh-dss-cert-v01@openssh.com": ("DsaCertificate", "_FIELD.DsaPubkeyField"), "ecdsa-sha2-nistp256-cert-v01@openssh.com": ( - "ECDSACertificate", - "_FIELD.ECDSAPubkeyField", + "EcdsaCertificate", + "_FIELD.EcdsaPubkeyField", ), "ecdsa-sha2-nistp384-cert-v01@openssh.com": ( - "ECDSACertificate", - "_FIELD.ECDSAPubkeyField", + "EcdsaCertificate", + "_FIELD.EcdsaPubkeyField", ), "ecdsa-sha2-nistp521-cert-v01@openssh.com": ( - "ECDSACertificate", - "_FIELD.ECDSAPubkeyField", + "EcdsaCertificate", + "_FIELD.EcdsaPubkeyField", ), "ssh-ed25519-cert-v01@openssh.com": ( - "ED25519Certificate", - "_FIELD.ED25519PubkeyField", + "Ed25519Certificate", + "_FIELD.Ed25519PubkeyField", ), } +@dataclass +class Fieldset: + def __setattr__(self, name, value): + field = getattr(self, name, None) + + if callable(field) and not isinstance(field, _FIELD.CertificateField): + if field.__name__ == "factory": + super().__setattr__(name, field()) + self.__setattr__(name, value) + return + + if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': + super().__setattr__(name, field(value)) + return + + if getattr(value, '__name__', '') != 'factory': + field.value = value + super().__setattr__(name, field) + + def replace_field(self, name: str, value: Union[_FIELD.CertificateField, type]): + super(Fieldset, self).__setattr__(name, value) + + def get(self, name: str, default=None): + field = getattr(self, name, default) + if field: + if isinstance(field, type): + return field.DEFAULT + return field.value + return field + + def getattrs(self) -> tuple: + return tuple(k for k in self.__dict__.keys() if not k.startswith('_')) + + def validate(self): + ex = [] + for key in self.getattrs(): + if not getattr(self, key).validate(): + list([ + ex.append(f"{type(x)}: {str(x)}") for x in getattr(self, key).exception + if isinstance(x, Exception) + ]) + + return True if len(ex) == 0 else ex + +@dataclass +class CertificateHeader(Fieldset): + public_key: _FIELD.PublicKeyField = _FIELD.PublicKeyField.factory + pubkey_type: _FIELD.PubkeyTypeField = _FIELD.PubkeyTypeField.factory + nonce: _FIELD.NonceField = _FIELD.NonceField.factory + + def __bytes__(self): + return concat_to_bytestring( + bytes(self.pubkey_type), + bytes(self.nonce), + bytes(self.public_key) + ) + +@dataclass +class CertificateFields(Fieldset): + serial: _FIELD.SerialField = _FIELD.SerialField.factory + cert_type: _FIELD.CertificateTypeField = _FIELD.CertificateTypeField.factory + key_id: _FIELD.KeyIdField = _FIELD.KeyIdField.factory + principals: _FIELD.PrincipalsField = _FIELD.PrincipalsField.factory + valid_after: _FIELD.ValidAfterField = _FIELD.ValidAfterField.factory + valid_before: _FIELD.ValidBeforeField = _FIELD.ValidBeforeField.factory + critical_options: _FIELD.CriticalOptionsField = _FIELD.CriticalOptionsField.factory + extensions: _FIELD.ExtensionsField = _FIELD.ExtensionsField.factory + + def __bytes__(self): + return concat_to_bytestring( + bytes(self.serial), + bytes(self.cert_type), + bytes(self.key_id), + bytes(self.principals), + bytes(self.valid_after), + bytes(self.valid_before), + bytes(self.critical_options), + bytes(self.extensions) + ) + +@dataclass +class CertificateFooter(Fieldset): + reserved: _FIELD.ReservedField = _FIELD.ReservedField.factory + ca_pubkey: _FIELD.CAPublicKeyField = _FIELD.CAPublicKeyField.factory + signature: _FIELD.SignatureField = _FIELD.SignatureField.factory + + def __bytes__(self): + return concat_to_bytestring( + bytes(self.reserved), + bytes(self.ca_pubkey) + ) + class SSHCertificate: """ General class for SSH Certificates, used for loading and parsing. To create new certificates, use the respective keytype classes or the from_public_key classmethod """ - + DEFAULT_KEY_TYPE = 'none@openssh.com' def __init__( self, subject_pubkey: PublicKey = None, ca_privkey: PrivateKey = None, - decoded: dict = None, - **kwargs, - ) -> None: + fields: CertificateFields = CertificateFields, + header: CertificateHeader = CertificateHeader, + footer: CertificateFooter = CertificateFooter + ): if self.__class__.__name__ == "SSHCertificate": raise _EX.InvalidClassCallException( "You cannot instantiate SSHCertificate directly. Use \n" - + "one of the child classes, or call via decode, \n" + + "one of the child classes, or call via decode, create \n" + "or one of the from_-classmethods" ) - - if decoded is not None: - self.signature = decoded.pop("signature") - self.signature_pubkey = decoded.pop("ca_pubkey") - - self.header = { - "pubkey_type": decoded.pop("pubkey_type"), - "nonce": decoded.pop("nonce"), - "public_key": decoded.pop("public_key"), - } - - self.fields = decoded - - return - - if subject_pubkey is None: - raise _EX.SSHCertificateException("The subject public key is required") - - self.header = { - "pubkey_type": _FIELD.PubkeyTypeField, - "nonce": _FIELD.NonceField(), - "public_key": _FIELD.PublicKeyField.from_object(subject_pubkey), - } - - if ca_privkey is not None: - self.signature = _FIELD.SignatureField.from_object(ca_privkey) - self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( - ca_privkey.public_key - ) - - self.fields = dict(CERTIFICATE_FIELDS) - self.set_opts(**kwargs) - - def __str__(self): - ls_space = " "*32 - - principals = "\n" + "\n".join( - [ ls_space + principal for principal in ensure_string(self.fields["principals"].value) ] - if len(self.fields["principals"].value) > 0 - else "None" - ) - - critical = "\n" + "\n".join( - [ls_space + cr_opt for cr_opt in ensure_string(self.fields["critical_options"].value)] - if not isinstance(self.fields["critical_options"].value, dict) - else [f'{ls_space}{cr_opt}={self.fields["critical_options"].value[cr_opt]}' for cr_opt in ensure_string(self.fields["critical_options"].value)] - ) - extensions = "\n" + "\n".join( - [ ls_space + ext for ext in ensure_string(self.fields["extensions"].value) ] - if len(self.fields["extensions"].value) > 0 - else "None" - ) + self.fields = fields() if isinstance(fields, type) else fields + self.header = header() if isinstance(header, type) else header + self.footer = footer() if isinstance(footer, type) else footer - signature_val = ( - b64encode(self.signature.value).decode("utf-8") - if isinstance(self.signature.value, bytes) - else "Not signed" - ) - - return f""" - Certificate: - Pubkey Type: {self.header['pubkey_type'].value} - Public Key: {str(self.header['public_key'])} - CA Public Key: {str(self.signature_pubkey)} - Nonce: {self.header['nonce'].value} - Certificate Type: {'User' if self.fields['cert_type'].value == 1 else 'Host'} - Valid After: {self.fields['valid_after'].value.strftime('%Y-%m-%d %H:%M:%S')} - Valid Until: {self.fields['valid_before'].value.strftime('%Y-%m-%d %H:%M:%S')} - Principals: {principals} - Critical options: {critical} - Extensions: {extensions} - Signature: {signature_val} - """ - - @staticmethod - def decode( - cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None - ) -> "SSHCertificate": - """ - Decode an existing certificate and import it into a new object - - Args: - cert_bytes (bytes): The certificate bytes, base64 decoded middle part of the certificate - pubkey_field (_FIELD.PublicKeyField): Instance of the PublicKeyField class, only needs - to be set if it can't be detected automatically - - Raises: - _EX.InvalidCertificateFormatException: Invalid or unknown certificate format - - Returns: - SSHCertificate: SSHCertificate child class - """ - if pubkey_class is None: - cert_type = _FIELD.StringField.decode(cert_bytes)[0].encode("utf-8") - pubkey_class = CERT_TYPES.get(cert_type, False) - - if pubkey_class is False: - raise _EX.InvalidCertificateFormatException( - "Could not determine certificate type, please use one " - + "of the specific classes or specify the pubkey_class" + if isinstance(header, type) and subject_pubkey is not None: + self.header.pubkey_type = self.DEFAULT_KEY_TYPE + self.header.replace_field( + 'public_key', + _FIELD.PublicKeyField.from_object(subject_pubkey) ) - - decode_fields = join_dicts( - { - "pubkey_type": _FIELD.PubkeyTypeField, - "nonce": _FIELD.NonceField, - "public_key": pubkey_class, - }, - CERTIFICATE_FIELDS, - { - "reserved": _FIELD.ReservedField, - "ca_pubkey": _FIELD.CAPublicKeyField, - "signature": _FIELD.SignatureField, - }, - ) - - cert = {} - - for item in decode_fields.keys(): - cert[item], cert_bytes = decode_fields[item].from_decode(cert_bytes) - - if cert_bytes != b"": + + if isinstance(footer, type) and ca_privkey is not None: + self.footer.ca_pubkey = ca_privkey.public_key + self.footer.replace_field( + 'signature', + _FIELD.SignatureField.from_object(ca_privkey) + ) + + self.__post_init__() + + def __post_init__(self): + """Extensible function for post-initialization for child classes""" + + def __bytes__(self): + if not self.footer.signature.is_signed: raise _EX.InvalidCertificateFormatException( - "The certificate has additional data after everything has been extracted" + "Failed exporting certificate: Certificate is not signed" ) - - pubkey_type = ensure_string(cert["pubkey_type"].value) - - cert_type = CERT_TYPES[pubkey_type] - cert.pop("reserved") - return globals()[cert_type[0]]( - subject_pubkey=cert["public_key"].value, decoded=cert + + return concat_to_bytestring( + bytes(self.header), + bytes(self.fields), + bytes(self.footer), + bytes(self.footer.signature) ) @classmethod - def from_public_class( - cls, public_key: PublicKey, ca_privkey: PrivateKey = None, **kwargs - ) -> "SSHCertificate": - """ - Creates a new certificate from a supplied public key - - Args: - public_key (PublicKey): The public key for which to create a certificate - - Returns: - SSHCertificate: SSHCertificate child class - """ - return globals()[ - public_key.__class__.__name__.replace("PublicKey", "Certificate") - ](public_key, ca_privkey, **kwargs) - - @classmethod - def from_bytes(cls, cert_bytes: bytes): - """ - Loads an existing certificate from the byte value. - - Args: - cert_bytes (bytes): Certificate bytes, base64 decoded middle part of the certificate - - Returns: - SSHCertificate: SSHCertificate child class - """ - cert_type, _ = _FIELD.StringField.decode(cert_bytes) - target_class = CERT_TYPES[cert_type] - return globals()[target_class[0]].decode(cert_bytes) - - @classmethod - def from_string(cls, cert_str: Union[str, bytes], encoding: str = "utf-8"): - """ - Loads an existing certificate from a string in the format - [certificate-type] [base64-encoded-certificate] [optional-comment] - - Args: - cert_str (str): The string containing the certificate - encoding (str, optional): The encoding of the string. Defaults to 'utf-8'. - - Returns: - SSHCertificate: SSHCertificate child class - """ - cert_str = ensure_bytestring(cert_str) - - certificate = b64decode(cert_str.split(b" ")[1]) - return cls.from_bytes(cert_bytes=certificate) - - @classmethod - def from_file(cls, path: str, encoding: str = "utf-8"): - """ - Loads an existing certificate from a file - - Args: - path (str): The path to the certificate file - encoding (str, optional): Encoding of the file. Defaults to 'utf-8'. - - Returns: - SSHCertificate: SSHCertificate child class - """ - return cls.from_string(open(path, "r", encoding=encoding).read()) - - def set_ca(self, ca_privkey: PrivateKey): - """ - Set the CA Private Key for signing the certificate - - Args: - ca_privkey (PrivateKey): The CA private key - """ - self.signature = _FIELD.SignatureField.from_object(ca_privkey) - self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( - ca_privkey.public_key + def create( + cls, + subject_pubkey: PublicKey = None, + ca_privkey: PrivateKey = None, + fields: CertificateFields = CertificateFields + ): + cert_class = subject_pubkey.__class__.__name__.replace("PublicKey", "Certificate") + return globals()[cert_class]( + subject_pubkey=subject_pubkey, + ca_privkey=ca_privkey, + fields=fields ) - - def set_type(self, pubkey_type: str): - """ - Set the type of the public key if not already set automatically - The child classes will set this automatically - - Args: - pubkey_type (str): Public key type, e.g. ssh-rsa-cert-v01@openssh.com - """ - if not getattr(self.header["pubkey_type"], "value", False): - self.header["pubkey_type"] = self.header["pubkey_type"](pubkey_type) - - def set_opt(self, key: str, value): - """ - Add information to a field in the certificate - - Args: - key (str): The key to set - value (mixed): The new value for the field - - Raises: - _EX.InvalidCertificateFieldException: Invalid field - """ - if key not in self.fields: - raise _EX.InvalidCertificateFieldException( - f"{key} is not a valid certificate field" + + @classmethod + def decode(cls,cert_data: Union[str, bytes]): + pass + + def get(self, field: str): + if field in ( + self.header.getattrs() + + self.fields.getattrs() + + self.footer.getattrs() + ): + return ( + self.fields.get(field, False) or + self.header.get(field, False) or + self.footer.get(field, False) ) - try: - if self.fields[key].value not in [None, False, "", [], ()]: - self.fields[key].value = value - except AttributeError: - self.fields[key] = self.fields[key](value) + raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") + + def set(self, field: str, value): + if self.fields.get(field, False): + setattr(self.fields, field, value) + return - def set_opts(self, **kwargs): - """ - Set multiple options at once - """ - for key, value in kwargs.items(): - self.set_opt(key, value) + if self.header.get(field, False): + setattr(self.header, field, value) + return - def get_opt(self, key: str): - """ - Get the value of a field in the certificate - - Args: - key (str): The key to get - - Raises: - _EX.InvalidCertificateFieldException: Invalid field - """ - if key not in self.fields: - raise _EX.InvalidCertificateFieldException( - f"{key} is not a valid certificate field" - ) - - return getattr(self.fields[key], "value", None) - - # pylint: disable=used-before-assignment + if self.footer.get(field, False): + setattr(self.footer, field, value) + return + + raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") + def can_sign(self) -> bool: - """ - Determine if the certificate is ready to be signed - - Raises: - ...: Exception from the respective field with error - _EX.NoPrivateKeyException: Private key is missing from class - - Returns: - bool: True/False if the certificate can be signed - """ - exceptions = [] - for field in self.fields.values(): - try: - valid = field.validate() - except TypeError: - valid = _EX.SignatureNotPossibleException( - f"The field {field} is missing a value" + valid_header = self.header.validate() + valid_fields = self.fields.validate() + check_keys = ( + True if isinstance(self.get('ca_pubkey'), PublicKey) and + isinstance(self.footer.signature.private_key, PrivateKey) + else [ + _EX.SignatureNotPossibleException('No CA Public/Private key is loaded') + ] + ) + + if (valid_header, valid_fields, check_keys) != (True, True, True): + raise _EX.SignatureNotPossibleException( + "\n".join( + valid_header if valid_header != True else [] + + valid_fields if valid_fields != True else [] + + check_keys if check_keys != True else [] ) - finally: - if isinstance(valid, Exception): - exceptions.append(valid) - - if ( - getattr(self, "signature", False) is False - or getattr(self, "signature_pubkey", False) is False - ): - exceptions.append( - _EX.SignatureNotPossibleException("No CA private key is set") ) - - if len(exceptions) > 0: - raise _EX.SignatureNotPossibleException(exceptions) - - if self.signature.can_sign() is True: - return True - - raise _EX.SignatureNotPossibleException( - "The certificate cannot be signed, the CA private key is not loaded" - ) - - def get_signable_data(self) -> bytes: + + return True + + def get_signable(self) -> bytes: """ - Gets the signable byte string from the certificate fields - - Returns: - bytes: The data in the certificate which is signed + Retrieves the signable data for the certificate in byte form """ - return ( - b"".join( - [ - bytes(x) - for x in tuple(self.header.values()) + tuple(self.fields.values()) - ] - ) - + bytes(_FIELD.ReservedField()) - + bytes(self.signature_pubkey) + return concat_to_bytestring( + bytes(self.header), + bytes(self.fields), + bytes(self.footer) ) - def sign(self, **signing_args): - """ - Sign the certificate - - Returns: - SSHCertificate: The signed certificate class - """ + def sign(self) -> bool: if self.can_sign(): - self.signature.sign(data=self.get_signable_data(), **signing_args) - - return self - - def verify(self, ca_pubkey: PublicKey = None) -> bool: - """ - Verifies a signature against a given public key. - - If no public key is provided, the signature is checked against - the public/private key provided to the class on creation - or decoding. - - Not providing the public key for the CA with an imported - certificate means the verification will succeed even if an - attacker has replaced the signature and public key for signing. - - If the certificate wasn't created and signed on the same occasion - as the validity check, you should always provide a public key for - verificiation. - - Returns: - bool: If the certificate signature is valid - """ - - if ca_pubkey is None: - ca_pubkey = self.signature_pubkey.value - - cert_data = self.get_signable_data() - signature = self.signature.value - - return ca_pubkey.verify(cert_data, signature) - - def to_bytes(self) -> bytes: - """ - Export the signed certificate in byte-format - - Raises: - _EX.NotSignedException: The certificate has not been signed yet - - Returns: - bytes: The certificate bytes - """ - if self.signature.is_signed is True: - return self.get_signable_data() + bytes(self.signature) - - raise _EX.NotSignedException("The certificate has not been signed") - - def to_string( - self, comment: Union[str, bytes] = None, encoding: str = "utf-8" - ) -> str: - """ - Export the signed certificate to a string, ready to be written to file - - Args: - comment (Union[str, bytes], optional): Comment to add to the string. Defaults to None. - encoding (str, optional): Encoding to use for the string. Defaults to 'utf-8'. + self.footer.signature.sign( + data=self.get_signable() + ) - Returns: - str: Certificate string - """ + return True + + def to_string(self, comment: str = '', encoding: str = 'utf-8'): return concat_to_string( - self.header["pubkey_type"].value, + self.header.get('pubkey_type'), " ", - b64encode(self.to_bytes()), + b64encode(bytes(self)), " ", comment if comment else "", encoding=encoding ) - - def to_file( - self, path: str, comment: Union[str, bytes] = None, encoding: str = "utf-8" - ): - """ - Saves the certificate to a file - - Args: - path (str): The path of the file to save to - comment (Union[str, bytes], optional): Comment to add to the certificate end. - Defaults to None. - encoding (str, optional): Encoding for the file. Defaults to 'utf-8'. - """ - with open(path, "w", encoding=encoding) as file: - file.write(self.to_string(comment, encoding)) - - -class RSACertificate(SSHCertificate): - """ - Specific class for RSA Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, - subject_pubkey: RSAPublicKey, - ca_privkey: PrivateKey = None, - rsa_alg: RsaAlgs = RsaAlgs.SHA512, - **kwargs, - ): - - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.rsa_alg = rsa_alg - self.set_type(f"{rsa_alg.value[0]}-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "SSHCertificate": - """ - Decode an existing RSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - RSACertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.RSAPubkeyField) - - -class DSACertificate(SSHCertificate): - """ - Specific class for DSA/DSS Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: DSAPublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type("ssh-dss-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "DSACertificate": - """ - Decode an existing DSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - DSACertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.DSAPubkeyField) - - -class ECDSACertificate(SSHCertificate): - """ - Specific class for ECDSA Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: ECDSAPublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type( - f"ecdsa-sha2-nistp{subject_pubkey.key.curve.key_size}-cert-v01@openssh.com" + + def to_file(self, filename: str): + with open(filename, 'w') as f: + f.write(self.to_string()) + +class RsaCertificate(SSHCertificate): + DEFAULT_KEY_TYPE = 'rsa-sha2-512-cert-v01@openssh.com' + +class DsaCertificate(SSHCertificate): + DEFAULT_KEY_TYPE = 'ssh-dss-cert-v01@openssh.com' + +class EcdsaCertificate(SSHCertificate): + DEFAULT_KEY_TYPE = 'ecdsa-sha2-nistp[curve_size]-cert-v01@openssh.com' + + def __post_init__(self): + """Set the key name from the public key curve size""" + self.header.pubkey_type = self.header.get("pubkey_type").replace( + "[curve_size]", + str(self.header.public_key.value.key.curve.key_size) ) - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "ECDSACertificate": - """ - Decode an existing ECDSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - ECDSACertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.ECDSAPubkeyField) - - -class ED25519Certificate(SSHCertificate): - """ - Specific class for ED25519 Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: ED25519PublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type("ssh-ed25519-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "ED25519Certificate": - """ - Decode an existing ED25519 Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - ED25519Certificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.ED25519PubkeyField) +class Ed25519Certificate(SSHCertificate): + DEFAULT_KEY_TYPE = 'ssh-ed25519-cert-v01@openssh.com' \ No newline at end of file diff --git a/src/sshkey_tools/certold.py b/src/sshkey_tools/certold.py new file mode 100644 index 0000000..81f20ec --- /dev/null +++ b/src/sshkey_tools/certold.py @@ -0,0 +1,588 @@ +"""Contains classes for OpenSSH Certificates, generation, parsing and signing + Raises: + _EX.SSHCertificateException: General error in certificate + _EX.InvalidCertificateFormatException: An error with the format of the certificate + _EX.InvalidCertificateFieldException: An invalid field has been added to the certificate + _EX.NoPrivateKeyException: The certificate contains no private key + _EX.NotSignedException: The certificate is not signed and cannot be exported + +""" +from base64 import b64encode, b64decode +from typing import Union +from .keys import ( + PublicKey, + PrivateKey, + RsaPublicKey, + DsaPublicKey, + EcdsaPublicKey, + Ed25519PublicKey, +) +from . import fields as _FIELD +from . import exceptions as _EX +from .keys import RsaAlgs +from .utils import join_dicts, concat_to_string, ensure_string, ensure_bytestring + +CERTIFICATE_FIELDS = { + "serial": _FIELD.SerialField, + "cert_type": _FIELD.CertificateTypeField, + "key_id": _FIELD.KeyIdField, + "principals": _FIELD.PrincipalsField, + "valid_after": _FIELD.ValidAfterField, + "valid_before": _FIELD.ValidBeforeField, + "critical_options": _FIELD.CriticalOptionsField, + "extensions": _FIELD.ExtensionsField, +} + +class SSHCertificate: + """ + General class for SSH Certificates, used for loading and parsing. + To create new certificates, use the respective keytype classes + or the from_public_key classmethod + """ + + def __init__( + self, + subject_pubkey: PublicKey = None, + ca_privkey: PrivateKey = None, + decoded: dict = None, + **kwargs, + ) -> None: + if self.__class__.__name__ == "SSHCertificate": + raise _EX.InvalidClassCallException( + "You cannot instantiate SSHCertificate directly. Use \n" + + "one of the child classes, or call via decode, \n" + + "or one of the from_-classmethods" + ) + + if decoded is not None: + self.signature = decoded.pop("signature") + self.signature_pubkey = decoded.pop("ca_pubkey") + + self.header = { + "pubkey_type": decoded.pop("pubkey_type"), + "nonce": decoded.pop("nonce"), + "public_key": decoded.pop("public_key"), + } + + self.fields = decoded + + return + + if subject_pubkey is None: + raise _EX.SSHCertificateException("The subject public key is required") + + self.header = { + "pubkey_type": _FIELD.PubkeyTypeField, + "nonce": _FIELD.NonceField(), + "public_key": _FIELD.PublicKeyField.from_object(subject_pubkey), + } + + if ca_privkey is not None: + self.signature = _FIELD.SignatureField.from_object(ca_privkey) + self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( + ca_privkey.public_key + ) + + self.fields = dict(CERTIFICATE_FIELDS) + self.set_opts(**kwargs) + + def __str__(self): + ls_space = " "*32 + + principals = "\n" + "\n".join( + [ ls_space + principal for principal in ensure_string(self.fields["principals"].value) ] + if len(self.fields["principals"].value) > 0 + else "None" + ) + + critical = "\n" + "\n".join( + [ls_space + cr_opt for cr_opt in ensure_string(self.fields["critical_options"].value)] + if not isinstance(self.fields["critical_options"].value, dict) + else [f'{ls_space}{cr_opt}={self.fields["critical_options"].value[cr_opt]}' for cr_opt in ensure_string(self.fields["critical_options"].value)] + ) + + extensions = "\n" + "\n".join( + [ ls_space + ext for ext in ensure_string(self.fields["extensions"].value) ] + if len(self.fields["extensions"].value) > 0 + else "None" + ) + + signature_val = ( + b64encode(self.signature.value).decode("utf-8") + if isinstance(self.signature.value, bytes) + else "Not signed" + ) + + return f""" + Certificate: + Pubkey Type: {self.header['pubkey_type'].value} + Public Key: {str(self.header['public_key'])} + CA Public Key: {str(self.signature_pubkey)} + Nonce: {self.header['nonce'].value} + Certificate Type: {'User' if self.fields['cert_type'].value == 1 else 'Host'} + Valid After: {self.fields['valid_after'].value.strftime('%Y-%m-%d %H:%M:%S')} + Valid Until: {self.fields['valid_before'].value.strftime('%Y-%m-%d %H:%M:%S')} + Principals: {principals} + Critical options: {critical} + Extensions: {extensions} + Signature: {signature_val} + """ + + @staticmethod + def decode( + cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None + ) -> "SSHCertificate": + """ + Decode an existing certificate and import it into a new object + + Args: + cert_bytes (bytes): The certificate bytes, base64 decoded middle part of the certificate + pubkey_field (_FIELD.PublicKeyField): Instance of the PublicKeyField class, only needs + to be set if it can't be detected automatically + + Raises: + _EX.InvalidCertificateFormatException: Invalid or unknown certificate format + + Returns: + SSHCertificate: SSHCertificate child class + """ + if pubkey_class is None: + cert_type = _FIELD.StringField.decode(cert_bytes)[0].encode("utf-8") + pubkey_class = CERT_TYPES.get(cert_type, False) + + if pubkey_class is False: + raise _EX.InvalidCertificateFormatException( + "Could not determine certificate type, please use one " + + "of the specific classes or specify the pubkey_class" + ) + + decode_fields = join_dicts( + { + "pubkey_type": _FIELD.PubkeyTypeField, + "nonce": _FIELD.NonceField, + "public_key": pubkey_class, + }, + CERTIFICATE_FIELDS, + { + "reserved": _FIELD.ReservedField, + "ca_pubkey": _FIELD.CAPublicKeyField, + "signature": _FIELD.SignatureField, + }, + ) + + cert = {} + + for item in decode_fields.keys(): + cert[item], cert_bytes = decode_fields[item].from_decode(cert_bytes) + + if cert_bytes != b"": + raise _EX.InvalidCertificateFormatException( + "The certificate has additional data after everything has been extracted" + ) + + pubkey_type = ensure_string(cert["pubkey_type"].value) + + cert_type = CERT_TYPES[pubkey_type] + cert.pop("reserved") + return globals()[cert_type[0]]( + subject_pubkey=cert["public_key"].value, decoded=cert + ) + + @classmethod + def from_public_class( + cls, public_key: PublicKey, ca_privkey: PrivateKey = None, **kwargs + ) -> "SSHCertificate": + """ + Creates a new certificate from a supplied public key + + Args: + public_key (PublicKey): The public key for which to create a certificate + + Returns: + SSHCertificate: SSHCertificate child class + """ + return globals()[ + public_key.__class__.__name__.replace("PublicKey", "Certificate") + ](public_key, ca_privkey, **kwargs) + + @classmethod + def from_bytes(cls, cert_bytes: bytes): + """ + Loads an existing certificate from the byte value. + + Args: + cert_bytes (bytes): Certificate bytes, base64 decoded middle part of the certificate + + Returns: + SSHCertificate: SSHCertificate child class + """ + cert_type, _ = _FIELD.StringField.decode(cert_bytes) + target_class = CERT_TYPES[cert_type] + return globals()[target_class[0]].decode(cert_bytes) + + @classmethod + def from_string(cls, cert_str: Union[str, bytes], encoding: str = "utf-8"): + """ + Loads an existing certificate from a string in the format + [certificate-type] [base64-encoded-certificate] [optional-comment] + + Args: + cert_str (str): The string containing the certificate + encoding (str, optional): The encoding of the string. Defaults to 'utf-8'. + + Returns: + SSHCertificate: SSHCertificate child class + """ + cert_str = ensure_bytestring(cert_str) + + certificate = b64decode(cert_str.split(b" ")[1]) + return cls.from_bytes(cert_bytes=certificate) + + @classmethod + def from_file(cls, path: str, encoding: str = "utf-8"): + """ + Loads an existing certificate from a file + + Args: + path (str): The path to the certificate file + encoding (str, optional): Encoding of the file. Defaults to 'utf-8'. + + Returns: + SSHCertificate: SSHCertificate child class + """ + return cls.from_string(open(path, "r", encoding=encoding).read()) + + def set_ca(self, ca_privkey: PrivateKey): + """ + Set the CA Private Key for signing the certificate + + Args: + ca_privkey (PrivateKey): The CA private key + """ + self.signature = _FIELD.SignatureField.from_object(ca_privkey) + self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( + ca_privkey.public_key + ) + + def set_type(self, pubkey_type: str): + """ + Set the type of the public key if not already set automatically + The child classes will set this automatically + + Args: + pubkey_type (str): Public key type, e.g. ssh-rsa-cert-v01@openssh.com + """ + if not getattr(self.header["pubkey_type"], "value", False): + self.header["pubkey_type"] = self.header["pubkey_type"](pubkey_type) + + def set_opt(self, key: str, value): + """ + Add information to a field in the certificate + + Args: + key (str): The key to set + value (mixed): The new value for the field + + Raises: + _EX.InvalidCertificateFieldException: Invalid field + """ + if key not in self.fields: + raise _EX.InvalidCertificateFieldException( + f"{key} is not a valid certificate field" + ) + + try: + if self.fields[key].value not in [None, False, "", [], ()]: + self.fields[key].value = value + except AttributeError: + self.fields[key] = self.fields[key](value) + + def set_opts(self, **kwargs): + """ + Set multiple options at once + """ + for key, value in kwargs.items(): + self.set_opt(key, value) + + def get_opt(self, key: str): + """ + Get the value of a field in the certificate + + Args: + key (str): The key to get + + Raises: + _EX.InvalidCertificateFieldException: Invalid field + """ + if key not in self.fields: + raise _EX.InvalidCertificateFieldException( + f"{key} is not a valid certificate field" + ) + + return getattr(self.fields[key], "value", None) + + # pylint: disable=used-before-assignment + def can_sign(self) -> bool: + """ + Determine if the certificate is ready to be signed + + Raises: + ...: Exception from the respective field with error + _EX.NoPrivateKeyException: Private key is missing from class + + Returns: + bool: True/False if the certificate can be signed + """ + self.header['nonce'].validate() + + exceptions = [] + for field in self.fields.values(): + try: + valid = field.validate() + except TypeError: + valid = _EX.SignatureNotPossibleException( + f"The field {field} is missing a value" + ) + finally: + if isinstance(valid, Exception): + exceptions.append(valid) + + if ( + getattr(self, "signature", False) is False + or getattr(self, "signature_pubkey", False) is False + ): + exceptions.append( + _EX.SignatureNotPossibleException("No CA private key is set") + ) + + if len(exceptions) > 0: + raise _EX.SignatureNotPossibleException(exceptions) + + if self.signature.can_sign() is True: + return True + + raise _EX.SignatureNotPossibleException( + "The certificate cannot be signed, the CA private key is not loaded" + ) + + def get_signable_data(self) -> bytes: + """ + Gets the signable byte string from the certificate fields + + Returns: + bytes: The data in the certificate which is signed + """ + return ( + b"".join( + [ + bytes(x) + for x in tuple(self.header.values()) + tuple(self.fields.values()) + ] + ) + + bytes(_FIELD.ReservedField()) + + bytes(self.signature_pubkey) + ) + + def sign(self, **signing_args): + """ + Sign the certificate + + Returns: + SSHCertificate: The signed certificate class + """ + if self.can_sign(): + self.signature.sign(data=self.get_signable_data(), **signing_args) + + return self + + def verify(self, ca_pubkey: PublicKey = None) -> bool: + """ + Verifies a signature against a given public key. + + If no public key is provided, the signature is checked against + the public/private key provided to the class on creation + or decoding. + + Not providing the public key for the CA with an imported + certificate means the verification will succeed even if an + attacker has replaced the signature and public key for signing. + + If the certificate wasn't created and signed on the same occasion + as the validity check, you should always provide a public key for + verificiation. + + Returns: + bool: If the certificate signature is valid + """ + + if ca_pubkey is None: + ca_pubkey = self.signature_pubkey.value + + cert_data = self.get_signable_data() + signature = self.signature.value + + return ca_pubkey.verify(cert_data, signature) + + def to_bytes(self) -> bytes: + """ + Export the signed certificate in byte-format + + Raises: + _EX.NotSignedException: The certificate has not been signed yet + + Returns: + bytes: The certificate bytes + """ + if self.signature.is_signed is True: + return self.get_signable_data() + bytes(self.signature) + + raise _EX.NotSignedException("The certificate has not been signed") + + def to_string( + self, comment: Union[str, bytes] = None, encoding: str = "utf-8" + ) -> str: + """ + Export the signed certificate to a string, ready to be written to file + + Args: + comment (Union[str, bytes], optional): Comment to add to the string. Defaults to None. + encoding (str, optional): Encoding to use for the string. Defaults to 'utf-8'. + + Returns: + str: Certificate string + """ + return concat_to_string( + self.header["pubkey_type"].value, + " ", + b64encode(self.to_bytes()), + " ", + comment if comment else "", + encoding=encoding + ) + + def to_file( + self, path: str, comment: Union[str, bytes] = None, encoding: str = "utf-8" + ): + """ + Saves the certificate to a file + + Args: + path (str): The path of the file to save to + comment (Union[str, bytes], optional): Comment to add to the certificate end. + Defaults to None. + encoding (str, optional): Encoding for the file. Defaults to 'utf-8'. + """ + with open(path, "w", encoding=encoding) as file: + file.write(self.to_string(comment, encoding)) + + +class RsaCertificate(SSHCertificate): + """ + Specific class for RSA Certificates. Inherits from SSHCertificate + """ + + def __init__( + self, + subject_pubkey: RsaPublicKey, + ca_privkey: PrivateKey = None, + rsa_alg: RsaAlgs = RsaAlgs.SHA512, + **kwargs, + ): + + super().__init__(subject_pubkey, ca_privkey, **kwargs) + self.rsa_alg = rsa_alg + self.set_type(f"{rsa_alg.value[0]}-cert-v01@openssh.com") + + @classmethod + # pylint: disable=arguments-differ + def decode(cls, cert_bytes: bytes) -> "SSHCertificate": + """ + Decode an existing RSA Certificate + + Args: + cert_bytes (bytes): The base64-decoded bytes for the certificate + + Returns: + RsaCertificate: The decoded certificate + """ + return super().decode(cert_bytes, _FIELD.RsaPubkeyField) + + +class DsaCertificate(SSHCertificate): + """ + Specific class for DSA/DSS Certificates. Inherits from SSHCertificate + """ + + def __init__( + self, subject_pubkey: DsaPublicKey, ca_privkey: PrivateKey = None, **kwargs + ): + super().__init__(subject_pubkey, ca_privkey, **kwargs) + self.set_type("ssh-dss-cert-v01@openssh.com") + + @classmethod + # pylint: disable=arguments-differ + def decode(cls, cert_bytes: bytes) -> "DsaCertificate": + """ + Decode an existing DSA Certificate + + Args: + cert_bytes (bytes): The base64-decoded bytes for the certificate + + Returns: + DsaCertificate: The decoded certificate + """ + return super().decode(cert_bytes, _FIELD.DsaPubkeyField) + + +class EcdsaCertificate(SSHCertificate): + """ + Specific class for ECDSA Certificates. Inherits from SSHCertificate + """ + + def __init__( + self, subject_pubkey: EcdsaPublicKey, ca_privkey: PrivateKey = None, **kwargs + ): + super().__init__(subject_pubkey, ca_privkey, **kwargs) + self.set_type( + f"ecdsa-sha2-nistp{subject_pubkey.key.curve.key_size}-cert-v01@openssh.com" + ) + + @classmethod + # pylint: disable=arguments-differ + def decode(cls, cert_bytes: bytes) -> "EcdsaCertificate": + """ + Decode an existing ECDSA Certificate + + Args: + cert_bytes (bytes): The base64-decoded bytes for the certificate + + Returns: + EcdsaCertificate: The decoded certificate + """ + return super().decode(cert_bytes, _FIELD.EcdsaPubkeyField) + + +class Ed25519Certificate(SSHCertificate): + """ + Specific class for ED25519 Certificates. Inherits from SSHCertificate + """ + + def __init__( + self, subject_pubkey: Ed25519PublicKey, ca_privkey: PrivateKey = None, **kwargs + ): + super().__init__(subject_pubkey, ca_privkey, **kwargs) + self.set_type("ssh-ed25519-cert-v01@openssh.com") + + @classmethod + # pylint: disable=arguments-differ + def decode(cls, cert_bytes: bytes) -> "Ed25519Certificate": + """ + Decode an existing ED25519 Certificate + + Args: + cert_bytes (bytes): The base64-decoded bytes for the certificate + + Returns: + Ed25519Certificate: The decoded certificate + """ + return super().decode(cert_bytes, _FIELD.Ed25519PubkeyField) diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index 3a6d85f..7d0e254 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -4,7 +4,7 @@ # pylint: disable=invalid-name,too-many-lines,arguments-differ import re from enum import Enum -from types import NoneType, MethodType +from types import MethodType from typing import Union, Tuple from dataclasses import dataclass from datetime import datetime, timedelta @@ -19,19 +19,19 @@ RsaAlgs, PrivateKey, PublicKey, - RSAPublicKey, - RSAPrivateKey, - DSAPublicKey, - DSAPrivateKey, - ECDSAPublicKey, - ECDSAPrivateKey, - ED25519PublicKey, - ED25519PrivateKey, + RsaPublicKey, + RsaPrivateKey, + DsaPublicKey, + DsaPrivateKey, + EcdsaPublicKey, + EcdsaPrivateKey, + Ed25519PublicKey, + Ed25519PrivateKey, ) from .utils import ( - long_to_bytes, - bytes_to_long, + long_to_bytes, + bytes_to_long, generate_secure_nonce, random_keyid, random_serial, @@ -40,7 +40,7 @@ concat_to_string ) - +NoneType = type(None) MAX_INT32 = 2**32 MAX_INT64 = 2**64 NEWLINE = '\n' @@ -52,24 +52,24 @@ } SUBJECT_PUBKEY_MAP = { - RSAPublicKey: "RSAPubkeyField", - DSAPublicKey: "DSAPubkeyField", - ECDSAPublicKey: "ECDSAPubkeyField", - ED25519PublicKey: "ED25519PubkeyField", + RsaPublicKey: "RsaPubkeyField", + DsaPublicKey: "DsaPubkeyField", + EcdsaPublicKey: "EcdsaPubkeyField", + Ed25519PublicKey: "Ed25519PubkeyField", } CA_SIGNATURE_MAP = { - RSAPrivateKey: "RSASignatureField", - DSAPrivateKey: "DSASignatureField", - ECDSAPrivateKey: "ECDSASignatureField", - ED25519PrivateKey: "ED25519SignatureField", + RsaPrivateKey: "RsaSignatureField", + DsaPrivateKey: "DsaSignatureField", + EcdsaPrivateKey: "EcdsaSignatureField", + Ed25519PrivateKey: "Ed25519SignatureField", } SIGNATURE_TYPE_MAP = { - b"rsa": "RSASignatureField", - b"dss": "DSASignatureField", - b"ecdsa": "ECDSASignatureField", - b"ed25519": "ED25519SignatureField", + b"rsa": "RsaSignatureField", + b"dss": "DsaSignatureField", + b"ecdsa": "EcdsaSignatureField", + b"ed25519": "Ed25519SignatureField", } class CERT_TYPE(Enum): @@ -89,7 +89,7 @@ class CertificateField: REQUIRED = False DATA_TYPE = NoneType - def __init__(self, value): + def __init__(self, value = None): self.value = value self.exception = None self.IS_SET = True @@ -104,11 +104,11 @@ def __bytes__(self) -> bytes: @classmethod def get_name(cls) -> str: return "_".join( - re.findall('[A-Z][^A-Z]*', cls.__class__.__name__)[:-1] + re.findall('[A-Z][^A-Z]*', cls.__name__)[:-1] ).lower() @classmethod - def __validate_type(cls, value, do_raise: bool = False) -> Union[bool, Exception]: + def __validate_type__(cls, value, do_raise: bool = False) -> Union[bool, Exception]: """ Validate the data type of the value against the class data type """ @@ -124,7 +124,7 @@ def __validate_type(cls, value, do_raise: bool = False) -> Union[bool, Exception return True - def __validate_required(self) -> Union[bool, Exception]: + def __validate_required__(self) -> Union[bool, Exception]: """ Validates if the field is set when required """ @@ -134,7 +134,7 @@ def __validate_required(self) -> Union[bool, Exception]: ) return True - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field Meant to be overridden by child classes @@ -144,15 +144,18 @@ def __validate_value(self) -> Union[bool, Exception]: # pylint: disable=no-self-use def validate(self) -> bool: """ - Validates the field + Validates all field contents and types """ + if isinstance(self.value, NoneType) and self.DEFAULT != None: + self.value = self.DEFAULT() if callable(self.DEFAULT) else self.DEFAULT + self.exception = ( - self.__validate_type(self.value), - self.__validate_required(), - self.__validate_value() + self.__validate_type__(self.value), + self.__validate_required__(), + self.__validate_value__() ) - return self.exception == [True, True, True] + return self.exception == (True, True, True) @staticmethod def decode(cls, data: bytes) -> tuple: @@ -209,7 +212,7 @@ def encode(cls, value: Union[int, bool]) -> bytes: Returns: bytes: Packed byte representing the boolean """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return pack("B", 1 if value else 0) @staticmethod @@ -222,7 +225,7 @@ def decode(data: bytes) -> Tuple[bool, bytes]: """ return bool(unpack("B", data[:1])[0]), data[1:] - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -249,7 +252,7 @@ def encode(cls, value: bytes) -> bytes: Returns: bytes: Packed byte string containing the source data """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return pack(">I", len(value)) + ensure_bytestring(value) @staticmethod @@ -285,9 +288,9 @@ def encode(cls, value: str, encoding: str = "utf-8"): Returns: bytes: Packed byte string containing the source data - """ - cls.__validate_type(value, True) - return BytestringField.encode(ensure_bytestring(encoding)) + """ + cls.__validate_type__(value, True) + return BytestringField.encode(ensure_bytestring(value, encoding)) @staticmethod def decode(data: bytes, encoding: str = "utf-8") -> Tuple[str, bytes]: @@ -322,7 +325,7 @@ def encode(cls, value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return pack(">I", value) @staticmethod @@ -337,13 +340,17 @@ def decode(data: bytes) -> Tuple[int, bytes]: """ return int(unpack(">I", data[:4])[0]), data[4:] - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - return True if self.value < MAX_INT32 else _EX.InvalidFieldDataException( + if self.value < MAX_INT32: + return True + + return _EX.InvalidFieldDataException( f"{self.get_name()} must be a 32-bit integer" ) + class Integer64Field(CertificateField): """ @@ -362,7 +369,7 @@ def encode(cls, value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return pack(">Q", value) @staticmethod @@ -377,12 +384,15 @@ def decode(data: bytes) -> Tuple[int, bytes]: """ return int(unpack(">Q", data[:8])[0]), data[8:] - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - return True if self.value < MAX_INT32 else _EX.InvalidFieldDataException( - f"{self.get_name()} must be a 32-bit integer" + if self.value < MAX_INT64: + return True + + return _EX.InvalidFieldDataException( + f"{self.get_name()} must be a 64-bit integer" ) @@ -404,7 +414,7 @@ def encode(cls, value: Union[datetime, int]) -> bytes: Returns: bytes: Packed byte string containing datetime timestamp """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) if isinstance(value, datetime): value = int(value.timestamp()) @@ -423,6 +433,19 @@ def decode(data: bytes) -> datetime: """ timestamp, data = Integer64Field.decode(data) return datetime.fromtimestamp(timestamp), data + + def __validate_value__(self) -> Union[bool, Exception]: + """ + Validates the contents of the field + """ + check = self.value if isinstance(self.value, int) else self.value.timestamp() + + if check < MAX_INT64: + return True + + return _EX.InvalidFieldDataException( + f"{self.get_name()} must be a 64-bit integer or datetime object" + ) class MpIntegerField(BytestringField): """ @@ -444,7 +467,7 @@ def encode(cls, value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return BytestringField.encode(long_to_bytes(value)) @staticmethod @@ -478,7 +501,7 @@ def encode(cls, value: Union[list, tuple, set]) -> bytes: Returns: bytes: Packed byte string containing the source data """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) try: if sum([not isinstance(item, (str, bytes)) for item in value]) > 0: @@ -508,15 +531,14 @@ def decode(data: bytes) -> Tuple[list, bytes]: return ensure_string(decoded), data - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - if not all((isinstance(val, (str, bytes)) for val in self.value)): + if hasattr(self.value, '__iter__') and not all((isinstance(val, (str, bytes)) for val in self.value)): return _EX.InvalidFieldDataException( "Expected list or tuple containing strings or bytes" ) - return True class KeyValueField(CertificateField): @@ -540,7 +562,7 @@ def encode(cls, value: Union[list, tuple, dict, set]) -> bytes: Returns: bytes: Packed byte string containing the source data """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) if not isinstance(value, dict): value = {item: "" for item in value} @@ -590,7 +612,7 @@ def decode(data: bytes) -> Tuple[dict, bytes]: return decoded, data - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -599,7 +621,7 @@ def __validate_value(self) -> Union[bool, Exception]: else list(self.value.keys()) + list(self.value.values()) ) - if not all((isinstance(val, (str, bytes)) for val in testvals)): + if hasattr(self.value, '__iter__') and not all((isinstance(val, (str, bytes)) for val in testvals)): return _EX.InvalidFieldDataException( "Expected dict, list, tuple, set with string or byte keys and values" ) @@ -625,7 +647,7 @@ class PubkeyTypeField(StringField): "ssh-ed25519-cert-v01@openssh.com", ) - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -646,12 +668,12 @@ class NonceField(StringField): """ DEFAULT = generate_secure_nonce DATA_TYPE = (str, bytes) - - def __validate_value(self) -> Union[bool, Exception]: + + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - if len(self.value) < 32: + if hasattr(self.value, '__count__') and len(self.value) < 32: return _EX.InvalidFieldDataException( "Expected a nonce of at least 32 bytes" ) @@ -681,12 +703,12 @@ def encode(cls, value: PublicKey) -> bytes: Encode the certificate field to a byte string Args: - value (RSAPublicKey): The public key to encode + value (RsaPublicKey): The public key to encode Returns: bytes: A byte string with the encoded public key """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return BytestringField.decode(value.raw_bytes())[1] @staticmethod @@ -712,15 +734,15 @@ class or childclass raise _EX.InvalidKeyException("The public key is invalid") from KeyError -class RSAPubkeyField(PublicKeyField): +class RsaPubkeyField(PublicKeyField): """ Holds the RSA Public Key for RSA Certificates """ DEFAULT = None - DATA_TYPE = RSAPublicKey + DATA_TYPE = RsaPublicKey @staticmethod - def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[RsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -729,22 +751,22 @@ def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ e, data = MpIntegerField.decode(data) n, data = MpIntegerField.decode(data) - return RSAPublicKey.from_numbers(e=e, n=n), data + return RsaPublicKey.from_numbers(e=e, n=n), data -class DSAPubkeyField(PublicKeyField): +class DsaPubkeyField(PublicKeyField): """ Holds the DSA Public Key for DSA Certificates """ DEFAULT = None - DATA_TYPE = DSAPublicKey + DATA_TYPE = DsaPublicKey @staticmethod - def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[DsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -753,24 +775,24 @@ def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: data (bytes): The byte string starting with the encoded key Returns: - Tuple[RSAPublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[RsaPublicKey, bytes]: The PublicKey field and remainder of the data """ p, data = MpIntegerField.decode(data) q, data = MpIntegerField.decode(data) g, data = MpIntegerField.decode(data) y, data = MpIntegerField.decode(data) - return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data + return DsaPublicKey.from_numbers(p=p, q=q, g=g, y=y), data -class ECDSAPubkeyField(PublicKeyField): +class EcdsaPubkeyField(PublicKeyField): """ Holds the ECDSA Public Key for ECDSA Certificates """ DEFAULT = None - DATA_TYPE = ECDSAPublicKey + DATA_TYPE = EcdsaPublicKey @staticmethod - def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: + def decode(data: bytes) -> Tuple[EcdsaPublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -787,7 +809,7 @@ def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: key_type = "ecdsa-sha2-" + curve return ( - ECDSAPublicKey.from_string( + EcdsaPublicKey.from_string( key_type + " " + b64encode( @@ -799,15 +821,15 @@ def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: data, ) -class ED25519PubkeyField(PublicKeyField): +class Ed25519PubkeyField(PublicKeyField): """ Holds the ED25519 Public Key for ED25519 Certificates """ DEFAULT = None - DATA_TYPE = ED25519PublicKey + DATA_TYPE = Ed25519PublicKey @staticmethod - def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: + def decode(data: bytes) -> Tuple[Ed25519PublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -816,11 +838,11 @@ def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: data (bytes): The byte string starting with the encoded key Returns: - Tuple[ED25519PublicKey, bytes]: The PublicKey field and remainder of the data + Tuple[Ed25519PublicKey, bytes]: The PublicKey field and remainder of the data """ pubkey, data = BytestringField.decode(data) - return ED25519PublicKey.from_raw_bytes(pubkey), data + return Ed25519PublicKey.from_raw_bytes(pubkey), data class SerialField(Integer64Field): @@ -838,7 +860,7 @@ class CertificateTypeField(Integer32Field): Host certificate: CERT_TYPE.HOST/2 """ DEFAULT = CERT_TYPE.USER - DATA_TYPE = Union[CERT_TYPE, int] + DATA_TYPE = (CERT_TYPE, int) ALLOWED_VALUES = ( CERT_TYPE.USER, CERT_TYPE.HOST, @@ -857,20 +879,20 @@ def encode(cls, value: Union[CERT_TYPE, int]) -> bytes: Returns: bytes: A byte string with the encoded public key """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) if isinstance(value, CERT_TYPE): value = value.value return Integer32Field.encode(value) - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ if self.value not in self.ALLOWED_VALUES: return _EX.InvalidCertificateFieldException( - f"The certificate type is invalid (expected {','.join(self.ALLOWED_VALUES)})" + f"The certificate type is invalid (expected int(1,2) or CERT_TYPE.X)" ) return True @@ -881,7 +903,7 @@ class KeyIdField(StringField): alphanumeric string """ DEFAULT = random_keyid - DATA_TYPE = Union[str, bytes] + DATA_TYPE = (str, bytes) class PrincipalsField(ListField): """ @@ -891,7 +913,7 @@ class PrincipalsField(ListField): only for servers that have no allowed principals specified """ DEFAFULT = [] - DATA_TYPE = Union[list, set, tuple] + DATA_TYPE = (list, set, tuple) class ValidAfterField(DateTimeField): """ @@ -899,7 +921,7 @@ class ValidAfterField(DateTimeField): represented by a datetime object """ DEFAULT = datetime.now() - DATA_TYPE = Union[datetime, int] + DATA_TYPE = (datetime, int) class ValidBeforeField(DateTimeField): """ @@ -907,16 +929,21 @@ class ValidBeforeField(DateTimeField): represented by a datetime object """ DEFAULT = datetime.now() + timedelta(minutes=10) - DATA_TYPE = Union[datetime, int] + DATA_TYPE = (datetime, int) - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field + Additional checks over standard datetime field are + done to ensure no already expired certificates are + created """ - val = val if isinstance(val, datetime) else datetime.fromtimestamp(val) - if val < datetime.now(): + super().__validate_value__() + check = self.value if isinstance(self.value, datetime) else datetime.fromtimestamp(self.value) + + if check < datetime.now(): return _EX.InvalidCertificateFieldException( - f'The certificate validity period is invalid (expected a future datetime object)' + f'The certificate validity period is invalid (expected a future datetime object or timestamp)' ) return True @@ -938,14 +965,14 @@ class CriticalOptionsField(KeyValueField): if using a hardware token """ DEFAULT = [] - DATA_TYPE = Union[list, set, tuple, dict] + DATA_TYPE = (list, set, tuple, dict) ALLOWED_VALUES = ( "force-command", "source-address", "verify-required" ) - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -986,7 +1013,7 @@ class ExtensionsField(KeyValueField): """ DEFAULT = [] - DATA_TYPE = Union[list, set, tuple, dict] + DATA_TYPE = (list, set, tuple, dict) ALLOWED_VALUES = ( "no-touch-required", "permit-X11-forwarding", @@ -996,7 +1023,7 @@ class ExtensionsField(KeyValueField): "permit-user-rc" ) - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -1018,7 +1045,7 @@ class ReservedField(StringField): DEFAULT = "" DATA_TYPE = str - def __validate_value(self) -> Union[bool, Exception]: + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ @@ -1032,7 +1059,7 @@ class CAPublicKeyField(BytestringField): that is used to sign the certificate. """ DEFAULT = None - DATA_TYPE = Union[str, bytes] + DATA_TYPE = (str, bytes) def __str__(self) -> str: return " ".join( @@ -1164,12 +1191,15 @@ def sign(self, data: bytes) -> None: """ Placeholder signing function """ - + raise _EX.InvalidClassCallException("The base class has no sign function") + def __bytes__(self) -> None: return self.encode(self.value) + + -class RSASignatureField(SignatureField): +class RsaSignatureField(SignatureField): """ Creates and contains the RSA signature from an RSA Private Key """ @@ -1178,7 +1208,7 @@ class RSASignatureField(SignatureField): def __init__( self, - private_key: RSAPrivateKey = None, + private_key: RsaPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, signature: bytes = None, ): @@ -1187,7 +1217,7 @@ def __init__( @classmethod # pylint: disable=arguments-renamed - def encode(cls, value: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: + def encode(cls, value: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512) -> bytes: """ Encodes the value to a byte string @@ -1199,7 +1229,7 @@ def encode(cls, value: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: Returns: bytes: The encoded byte string """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return BytestringField.encode( StringField.encode(hash_alg.value[0]) + BytestringField.encode(value) @@ -1224,9 +1254,9 @@ def decode(data: bytes) -> Tuple[Tuple[bytes, bytes], bytes]: return (sig_type, signature), data @classmethod - def from_decode(cls, data: bytes) -> Tuple["RSASignatureField", bytes]: + def from_decode(cls, data: bytes) -> Tuple["RsaSignatureField", bytes]: """ - Generates an RSASignatureField class from the encoded signature + Generates an RsaSignatureField class from the encoded signature Args: data (bytes): The bytestring containing the encoded signature @@ -1235,7 +1265,7 @@ def from_decode(cls, data: bytes) -> Tuple["RSASignatureField", bytes]: _EX.InvalidDataException: Invalid data Returns: - Tuple[RSASignatureField, bytes]: RSA Signature field and remainder of data + Tuple[RsaSignatureField, bytes]: RSA Signature field and remainder of data """ signature, data = cls.decode(data) @@ -1248,7 +1278,7 @@ def from_decode(cls, data: bytes) -> Tuple["RSASignatureField", bytes]: data, ) - def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256, **kwargs) -> None: + def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512, **kwargs) -> None: """ Signs the provided data with the provided private key @@ -1266,7 +1296,7 @@ def __bytes__(self): return self.encode(self.value, self.hash_alg) -class DSASignatureField(SignatureField): +class DsaSignatureField(SignatureField): """ Creates and contains the DSA signature from an DSA Private Key """ @@ -1274,7 +1304,7 @@ class DSASignatureField(SignatureField): DATA_TYPE = bytes def __init__( - self, private_key: DSAPrivateKey = None, signature: bytes = None + self, private_key: DsaPrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -1289,7 +1319,7 @@ def encode(cls, value: bytes): Returns: bytes: The encoded byte string """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) r, s = decode_dss_signature(value) @@ -1320,7 +1350,7 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: return signature, data @classmethod - def from_decode(cls, data: bytes) -> Tuple["DSASignatureField", bytes]: + def from_decode(cls, data: bytes) -> Tuple["DsaSignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1328,7 +1358,7 @@ def from_decode(cls, data: bytes) -> Tuple["DSASignatureField", bytes]: data (bytes): The bytestring starting with the Signature Returns: - Tuple[ DSASignatureField, bytes ]: signature, remainder of the data + Tuple[ DsaSignatureField, bytes ]: signature, remainder of the data """ signature, data = cls.decode(data) @@ -1345,7 +1375,7 @@ def sign(self, data: bytes, **kwargs) -> None: self.is_signed = True -class ECDSASignatureField(SignatureField): +class EcdsaSignatureField(SignatureField): """ Creates and contains the ECDSA signature from an ECDSA Private Key """ @@ -1354,7 +1384,7 @@ class ECDSASignatureField(SignatureField): def __init__( self, - private_key: ECDSAPrivateKey = None, + private_key: EcdsaPrivateKey = None, signature: bytes = None, curve_name: str = None, ) -> None: @@ -1379,7 +1409,7 @@ def encode(cls, value: bytes, curve_name: str = None) -> bytes: Returns: bytes: The encoded byte string """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) r, s = decode_dss_signature(value) @@ -1414,7 +1444,7 @@ def decode(data: bytes) -> Tuple[Tuple[bytes, bytes], bytes]: return (curve, signature), data @classmethod - def from_decode(cls, data: bytes) -> Tuple["ECDSASignatureField", bytes]: + def from_decode(cls, data: bytes) -> Tuple["EcdsaSignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1422,7 +1452,7 @@ def from_decode(cls, data: bytes) -> Tuple["ECDSASignatureField", bytes]: data (bytes): The bytestring starting with the Signature Returns: - Tuple[ ECDSASignatureField , bytes ]: signature, remainder of the data + Tuple[ EcdsaSignatureField , bytes ]: signature, remainder of the data """ signature, data = cls.decode(data) @@ -1445,7 +1475,7 @@ def __bytes__(self): return self.encode(self.value, self.curve) -class ED25519SignatureField(SignatureField): +class Ed25519SignatureField(SignatureField): """ Creates and contains the ED25519 signature from an ED25519 Private Key """ @@ -1453,12 +1483,12 @@ class ED25519SignatureField(SignatureField): DATA_TYPE = bytes def __init__( - self, private_key: ED25519PrivateKey = None, signature: bytes = None + self, private_key: Ed25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @classmethod - def value(cls, value: bytes) -> None: + def encode(cls, value: bytes) -> None: """ Encodes the signature to a byte string @@ -1468,7 +1498,7 @@ def value(cls, value: bytes) -> None: Returns: bytes: The encoded byte string """ - cls.__validate_type(value, True) + cls.__validate_type__(value, True) return BytestringField.encode( StringField.encode("ssh-ed25519") + BytestringField.encode(value) @@ -1492,7 +1522,7 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: return signature, data @classmethod - def from_decode(cls, data: bytes) -> Tuple["ED25519SignatureField", bytes]: + def from_decode(cls, data: bytes) -> Tuple["Ed25519SignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1500,7 +1530,7 @@ def from_decode(cls, data: bytes) -> Tuple["ED25519SignatureField", bytes]: data (bytes): The bytestring starting with the Signature Returns: - Tuple[ ED25519SignatureField , bytes ]: signature, remainder of the data + Tuple[ Ed25519SignatureField , bytes ]: signature, remainder of the data """ signature, data = cls.decode(data) @@ -1521,51 +1551,51 @@ def sign(self, data: bytes, **kwargs) -> None: # ADD: Default class for each field # ADD: Typing class for each field -@dataclass -class Fieldset: - def __setattr__(self, name, value): - field = getattr(self, name, None) - if callable(field) and not isinstance(field, CertificateField): - if field.__name__ == "factory": - super().__setattr__(name, field()) - self.__setattr__(name, value) - return - - if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': - super().__setattr__(name, field(value)) - return +# @dataclass +# class Fieldset: +# def __setattr__(self, name, value): +# field = getattr(self, name, None) +# if callable(field) and not isinstance(field, CertificateField): +# if field.__name__ == "factory": +# super().__setattr__(name, field()) +# self.__setattr__(name, value) +# return + +# if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': +# super().__setattr__(name, field(value)) +# return - if getattr(value, '__name__', '') != 'factory': - field.value = value - super().__setattr__(name, field) +# if getattr(value, '__name__', '') != 'factory': +# field.value = value +# super().__setattr__(name, field) - def get(self, name: str): - field = getattr(self, name, None) - if field: - if isinstance(field, type): - return field.DEFAULT - return field.value - raise _EX.InvalidCertificateFieldException(f"Unknown field {name}") - -@dataclass -class CertificateHeaders(Fieldset): - public_key: PublicKeyField = PublicKeyField.factory - pubkey_type: PubkeyTypeField = PubkeyTypeField.factory - nonce: NonceField = NonceField.factory - -@dataclass -class CertificateFooter(Fieldset): - reserved: ReservedField = ReservedField.factory - ca_pubkey: CAPublicKeyField = CAPublicKeyField.factory - signature: SignatureField = SignatureField.factory - -@dataclass -class CertificateFields(Fieldset): - serial: SerialField = SerialField.factory - cert_type: CertificateTypeField = CertificateTypeField.factory - key_id: KeyIdField = KeyIdField.factory - principals: PrincipalsField = PrincipalsField.factory - valid_after: ValidAfterField = ValidAfterField.factory - valid_before: ValidBeforeField = ValidBeforeField.factory - critical_options: CriticalOptionsField = CriticalOptionsField.factory - extensions: ExtensionsField = ExtensionsField.factory \ No newline at end of file +# def get(self, name: str): +# field = getattr(self, name, None) +# if field: +# if isinstance(field, type): +# return field.DEFAULT +# return field.value +# raise _EX.InvalidCertificateFieldException(f"Unknown field {name}") + +# @dataclass +# class CertificateHeader(Fieldset): +# public_key: PublicKeyField = PublicKeyField.factory +# pubkey_type: PubkeyTypeField = PubkeyTypeField.factory +# nonce: NonceField = NonceField.factory + +# @dataclass +# class CertificateFooter(Fieldset): +# reserved: ReservedField = ReservedField.factory +# ca_pubkey: CAPublicKeyField = CAPublicKeyField.factory +# signature: SignatureField = SignatureField.factory + +# @dataclass +# class CertificateBody(Fieldset): +# serial: SerialField = SerialField.factory +# cert_type: CertificateTypeField = CertificateTypeField.factory +# key_id: KeyIdField = KeyIdField.factory +# principals: PrincipalsField = PrincipalsField.factory +# valid_after: ValidAfterField = ValidAfterField.factory +# valid_before: ValidBeforeField = ValidBeforeField.factory +# critical_options: CriticalOptionsField = CriticalOptionsField.factory +# extensions: ExtensionsField = ExtensionsField.factory \ No newline at end of file diff --git a/src/sshkey_tools/keys.py b/src/sshkey_tools/keys.py index dc6e0aa..534f420 100644 --- a/src/sshkey_tools/keys.py +++ b/src/sshkey_tools/keys.py @@ -39,17 +39,17 @@ ) PUBKEY_MAP = { - _RSAPublicKey: "RSAPublicKey", - _DSAPublicKey: "DSAPublicKey", - _EllipticCurvePublicKey: "ECDSAPublicKey", - _Ed25519PublicKey: "ED25519PublicKey", + _RSAPublicKey: "RsaPublicKey", + _DSAPublicKey: "DsaPublicKey", + _EllipticCurvePublicKey: "EcdsaPublicKey", + _Ed25519PublicKey: "Ed25519PublicKey", } PRIVKEY_MAP = { - _RSAPrivateKey: "RSAPrivateKey", - _DSAPrivateKey: "DSAPrivateKey", - _EllipticCurvePrivateKey: "ECDSAPrivateKey", - _Ed25519PrivateKey: "ED25519PrivateKey", + _RSAPrivateKey: "RsaPrivateKey", + _DSAPrivateKey: "DsaPrivateKey", + _EllipticCurvePrivateKey: "EcdsaPrivateKey", + _Ed25519PrivateKey: "Ed25519PrivateKey", } ECDSA_HASHES = { @@ -407,7 +407,7 @@ def to_file( key_file.write(self.to_string(password, encoding)) -class RSAPublicKey(PublicKey): +class RsaPublicKey(PublicKey): """ Class for holding RSA public keys """ @@ -429,7 +429,7 @@ def __init__( @classmethod # pylint: disable=invalid-name - def from_numbers(cls, e: int, n: int) -> "RSAPublicKey": + def from_numbers(cls, e: int, n: int) -> "RsaPublicKey": """ Loads an RSA Public Key from the public numbers e and n @@ -438,7 +438,7 @@ def from_numbers(cls, e: int, n: int) -> "RSAPublicKey": n (int): n-value Returns: - RSAPublicKey: _description_ + RsaPublicKey: _description_ """ return cls(key=_RSA.RSAPublicNumbers(e, n).public_key()) @@ -466,7 +466,7 @@ def verify( ) from InvalidSignature -class RSAPrivateKey(PrivateKey): +class RsaPrivateKey(PrivateKey): """ Class for holding RSA private keys """ @@ -474,7 +474,7 @@ class RSAPrivateKey(PrivateKey): def __init__(self, key: _RSA.RSAPrivateKey): super().__init__( key=key, - public_key=RSAPublicKey(key.public_key()), + public_key=RsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @@ -490,7 +490,7 @@ def from_numbers( dmp1: int = None, dmq1: int = None, iqmp: int = None, - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Load an RSA private key from numbers @@ -513,7 +513,7 @@ def from_numbers( Automatically generates if not provided Returns: - RSAPrivateKey: An instance of RSAPrivateKey + RsaPrivateKey: An instance of RsaPrivateKey """ if None in (p, q): p, q = _RSA.rsa_recover_prime_factors(n, e, d) @@ -537,7 +537,7 @@ def from_numbers( @classmethod def generate( cls, key_size: int = 4096, public_exponent: int = 65537 - ) -> "RSAPrivateKey": + ) -> "RsaPrivateKey": """ Generates a new RSA private key @@ -546,7 +546,7 @@ def generate( public_exponent (int, optional): The public exponent to use. Defaults to 65537. Returns: - RSAPrivateKey: Instance of RSAPrivateKey + RsaPrivateKey: Instance of RsaPrivateKey """ return cls.from_class( _RSA.generate_private_key( @@ -569,7 +569,7 @@ def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512) -> bytes: return self.key.sign(data, _PADDING.PKCS1v15(), hash_alg.value[1]()) -class DSAPublicKey(PublicKey): +class DsaPublicKey(PublicKey): """ Class for holding DSA public keys """ @@ -592,7 +592,7 @@ def __init__( @classmethod # pylint: disable=invalid-name - def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey": + def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DsaPublicKey": """ Create a DSA public key from public numbers and parameters @@ -603,7 +603,7 @@ def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey": y (int): The public number Y Returns: - DSAPublicKey: An instance of DSAPublicKey + DsaPublicKey: An instance of DsaPublicKey """ return cls( key=_DSA.DSAPublicNumbers( @@ -630,7 +630,7 @@ def verify(self, data: bytes, signature: bytes) -> None: ) from InvalidSignature -class DSAPrivateKey(PrivateKey): +class DsaPrivateKey(PrivateKey): """ Class for holding DSA private keys """ @@ -638,15 +638,15 @@ class DSAPrivateKey(PrivateKey): def __init__(self, key: _DSA.DSAPrivateKey): super().__init__( key=key, - public_key=DSAPublicKey(key.public_key()), + public_key=DsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @classmethod # pylint: disable=invalid-name,too-many-arguments - def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey": + def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DsaPrivateKey": """ - Creates a new DSAPrivateKey object from parameters and public/private numbers + Creates a new DsaPrivateKey object from parameters and public/private numbers Args: p (int): P parameter, the prime modulus @@ -668,13 +668,13 @@ def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey" ) @classmethod - def generate(cls) -> "DSAPrivateKey": + def generate(cls) -> "DsaPrivateKey": """ Generate a new DSA private key Key size is fixed since OpenSSH only supports 1024-bit DSA keys Returns: - DSAPrivateKey: An instance of DSAPrivateKey + DsaPrivateKey: An instance of DsaPrivateKey """ return cls.from_class(_DSA.generate_private_key(key_size=1024)) @@ -691,7 +691,7 @@ def sign(self, data: bytes): return self.key.sign(data, _HASHES.SHA1()) -class ECDSAPublicKey(PublicKey): +class EcdsaPublicKey(PublicKey): """ Class for holding ECDSA public keys """ @@ -715,7 +715,7 @@ def __init__( # pylint: disable=invalid-name def from_numbers( cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int - ) -> "ECDSAPublicKey": + ) -> "EcdsaPublicKey": """ Create an ECDSA public key from public numbers and parameters @@ -725,7 +725,7 @@ def from_numbers( y (int): The affine Y component of the public point Returns: - ECDSAPublicKey: An instance of ECDSAPublicKey + EcdsaPublicKey: An instance of EcdsaPublicKey """ if not isinstance(curve, _ECDSA.EllipticCurve) and curve not in ECDSA_HASHES: raise _EX.InvalidCurveException( @@ -763,7 +763,7 @@ def verify(self, data: bytes, signature: bytes) -> None: ) from InvalidSignature -class ECDSAPrivateKey(PrivateKey): +class EcdsaPrivateKey(PrivateKey): """ Class for holding ECDSA private keys """ @@ -771,7 +771,7 @@ class ECDSAPrivateKey(PrivateKey): def __init__(self, key: _ECDSA.EllipticCurvePrivateKey): super().__init__( key=key, - public_key=ECDSAPublicKey(key.public_key()), + public_key=EcdsaPublicKey(key.public_key()), private_numbers=key.private_numbers(), ) @@ -781,7 +781,7 @@ def from_numbers( cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int, private_value: int ): """ - Creates a new ECDSAPrivateKey object from parameters and public/private numbers + Creates a new EcdsaPrivateKey object from parameters and public/private numbers Args: curve Union[str, _ECDSA.EllipticCurve]: Curve used by the key @@ -820,7 +820,7 @@ def generate(cls, curve: EcdsaCurves = EcdsaCurves.P521): curve (EcdsaCurves): Which curve to use. Default secp521r1 Returns: - ECDSAPrivateKey: An instance of ECDSAPrivateKey + EcdsaPrivateKey: An instance of EcdsaPrivateKey """ return cls.from_class(_ECDSA.generate_private_key(curve=curve.value)) @@ -838,7 +838,7 @@ def sign(self, data: bytes): return self.key.sign(data, _ECDSA.ECDSA(curve_hash)) -class ED25519PublicKey(PublicKey): +class Ed25519PublicKey(PublicKey): """ Class for holding ED25519 public keys """ @@ -855,7 +855,7 @@ def __init__( ) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey": + def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PublicKey": """ Load an ED25519 public key from raw bytes @@ -863,7 +863,7 @@ def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey": raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PublicKey: Instance of ED25519PublicKey + Ed25519PublicKey: Instance of Ed25519PublicKey """ if b"ssh-ed25519" in raw_bytes: id_length = unpack(">I", raw_bytes[:4])[0] + 8 @@ -892,16 +892,16 @@ def verify(self, data: bytes, signature: bytes) -> None: ) from InvalidSignature -class ED25519PrivateKey(PrivateKey): +class Ed25519PrivateKey(PrivateKey): """ Class for holding ED25519 private keys """ def __init__(self, key: _ED25519.Ed25519PrivateKey): - super().__init__(key=key, public_key=ED25519PublicKey(key.public_key())) + super().__init__(key=key, public_key=Ed25519PublicKey(key.public_key())) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey": + def from_raw_bytes(cls, raw_bytes: bytes) -> "Ed25519PrivateKey": """ Load an ED25519 private key from raw bytes @@ -909,19 +909,19 @@ def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey": raw_bytes (bytes): The raw bytes of the key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class( _ED25519.Ed25519PrivateKey.from_private_bytes(data=raw_bytes) ) @classmethod - def generate(cls) -> "ED25519PrivateKey": + def generate(cls) -> "Ed25519PrivateKey": """ Generates a new ED25519 Private Key Returns: - ED25519PrivateKey: Instance of ED25519PrivateKey + Ed25519PrivateKey: Instance of Ed25519PrivateKey """ return cls.from_class(_ED25519.Ed25519PrivateKey.generate()) diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index 2ced2d1..34c477e 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -2,7 +2,6 @@ Utilities for handling keys and certificates """ import sys -from types import NoneType from typing import Union, List, Dict from secrets import randbits from random import randint @@ -10,6 +9,8 @@ from base64 import b64encode import hashlib as hl +NoneType = type(None) + def ensure_string( obj: Union[str, bytes, list, tuple, set, dict, NoneType], encoding: str = 'utf-8', @@ -90,7 +91,7 @@ def random_keyid() -> str: Returns: str: Random keyid """ - return uuid4() + return str(uuid4()) def random_serial() -> str: """ Generates a random serial number diff --git a/testcert b/testcert new file mode 100644 index 0000000..7ad0607 --- /dev/null +++ b/testcert @@ -0,0 +1 @@ +ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAJjc0MzU0MDAxNzgzNDU3MzQxNTU4MDE1NTE3Njg3Mjk0ODY0MzcwAAAAIMxLXknwP4nAHkxKEbmw5CjRhWoGeUTZgYtQXB4tI7SLI11oLXZuTzkAAAACAAAAKml2c09iemxXY2VxR3BLS3NtSlRwZWxqb0RuYVpibW9PSXlPUndiYVBWUAAAAqEAAABheUNZYU1zSEFhbVptSXVYeEtRQ01mbmJBRWVkdk9scFFQVGplendMQ0tLRHpLY0JBU01zSGd6a1V4Z0FGeVF3cFZyZVVoYXdZa09kQ0VyR2pvUWZKWkp2SU1NS1laYXF2QgAAAElWT0h3akRMTVhkUE9lRW9obEZ0dlhPUkR1VHJyTEl1TUZOV2plVnhKTHJ0Ymdud3l6QnVGWEZaZU9PZkFBeVRHamlSVHZIUXpKAAAATHp6RU5uZUt0V09aemxGTUtoU2JySkxjbllyWVN0UlR3RGhDQ3hSS3VxbGRSR3FMRGdkR0NoamtYRWlQdmtnektNd3NWRlBOZ2dnbkEAAABlRHJnd1ZTV013eU9qR1RVVmdTZENuU0RDT21DVVlqY0VHV0t1YlVSekVaWUJTVG1RRkhPU3ZQc2RvYUhTeEFsS2N1TXNQdnl3T2d6aGtPamtTbXpvVXVsV1NvekdqTURvcFNHeWsAAABTeGd0ZmN0YVNwekJtWWVaTkZYemh6YW54cHZNU1JWdUR2Q0JUbGFSRU9QT1JEdVpSd0hBUUdOY2xPU1R3RW9kcVRoQlpYTnVqbFV3d1dQWVBEYkUAAAAxaWtYWWpqcmNmSE9xRnhGQXpZZVN1bmlsRlpuWG9PS0JHd2hzSWNra0RMbXhXTEV6bAAAAEVjVU5laU9FaXNDZ0pBbkRuRVNYc1hlaEFvUEtYZUpWZUdxeklvSG9uUEVyU1RyeUxudlhmUkVka0NDTXlFR1VFRFJucEoAAAAwTlNuUm1aQXlxSFlyTkpBZ21ZbU9LckVuR2RjdWtmUWlHRWZNZ0lqTUNjd1NZTkdJAAAAKUVPd2dqcG1NQUZwam5QdUJVaVVqUkpjUmtOUnJoRFFGbVNRUG9oSHlNAAAAAGlpCcwAAAAAe0tVjAAAAGsAAAANZm9yY2UtY29tbWFuZAAAABEAAAANc2Z0cC1pbnRlcm5hbAAAAA5zb3VyY2UtYWRkcmVzcwAAABgAAAAUMS4yLjMuNC84LDUuNi43LjgvMTYAAAAPdmVyaWZ5LXJlcXVpcmVkAAAAAAAAADwAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACDMS15J8D+JwB5MShG5sOQo0YVqBnlE2YGLUFweLSO0iwAAAFMAAAALc3NoLWVkMjU1MTkAAABAkm6FT2GepxZWvdzBGUt5UQbX5O+ppLp17uQeR4JeaW+ddhpUP6HVLnuyuQZ3p3Ck0jWj5Bbk0nnm5Yo8D2AoDA== \ No newline at end of file diff --git a/tests/test_certificates.py b/tests/test_certificates.py index 16a661c..4b3f22a 100644 --- a/tests/test_certificates.py +++ b/tests/test_certificates.py @@ -13,7 +13,7 @@ import src.sshkey_tools.keys as _KEY import src.sshkey_tools.fields as _FIELD -import src.sshkey_tools.cert as _CERT +import sshkey_tools.certold as _CERT import src.sshkey_tools.exceptions as _EX CERTIFICATE_TYPES = ['rsa', 'dsa', 'ecdsa', 'ed25519'] @@ -21,10 +21,10 @@ class TestCertificateFields(unittest.TestCase): def setUp(self): self.faker = faker.Faker() - self.rsa_key = _KEY.RSAPrivateKey.generate(1024) - self.dsa_key = _KEY.DSAPrivateKey.generate() - self.ecdsa_key = _KEY.ECDSAPrivateKey.generate() - self.ed25519_key = _KEY.ED25519PrivateKey.generate() + self.rsa_key = _KEY.RsaPrivateKey.generate(1024) + self.dsa_key = _KEY.DsaPrivateKey.generate() + self.ecdsa_key = _KEY.EcdsaPrivateKey.generate() + self.ed25519_key = _KEY.Ed25519PrivateKey.generate() def assertRandomResponse(self, field_class, values = None, random_function = None): if values is None: @@ -37,7 +37,6 @@ def assertRandomResponse(self, field_class, values = None, random_function = Non bytestring += field_class.encode(value) field = field_class(value) - self.assertTrue(field.validate()) fields.append(field) self.assertEqual( @@ -60,6 +59,13 @@ def assertExpectedResponse(self, field_class, input, expected_output): field_class.encode(input), expected_output ) + + def assertFieldContainsException(self, field, exception): + for item in field.exception: + if isinstance(item, exception): + return True + + return False def test_boolean_field(self): self.assertRandomResponse( @@ -69,13 +75,14 @@ def test_boolean_field(self): def test_invalid_boolean_field(self): field = _FIELD.BooleanField("SomeInvalidData") - - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + field.validate() + + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_bytestring_field(self): @@ -85,15 +92,16 @@ def test_bytestring_field(self): ) def test_invalid_bytestring_field(self): - field = _FIELD.BytestringField('Hello') + field = _FIELD.BytestringField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): - field.encode('String') + with self.assertRaises(_EX.InvalidDataException): + field.encode(ValueError) def test_string_field(self): self.assertRandomResponse( @@ -102,15 +110,16 @@ def test_string_field(self): ) def test_invalid_string_field(self): - field = _FIELD.StringField(b'Hello') + field = _FIELD.StringField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): - field.encode(b'String') + with self.assertRaises(_EX.InvalidDataException): + field.encode(ValueError) def test_integer32_field(self): self.assertRandomResponse( @@ -120,20 +129,22 @@ def test_integer32_field(self): def test_invalid_integer32_field(self): field = _FIELD.Integer32Field(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.Integer32Field(_FIELD.MAX_INT32 + 1) + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.IntegerOverflowException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_integer64_field(self): @@ -144,20 +155,22 @@ def test_integer64_field(self): def test_invalid_integer64_field(self): field = _FIELD.Integer64Field(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.Integer64Field(_FIELD.MAX_INT64 + 1) + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.IntegerOverflowException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_datetime_field(self): @@ -168,13 +181,14 @@ def test_datetime_field(self): def test_invalid_datetime_field(self): field = _FIELD.DateTimeField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_mp_integer_field(self): @@ -185,20 +199,22 @@ def test_mp_integer_field(self): def test_invalid_mp_integer_field(self): field = _FIELD.MpIntegerField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.MpIntegerField('InvalidData') + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_list_field(self): @@ -210,20 +226,22 @@ def test_list_field(self): def test_invalid_list_field(self): field = _FIELD.ListField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.ListField([ValueError, ValueError, ValueError]) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_key_value_field(self): @@ -237,20 +255,22 @@ def test_key_value_field(self): def test_invalid_key_value_field(self): field = _FIELD.KeyValueField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - + field = _FIELD.KeyValueField([ValueError, ValueError, ValueError]) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_pubkey_type_field(self): @@ -274,19 +294,25 @@ def test_pubkey_type_field(self): def test_invalid_pubkey_type_field(self): field = _FIELD.PubkeyTypeField('HelloWorld') + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_nonce_field(self): - randomized = _FIELD.NonceField() + randomized = _FIELD.NonceField(_FIELD.NonceField.DEFAULT()) + randomized.validate() - self.assertTrue(randomized.validate()) + self.assertEqual( + randomized.exception, + (True, True, True) + ) + specific = ( 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz', @@ -305,10 +331,10 @@ def test_pubkey_class_assignment(self): ecdsa_field = _FIELD.PublicKeyField.from_object(self.ecdsa_key.public_key) ed25519_field = _FIELD.PublicKeyField.from_object(self.ed25519_key.public_key) - self.assertIsInstance(rsa_field, _FIELD.RSAPubkeyField) - self.assertIsInstance(dsa_field, _FIELD.DSAPubkeyField) - self.assertIsInstance(ecdsa_field, _FIELD.ECDSAPubkeyField) - self.assertIsInstance(ed25519_field, _FIELD.ED25519PubkeyField) + self.assertIsInstance(rsa_field, _FIELD.RsaPubkeyField) + self.assertIsInstance(dsa_field, _FIELD.DsaPubkeyField) + self.assertIsInstance(ecdsa_field, _FIELD.EcdsaPubkeyField) + self.assertIsInstance(ed25519_field, _FIELD.Ed25519PubkeyField) self.assertTrue(rsa_field.validate()) self.assertTrue(dsa_field.validate()) @@ -337,23 +363,23 @@ def assertPubkeyOutput(self, key_class, *opts): def test_rsa_pubkey_output(self): self.assertPubkeyOutput( - _KEY.RSAPrivateKey, + _KEY.RsaPrivateKey, 1024 ) def test_dsa_pubkey_output(self): self.assertPubkeyOutput( - _KEY.DSAPrivateKey + _KEY.DsaPrivateKey ) def test_ecdsa_pubkey_output(self): self.assertPubkeyOutput( - _KEY.ECDSAPrivateKey + _KEY.EcdsaPrivateKey ) def test_ed25519_pubkey_output(self): self.assertPubkeyOutput( - _KEY.ED25519PrivateKey + _KEY.Ed25519PrivateKey ) def test_serial_field(self): @@ -364,20 +390,22 @@ def test_serial_field(self): def test_invalid_serial_field(self): field = _FIELD.SerialField('abcdefg') + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.SerialField(random.randint(2**65, 2**66)) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.IntegerOverflowException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_certificate_type_field(self): @@ -407,27 +435,32 @@ def test_certificate_type_field(self): def test_invalid_certificate_field(self): field = _FIELD.CertificateTypeField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) + field = _FIELD.CertificateTypeField(3) + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.InvalidDataException ) + field = _FIELD.CertificateTypeField(0) + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_key_id_field(self): @@ -438,13 +471,15 @@ def test_key_id_field(self): def test_invalid_key_id_field(self): field = _FIELD.KeyIdField('') + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_principals_field(self): @@ -463,13 +498,15 @@ def test_principals_field(self): def test_invalid_principals_field(self): field = _FIELD.PrincipalsField([ValueError, ValueError, ValueError]) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_validity_start_field(self): @@ -480,13 +517,15 @@ def test_validity_start_field(self): def test_invalid_validity_start_field(self): field = _FIELD.ValidAfterField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_validity_end_field(self): @@ -497,13 +536,14 @@ def test_validity_end_field(self): def test_invalid_validity_end_field(self): field = _FIELD.ValidBeforeField(ValueError) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_critical_options_field(self): @@ -543,26 +583,30 @@ def test_critical_options_field(self): def test_invalid_critical_options_field(self): field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', 'unpermit']) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.CriticalOptionsField('InvalidData') - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + field.validate() + + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_extensions_field(self): @@ -582,26 +626,30 @@ def test_extensions_field(self): def test_invalid_extensions_field(self): field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', b'unpermit']) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.CriticalOptionsField('InvalidData') - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + field.validate() + + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + field.validate() - self.assertIsInstance( - field.validate(), - _EX.InvalidFieldDataException + self.assertFieldContainsException( + field, + _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def test_reserved_field(self): @@ -613,13 +661,14 @@ def test_reserved_field(self): def test_invalid_reserved_field(self): field = _FIELD.ReservedField('InvalidData') + field.validate() - self.assertIsInstance( - field.validate(), + self.assertFieldContainsException( + field, _EX.InvalidDataException ) - with self.assertRaises(_EX.InvalidFieldDataException): + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) def assertCAPubkeyField(self, type): @@ -645,15 +694,15 @@ def setUp(self): self.faker = faker.Faker() - self.rsa_ca = _KEY.RSAPrivateKey.generate(1024) - self.dsa_ca = _KEY.DSAPrivateKey.generate() - self.ecdsa_ca = _KEY.ECDSAPrivateKey.generate() - self.ed25519_ca = _KEY.ED25519PrivateKey.generate() + self.rsa_ca = _KEY.RsaPrivateKey.generate(1024) + self.dsa_ca = _KEY.DsaPrivateKey.generate() + self.ecdsa_ca = _KEY.EcdsaPrivateKey.generate() + self.ed25519_ca = _KEY.Ed25519PrivateKey.generate() - self.rsa_user = _KEY.RSAPrivateKey.generate(1024).public_key - self.dsa_user = _KEY.DSAPrivateKey.generate().public_key - self.ecdsa_user = _KEY.ECDSAPrivateKey.generate().public_key - self.ed25519_user = _KEY.ED25519PrivateKey.generate().public_key + self.rsa_user = _KEY.RsaPrivateKey.generate(1024).public_key + self.dsa_user = _KEY.DsaPrivateKey.generate().public_key + self.ecdsa_user = _KEY.EcdsaPrivateKey.generate().public_key + self.ed25519_user = _KEY.Ed25519PrivateKey.generate().public_key self.cert_opts = { @@ -688,10 +737,10 @@ def test_cert_type_assignment(self): ecdsa_cert = _CERT.SSHCertificate.from_public_class(self.ecdsa_user) ed25519_cert = _CERT.SSHCertificate.from_public_class(self.ed25519_user) - self.assertIsInstance(rsa_cert, _CERT.RSACertificate) - self.assertIsInstance(dsa_cert, _CERT.DSACertificate) - self.assertIsInstance(ecdsa_cert, _CERT.ECDSACertificate) - self.assertIsInstance(ed25519_cert, _CERT.ED25519Certificate) + self.assertIsInstance(rsa_cert, _CERT.RsaCertificate) + self.assertIsInstance(dsa_cert, _CERT.DsaCertificate) + self.assertIsInstance(ecdsa_cert, _CERT.EcdsaCertificate) + self.assertIsInstance(ed25519_cert, _CERT.Ed25519Certificate) def assertCertificateCreated(self, sub_type, ca_type): diff --git a/tests/test_keypairs.py b/tests/test_keypairs.py index 90393aa..168572b 100644 --- a/tests/test_keypairs.py +++ b/tests/test_keypairs.py @@ -10,16 +10,16 @@ from src.sshkey_tools.keys import ( PrivkeyClasses, PrivateKey, - RSAPrivateKey, - DSAPrivateKey, - ECDSAPrivateKey, - ED25519PrivateKey, + RsaPrivateKey, + DsaPrivateKey, + EcdsaPrivateKey, + Ed25519PrivateKey, PubkeyClasses, PublicKey, - RSAPublicKey, - DSAPublicKey, - ECDSAPublicKey, - ED25519PublicKey, + RsaPublicKey, + DsaPublicKey, + EcdsaPublicKey, + Ed25519PublicKey, EcdsaCurves ) @@ -34,10 +34,10 @@ class KeypairMethods(unittest.TestCase): def generateClasses(self): - self.rsa_key = RSAPrivateKey.generate(2048) - self.dsa_key = DSAPrivateKey.generate() - self.ecdsa_key = ECDSAPrivateKey.generate(EcdsaCurves.P256) - self.ed25519_key = ED25519PrivateKey.generate() + self.rsa_key = RsaPrivateKey.generate(2048) + self.dsa_key = DsaPrivateKey.generate() + self.ecdsa_key = EcdsaPrivateKey.generate(EcdsaCurves.P256) + self.ed25519_key = Ed25519PrivateKey.generate() def generateFiles(self, folder): self.folder = folder @@ -114,17 +114,17 @@ class TestKeypairMethods(KeypairMethods): def test_fail_assertions(self): with self.assertRaises(AssertionError): self.assertEqualPrivateKeys( - RSAPrivateKey, - RSAPublicKey, - RSAPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password'), - DSAPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + RsaPrivateKey, + RsaPublicKey, + RsaPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password'), + DsaPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') ) with self.assertRaises(AssertionError): self.assertEqualPublicKeys( - RSAPublicKey, - RSAPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub'), - DSAPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + RsaPublicKey, + RsaPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub'), + DsaPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') ) with self.assertRaises(AssertionError): @@ -191,34 +191,34 @@ def test_rsa(self): ] for bits in key_bits: - key = RSAPrivateKey.generate(bits) + key = RsaPrivateKey.generate(bits) - assert isinstance(key, RSAPrivateKey) + assert isinstance(key, RsaPrivateKey) assert isinstance(key, PrivateKey) - assert isinstance(key.key, _RSA.RSAPrivateKey) + assert isinstance(key.key, _RSA.RsaPrivateKey) assert isinstance(key.private_numbers, _RSA.RSAPrivateNumbers) - assert isinstance(key.public_key, RSAPublicKey) + assert isinstance(key.public_key, RsaPublicKey) assert isinstance(key.public_key, PublicKey) - assert isinstance(key.public_key.key, _RSA.RSAPublicKey) + assert isinstance(key.public_key.key, _RSA.RsaPublicKey) assert isinstance(key.public_key.public_numbers, _RSA.RSAPublicNumbers) def test_rsa_incorrect_keysize(self): with self.assertRaises(ValueError): - RSAPrivateKey.generate(256) + RsaPrivateKey.generate(256) def test_dsa(self): - key = DSAPrivateKey.generate() + key = DsaPrivateKey.generate() - assert isinstance(key, DSAPrivateKey) + assert isinstance(key, DsaPrivateKey) assert isinstance(key, PrivateKey) - assert isinstance(key.key, _DSA.DSAPrivateKey) + assert isinstance(key.key, _DSA.DsaPrivateKey) assert isinstance(key.private_numbers, _DSA.DSAPrivateNumbers) - assert isinstance(key.public_key, DSAPublicKey) + assert isinstance(key.public_key, DsaPublicKey) assert isinstance(key.public_key, PublicKey) - assert isinstance(key.public_key.key, _DSA.DSAPublicKey) + assert isinstance(key.public_key.key, _DSA.DsaPublicKey) assert isinstance(key.public_key.public_numbers, _DSA.DSAPublicNumbers) assert isinstance(key.public_key.parameters, _DSA.DSAParameterNumbers) @@ -230,15 +230,15 @@ def test_ecdsa(self): ] for curve in curves: - key = ECDSAPrivateKey.generate(curve) + key = EcdsaPrivateKey.generate(curve) - assert isinstance(key, ECDSAPrivateKey) + assert isinstance(key, EcdsaPrivateKey) assert isinstance(key, PrivateKey) assert isinstance(key.key, _EC.EllipticCurvePrivateKey) assert isinstance(key.private_numbers, _EC.EllipticCurvePrivateNumbers) - assert isinstance(key.public_key, ECDSAPublicKey) + assert isinstance(key.public_key, EcdsaPublicKey) assert isinstance(key.public_key, PublicKey) assert isinstance(key.public_key.key, _EC.EllipticCurvePublicKey) assert isinstance(key.public_key.public_numbers, _EC.EllipticCurvePublicNumbers) @@ -246,16 +246,16 @@ def test_ecdsa(self): def test_ecdsa_not_a_curve(self): with self.assertRaises(AttributeError): - ECDSAPrivateKey.generate('p256') + EcdsaPrivateKey.generate('p256') def test_ed25519(self): - key = ED25519PrivateKey.generate() + key = Ed25519PrivateKey.generate() - assert isinstance(key, ED25519PrivateKey) + assert isinstance(key, Ed25519PrivateKey) assert isinstance(key, PrivateKey) assert isinstance(key.key, _ED25519.Ed25519PrivateKey) - assert isinstance(key.public_key, ED25519PublicKey) + assert isinstance(key.public_key, Ed25519PublicKey) assert isinstance(key.public_key, PublicKey) assert isinstance(key.public_key.key, _ED25519.Ed25519PublicKey) @@ -275,14 +275,14 @@ def test_encoding(self): from_file_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') self.assertEqualPrivateKeys( - RSAPrivateKey, - RSAPublicKey, + RsaPrivateKey, + RsaPublicKey, from_string, from_file ) self.assertEqualPublicKeys( - RSAPublicKey, + RsaPublicKey, from_string_pub, from_file_pub ) @@ -290,10 +290,10 @@ def test_encoding(self): def test_rsa_files(self): parent = PrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') - child = RSAPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') + child = RsaPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') parent_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') - child_pub = RSAPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') + child_pub = RsaPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') parent.to_file(f'tests/{self.folder}/rsa_key_saved_parent', 'password') child.to_file(f'tests/{self.folder}/rsa_key_saved_child') @@ -303,20 +303,20 @@ def test_rsa_files(self): self.assertEqualPrivateKeys( - RSAPrivateKey, - RSAPublicKey, + RsaPrivateKey, + RsaPublicKey, parent, child ) self.assertEqualPublicKeys( - RSAPublicKey, + RsaPublicKey, parent_pub, child_pub ) self.assertEqualPublicKeys( - RSAPublicKey, + RsaPublicKey, parent.public_key, child_pub ) @@ -347,10 +347,10 @@ def test_rsa_files(self): def test_dsa_files(self): parent = PrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') - child = DSAPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + child = DsaPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') parent_pub = PublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') - child_pub = DSAPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + child_pub = DsaPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') parent.to_file(f'tests/{self.folder}/dsa_key_saved_parent') child.to_file(f'tests/{self.folder}/dsa_key_saved_child') @@ -359,20 +359,20 @@ def test_dsa_files(self): child_pub.to_file(f'tests/{self.folder}/dsa_key_saved_child.pub') self.assertEqualPrivateKeys( - DSAPrivateKey, - DSAPublicKey, + DsaPrivateKey, + DsaPublicKey, parent, child ) self.assertEqualPublicKeys( - DSAPublicKey, + DsaPublicKey, parent_pub, child_pub ) self.assertEqualPublicKeys( - DSAPublicKey, + DsaPublicKey, parent.public_key, child_pub ) @@ -403,10 +403,10 @@ def test_dsa_files(self): def test_ecdsa_files(self): parent = PrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') - child = ECDSAPrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') + child = EcdsaPrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') parent_pub = PublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') - child_pub = ECDSAPublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') + child_pub = EcdsaPublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') parent.to_file(f'tests/{self.folder}/ecdsa_key_saved_parent') child.to_file(f'tests/{self.folder}/ecdsa_key_saved_child') @@ -415,20 +415,20 @@ def test_ecdsa_files(self): child_pub.to_file(f'tests/{self.folder}/ecdsa_key_saved_child.pub') self.assertEqualPrivateKeys( - ECDSAPrivateKey, - ECDSAPublicKey, + EcdsaPrivateKey, + EcdsaPublicKey, parent, child ) self.assertEqualPublicKeys( - ECDSAPublicKey, + EcdsaPublicKey, parent_pub, child_pub ) self.assertEqualPublicKeys( - ECDSAPublicKey, + EcdsaPublicKey, parent.public_key, child_pub ) @@ -459,10 +459,10 @@ def test_ecdsa_files(self): def test_ed25519_files(self): parent = PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') - child = ED25519PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') + child = Ed25519PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') parent_pub = PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') - child_pub = ED25519PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') + child_pub = Ed25519PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') parent.to_file(f'tests/{self.folder}/ed25519_key_saved_parent') child.to_file(f'tests/{self.folder}/ed25519_key_saved_child') @@ -471,20 +471,20 @@ def test_ed25519_files(self): child_pub.to_file(f'tests/{self.folder}/ed25519_key_saved_child.pub') self.assertEqualPrivateKeys( - ED25519PrivateKey, - ED25519PublicKey, + Ed25519PrivateKey, + Ed25519PublicKey, parent, child ) self.assertEqualPublicKeys( - ED25519PublicKey, + Ed25519PublicKey, parent_pub, child_pub ) self.assertEqualPublicKeys( - ED25519PublicKey, + Ed25519PublicKey, parent.public_key, child_pub ) @@ -541,44 +541,44 @@ def test_invalid_key_exception(self): def test_rsa_from_class(self): parent = PrivateKey.from_class(self.rsa_key) - child = RSAPrivateKey.from_class(self.rsa_key) + child = RsaPrivateKey.from_class(self.rsa_key) self.assertEqualPrivateKeys( - RSAPrivateKey, - RSAPublicKey, + RsaPrivateKey, + RsaPublicKey, parent, child ) def test_dsa_from_class(self): parent = PrivateKey.from_class(self.dsa_key) - child = DSAPrivateKey.from_class(self.dsa_key) + child = DsaPrivateKey.from_class(self.dsa_key) self.assertEqualPrivateKeys( - DSAPrivateKey, - DSAPublicKey, + DsaPrivateKey, + DsaPublicKey, parent, child ) def test_ecdsa_from_class(self): parent = PrivateKey.from_class(self.ecdsa_key) - child = ECDSAPrivateKey.from_class(self.ecdsa_key) + child = EcdsaPrivateKey.from_class(self.ecdsa_key) self.assertEqualPrivateKeys( - ECDSAPrivateKey, - ECDSAPublicKey, + EcdsaPrivateKey, + EcdsaPublicKey, parent, child ) def test_ed25519_from_class(self): parent = PrivateKey.from_class(self.ed25519_key) - child = ED25519PrivateKey.from_class(self.ed25519_key) + child = Ed25519PrivateKey.from_class(self.ed25519_key) self.assertEqualPrivateKeys( - ED25519PrivateKey, - ED25519PublicKey, + Ed25519PrivateKey, + Ed25519PublicKey, parent, child ) @@ -591,24 +591,24 @@ def tearDown(self): pass def test_rsa_from_numbers(self): - from_numbers = RSAPrivateKey.from_numbers( + from_numbers = RsaPrivateKey.from_numbers( n=self.rsa_key.public_key.public_numbers.n, e=self.rsa_key.public_key.public_numbers.e, d=self.rsa_key.private_numbers.d ) - from_numbers_pub = RSAPublicKey.from_numbers( + from_numbers_pub = RsaPublicKey.from_numbers( n=self.rsa_key.public_key.public_numbers.n, e=self.rsa_key.public_key.public_numbers.e ) self.assertEqualPublicKeys( - RSAPublicKey, + RsaPublicKey, from_numbers_pub, from_numbers.public_key ) - self.assertIsInstance(from_numbers, RSAPrivateKey) + self.assertIsInstance(from_numbers, RsaPrivateKey) self.assertEqual( self.rsa_key.public_key.public_numbers.n, @@ -626,7 +626,7 @@ def test_rsa_from_numbers(self): ) def test_dsa_from_numbers(self): - from_numbers = DSAPrivateKey.from_numbers( + from_numbers = DsaPrivateKey.from_numbers( p=self.dsa_key.public_key.parameters.p, q=self.dsa_key.public_key.parameters.q, g=self.dsa_key.public_key.parameters.g, @@ -634,7 +634,7 @@ def test_dsa_from_numbers(self): x=self.dsa_key.private_numbers.x ) - from_numbers_pub = DSAPublicKey.from_numbers( + from_numbers_pub = DsaPublicKey.from_numbers( p=self.dsa_key.public_key.parameters.p, q=self.dsa_key.public_key.parameters.q, g=self.dsa_key.public_key.parameters.g, @@ -642,46 +642,46 @@ def test_dsa_from_numbers(self): ) self.assertEqualPrivateKeys( - DSAPrivateKey, - DSAPublicKey, + DsaPrivateKey, + DsaPublicKey, self.dsa_key, from_numbers ) self.assertEqualPublicKeys( - DSAPublicKey, + DsaPublicKey, from_numbers_pub, self.dsa_key.public_key ) def test_ecdsa_from_numbers(self): - from_numbers = ECDSAPrivateKey.from_numbers( + from_numbers = EcdsaPrivateKey.from_numbers( curve=self.ecdsa_key.public_key.key.curve, x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, private_value=self.ecdsa_key.private_numbers.private_value ) - from_numbers_pub = ECDSAPublicKey.from_numbers( + from_numbers_pub = EcdsaPublicKey.from_numbers( curve=self.ecdsa_key.public_key.key.curve, x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y ) self.assertEqualPrivateKeys( - ECDSAPrivateKey, - ECDSAPublicKey, + EcdsaPrivateKey, + EcdsaPublicKey, self.ecdsa_key, from_numbers ) self.assertEqualPublicKeys( - ECDSAPublicKey, + EcdsaPublicKey, from_numbers_pub, self.ecdsa_key.public_key ) - from_numbers = ECDSAPrivateKey.from_numbers( + from_numbers = EcdsaPrivateKey.from_numbers( curve=self.ecdsa_key.public_key.key.curve.name, x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, @@ -689,30 +689,30 @@ def test_ecdsa_from_numbers(self): ) self.assertEqualPrivateKeys( - ECDSAPrivateKey, - ECDSAPublicKey, + EcdsaPrivateKey, + EcdsaPublicKey, self.ecdsa_key, from_numbers ) def test_ed25519_from_raw_bytes(self): - from_raw = ED25519PrivateKey.from_raw_bytes( + from_raw = Ed25519PrivateKey.from_raw_bytes( self.ed25519_key.raw_bytes() ) - from_raw_pub = ED25519PublicKey.from_raw_bytes( + from_raw_pub = Ed25519PublicKey.from_raw_bytes( self.ed25519_key.public_key.raw_bytes() ) self.assertEqualPrivateKeys( - ED25519PrivateKey, - ED25519PublicKey, + Ed25519PrivateKey, + Ed25519PublicKey, self.ed25519_key, from_raw, [] ) self.assertEqualPublicKeys( - ED25519PublicKey, + Ed25519PublicKey, self.ed25519_key.public_key, from_raw_pub ) @@ -724,7 +724,7 @@ def setUp(self): self.generateFiles('TestFingerprint') def test_rsa_fingerprint(self): - key = RSAPrivateKey.from_file( + key = RsaPrivateKey.from_file( f'tests/{self.folder}/rsa_key_sshkeygen', 'password' ) @@ -734,7 +734,7 @@ def test_rsa_fingerprint(self): self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) def test_dsa_fingerprint(self): - key = DSAPrivateKey.from_file( + key = DsaPrivateKey.from_file( f'tests/{self.folder}/dsa_key_sshkeygen', ) @@ -743,7 +743,7 @@ def test_dsa_fingerprint(self): self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) def test_ecdsa_fingerprint(self): - key = ECDSAPrivateKey.from_file( + key = EcdsaPrivateKey.from_file( f'tests/{self.folder}/ecdsa_key_sshkeygen', ) @@ -752,7 +752,7 @@ def test_ecdsa_fingerprint(self): self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) def test_ed25519_fingerprint(self): - key = ED25519PrivateKey.from_file( + key = Ed25519PrivateKey.from_file( f'tests/{self.folder}/ed25519_key_sshkeygen', ) @@ -852,14 +852,14 @@ def test_invalid_private_key(self): def test_invalid_ecdsa_curve(self): with self.assertRaises(_EX.InvalidCurveException): - key = ECDSAPublicKey.from_numbers( + key = EcdsaPublicKey.from_numbers( 'abc123', x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y ) with self.assertRaises(_EX.InvalidCurveException): - key = ECDSAPrivateKey.from_numbers( + key = EcdsaPrivateKey.from_numbers( 'abc123', x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, From fd1ea68965547f13309d16d4fc36465cce86ba5e Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Mon, 18 Jul 2022 11:55:12 +0000 Subject: [PATCH 3/6] Updated tests for certificates --- src/sshkey_tools/fields.py | 88 ++++++++++++++++---------------------- tests/test_certificates.py | 37 ++++++++-------- tests/test_keypairs.py | 8 ++-- 3 files changed, 59 insertions(+), 74 deletions(-) diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index 7d0e254..273e3f8 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -344,6 +344,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if self.value < MAX_INT32: return True @@ -388,6 +391,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if self.value < MAX_INT64: return True @@ -438,6 +444,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + check = self.value if isinstance(self.value, int) else self.value.timestamp() if check < MAX_INT64: @@ -535,6 +544,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if hasattr(self.value, '__iter__') and not all((isinstance(val, (str, bytes)) for val in self.value)): return _EX.InvalidFieldDataException( "Expected list or tuple containing strings or bytes" @@ -616,6 +628,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + testvals = ( self.value if not isinstance(self.value, dict) else list(self.value.keys()) + list(self.value.values()) @@ -651,6 +666,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if ensure_string(self.value) not in self.ALLOWED_VALUES: return _EX.InvalidFieldDataException( "Expected one of the following values: {}".format( @@ -673,6 +691,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if hasattr(self.value, '__count__') and len(self.value) < 32: return _EX.InvalidFieldDataException( "Expected a nonce of at least 32 bytes" @@ -890,6 +911,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + if self.value not in self.ALLOWED_VALUES: return _EX.InvalidCertificateFieldException( f"The certificate type is invalid (expected int(1,2) or CERT_TYPE.X)" @@ -938,6 +962,9 @@ def __validate_value__(self) -> Union[bool, Exception]: done to ensure no already expired certificates are created """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + super().__validate_value__() check = self.value if isinstance(self.value, datetime) else datetime.fromtimestamp(self.value) @@ -976,6 +1003,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + for elem in self.value if not isinstance(self.value, dict) else list(self.value.keys()): if elem not in self.ALLOWED_VALUES: return _EX.InvalidCertificateFieldException( @@ -1027,6 +1057,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + for item in self.value: if item not in self.ALLOWED_VALUES: return _EX.InvalidDataException( @@ -1049,6 +1082,9 @@ def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ + if isinstance(self.__validate_type__(self.value), Exception): + return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + return True if self.value == "" else _EX.InvalidDataException( f"The reserved field is not empty" ) @@ -1547,55 +1583,3 @@ def sign(self, data: bytes, **kwargs) -> None: """ self.value = self.private_key.sign(data) self.is_signed = True - -# ADD: Default class for each field -# ADD: Typing class for each field - -# @dataclass -# class Fieldset: -# def __setattr__(self, name, value): -# field = getattr(self, name, None) -# if callable(field) and not isinstance(field, CertificateField): -# if field.__name__ == "factory": -# super().__setattr__(name, field()) -# self.__setattr__(name, value) -# return - -# if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': -# super().__setattr__(name, field(value)) -# return - -# if getattr(value, '__name__', '') != 'factory': -# field.value = value -# super().__setattr__(name, field) - -# def get(self, name: str): -# field = getattr(self, name, None) -# if field: -# if isinstance(field, type): -# return field.DEFAULT -# return field.value -# raise _EX.InvalidCertificateFieldException(f"Unknown field {name}") - -# @dataclass -# class CertificateHeader(Fieldset): -# public_key: PublicKeyField = PublicKeyField.factory -# pubkey_type: PubkeyTypeField = PubkeyTypeField.factory -# nonce: NonceField = NonceField.factory - -# @dataclass -# class CertificateFooter(Fieldset): -# reserved: ReservedField = ReservedField.factory -# ca_pubkey: CAPublicKeyField = CAPublicKeyField.factory -# signature: SignatureField = SignatureField.factory - -# @dataclass -# class CertificateBody(Fieldset): -# serial: SerialField = SerialField.factory -# cert_type: CertificateTypeField = CertificateTypeField.factory -# key_id: KeyIdField = KeyIdField.factory -# principals: PrincipalsField = PrincipalsField.factory -# valid_after: ValidAfterField = ValidAfterField.factory -# valid_before: ValidBeforeField = ValidBeforeField.factory -# critical_options: CriticalOptionsField = CriticalOptionsField.factory -# extensions: ExtensionsField = ExtensionsField.factory \ No newline at end of file diff --git a/tests/test_certificates.py b/tests/test_certificates.py index 4b3f22a..c04d51f 100644 --- a/tests/test_certificates.py +++ b/tests/test_certificates.py @@ -13,7 +13,7 @@ import src.sshkey_tools.keys as _KEY import src.sshkey_tools.fields as _FIELD -import sshkey_tools.certold as _CERT +import src.sshkey_tools.cert as _CERT import src.sshkey_tools.exceptions as _EX CERTIFICATE_TYPES = ['rsa', 'dsa', 'ecdsa', 'ed25519'] @@ -705,37 +705,37 @@ def setUp(self): self.ed25519_user = _KEY.Ed25519PrivateKey.generate().public_key - self.cert_opts = { - 'serial': 1234567890, - 'cert_type': _FIELD.CERT_TYPE.USER, - 'key_id': 'KeyIdentifier', - 'principals': [ + self.cert_fields = _CERT.CertificateFields( + serial=1234567890, + cert_type=_FIELD.CERT_TYPE.USER, + key_id='KeyIdentifier', + principals=[ 'pr_a', 'pr_b', 'pr_c' ], - 'valid_after': 1968491468, - 'valid_before': 1968534668, - 'critical_options': { + valid_after=1968491468, + valid_before=1968534668, + critical_options={ 'force-command': 'sftp-internal', 'source-address': '1.2.3.4/8,5.6.7.8/16', 'verify-required': '' }, - 'extensions': [ + extensions=[ 'permit-agent-forwarding', 'permit-X11-forwarding' ] - } + ) def tearDown(self): shutil.rmtree(f'tests/certificates') os.mkdir(f'tests/certificates') def test_cert_type_assignment(self): - rsa_cert = _CERT.SSHCertificate.from_public_class(self.rsa_user) - dsa_cert = _CERT.SSHCertificate.from_public_class(self.dsa_user) - ecdsa_cert = _CERT.SSHCertificate.from_public_class(self.ecdsa_user) - ed25519_cert = _CERT.SSHCertificate.from_public_class(self.ed25519_user) + rsa_cert = _CERT.SSHCertificate.create(self.rsa_user) + dsa_cert = _CERT.SSHCertificate.create(self.dsa_user) + ecdsa_cert = _CERT.SSHCertificate.create(self.ecdsa_user) + ed25519_cert = _CERT.SSHCertificate.create(self.ed25519_user) self.assertIsInstance(rsa_cert, _CERT.RsaCertificate) self.assertIsInstance(dsa_cert, _CERT.DsaCertificate) @@ -747,10 +747,10 @@ def assertCertificateCreated(self, sub_type, ca_type): sub_pubkey = getattr(self, f'{sub_type}_user') ca_privkey = getattr(self, f'{ca_type}_ca') - certificate = _CERT.SSHCertificate.from_public_class( - public_key=sub_pubkey, + certificate = _CERT.SSHCertificate.create( + subject_pubkey=sub_pubkey, ca_privkey=ca_privkey, - **self.cert_opts + fields=self.cert_fields ) self.assertTrue(certificate.can_sign()) @@ -762,6 +762,7 @@ def assertCertificateCreated(self, sub_type, ca_type): os.system(f'ssh-keygen -Lf tests/certificates/{sub_type}_{ca_type}-cert.pub') ) + reloaded_cert = _CERT.SSHCertificate.from_file(f'tests/certificates/{sub_type}_{ca_type}-cert.pub') self.assertEqual( diff --git a/tests/test_keypairs.py b/tests/test_keypairs.py index 168572b..14c1880 100644 --- a/tests/test_keypairs.py +++ b/tests/test_keypairs.py @@ -195,12 +195,12 @@ def test_rsa(self): assert isinstance(key, RsaPrivateKey) assert isinstance(key, PrivateKey) - assert isinstance(key.key, _RSA.RsaPrivateKey) + assert isinstance(key.key, _RSA.RSAPrivateKey) assert isinstance(key.private_numbers, _RSA.RSAPrivateNumbers) assert isinstance(key.public_key, RsaPublicKey) assert isinstance(key.public_key, PublicKey) - assert isinstance(key.public_key.key, _RSA.RsaPublicKey) + assert isinstance(key.public_key.key, _RSA.RSAPublicKey) assert isinstance(key.public_key.public_numbers, _RSA.RSAPublicNumbers) def test_rsa_incorrect_keysize(self): @@ -213,12 +213,12 @@ def test_dsa(self): assert isinstance(key, DsaPrivateKey) assert isinstance(key, PrivateKey) - assert isinstance(key.key, _DSA.DsaPrivateKey) + assert isinstance(key.key, _DSA.DSAPrivateKey) assert isinstance(key.private_numbers, _DSA.DSAPrivateNumbers) assert isinstance(key.public_key, DsaPublicKey) assert isinstance(key.public_key, PublicKey) - assert isinstance(key.public_key.key, _DSA.DsaPublicKey) + assert isinstance(key.public_key.key, _DSA.DSAPublicKey) assert isinstance(key.public_key.public_numbers, _DSA.DSAPublicNumbers) assert isinstance(key.public_key.parameters, _DSA.DSAParameterNumbers) From 52bc67e2f698bb7e69bb15580c25507d1fd903ae Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Mon, 18 Jul 2022 15:27:45 +0000 Subject: [PATCH 4/6] Finished re-write of field definitions. Todo: Linting --- .gitpod.yml | 2 +- README.md | 44 +- docs/fields.html | 2 + requirements-test.txt | 3 +- requirements.txt | 3 +- src/sshkey_tools/cert.py | 418 +++++++++++++------ src/sshkey_tools/certold.py | 588 -------------------------- src/sshkey_tools/fields.py | 342 ++++++++++------ src/sshkey_tools/keys.py | 70 ++-- src/sshkey_tools/utils.py | 64 ++- tests/test_certificates.py | 793 +++++++++++++++--------------------- tests/test_keypairs.py | 754 ++++++++++++++-------------------- tests/test_utils.py | 406 +++++++++--------- 13 files changed, 1461 insertions(+), 2028 deletions(-) delete mode 100644 src/sshkey_tools/certold.py diff --git a/.gitpod.yml b/.gitpod.yml index 0c589bf..20e68fa 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -3,4 +3,4 @@ tasks: - init: | sudo apt-get update && sudo apt-get -y upgrade && \ sudo apt -y install openssh-server && \ - pip3 install -r requirements.txt \ No newline at end of file + pip3 install -r requirements-test.txt diff --git a/README.md b/README.md index 57ba685..90a7aee 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,19 @@ # sshkey-tools + Python and CLI tools for managing OpenSSH keypairs and certificates # Installation + ## With pip + ```bash pip3 install sshkey-tools # or pip3 install -e git+https://github.com/scheiblingco/sshkey-tools.git ``` + ## From source + ```bash git clone https://github.com/scheiblingco/sshkey-tools cd sshkey-tools @@ -16,11 +21,15 @@ pip3 install ./ ``` # Documentation + [scheiblingco.github.io/sshkey-tools/](https://scheiblingco.github.io/sshkey-tools/) # Basic usage + ## SSH Keypairs + ### Generate keys + ```python from sshkey_tools.keys import ( RsaPrivateKey, @@ -58,7 +67,9 @@ rsa_pub = rsa_private.public_key ``` ### Load keys + You can load keys either directly with the specific key classes (RsaPrivateKey, DsaPrivateKey, etc.) or the general PrivateKey class + ```python from sshkey_tools.keys import ( PrivateKey, @@ -111,20 +122,24 @@ rsa_private = RsaPrivateKey.from_numbers( ``` ## SSH Certificates + ### Attributes -|Attribute|Type|Key|Example Value|Description| -|---|---|---|---|---| -|Certificate Type|Integer (1/2)|cert_type|1|The type of certificate, 1 for User and 2 for Host. Can also be defined as sshkey_tools.fields.CERT_TYPE.USER or sshkey_tools.fields.CERT_TYPE.HOST| -|Serial|Integer|serial|11223344|The serial number for the certificate, a 64-bit integer| -|Key ID|String|key_id|someuser@somehost|The key identifier, can be set to any string, for example username, email or other unique identifier| -|Principals|List|principals|['zone-webservers', 'server-01']|The principals for which the certificate is valid, this needs to correspond to the allowed principals on the OpenSSH Server-side. Only valid for User certificates| -|Valid After|Integer|valid_after|datetime.now()|The datetime object or unix timestamp for when the certificate validity starts| -|Valid Before|Integer|valid_before|datetime.now() + timedelta(hours=12)|The datetime object or unix timestamp for when the certificate validity ends| -|Critical Options|Dict|critical_options|{'source-address': '1.2.3.4/8'}|Options set on the certificate that the OpenSSH server cannot choose to ignore (critical). Only valid on user certificates. Valid options are force-command (for limiting the user to a certain shell, e.g. sftp-internal), source-address (to limit the source IPs the user can connect from) and verify-required (to require the user to touch a hardware key before usage)| -|Extensions|Dict/Set/List/Tuple|extensions|{'permit-X11-forwarding', 'permit-port-forwarding'}|Extensions that the certificate holder is allowed to use. Valid options are no-touch-required, permit-X11-forwarding, permit-agent-forwarding, permit-port-forwarding, permit-pty, permit-user-rc| + +| Attribute | Type | Key | Example Value | Description | +| ---------------- | ------------------- | ---------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Certificate Type | Integer (1/2) | cert_type | 1 | The type of certificate, 1 for User and 2 for Host. Can also be defined as sshkey_tools.fields.CERT_TYPE.USER or sshkey_tools.fields.CERT_TYPE.HOST | +| Serial | Integer | serial | 11223344 | The serial number for the certificate, a 64-bit integer | +| Key ID | String | key_id | someuser@somehost | The key identifier, can be set to any string, for example username, email or other unique identifier | +| Principals | List | principals | ['zone-webservers', 'server-01'] | The principals for which the certificate is valid, this needs to correspond to the allowed principals on the OpenSSH Server-side. Only valid for User certificates | +| Valid After | Integer | valid_after | datetime.now() | The datetime object or unix timestamp for when the certificate validity starts | +| Valid Before | Integer | valid_before | datetime.now() + timedelta(hours=12) | The datetime object or unix timestamp for when the certificate validity ends | +| Critical Options | Dict | critical_options | {'source-address': '1.2.3.4/8'} | Options set on the certificate that the OpenSSH server cannot choose to ignore (critical). Only valid on user certificates. Valid options are force-command (for limiting the user to a certain shell, e.g. sftp-internal), source-address (to limit the source IPs the user can connect from) and verify-required (to require the user to touch a hardware key before usage) | +| Extensions | Dict/Set/List/Tuple | extensions | {'permit-X11-forwarding', 'permit-port-forwarding'} | Extensions that the certificate holder is allowed to use. Valid options are no-touch-required, permit-X11-forwarding, permit-agent-forwarding, permit-port-forwarding, permit-pty, permit-user-rc | ### Certificate creation + The basis for a certificate is the public key for the subject (User/Host), and bases the format of the certificate on that. + ```python from datetime import datetime, timedelta from cryptography.hazmat.primitives import ( @@ -156,7 +171,7 @@ cert_opts = { 'permit-pty', 'permit-user-rc', 'permit-port-forwarding' - ] + ] } @@ -207,13 +222,13 @@ except SignatureNotPossibleException: # Sign the certificate certificate.sign() -# For certificates signed by an RSA key, you can choose the hashing algorithm +# For certificates signed by an RSA key, you can choose the hashing algorithm # to be used for creating the hash of the certificate data before signing certificate.sign( hash_alg=RsaAlgs.SHA512 ) -# If you want to verify the signature after creation, +# If you want to verify the signature after creation, # you can do so with the verify()-method # # Please note that a public key should always be provided @@ -259,6 +274,7 @@ cert_bytes = certificate.to_bytes() ``` ### Load an existing certificate + Certificates can be loaded from file, a string/bytestring with file contents or the base64-decoded byte data of the certificate @@ -285,4 +301,4 @@ ca_privkey = PrivateKey.from_file('path/to/ca_privkey') certificate.set_ca(ca_privkey) certificate.sign() certificate.to_file('path/to/user_key-cert2.pub') -``` \ No newline at end of file +``` diff --git a/docs/fields.html b/docs/fields.html index ef52535..8632d76 100644 --- a/docs/fields.html +++ b/docs/fields.html @@ -1584,6 +1584,7 @@

      Module sshkey_tools.fields

      """ def __init__( + self, private_key: Ed25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -3774,6 +3775,7 @@

      Inherited members

      """ def __init__( + self, private_key: Ed25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) diff --git a/requirements-test.txt b/requirements-test.txt index 36efcbe..c35452a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,4 +7,5 @@ coverage black pytest-cov faker -cprint \ No newline at end of file +cprint +PrettyTable \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 64ee331..4a1cb1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ click cryptography bcrypt -enum34 \ No newline at end of file +enum34 +PrettyTable \ No newline at end of file diff --git a/src/sshkey_tools/cert.py b/src/sshkey_tools/cert.py index 2f775f3..602ab19 100644 --- a/src/sshkey_tools/cert.py +++ b/src/sshkey_tools/cert.py @@ -1,112 +1,178 @@ -"""Contains classes for OpenSSH Certificates, generation, parsing and signing - Raises: - _EX.SSHCertificateException: General error in certificate - _EX.InvalidCertificateFormatException: An error with the format of the certificate - _EX.InvalidCertificateFieldException: An invalid field has been added to the certificate - _EX.NoPrivateKeyException: The certificate contains no private key - _EX.NotSignedException: The certificate is not signed and cannot be exported - +# pylint: disable=super-with-arguments +""" +Contains classes for OpenSSH Certificates, generation, parsing and signing +Raises: + _EX.SSHCertificateException: General error in certificate + _EX.InvalidCertificateFormatException: An error with the format of the certificate + _EX.InvalidCertificateFieldException: An invalid field has been added to the certificate + _EX.NoPrivateKeyException: The certificate contains no private key + _EX.NotSignedException: The certificate is not signed and cannot be exported """ -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode from dataclasses import dataclass -from typing import Union -from enum import Enum -from .keys import ( - PublicKey, - PrivateKey, - RsaPublicKey, - DsaPublicKey, - EcdsaPublicKey, - Ed25519PublicKey, -) -from . import fields as _FIELD +from typing import Tuple, Union + +from prettytable import PrettyTable + from . import exceptions as _EX -from .keys import RsaAlgs -from .utils import join_dicts, concat_to_string, concat_to_bytestring, ensure_string, ensure_bytestring +from . import fields as _FIELD +from .keys import PrivateKey, PublicKey +from .utils import concat_to_bytestring, concat_to_string, ensure_bytestring CERT_TYPES = { - "ssh-rsa-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), - "rsa-sha2-256-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), - "rsa-sha2-512-cert-v01@openssh.com": ("RsaCertificate", "_FIELD.RsaPubkeyField"), - "ssh-dss-cert-v01@openssh.com": ("DsaCertificate", "_FIELD.DsaPubkeyField"), + "ssh-rsa-cert-v01@openssh.com": ("RsaCertificate", "RsaPubkeyField"), + "rsa-sha2-256-cert-v01@openssh.com": ("RsaCertificate", "RsaPubkeyField"), + "rsa-sha2-512-cert-v01@openssh.com": ("RsaCertificate", "RsaPubkeyField"), + "ssh-dss-cert-v01@openssh.com": ("DsaCertificate", "DsaPubkeyField"), "ecdsa-sha2-nistp256-cert-v01@openssh.com": ( "EcdsaCertificate", - "_FIELD.EcdsaPubkeyField", + "EcdsaPubkeyField", ), "ecdsa-sha2-nistp384-cert-v01@openssh.com": ( "EcdsaCertificate", - "_FIELD.EcdsaPubkeyField", + "EcdsaPubkeyField", ), "ecdsa-sha2-nistp521-cert-v01@openssh.com": ( "EcdsaCertificate", - "_FIELD.EcdsaPubkeyField", + "EcdsaPubkeyField", ), "ssh-ed25519-cert-v01@openssh.com": ( "Ed25519Certificate", - "_FIELD.Ed25519PubkeyField", + "Ed25519PubkeyField", ), } @dataclass class Fieldset: + """Set of fields for SSHCertificate class""" + + DECODE_ORDER = [] + + def __table__(self): + return [getattr(self, item).__table__() for item in self.getattrs()] + def __setattr__(self, name, value): field = getattr(self, name, None) - + + if isinstance(value, _FIELD.CertificateField): + self.replace_field(name, value) + return + if callable(field) and not isinstance(field, _FIELD.CertificateField): if field.__name__ == "factory": super().__setattr__(name, field()) self.__setattr__(name, value) return - if isinstance(field, type) and getattr(value, '__name__', '') != 'factory': + if isinstance(field, type) and getattr(value, "__name__", "") != "factory": super().__setattr__(name, field(value)) return - - if getattr(value, '__name__', '') != 'factory': + + if getattr(value, "__name__", "") != "factory": field.value = value super().__setattr__(name, field) - + def replace_field(self, name: str, value: Union[_FIELD.CertificateField, type]): + """Completely replace field instead of just setting value (original __setattr__ behaviour) + + Args: + name (str): The field to replace + value (Union[_FIELD.CertificateField, type]): The CertificateField + subclass or instance to replace with + """ super(Fieldset, self).__setattr__(name, value) - + def get(self, name: str, default=None): + """Get field contents + + Args: + name (str): Field name + default (_type_, optional): The default value to return in case the + field is not set. Defaults to None. + + Returns: + mixed: The contents of the field + """ field = getattr(self, name, default) if field: if isinstance(field, type): return field.DEFAULT return field.value return field - + def getattrs(self) -> tuple: - return tuple(k for k in self.__dict__.keys() if not k.startswith('_')) - + """Get all class attributes + + Returns: + tuple: All public class attributes + """ + # pylint: disable=consider-iterating-dictionary + return tuple(att for att in self.__dict__.keys() if not att.startswith("_")) + def validate(self): + """Validate all fields to ensure the data is correct + + Returns: + bool: True if valid, else exception + """ ex = [] for key in self.getattrs(): if not getattr(self, key).validate(): - list([ - ex.append(f"{type(x)}: {str(x)}") for x in getattr(self, key).exception - if isinstance(x, Exception) - ]) - + list( + ex.append(f"{type(x)}: {str(x)}") + for x in getattr(self, key).exception + if isinstance(x, Exception) + ) + return True if len(ex) == 0 else ex + @classmethod + def decode(cls, data: bytes) -> Tuple["Fieldset", bytes]: + """Decode the certificate field data from a stream of bytes + + Returns: + Tuple[Fieldset, bytes]: A tuple with the fieldset (Header, Fields or Footer) + and the remaining bytes. + """ + cl_instance = cls() + for item in cls.DECODE_ORDER: + decoded, data = getattr(cl_instance, item).from_decode(data) + setattr(item, decoded) + + return cl_instance, data + + @dataclass class CertificateHeader(Fieldset): + """Header fields for the certificate""" public_key: _FIELD.PublicKeyField = _FIELD.PublicKeyField.factory pubkey_type: _FIELD.PubkeyTypeField = _FIELD.PubkeyTypeField.factory nonce: _FIELD.NonceField = _FIELD.NonceField.factory + DECODE_ORDER = ["pubkey_type", "nonce"] + def __bytes__(self): return concat_to_bytestring( - bytes(self.pubkey_type), - bytes(self.nonce), - bytes(self.public_key) + bytes(self.pubkey_type), bytes(self.nonce), bytes(self.public_key) ) - + + @classmethod + def decode(cls, data: bytes) -> Tuple["CertificateHeader", bytes]: + cl_instance, data = super().decode(data) + + target_class = CERT_TYPES[cl_instance.get('pubkey_type')] + + public_key, data = getattr(_FIELD, target_class[1]).from_decode(data) + cl_instance.public_key = public_key + + return cl_instance, data + + @dataclass +# pylint: disable=too-many-instance-attributes class CertificateFields(Fieldset): + """Information fields for the certificate""" serial: _FIELD.SerialField = _FIELD.SerialField.factory cert_type: _FIELD.CertificateTypeField = _FIELD.CertificateTypeField.factory key_id: _FIELD.KeyIdField = _FIELD.KeyIdField.factory @@ -115,7 +181,18 @@ class CertificateFields(Fieldset): valid_before: _FIELD.ValidBeforeField = _FIELD.ValidBeforeField.factory critical_options: _FIELD.CriticalOptionsField = _FIELD.CriticalOptionsField.factory extensions: _FIELD.ExtensionsField = _FIELD.ExtensionsField.factory - + + DECODE_ORDER = [ + "serial", + "cert_type", + "key_id", + "principals", + "valid_after", + "valid_before", + "critical_options", + "extensions", + ] + def __bytes__(self): return concat_to_bytestring( bytes(self.serial), @@ -125,20 +202,22 @@ def __bytes__(self): bytes(self.valid_after), bytes(self.valid_before), bytes(self.critical_options), - bytes(self.extensions) + bytes(self.extensions), ) + @dataclass class CertificateFooter(Fieldset): + """Footer fields and signature for the certificate""" reserved: _FIELD.ReservedField = _FIELD.ReservedField.factory ca_pubkey: _FIELD.CAPublicKeyField = _FIELD.CAPublicKeyField.factory signature: _FIELD.SignatureField = _FIELD.SignatureField.factory - + + DECODE_ORDER = ["reserved", "ca_pubkey", "signature"] + def __bytes__(self): - return concat_to_bytestring( - bytes(self.reserved), - bytes(self.ca_pubkey) - ) + return concat_to_bytestring(bytes(self.reserved), bytes(self.ca_pubkey)) + class SSHCertificate: """ @@ -146,14 +225,16 @@ class SSHCertificate: To create new certificates, use the respective keytype classes or the from_public_key classmethod """ - DEFAULT_KEY_TYPE = 'none@openssh.com' + + DEFAULT_KEY_TYPE = "none@openssh.com" + # pylint: disable=too-many-arguments def __init__( self, subject_pubkey: PublicKey = None, ca_privkey: PrivateKey = None, fields: CertificateFields = CertificateFields, header: CertificateHeader = CertificateHeader, - footer: CertificateFooter = CertificateFooter + footer: CertificateFooter = CertificateFooter, ): if self.__class__.__name__ == "SSHCertificate": raise _EX.InvalidClassCallException( @@ -161,75 +242,140 @@ def __init__( + "one of the child classes, or call via decode, create \n" + "or one of the from_-classmethods" ) - + self.fields = fields() if isinstance(fields, type) else fields self.header = header() if isinstance(header, type) else header self.footer = footer() if isinstance(footer, type) else footer - + if isinstance(header, type) and subject_pubkey is not None: self.header.pubkey_type = self.DEFAULT_KEY_TYPE self.header.replace_field( - 'public_key', - _FIELD.PublicKeyField.from_object(subject_pubkey) + "public_key", _FIELD.PublicKeyField.from_object(subject_pubkey) ) - + if isinstance(footer, type) and ca_privkey is not None: self.footer.ca_pubkey = ca_privkey.public_key self.footer.replace_field( - 'signature', - _FIELD.SignatureField.from_object(ca_privkey) + "signature", _FIELD.SignatureField.from_object(ca_privkey) ) - + self.__post_init__() - + def __post_init__(self): """Extensible function for post-initialization for child classes""" - + def __bytes__(self): if not self.footer.signature.is_signed: raise _EX.InvalidCertificateFormatException( "Failed exporting certificate: Certificate is not signed" ) - + return concat_to_bytestring( bytes(self.header), bytes(self.fields), bytes(self.footer), - bytes(self.footer.signature) + bytes(self.footer.signature), ) + def __str__(self) -> str: + table = PrettyTable(["Field", "Value"]) + + for item in (self.header, self.fields, self.footer): + for row in item.__table__(): + table.add_row(row) + + return str(table) + @classmethod def create( cls, subject_pubkey: PublicKey = None, ca_privkey: PrivateKey = None, - fields: CertificateFields = CertificateFields + fields: CertificateFields = CertificateFields, ): - cert_class = subject_pubkey.__class__.__name__.replace("PublicKey", "Certificate") + cert_class = subject_pubkey.__class__.__name__.replace( + "PublicKey", "Certificate" + ) return globals()[cert_class]( - subject_pubkey=subject_pubkey, - ca_privkey=ca_privkey, - fields=fields + subject_pubkey=subject_pubkey, ca_privkey=ca_privkey, fields=fields ) - + @classmethod - def decode(cls,cert_data: Union[str, bytes]): - pass - + def decode(cls, data: bytes) -> "SSHCertificate": + """ + Decode an existing certificate and import it into a new object + + Args: + data (bytes): The certificate bytes, base64 decoded middle part of the certificate + + Returns: + SSHCertificate: SSHCertificate child class + """ + cert_header, data = CertificateHeader.decode(data) + cert_fields, data = CertificateFields.decode(data) + cert_footer, data = CertificateFooter.decode(data) + + return cls(header=cert_header, fields=cert_fields, footer=cert_footer) + + @classmethod + def from_bytes(cls, cert_bytes: bytes): + """ + Loads an existing certificate from the byte value. + + Args: + cert_bytes (bytes): Certificate bytes, base64 decoded middle part of the certificate + + Returns: + SSHCertificate: SSHCertificate child class + """ + cert_type, _ = _FIELD.StringField.decode(cert_bytes) + target_class = CERT_TYPES[cert_type] + return globals()[target_class[0]].decode(cert_bytes) + + @classmethod + def from_string(cls, cert_str: Union[str, bytes], encoding: str = "utf-8"): + """ + Loads an existing certificate from a string in the format + [certificate-type] [base64-encoded-certificate] [optional-comment] + + Args: + cert_str (str): The string containing the certificate + encoding (str, optional): The encoding of the string. Defaults to 'utf-8'. + + Returns: + SSHCertificate: SSHCertificate child class + """ + cert_str = ensure_bytestring(cert_str, encoding) + + certificate = b64decode(cert_str.split(b" ")[1]) + return cls.from_bytes(cert_bytes=certificate) + + @classmethod + def from_file(cls, path: str, encoding: str = "utf-8"): + """ + Loads an existing certificate from a file + + Args: + path (str): The path to the certificate file + encoding (str, optional): Encoding of the file. Defaults to 'utf-8'. + + Returns: + SSHCertificate: SSHCertificate child class + """ + return cls.from_string(open(path, "r", encoding=encoding).read()) + def get(self, field: str): if field in ( - self.header.getattrs() + - self.fields.getattrs() + - self.footer.getattrs() + self.header.getattrs() + self.fields.getattrs() + self.footer.getattrs() ): return ( - self.fields.get(field, False) or - self.header.get(field, False) or - self.footer.get(field, False) + self.fields.get(field, False) + or self.header.get(field, False) + or self.footer.get(field, False) ) raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") - + def set(self, field: str, value): if self.fields.get(field, False): setattr(self.fields, field, value) @@ -238,82 +384,114 @@ def set(self, field: str, value): if self.header.get(field, False): setattr(self.header, field, value) return - + if self.footer.get(field, False): setattr(self.footer, field, value) return - + raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") - + def can_sign(self) -> bool: valid_header = self.header.validate() valid_fields = self.fields.validate() check_keys = ( - True if isinstance(self.get('ca_pubkey'), PublicKey) and - isinstance(self.footer.signature.private_key, PrivateKey) + True + if isinstance(self.get("ca_pubkey"), PublicKey) + and isinstance(self.footer.signature.private_key, PrivateKey) else [ - _EX.SignatureNotPossibleException('No CA Public/Private key is loaded') + _EX.SignatureNotPossibleException("No CA Public/Private key is loaded") ] ) - + if (valid_header, valid_fields, check_keys) != (True, True, True): raise _EX.SignatureNotPossibleException( "\n".join( - valid_header if valid_header != True else [] + - valid_fields if valid_fields != True else [] + - check_keys if check_keys != True else [] + valid_header + if isinstance(valid_header, Exception) + else [] + valid_fields + if isinstance(valid_fields, Exception) + else [] + check_keys + if isinstance(check_keys, Exception) + else [] ) ) - + return True - + def get_signable(self) -> bytes: """ Retrieves the signable data for the certificate in byte form """ return concat_to_bytestring( - bytes(self.header), - bytes(self.fields), - bytes(self.footer) + bytes(self.header), bytes(self.fields), bytes(self.footer) ) def sign(self) -> bool: + """Sign the certificate + + Raises: + _EX.NotSignedException: The certificate could not be signed + + Returns: + bool: Whether successful + """ if self.can_sign(): - self.footer.signature.sign( - data=self.get_signable() - ) + self.footer.signature.sign(data=self.get_signable()) return True - - def to_string(self, comment: str = '', encoding: str = 'utf-8'): + raise _EX.NotSignedException("There was an error while signing the certificate") + + def to_string(self, comment: str = "", encoding: str = "utf-8"): + """Export the certificate to a string + + Args: + comment (str, optional): Comment to append to the certificate. Defaults to "". + encoding (str, optional): Which encoding to use for the string. Defaults to "utf-8". + + Returns: + str: The certificate data, base64-encoded and in string format + """ return concat_to_string( - self.header.get('pubkey_type'), + self.header.get("pubkey_type"), " ", b64encode(bytes(self)), " ", comment if comment else "", - encoding=encoding + encoding=encoding, ) - - def to_file(self, filename: str): - with open(filename, 'w') as f: - f.write(self.to_string()) - + + def to_file(self, filename: str, encoding: str = "utf-8"): + """Export certificate to file + + Args: + filename (str): The filename to write to + encoding (str, optional): The encoding to use for the file/string. Defaults to "utf-8". + """ + with open(filename, "w", encoding=encoding) as file: + file.write(self.to_string()) + + class RsaCertificate(SSHCertificate): - DEFAULT_KEY_TYPE = 'rsa-sha2-512-cert-v01@openssh.com' - + """The RSA Certificate class""" + DEFAULT_KEY_TYPE = "rsa-sha2-512-cert-v01@openssh.com" + + class DsaCertificate(SSHCertificate): - DEFAULT_KEY_TYPE = 'ssh-dss-cert-v01@openssh.com' + """The DSA Certificate class""" + DEFAULT_KEY_TYPE = "ssh-dss-cert-v01@openssh.com" + class EcdsaCertificate(SSHCertificate): - DEFAULT_KEY_TYPE = 'ecdsa-sha2-nistp[curve_size]-cert-v01@openssh.com' - + """The ECDSA certificate class""" + DEFAULT_KEY_TYPE = "ecdsa-sha2-nistp[curve_size]-cert-v01@openssh.com" + def __post_init__(self): """Set the key name from the public key curve size""" self.header.pubkey_type = self.header.get("pubkey_type").replace( - "[curve_size]", - str(self.header.public_key.value.key.curve.key_size) + "[curve_size]", str(self.header.public_key.value.key.curve.key_size) ) + class Ed25519Certificate(SSHCertificate): - DEFAULT_KEY_TYPE = 'ssh-ed25519-cert-v01@openssh.com' \ No newline at end of file + """The ED25519 certificate class""" + DEFAULT_KEY_TYPE = "ssh-ed25519-cert-v01@openssh.com" diff --git a/src/sshkey_tools/certold.py b/src/sshkey_tools/certold.py deleted file mode 100644 index 81f20ec..0000000 --- a/src/sshkey_tools/certold.py +++ /dev/null @@ -1,588 +0,0 @@ -"""Contains classes for OpenSSH Certificates, generation, parsing and signing - Raises: - _EX.SSHCertificateException: General error in certificate - _EX.InvalidCertificateFormatException: An error with the format of the certificate - _EX.InvalidCertificateFieldException: An invalid field has been added to the certificate - _EX.NoPrivateKeyException: The certificate contains no private key - _EX.NotSignedException: The certificate is not signed and cannot be exported - -""" -from base64 import b64encode, b64decode -from typing import Union -from .keys import ( - PublicKey, - PrivateKey, - RsaPublicKey, - DsaPublicKey, - EcdsaPublicKey, - Ed25519PublicKey, -) -from . import fields as _FIELD -from . import exceptions as _EX -from .keys import RsaAlgs -from .utils import join_dicts, concat_to_string, ensure_string, ensure_bytestring - -CERTIFICATE_FIELDS = { - "serial": _FIELD.SerialField, - "cert_type": _FIELD.CertificateTypeField, - "key_id": _FIELD.KeyIdField, - "principals": _FIELD.PrincipalsField, - "valid_after": _FIELD.ValidAfterField, - "valid_before": _FIELD.ValidBeforeField, - "critical_options": _FIELD.CriticalOptionsField, - "extensions": _FIELD.ExtensionsField, -} - -class SSHCertificate: - """ - General class for SSH Certificates, used for loading and parsing. - To create new certificates, use the respective keytype classes - or the from_public_key classmethod - """ - - def __init__( - self, - subject_pubkey: PublicKey = None, - ca_privkey: PrivateKey = None, - decoded: dict = None, - **kwargs, - ) -> None: - if self.__class__.__name__ == "SSHCertificate": - raise _EX.InvalidClassCallException( - "You cannot instantiate SSHCertificate directly. Use \n" - + "one of the child classes, or call via decode, \n" - + "or one of the from_-classmethods" - ) - - if decoded is not None: - self.signature = decoded.pop("signature") - self.signature_pubkey = decoded.pop("ca_pubkey") - - self.header = { - "pubkey_type": decoded.pop("pubkey_type"), - "nonce": decoded.pop("nonce"), - "public_key": decoded.pop("public_key"), - } - - self.fields = decoded - - return - - if subject_pubkey is None: - raise _EX.SSHCertificateException("The subject public key is required") - - self.header = { - "pubkey_type": _FIELD.PubkeyTypeField, - "nonce": _FIELD.NonceField(), - "public_key": _FIELD.PublicKeyField.from_object(subject_pubkey), - } - - if ca_privkey is not None: - self.signature = _FIELD.SignatureField.from_object(ca_privkey) - self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( - ca_privkey.public_key - ) - - self.fields = dict(CERTIFICATE_FIELDS) - self.set_opts(**kwargs) - - def __str__(self): - ls_space = " "*32 - - principals = "\n" + "\n".join( - [ ls_space + principal for principal in ensure_string(self.fields["principals"].value) ] - if len(self.fields["principals"].value) > 0 - else "None" - ) - - critical = "\n" + "\n".join( - [ls_space + cr_opt for cr_opt in ensure_string(self.fields["critical_options"].value)] - if not isinstance(self.fields["critical_options"].value, dict) - else [f'{ls_space}{cr_opt}={self.fields["critical_options"].value[cr_opt]}' for cr_opt in ensure_string(self.fields["critical_options"].value)] - ) - - extensions = "\n" + "\n".join( - [ ls_space + ext for ext in ensure_string(self.fields["extensions"].value) ] - if len(self.fields["extensions"].value) > 0 - else "None" - ) - - signature_val = ( - b64encode(self.signature.value).decode("utf-8") - if isinstance(self.signature.value, bytes) - else "Not signed" - ) - - return f""" - Certificate: - Pubkey Type: {self.header['pubkey_type'].value} - Public Key: {str(self.header['public_key'])} - CA Public Key: {str(self.signature_pubkey)} - Nonce: {self.header['nonce'].value} - Certificate Type: {'User' if self.fields['cert_type'].value == 1 else 'Host'} - Valid After: {self.fields['valid_after'].value.strftime('%Y-%m-%d %H:%M:%S')} - Valid Until: {self.fields['valid_before'].value.strftime('%Y-%m-%d %H:%M:%S')} - Principals: {principals} - Critical options: {critical} - Extensions: {extensions} - Signature: {signature_val} - """ - - @staticmethod - def decode( - cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None - ) -> "SSHCertificate": - """ - Decode an existing certificate and import it into a new object - - Args: - cert_bytes (bytes): The certificate bytes, base64 decoded middle part of the certificate - pubkey_field (_FIELD.PublicKeyField): Instance of the PublicKeyField class, only needs - to be set if it can't be detected automatically - - Raises: - _EX.InvalidCertificateFormatException: Invalid or unknown certificate format - - Returns: - SSHCertificate: SSHCertificate child class - """ - if pubkey_class is None: - cert_type = _FIELD.StringField.decode(cert_bytes)[0].encode("utf-8") - pubkey_class = CERT_TYPES.get(cert_type, False) - - if pubkey_class is False: - raise _EX.InvalidCertificateFormatException( - "Could not determine certificate type, please use one " - + "of the specific classes or specify the pubkey_class" - ) - - decode_fields = join_dicts( - { - "pubkey_type": _FIELD.PubkeyTypeField, - "nonce": _FIELD.NonceField, - "public_key": pubkey_class, - }, - CERTIFICATE_FIELDS, - { - "reserved": _FIELD.ReservedField, - "ca_pubkey": _FIELD.CAPublicKeyField, - "signature": _FIELD.SignatureField, - }, - ) - - cert = {} - - for item in decode_fields.keys(): - cert[item], cert_bytes = decode_fields[item].from_decode(cert_bytes) - - if cert_bytes != b"": - raise _EX.InvalidCertificateFormatException( - "The certificate has additional data after everything has been extracted" - ) - - pubkey_type = ensure_string(cert["pubkey_type"].value) - - cert_type = CERT_TYPES[pubkey_type] - cert.pop("reserved") - return globals()[cert_type[0]]( - subject_pubkey=cert["public_key"].value, decoded=cert - ) - - @classmethod - def from_public_class( - cls, public_key: PublicKey, ca_privkey: PrivateKey = None, **kwargs - ) -> "SSHCertificate": - """ - Creates a new certificate from a supplied public key - - Args: - public_key (PublicKey): The public key for which to create a certificate - - Returns: - SSHCertificate: SSHCertificate child class - """ - return globals()[ - public_key.__class__.__name__.replace("PublicKey", "Certificate") - ](public_key, ca_privkey, **kwargs) - - @classmethod - def from_bytes(cls, cert_bytes: bytes): - """ - Loads an existing certificate from the byte value. - - Args: - cert_bytes (bytes): Certificate bytes, base64 decoded middle part of the certificate - - Returns: - SSHCertificate: SSHCertificate child class - """ - cert_type, _ = _FIELD.StringField.decode(cert_bytes) - target_class = CERT_TYPES[cert_type] - return globals()[target_class[0]].decode(cert_bytes) - - @classmethod - def from_string(cls, cert_str: Union[str, bytes], encoding: str = "utf-8"): - """ - Loads an existing certificate from a string in the format - [certificate-type] [base64-encoded-certificate] [optional-comment] - - Args: - cert_str (str): The string containing the certificate - encoding (str, optional): The encoding of the string. Defaults to 'utf-8'. - - Returns: - SSHCertificate: SSHCertificate child class - """ - cert_str = ensure_bytestring(cert_str) - - certificate = b64decode(cert_str.split(b" ")[1]) - return cls.from_bytes(cert_bytes=certificate) - - @classmethod - def from_file(cls, path: str, encoding: str = "utf-8"): - """ - Loads an existing certificate from a file - - Args: - path (str): The path to the certificate file - encoding (str, optional): Encoding of the file. Defaults to 'utf-8'. - - Returns: - SSHCertificate: SSHCertificate child class - """ - return cls.from_string(open(path, "r", encoding=encoding).read()) - - def set_ca(self, ca_privkey: PrivateKey): - """ - Set the CA Private Key for signing the certificate - - Args: - ca_privkey (PrivateKey): The CA private key - """ - self.signature = _FIELD.SignatureField.from_object(ca_privkey) - self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( - ca_privkey.public_key - ) - - def set_type(self, pubkey_type: str): - """ - Set the type of the public key if not already set automatically - The child classes will set this automatically - - Args: - pubkey_type (str): Public key type, e.g. ssh-rsa-cert-v01@openssh.com - """ - if not getattr(self.header["pubkey_type"], "value", False): - self.header["pubkey_type"] = self.header["pubkey_type"](pubkey_type) - - def set_opt(self, key: str, value): - """ - Add information to a field in the certificate - - Args: - key (str): The key to set - value (mixed): The new value for the field - - Raises: - _EX.InvalidCertificateFieldException: Invalid field - """ - if key not in self.fields: - raise _EX.InvalidCertificateFieldException( - f"{key} is not a valid certificate field" - ) - - try: - if self.fields[key].value not in [None, False, "", [], ()]: - self.fields[key].value = value - except AttributeError: - self.fields[key] = self.fields[key](value) - - def set_opts(self, **kwargs): - """ - Set multiple options at once - """ - for key, value in kwargs.items(): - self.set_opt(key, value) - - def get_opt(self, key: str): - """ - Get the value of a field in the certificate - - Args: - key (str): The key to get - - Raises: - _EX.InvalidCertificateFieldException: Invalid field - """ - if key not in self.fields: - raise _EX.InvalidCertificateFieldException( - f"{key} is not a valid certificate field" - ) - - return getattr(self.fields[key], "value", None) - - # pylint: disable=used-before-assignment - def can_sign(self) -> bool: - """ - Determine if the certificate is ready to be signed - - Raises: - ...: Exception from the respective field with error - _EX.NoPrivateKeyException: Private key is missing from class - - Returns: - bool: True/False if the certificate can be signed - """ - self.header['nonce'].validate() - - exceptions = [] - for field in self.fields.values(): - try: - valid = field.validate() - except TypeError: - valid = _EX.SignatureNotPossibleException( - f"The field {field} is missing a value" - ) - finally: - if isinstance(valid, Exception): - exceptions.append(valid) - - if ( - getattr(self, "signature", False) is False - or getattr(self, "signature_pubkey", False) is False - ): - exceptions.append( - _EX.SignatureNotPossibleException("No CA private key is set") - ) - - if len(exceptions) > 0: - raise _EX.SignatureNotPossibleException(exceptions) - - if self.signature.can_sign() is True: - return True - - raise _EX.SignatureNotPossibleException( - "The certificate cannot be signed, the CA private key is not loaded" - ) - - def get_signable_data(self) -> bytes: - """ - Gets the signable byte string from the certificate fields - - Returns: - bytes: The data in the certificate which is signed - """ - return ( - b"".join( - [ - bytes(x) - for x in tuple(self.header.values()) + tuple(self.fields.values()) - ] - ) - + bytes(_FIELD.ReservedField()) - + bytes(self.signature_pubkey) - ) - - def sign(self, **signing_args): - """ - Sign the certificate - - Returns: - SSHCertificate: The signed certificate class - """ - if self.can_sign(): - self.signature.sign(data=self.get_signable_data(), **signing_args) - - return self - - def verify(self, ca_pubkey: PublicKey = None) -> bool: - """ - Verifies a signature against a given public key. - - If no public key is provided, the signature is checked against - the public/private key provided to the class on creation - or decoding. - - Not providing the public key for the CA with an imported - certificate means the verification will succeed even if an - attacker has replaced the signature and public key for signing. - - If the certificate wasn't created and signed on the same occasion - as the validity check, you should always provide a public key for - verificiation. - - Returns: - bool: If the certificate signature is valid - """ - - if ca_pubkey is None: - ca_pubkey = self.signature_pubkey.value - - cert_data = self.get_signable_data() - signature = self.signature.value - - return ca_pubkey.verify(cert_data, signature) - - def to_bytes(self) -> bytes: - """ - Export the signed certificate in byte-format - - Raises: - _EX.NotSignedException: The certificate has not been signed yet - - Returns: - bytes: The certificate bytes - """ - if self.signature.is_signed is True: - return self.get_signable_data() + bytes(self.signature) - - raise _EX.NotSignedException("The certificate has not been signed") - - def to_string( - self, comment: Union[str, bytes] = None, encoding: str = "utf-8" - ) -> str: - """ - Export the signed certificate to a string, ready to be written to file - - Args: - comment (Union[str, bytes], optional): Comment to add to the string. Defaults to None. - encoding (str, optional): Encoding to use for the string. Defaults to 'utf-8'. - - Returns: - str: Certificate string - """ - return concat_to_string( - self.header["pubkey_type"].value, - " ", - b64encode(self.to_bytes()), - " ", - comment if comment else "", - encoding=encoding - ) - - def to_file( - self, path: str, comment: Union[str, bytes] = None, encoding: str = "utf-8" - ): - """ - Saves the certificate to a file - - Args: - path (str): The path of the file to save to - comment (Union[str, bytes], optional): Comment to add to the certificate end. - Defaults to None. - encoding (str, optional): Encoding for the file. Defaults to 'utf-8'. - """ - with open(path, "w", encoding=encoding) as file: - file.write(self.to_string(comment, encoding)) - - -class RsaCertificate(SSHCertificate): - """ - Specific class for RSA Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, - subject_pubkey: RsaPublicKey, - ca_privkey: PrivateKey = None, - rsa_alg: RsaAlgs = RsaAlgs.SHA512, - **kwargs, - ): - - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.rsa_alg = rsa_alg - self.set_type(f"{rsa_alg.value[0]}-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "SSHCertificate": - """ - Decode an existing RSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - RsaCertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.RsaPubkeyField) - - -class DsaCertificate(SSHCertificate): - """ - Specific class for DSA/DSS Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: DsaPublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type("ssh-dss-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "DsaCertificate": - """ - Decode an existing DSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - DsaCertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.DsaPubkeyField) - - -class EcdsaCertificate(SSHCertificate): - """ - Specific class for ECDSA Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: EcdsaPublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type( - f"ecdsa-sha2-nistp{subject_pubkey.key.curve.key_size}-cert-v01@openssh.com" - ) - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "EcdsaCertificate": - """ - Decode an existing ECDSA Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - EcdsaCertificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.EcdsaPubkeyField) - - -class Ed25519Certificate(SSHCertificate): - """ - Specific class for ED25519 Certificates. Inherits from SSHCertificate - """ - - def __init__( - self, subject_pubkey: Ed25519PublicKey, ca_privkey: PrivateKey = None, **kwargs - ): - super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type("ssh-ed25519-cert-v01@openssh.com") - - @classmethod - # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> "Ed25519Certificate": - """ - Decode an existing ED25519 Certificate - - Args: - cert_bytes (bytes): The base64-decoded bytes for the certificate - - Returns: - Ed25519Certificate: The decoded certificate - """ - return super().decode(cert_bytes, _FIELD.Ed25519PubkeyField) diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index 273e3f8..a23af5e 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -3,47 +3,46 @@ """ # pylint: disable=invalid-name,too-many-lines,arguments-differ import re -from enum import Enum -from types import MethodType -from typing import Union, Tuple -from dataclasses import dataclass +from base64 import b64encode from datetime import datetime, timedelta +from enum import Enum from struct import pack, unpack -from base64 import b64encode +from typing import Tuple, Union + from cryptography.hazmat.primitives.asymmetric.utils import ( decode_dss_signature, encode_dss_signature, ) + from . import exceptions as _EX from .keys import ( - RsaAlgs, - PrivateKey, - PublicKey, - RsaPublicKey, - RsaPrivateKey, - DsaPublicKey, DsaPrivateKey, - EcdsaPublicKey, + DsaPublicKey, EcdsaPrivateKey, - Ed25519PublicKey, + EcdsaPublicKey, Ed25519PrivateKey, + Ed25519PublicKey, + PrivateKey, + PublicKey, + RsaAlgs, + RsaPrivateKey, + RsaPublicKey, ) - from .utils import ( - long_to_bytes, bytes_to_long, + concat_to_string, + ensure_bytestring, + ensure_string, generate_secure_nonce, + long_to_bytes, random_keyid, random_serial, - ensure_string, - ensure_bytestring, - concat_to_string ) NoneType = type(None) MAX_INT32 = 2**32 MAX_INT64 = 2**64 -NEWLINE = '\n' +NEWLINE = "\n" ECDSA_CURVE_MAP = { "secp256r1": "nistp256", @@ -72,6 +71,7 @@ b"ed25519": "Ed25519SignatureField", } + class CERT_TYPE(Enum): """ Certificate types, User certificate/Host certificate @@ -80,21 +80,26 @@ class CERT_TYPE(Enum): USER = 1 HOST = 2 + class CertificateField: """ The base class for certificate fields """ + IS_SET = None DEFAULT = None REQUIRED = False DATA_TYPE = NoneType - def __init__(self, value = None): + def __init__(self, value=None): self.value = value self.exception = None self.IS_SET = True self.name = self.get_name() + def __table__(self): + return (str(self.name), str(self.value)) + def __str__(self): return f"{self.name}: {self.value}" @@ -103,9 +108,7 @@ def __bytes__(self) -> bytes: @classmethod def get_name(cls) -> str: - return "_".join( - re.findall('[A-Z][^A-Z]*', cls.__name__)[:-1] - ).lower() + return "_".join(re.findall("[A-Z][^A-Z]*", cls.__name__)[:-1]).lower() @classmethod def __validate_type__(cls, value, do_raise: bool = False) -> Union[bool, Exception]: @@ -116,19 +119,19 @@ def __validate_type__(cls, value, do_raise: bool = False) -> Union[bool, Excepti ex = _EX.InvalidDataException( f"Invalid data type for {cls.get_name()} (expected {cls.DATA_TYPE}, got {type(value)})" ) - + if do_raise: raise ex - + return ex - return True + return True def __validate_required__(self) -> Union[bool, Exception]: """ Validates if the field is set when required """ - if self.DEFAULT == self.value == None: + if self.DEFAULT == self.value is None: return _EX.InvalidFieldDataException( f"{self.get_name()} is a required field" ) @@ -146,19 +149,19 @@ def validate(self) -> bool: """ Validates all field contents and types """ - if isinstance(self.value, NoneType) and self.DEFAULT != None: + if isinstance(self.value, NoneType) and self.DEFAULT is not None: self.value = self.DEFAULT() if callable(self.DEFAULT) else self.DEFAULT - + self.exception = ( self.__validate_type__(self.value), self.__validate_required__(), - self.__validate_value__() + self.__validate_value__(), ) - + return self.exception == (True, True, True) @staticmethod - def decode(cls, data: bytes) -> tuple: + def decode(data: bytes) -> tuple: """ Returns the decoded value of the field """ @@ -168,7 +171,7 @@ def encode(cls, value) -> bytes: """ Returns the encoded value of the field """ - + @classmethod def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: """ @@ -179,26 +182,32 @@ def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: """ value, data = cls.decode(data) return cls(value), data - + @classmethod - def factory(cls) -> 'CertificateField': + def factory(cls, blank: bool = False) -> "CertificateField": """ Factory to create field with default value if set, otherwise empty + Args: + blank (bool): Return a blank class (for decoding) + Returns: CertificateField: A new CertificateField subclass instance """ + if cls.DEFAULT is None or blank: + return cls + if callable(cls.DEFAULT): return cls(cls.DEFAULT()) - if cls.DEFAULT is None: - return cls - + return cls(cls.DEFAULT) + class BooleanField(CertificateField): """ Field representing a boolean value (True/False) or (1/0) """ + DATA_TYPE = (bool, int) @classmethod @@ -224,19 +233,25 @@ def decode(data: bytes) -> Tuple[bool, bytes]: data (bytes): The byte string starting with an encoded boolean """ return bool(unpack("B", data[:1])[0]), data[1:] - + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - return True if self.value in (True, False, 1, 0) else _EX.InvalidFieldDataException( - f"{self.get_name()} must be a boolean (True/1 or False/0)" + return ( + True + if self.value in (True, False, 1, 0) + else _EX.InvalidFieldDataException( + f"{self.get_name()} must be a boolean (True/1 or False/0)" + ) ) + class BytestringField(CertificateField): """ Field representing a bytestring value """ + DATA_TYPE = (bytes, str) DEFAULT = b"" @@ -270,10 +285,12 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: length = unpack(">I", data[:4])[0] + 4 return ensure_bytestring(data[4:length]), data[length:] + class StringField(BytestringField): """ Field representing a string value """ + DATA_TYPE = (str, bytes) DEFAULT = "" @@ -288,7 +305,7 @@ def encode(cls, value: str, encoding: str = "utf-8"): Returns: bytes: Packed byte string containing the source data - """ + """ cls.__validate_type__(value, True) return BytestringField.encode(ensure_bytestring(value, encoding)) @@ -306,12 +323,14 @@ def decode(data: bytes, encoding: str = "utf-8") -> Tuple[str, bytes]: """ value, data = BytestringField.decode(data) - return value.decode(encoding), data + return value.decode(encoding), data + class Integer32Field(CertificateField): """ Certificate field representing a 32-bit integer """ + DATA_TYPE = int DEFAULT = 0 @@ -345,20 +364,23 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + if self.value < MAX_INT32: return True - + return _EX.InvalidFieldDataException( f"{self.get_name()} must be a 32-bit integer" ) - + class Integer64Field(CertificateField): """ Certificate field representing a 64-bit integer """ + DATA_TYPE = int DEFAULT = 0 @@ -392,11 +414,13 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + if self.value < MAX_INT64: return True - + return _EX.InvalidFieldDataException( f"{self.get_name()} must be a 64-bit integer" ) @@ -407,6 +431,7 @@ class DateTimeField(Integer64Field): Certificate field representing a datetime value. The value is saved as a 64-bit integer (unix timestamp) """ + DATA_TYPE = (datetime, int) DEFAULT = datetime.now @@ -439,28 +464,32 @@ def decode(data: bytes) -> datetime: """ timestamp, data = Integer64Field.decode(data) return datetime.fromtimestamp(timestamp), data - + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) check = self.value if isinstance(self.value, int) else self.value.timestamp() - + if check < MAX_INT64: return True - + return _EX.InvalidFieldDataException( f"{self.get_name()} must be a 64-bit integer or datetime object" ) + class MpIntegerField(BytestringField): """ Certificate field representing a multiple precision integer, an integer too large to fit in 64 bits. """ + DATA_TYPE = int DEFAULT = 0 @@ -492,10 +521,12 @@ def decode(data: bytes) -> Tuple[int, bytes]: mpint, data = BytestringField.decode(data) return bytes_to_long(mpint), data + class ListField(CertificateField): """ Certificate field representing a list or tuple of strings """ + DATA_TYPE = (list, set, tuple) DEFAULT = [] @@ -545,19 +576,25 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - - if hasattr(self.value, '__iter__') and not all((isinstance(val, (str, bytes)) for val in self.value)): + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + + if hasattr(self.value, "__iter__") and not all( + (isinstance(val, (str, bytes)) for val in self.value) + ): return _EX.InvalidFieldDataException( "Expected list or tuple containing strings or bytes" ) return True + class KeyValueField(CertificateField): """ Certificate field representing a list or integer in python, separated in byte-form by null-bytes. """ + DATA_TYPE = (list, tuple, set, dict) DEFAULT = {} @@ -629,26 +666,33 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + testvals = ( - self.value if not isinstance(self.value, dict) + self.value + if not isinstance(self.value, dict) else list(self.value.keys()) + list(self.value.values()) ) - - if hasattr(self.value, '__iter__') and not all((isinstance(val, (str, bytes)) for val in testvals)): + + if hasattr(self.value, "__iter__") and not all( + (isinstance(val, (str, bytes)) for val in testvals) + ): return _EX.InvalidFieldDataException( - "Expected dict, list, tuple, set with string or byte keys and values" - ) + "Expected dict, list, tuple, set with string or byte keys and values" + ) return True + class PubkeyTypeField(StringField): """ Contains the certificate type, which is based on the public key type the certificate is created for, e.g. 'ssh-ed25519-cert-v01@openssh.com' for an ED25519 key """ + DEFAULT = None DATA_TYPE = (str, bytes) ALLOWED_VALUES = ( @@ -667,34 +711,40 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + if ensure_string(self.value) not in self.ALLOWED_VALUES: return _EX.InvalidFieldDataException( "Expected one of the following values: {}".format( NEWLINE.join(self.ALLOWED_VALUES) ) ) - + return True + class NonceField(StringField): """ Contains the nonce for the certificate, randomly generated this protects the integrity of the private key, especially for ecdsa. """ + DEFAULT = generate_secure_nonce DATA_TYPE = (str, bytes) - + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - - if hasattr(self.value, '__count__') and len(self.value) < 32: + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + + if hasattr(self.value, "__count__") and len(self.value) < 32: return _EX.InvalidFieldDataException( "Expected a nonce of at least 32 bytes" ) @@ -707,9 +757,13 @@ class PublicKeyField(CertificateField): Contains the subject (User or Host) public key for whom/which the certificate is created. """ + DEFAULT = None DATA_TYPE = PublicKey + def __table__(self) -> tuple: + return [str(self.name), str(self.value.get_fingerprint())] + def __str__(self) -> str: return " ".join( [ @@ -759,6 +813,7 @@ class RsaPubkeyField(PublicKeyField): """ Holds the RSA Public Key for RSA Certificates """ + DEFAULT = None DATA_TYPE = RsaPublicKey @@ -779,10 +834,12 @@ def decode(data: bytes) -> Tuple[RsaPublicKey, bytes]: return RsaPublicKey.from_numbers(e=e, n=n), data + class DsaPubkeyField(PublicKeyField): """ Holds the DSA Public Key for DSA Certificates """ + DEFAULT = None DATA_TYPE = DsaPublicKey @@ -805,10 +862,12 @@ def decode(data: bytes) -> Tuple[DsaPublicKey, bytes]: return DsaPublicKey.from_numbers(p=p, q=q, g=g, y=y), data + class EcdsaPubkeyField(PublicKeyField): """ Holds the ECDSA Public Key for ECDSA Certificates """ + DEFAULT = None DATA_TYPE = EcdsaPublicKey @@ -842,10 +901,12 @@ def decode(data: bytes) -> Tuple[EcdsaPublicKey, bytes]: data, ) + class Ed25519PubkeyField(PublicKeyField): """ Holds the ED25519 Public Key for ED25519 Certificates """ + DEFAULT = None DATA_TYPE = Ed25519PublicKey @@ -871,23 +932,21 @@ class SerialField(Integer64Field): Contains the numeric serial number of the certificate, maximum is (2**64)-1 """ + DEFAULT = random_serial DATA_TYPE = int + class CertificateTypeField(Integer32Field): """ Contains the certificate type User certificate: CERT_TYPE.USER/1 Host certificate: CERT_TYPE.HOST/2 """ + DEFAULT = CERT_TYPE.USER DATA_TYPE = (CERT_TYPE, int) - ALLOWED_VALUES = ( - CERT_TYPE.USER, - CERT_TYPE.HOST, - 1, - 2 - ) + ALLOWED_VALUES = (CERT_TYPE.USER, CERT_TYPE.HOST, 1, 2) @classmethod def encode(cls, value: Union[CERT_TYPE, int]) -> bytes: @@ -912,23 +971,28 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + if self.value not in self.ALLOWED_VALUES: return _EX.InvalidCertificateFieldException( - f"The certificate type is invalid (expected int(1,2) or CERT_TYPE.X)" + "The certificate type is invalid (expected int(1,2) or CERT_TYPE.X)" ) - + return True + class KeyIdField(StringField): """ Contains the key identifier (subject) of the certificate, alphanumeric string """ + DEFAULT = random_keyid DATA_TYPE = (str, bytes) + class PrincipalsField(ListField): """ Contains a list of principals for the certificate, @@ -936,25 +1000,30 @@ class PrincipalsField(ListField): If no principals are added, the certificate is valid only for servers that have no allowed principals specified """ + DEFAFULT = [] DATA_TYPE = (list, set, tuple) + class ValidAfterField(DateTimeField): """ Contains the start of the validity period for the certificate, represented by a datetime object """ + DEFAULT = datetime.now() DATA_TYPE = (datetime, int) + class ValidBeforeField(DateTimeField): """ Contains the end of the validity period for the certificate, represented by a datetime object """ + DEFAULT = datetime.now() + timedelta(minutes=10) DATA_TYPE = (datetime, int) - + def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field @@ -963,18 +1032,25 @@ def __validate_value__(self) -> Union[bool, Exception]: created """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + super().__validate_value__() - check = self.value if isinstance(self.value, datetime) else datetime.fromtimestamp(self.value) - + check = ( + self.value + if isinstance(self.value, datetime) + else datetime.fromtimestamp(self.value) + ) + if check < datetime.now(): return _EX.InvalidCertificateFieldException( - f'The certificate validity period is invalid (expected a future datetime object or timestamp)' + "The certificate validity period is invalid (expected a future datetime object or timestamp)" ) - + return True + class CriticalOptionsField(KeyValueField): """ Contains the critical options part of the certificate (optional). @@ -991,30 +1067,32 @@ class CriticalOptionsField(KeyValueField): If set to true, the user must verify their identity if using a hardware token """ + DEFAULT = [] DATA_TYPE = (list, set, tuple, dict) - ALLOWED_VALUES = ( - "force-command", - "source-address", - "verify-required" - ) + ALLOWED_VALUES = ("force-command", "source-address", "verify-required") def __validate_value__(self) -> Union[bool, Exception]: """ Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - - for elem in self.value if not isinstance(self.value, dict) else list(self.value.keys()): + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + + for elem in ( + self.value if not isinstance(self.value, dict) else list(self.value.keys()) + ): if elem not in self.ALLOWED_VALUES: return _EX.InvalidCertificateFieldException( - f"Critical option not recognized ({elem}){NEWLINE}" + - f"Valid options are {', '.join(self.ALLOWED_VALUES)}" + f"Critical option not recognized ({elem}){NEWLINE}" + + f"Valid options are {', '.join(self.ALLOWED_VALUES)}" ) return True + class ExtensionsField(KeyValueField): """ Contains a list of extensions for the certificate, @@ -1042,6 +1120,7 @@ class ExtensionsField(KeyValueField): Permits the user to use the user rc file """ + DEFAULT = [] DATA_TYPE = (list, set, tuple, dict) ALLOWED_VALUES = ( @@ -1050,7 +1129,7 @@ class ExtensionsField(KeyValueField): "permit-agent-forwarding", "permit-port-forwarding", "permit-pty", - "permit-user-rc" + "permit-user-rc", ) def __validate_value__(self) -> Union[bool, Exception]: @@ -1058,13 +1137,15 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + for item in self.value: if item not in self.ALLOWED_VALUES: return _EX.InvalidDataException( - f"Invalid extension '{item}'{NEWLINE}" + - f"Allowed values are: {NEWLINE.join(self.ALLOWED_VALUES)}" + f"Invalid extension '{item}'{NEWLINE}" + + f"Allowed values are: {NEWLINE.join(self.ALLOWED_VALUES)}" ) return True @@ -1075,6 +1156,7 @@ class ReservedField(StringField): This field is reserved for future use, and doesn't contain any actual data, just an empty string. """ + DEFAULT = "" DATA_TYPE = str @@ -1083,17 +1165,23 @@ def __validate_value__(self) -> Union[bool, Exception]: Validates the contents of the field """ if isinstance(self.__validate_type__(self.value), Exception): - return _EX.InvalidFieldDataException(f"{self.get_name()} Could not validate value, invalid type") - - return True if self.value == "" else _EX.InvalidDataException( - f"The reserved field is not empty" + return _EX.InvalidFieldDataException( + f"{self.get_name()} Could not validate value, invalid type" + ) + + return ( + True + if self.value == "" + else _EX.InvalidDataException("The reserved field is not empty") ) + class CAPublicKeyField(BytestringField): """ Contains the public key of the certificate authority that is used to sign the certificate. """ + DEFAULT = None DATA_TYPE = (str, bytes) @@ -1109,6 +1197,12 @@ def __str__(self) -> str: ] ) + def __bytes__(self) -> bytes: + return self.encode(self.value.raw_bytes()) + + def __table__(self) -> tuple: + return ("CA Public Key", self.value.get_fingerprint()) + def validate(self) -> Union[bool, Exception]: """ Validates the contents of the field @@ -1124,7 +1218,7 @@ def validate(self) -> Union[bool, Exception]: return True @staticmethod - def decode(data) -> Tuple[PublicKey, bytes]: + def decode(data: bytes) -> Tuple[PublicKey, bytes]: """ Decode the certificate field from a byte string starting with the encoded public key @@ -1139,13 +1233,12 @@ def decode(data) -> Tuple[PublicKey, bytes]: pubkey_type = StringField.decode(pubkey)[0] return ( - PublicKey.from_string(concat_to_string(pubkey_type, " ", b64encode(pubkey))), + PublicKey.from_string( + concat_to_string(pubkey_type, " ", b64encode(pubkey)) + ), data, ) - def __bytes__(self) -> bytes: - return self.encode(self.value.raw_bytes()) - @classmethod def from_object(cls, public_key: PublicKey) -> "CAPublicKeyField": """ @@ -1158,6 +1251,7 @@ class SignatureField(CertificateField): """ Creates and contains the signature of the certificate """ + DEFAULT = None DATA_TYPE = bytes @@ -1166,10 +1260,17 @@ def __init__(self, private_key: PrivateKey = None, signature: bytes = None): self.private_key = private_key self.is_signed = False self.value = signature - + if signature is not None and ensure_bytestring(signature) not in ("", " "): self.is_signed = True + def __table__(self) -> tuple: + msg = "No signature" + if self.is_signed: + msg = f"Signed with private key {self.private_key.get_fingerprint()}" + + return ("Signature", msg) + @staticmethod def from_object(private_key: PrivateKey): """ @@ -1228,17 +1329,16 @@ def sign(self, data: bytes) -> None: Placeholder signing function """ raise _EX.InvalidClassCallException("The base class has no sign function") - + def __bytes__(self) -> None: return self.encode(self.value) - - class RsaSignatureField(SignatureField): """ Creates and contains the RSA signature from an RSA Private Key """ + DEFAULT = None DATA_TYPE = bytes @@ -1309,7 +1409,7 @@ def from_decode(cls, data: bytes) -> Tuple["RsaSignatureField", bytes]: cls( private_key=None, hash_alg=[alg for alg in RsaAlgs if alg.value[0] == signature[0]][0], - signature=signature[1] + signature=signature[1], ), data, ) @@ -1336,6 +1436,7 @@ class DsaSignatureField(SignatureField): """ Creates and contains the DSA signature from an DSA Private Key """ + DEFAULT = None DATA_TYPE = bytes @@ -1415,6 +1516,7 @@ class EcdsaSignatureField(SignatureField): """ Creates and contains the ECDSA signature from an ECDSA Private Key """ + DEFAULT = None DATA_TYPE = bytes @@ -1515,11 +1617,15 @@ class Ed25519SignatureField(SignatureField): """ Creates and contains the ED25519 signature from an ED25519 Private Key """ + DEFAULT = None DATA_TYPE = bytes def __init__( - self, private_key: Ed25519PrivateKey = None, signature: bytes = None + self, + # trunk-ignore(gitleaks/generic-api-key) + private_key: Ed25519PrivateKey = None, + signature: bytes = None, ) -> None: super().__init__(private_key, signature) diff --git a/src/sshkey_tools/keys.py b/src/sshkey_tools/keys.py index 534f420..bd521d0 100644 --- a/src/sshkey_tools/keys.py +++ b/src/sshkey_tools/keys.py @@ -1,42 +1,35 @@ """ Classes for handling SSH public/private keys """ -from typing import Union -from enum import Enum from base64 import b64decode +from enum import Enum from struct import unpack -from cryptography.hazmat.backends.openssl.rsa import _RSAPublicKey, _RSAPrivateKey -from cryptography.hazmat.backends.openssl.dsa import _DSAPublicKey, _DSAPrivateKey -from cryptography.hazmat.backends.openssl.ed25519 import ( - _Ed25519PublicKey, - _Ed25519PrivateKey, -) +from typing import Union + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends.openssl.dsa import _DSAPrivateKey, _DSAPublicKey from cryptography.hazmat.backends.openssl.ec import ( - _EllipticCurvePublicKey, _EllipticCurvePrivateKey, + _EllipticCurvePublicKey, ) -from cryptography.hazmat.primitives import ( - serialization as _SERIALIZATION, - hashes as _HASHES, -) -from cryptography.hazmat.primitives.asymmetric import ( - rsa as _RSA, - dsa as _DSA, - ec as _ECDSA, - ed25519 as _ED25519, - padding as _PADDING, +from cryptography.hazmat.backends.openssl.ed25519 import ( + _Ed25519PrivateKey, + _Ed25519PublicKey, ) - -from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey, _RSAPublicKey +from cryptography.hazmat.primitives import hashes as _HASHES +from cryptography.hazmat.primitives import serialization as _SERIALIZATION +from cryptography.hazmat.primitives.asymmetric import dsa as _DSA +from cryptography.hazmat.primitives.asymmetric import ec as _ECDSA +from cryptography.hazmat.primitives.asymmetric import ed25519 as _ED25519 +from cryptography.hazmat.primitives.asymmetric import padding as _PADDING +from cryptography.hazmat.primitives.asymmetric import rsa as _RSA from . import exceptions as _EX -from .utils import ( - md5_fingerprint as _FP_MD5, - sha256_fingerprint as _FP_SHA256, - sha512_fingerprint as _FP_SHA512, - ensure_string, - ensure_bytestring -) +from .utils import ensure_bytestring, ensure_string +from .utils import md5_fingerprint as _FP_MD5 +from .utils import sha256_fingerprint as _FP_SHA256 +from .utils import sha512_fingerprint as _FP_SHA512 PUBKEY_MAP = { _RSAPublicKey: "RsaPublicKey", @@ -49,6 +42,7 @@ _RSAPrivateKey: "RsaPrivateKey", _DSAPrivateKey: "DsaPrivateKey", _EllipticCurvePrivateKey: "EcdsaPrivateKey", + # trunk-ignore(gitleaks/generic-api-key) _Ed25519PrivateKey: "Ed25519PrivateKey", } @@ -178,7 +172,7 @@ def from_string( Returns: PublicKey: Any of the PublicKey child classes - """ + """ split = ensure_bytestring(data).split(b" ") comment = None if len(split) > 2: @@ -205,6 +199,17 @@ def from_file(cls, path: str) -> "PublicKey": return cls.from_string(data) + @classmethod + def from_bytes(cls, data: bytes) -> "PublicKey": + for key_class in PUBKEY_MAP.values(): + try: + key = globals()[key_class].from_raw_bytes(data) + return key + except Exception: + pass + + raise _EX.InvalidKeyException("Invalid public key") + def get_fingerprint( self, hash_method: FingerprintHashes = FingerprintHashes.SHA256 ) -> str: @@ -246,7 +251,10 @@ def to_string(self, encoding: str = "utf-8") -> str: encoding(str, optional): The encoding of the file. Defaults to 'utf-8'. """ return " ".join( - [ ensure_string(self.serialize()), ensure_string(getattr(self, 'comment', '')) ] + [ + ensure_string(self.serialize()), + ensure_string(getattr(self, "comment", "")), + ] ) def to_file(self, path: str, encoding: str = "utf-8") -> None: @@ -365,7 +373,7 @@ def to_bytes(self, password: Union[str, bytes] = None) -> bytes: bytes: The private key in PEM format """ password = ensure_bytestring(password) - + encryption = _SERIALIZATION.NoEncryption() if password is not None: encryption = self.export_opts["encryption"](password) diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index 34c477e..5a2803e 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -1,20 +1,21 @@ """ Utilities for handling keys and certificates """ +import hashlib as hl import sys -from typing import Union, List, Dict -from secrets import randbits +from base64 import b64encode from random import randint +from secrets import randbits +from typing import Dict, List, Union from uuid import uuid4 -from base64 import b64encode -import hashlib as hl NoneType = type(None) + def ensure_string( - obj: Union[str, bytes, list, tuple, set, dict, NoneType], - encoding: str = 'utf-8', - required: bool = False + obj: Union[str, bytes, list, tuple, set, dict, NoneType], + encoding: str = "utf-8", + required: bool = False, ) -> Union[str, List[str], Dict[str, str], NoneType]: """Ensure the provided value is or contains a string/strings @@ -32,14 +33,20 @@ def ensure_string( elif isinstance(obj, (list, tuple, set)): return [ensure_string(o, encoding) for o in obj] elif isinstance(obj, dict): - return {ensure_string(k, encoding): ensure_string(v, encoding) for k, v in obj.items()} + return { + ensure_string(k, encoding): ensure_string(v, encoding) + for k, v in obj.items() + } else: - raise TypeError(f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}.") + raise TypeError( + f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." + ) + def ensure_bytestring( - obj: Union[str, bytes, list, tuple, set, dict, NoneType], - encoding: str = 'utf-8', - required: bool = None + obj: Union[str, bytes, list, tuple, set, dict, NoneType], + encoding: str = "utf-8", + required: bool = None, ) -> Union[str, List[str], Dict[str, str], NoneType]: """Ensure the provided value is or contains a bytestring/bytestrings @@ -57,11 +64,17 @@ def ensure_bytestring( elif isinstance(obj, (list, tuple, set)): return [ensure_bytestring(o, encoding) for o in obj] elif isinstance(obj, dict): - return {ensure_bytestring(k, encoding): ensure_bytestring(v, encoding) for k, v in obj.items()} + return { + ensure_bytestring(k, encoding): ensure_bytestring(v, encoding) + for k, v in obj.items() + } else: - raise TypeError(f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}.") + raise TypeError( + f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." + ) + -def concat_to_string(*strs, encoding: str = 'utf-8') -> str: +def concat_to_string(*strs, encoding: str = "utf-8") -> str: """Concatenates a list of strings or bytestrings to a single string. Args: @@ -71,9 +84,10 @@ def concat_to_string(*strs, encoding: str = 'utf-8') -> str: Returns: str: Concatenated string """ - return ''.join(st if st is not None else "" for st in ensure_string(strs, encoding)) - -def concat_to_bytestring(*strs, encoding: str = 'utf-8') -> bytes: + return "".join(st if st is not None else "" for st in ensure_string(strs, encoding)) + + +def concat_to_bytestring(*strs, encoding: str = "utf-8") -> bytes: """Concatenates a list of strings or bytestrings to a single bytestring. Args: @@ -83,7 +97,11 @@ def concat_to_bytestring(*strs, encoding: str = 'utf-8') -> bytes: Returns: bytes: Concatenated bytestring """ - return b"".join(st if st is not None else b"" for st in ensure_bytestring(strs, encoding=encoding)) + return b"".join( + st if st is not None else b"" + for st in ensure_bytestring(strs, encoding=encoding) + ) + def random_keyid() -> str: """Generates a random Key ID @@ -93,13 +111,15 @@ def random_keyid() -> str: """ return str(uuid4()) + def random_serial() -> str: - """ Generates a random serial number - + """Generates a random serial number + Returns: int: Random serial """ - return randint(0, 2**64-1) + return randint(0, 2**64 - 1) + def long_to_bytes( source_int: int, force_length: int = None, byteorder: str = "big" diff --git a/tests/test_certificates.py b/tests/test_certificates.py index c04d51f..682c0a3 100644 --- a/tests/test_certificates.py +++ b/tests/test_certificates.py @@ -5,18 +5,19 @@ # Generated by ssh-keygen and decoded by script, created by script and verified by ssh-keygen import os +import random import shutil import unittest -import random -import datetime + import faker -import src.sshkey_tools.keys as _KEY -import src.sshkey_tools.fields as _FIELD import src.sshkey_tools.cert as _CERT import src.sshkey_tools.exceptions as _EX +import src.sshkey_tools.fields as _FIELD +import src.sshkey_tools.keys as _KEY + +CERTIFICATE_TYPES = ["rsa", "dsa", "ecdsa", "ed25519"] -CERTIFICATE_TYPES = ['rsa', 'dsa', 'ecdsa', 'ed25519'] class TestCertificateFields(unittest.TestCase): def setUp(self): @@ -25,759 +26,605 @@ def setUp(self): self.dsa_key = _KEY.DsaPrivateKey.generate() self.ecdsa_key = _KEY.EcdsaPrivateKey.generate() self.ed25519_key = _KEY.Ed25519PrivateKey.generate() - - def assertRandomResponse(self, field_class, values = None, random_function = None): + + def assertRandomResponse(self, field_class, values=None, random_function=None): if values is None: values = [random_function() for _ in range(100)] fields = [] - - bytestring = b'' + + bytestring = b"" for value in values: bytestring += field_class.encode(value) - + field = field_class(value) fields.append(field) - - self.assertEqual( - bytestring, - b''.join(bytes(x) for x in fields) - ) - + + self.assertEqual(bytestring, b"".join(bytes(x) for x in fields)) + decoded = [] - while bytestring != b'': + while bytestring != b"": decode, bytestring = field_class.decode(bytestring) decoded.append(decode) - - self.assertEqual( - decoded, - values - ) - + + self.assertEqual(decoded, values) + def assertExpectedResponse(self, field_class, input, expected_output): - self.assertEqual( - field_class.encode(input), - expected_output - ) - + self.assertEqual(field_class.encode(input), expected_output) + def assertFieldContainsException(self, field, exception): for item in field.exception: if isinstance(item, exception): return True return False - + def test_boolean_field(self): self.assertRandomResponse( - _FIELD.BooleanField, - random_function=lambda : self.faker.pybool() + _FIELD.BooleanField, random_function=lambda: self.faker.pybool() ) - + def test_invalid_boolean_field(self): field = _FIELD.BooleanField("SomeInvalidData") field.validate() - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_bytestring_field(self): self.assertRandomResponse( _FIELD.BytestringField, - random_function=lambda : self.faker.pystr(1, 100).encode('utf-8') + random_function=lambda: self.faker.pystr(1, 100).encode("utf-8"), ) - + def test_invalid_bytestring_field(self): field = _FIELD.BytestringField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_string_field(self): self.assertRandomResponse( - _FIELD.StringField, - random_function=lambda : self.faker.pystr(1, 100) + _FIELD.StringField, random_function=lambda: self.faker.pystr(1, 100) ) - + def test_invalid_string_field(self): field = _FIELD.StringField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_integer32_field(self): self.assertRandomResponse( _FIELD.Integer32Field, - random_function=lambda : random.randint(2**2, 2**32) + random_function=lambda: random.randint(2**2, 2**32), ) - + def test_invalid_integer32_field(self): field = _FIELD.Integer32Field(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.Integer32Field(_FIELD.MAX_INT32 + 1) field.validate() - - self.assertFieldContainsException( - field, - _EX.IntegerOverflowException - ) - + + self.assertFieldContainsException(field, _EX.IntegerOverflowException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_integer64_field(self): self.assertRandomResponse( _FIELD.Integer64Field, - random_function=lambda : random.randint(2**32, 2**64) + random_function=lambda: random.randint(2**32, 2**64), ) - + def test_invalid_integer64_field(self): field = _FIELD.Integer64Field(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.Integer64Field(_FIELD.MAX_INT64 + 1) field.validate() - - self.assertFieldContainsException( - field, - _EX.IntegerOverflowException - ) - + + self.assertFieldContainsException(field, _EX.IntegerOverflowException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_datetime_field(self): self.assertRandomResponse( - _FIELD.DateTimeField, - random_function=lambda : self.faker.date_time() + _FIELD.DateTimeField, random_function=lambda: self.faker.date_time() ) - + def test_invalid_datetime_field(self): field = _FIELD.DateTimeField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_mp_integer_field(self): self.assertRandomResponse( _FIELD.MpIntegerField, - random_function=lambda : random.randint(2**128, 2**512) + random_function=lambda: random.randint(2**128, 2**512), ) - + def test_invalid_mp_integer_field(self): field = _FIELD.MpIntegerField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - field = _FIELD.MpIntegerField('InvalidData') + + self.assertFieldContainsException(field, _EX.InvalidDataException) + + field = _FIELD.MpIntegerField("InvalidData") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_list_field(self): self.assertRandomResponse( _FIELD.ListField, - random_function=lambda : [self.faker.pystr(0, 100) for _ in range(10)] + random_function=lambda: [self.faker.pystr(0, 100) for _ in range(10)], ) - - + def test_invalid_list_field(self): field = _FIELD.ListField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) + + self.assertFieldContainsException(field, _EX.InvalidDataException) field = _FIELD.ListField([ValueError, ValueError, ValueError]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_key_value_field(self): self.assertRandomResponse( _FIELD.KeyValueField, - random_function=lambda : { - self.faker.pystr(1, 10): self.faker.pystr(1, 100) - for _ in range(10) - } + random_function=lambda: { + self.faker.pystr(1, 10): self.faker.pystr(1, 100) for _ in range(10) + }, ) - + def test_invalid_key_value_field(self): field = _FIELD.KeyValueField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.KeyValueField([ValueError, ValueError, ValueError]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_pubkey_type_field(self): allowed_values = ( - ('ssh-rsa-cert-v01@openssh.com', b'\x00\x00\x00\x1cssh-rsa-cert-v01@openssh.com'), - ('rsa-sha2-256-cert-v01@openssh.com', b'\x00\x00\x00!rsa-sha2-256-cert-v01@openssh.com'), - ('rsa-sha2-512-cert-v01@openssh.com', b'\x00\x00\x00!rsa-sha2-512-cert-v01@openssh.com'), - ('ssh-dss-cert-v01@openssh.com', b'\x00\x00\x00\x1cssh-dss-cert-v01@openssh.com'), - ('ecdsa-sha2-nistp256-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp256-cert-v01@openssh.com'), - ('ecdsa-sha2-nistp384-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp384-cert-v01@openssh.com'), - ('ecdsa-sha2-nistp521-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp521-cert-v01@openssh.com'), - ('ssh-ed25519-cert-v01@openssh.com', b'\x00\x00\x00 ssh-ed25519-cert-v01@openssh.com') + ( + "ssh-rsa-cert-v01@openssh.com", + b"\x00\x00\x00\x1cssh-rsa-cert-v01@openssh.com", + ), + ( + "rsa-sha2-256-cert-v01@openssh.com", + b"\x00\x00\x00!rsa-sha2-256-cert-v01@openssh.com", + ), + ( + "rsa-sha2-512-cert-v01@openssh.com", + b"\x00\x00\x00!rsa-sha2-512-cert-v01@openssh.com", + ), + ( + "ssh-dss-cert-v01@openssh.com", + b"\x00\x00\x00\x1cssh-dss-cert-v01@openssh.com", + ), + ( + "ecdsa-sha2-nistp256-cert-v01@openssh.com", + b"\x00\x00\x00(ecdsa-sha2-nistp256-cert-v01@openssh.com", + ), + ( + "ecdsa-sha2-nistp384-cert-v01@openssh.com", + b"\x00\x00\x00(ecdsa-sha2-nistp384-cert-v01@openssh.com", + ), + ( + "ecdsa-sha2-nistp521-cert-v01@openssh.com", + b"\x00\x00\x00(ecdsa-sha2-nistp521-cert-v01@openssh.com", + ), + ( + "ssh-ed25519-cert-v01@openssh.com", + b"\x00\x00\x00 ssh-ed25519-cert-v01@openssh.com", + ), ) - + for value in allowed_values: - self.assertExpectedResponse( - _FIELD.PubkeyTypeField, - value[0], - value[1] - ) - + self.assertExpectedResponse(_FIELD.PubkeyTypeField, value[0], value[1]) + def test_invalid_pubkey_type_field(self): - field = _FIELD.PubkeyTypeField('HelloWorld') + field = _FIELD.PubkeyTypeField("HelloWorld") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_nonce_field(self): randomized = _FIELD.NonceField(_FIELD.NonceField.DEFAULT()) randomized.validate() - - self.assertEqual( - randomized.exception, - (True, True, True) - ) - + self.assertEqual(randomized.exception, (True, True, True)) + specific = ( - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz', - b'\x00\x00\x004abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz' - ) - - self.assertExpectedResponse( - _FIELD.NonceField, - specific[0], - specific[1] + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", + b"\x00\x00\x004abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", ) - + + self.assertExpectedResponse(_FIELD.NonceField, specific[0], specific[1]) + def test_pubkey_class_assignment(self): rsa_field = _FIELD.PublicKeyField.from_object(self.rsa_key.public_key) dsa_field = _FIELD.PublicKeyField.from_object(self.dsa_key.public_key) ecdsa_field = _FIELD.PublicKeyField.from_object(self.ecdsa_key.public_key) ed25519_field = _FIELD.PublicKeyField.from_object(self.ed25519_key.public_key) - + self.assertIsInstance(rsa_field, _FIELD.RsaPubkeyField) self.assertIsInstance(dsa_field, _FIELD.DsaPubkeyField) self.assertIsInstance(ecdsa_field, _FIELD.EcdsaPubkeyField) self.assertIsInstance(ed25519_field, _FIELD.Ed25519PubkeyField) - + self.assertTrue(rsa_field.validate()) self.assertTrue(dsa_field.validate()) self.assertTrue(ecdsa_field.validate()) self.assertTrue(ed25519_field.validate()) - + def assertPubkeyOutput(self, key_class, *opts): - key = getattr(self, key_class.__name__.replace('PrivateKey', '').lower() + '_key').public_key + key = getattr( + self, key_class.__name__.replace("PrivateKey", "").lower() + "_key" + ).public_key raw_start = key.raw_bytes() field = _FIELD.PublicKeyField.from_object(key) byte_data = bytes(field) decoded = field.decode(byte_data)[0] - + self.assertTrue(field.validate()) - - self.assertEqual( - raw_start, - decoded.raw_bytes() - ) - + + self.assertEqual(raw_start, decoded.raw_bytes()) + self.assertEqual( f'{key.__class__.__name__.replace("PublicKey", "")} {key.get_fingerprint()}', - str(field) + str(field), ) - - + def test_rsa_pubkey_output(self): - self.assertPubkeyOutput( - _KEY.RsaPrivateKey, - 1024 - ) - + self.assertPubkeyOutput(_KEY.RsaPrivateKey, 1024) + def test_dsa_pubkey_output(self): - self.assertPubkeyOutput( - _KEY.DsaPrivateKey - ) - + self.assertPubkeyOutput(_KEY.DsaPrivateKey) + def test_ecdsa_pubkey_output(self): - self.assertPubkeyOutput( - _KEY.EcdsaPrivateKey - ) - + self.assertPubkeyOutput(_KEY.EcdsaPrivateKey) + def test_ed25519_pubkey_output(self): - self.assertPubkeyOutput( - _KEY.Ed25519PrivateKey - ) - + self.assertPubkeyOutput(_KEY.Ed25519PrivateKey) + def test_serial_field(self): self.assertRandomResponse( - _FIELD.SerialField, - random_function=lambda : random.randint(0, 2**64 - 1) + _FIELD.SerialField, random_function=lambda: random.randint(0, 2**64 - 1) ) - + def test_invalid_serial_field(self): - field = _FIELD.SerialField('abcdefg') + field = _FIELD.SerialField("abcdefg") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.SerialField(random.randint(2**65, 2**66)) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_certificate_type_field(self): + self.assertExpectedResponse(_FIELD.CertificateTypeField, 1, b"\x00\x00\x00\x01") + + self.assertExpectedResponse(_FIELD.CertificateTypeField, 2, b"\x00\x00\x00\x02") + self.assertExpectedResponse( - _FIELD.CertificateTypeField, - 1, - b'\x00\x00\x00\x01' - ) - - self.assertExpectedResponse( - _FIELD.CertificateTypeField, - 2, - b'\x00\x00\x00\x02' - ) - - self.assertExpectedResponse( - _FIELD.CertificateTypeField, - _FIELD.CERT_TYPE.USER, - b'\x00\x00\x00\x01' + _FIELD.CertificateTypeField, _FIELD.CERT_TYPE.USER, b"\x00\x00\x00\x01" ) - + self.assertExpectedResponse( - _FIELD.CertificateTypeField, - _FIELD.CERT_TYPE.HOST, - b'\x00\x00\x00\x02' + _FIELD.CertificateTypeField, _FIELD.CERT_TYPE.HOST, b"\x00\x00\x00\x02" ) - + def test_invalid_certificate_field(self): field = _FIELD.CertificateTypeField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.CertificateTypeField(3) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + field = _FIELD.CertificateTypeField(0) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_key_id_field(self): self.assertRandomResponse( - _FIELD.KeyIdField, - random_function=lambda : self.faker.pystr(8, 128) + _FIELD.KeyIdField, random_function=lambda: self.faker.pystr(8, 128) ) - + def test_invalid_key_id_field(self): - field = _FIELD.KeyIdField('') + field = _FIELD.KeyIdField("") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_principals_field(self): self.assertRandomResponse( _FIELD.PrincipalsField, - random_function=lambda : [self.faker.pystr(8, 128) for _ in range(20)] + random_function=lambda: [self.faker.pystr(8, 128) for _ in range(20)], ) - + comparison = [self.faker.pystr(8, 128) for _ in range(20)] - + self.assertExpectedResponse( - _FIELD.PrincipalsField, - comparison, - _FIELD.ListField.encode(comparison) + _FIELD.PrincipalsField, comparison, _FIELD.ListField.encode(comparison) ) - + def test_invalid_principals_field(self): field = _FIELD.PrincipalsField([ValueError, ValueError, ValueError]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_validity_start_field(self): self.assertRandomResponse( - _FIELD.ValidAfterField, - random_function=lambda : self.faker.date_time() + _FIELD.ValidAfterField, random_function=lambda: self.faker.date_time() ) - + def test_invalid_validity_start_field(self): field = _FIELD.ValidAfterField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_validity_end_field(self): self.assertRandomResponse( - _FIELD.ValidBeforeField, - random_function=lambda : self.faker.date_time() + _FIELD.ValidBeforeField, random_function=lambda: self.faker.date_time() ) - + def test_invalid_validity_end_field(self): field = _FIELD.ValidBeforeField(ValueError) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_critical_options_field(self): valid_opts_dict = { - 'force-command': 'sftp-internal', - 'source-address': '1.2.3.4/8,5.6.7.8/16', - 'verify-required': '' + "force-command": "sftp-internal", + "source-address": "1.2.3.4/8,5.6.7.8/16", + "verify-required": "", } - - valid_opts_list = [ - 'force-command', - 'source-address', - 'verify-required' - ] - + + valid_opts_list = ["force-command", "source-address", "verify-required"] + verify_dict = _FIELD.CriticalOptionsField(valid_opts_dict) verify_list = _FIELD.CriticalOptionsField(valid_opts_list) - + encoded_dict = _FIELD.CriticalOptionsField.encode(valid_opts_dict) encoded_list = _FIELD.CriticalOptionsField.encode(valid_opts_list) - + decoded_dict = _FIELD.CriticalOptionsField.decode(encoded_dict)[0] decoded_list = _FIELD.CriticalOptionsField.decode(encoded_list)[0] - + self.assertTrue(verify_dict.validate()) self.assertTrue(verify_list.validate()) - - self.assertEqual( - valid_opts_dict, - decoded_dict - ) - - self.assertEqual( - valid_opts_list, - decoded_list - ) - + + self.assertEqual(valid_opts_dict, decoded_dict) + + self.assertEqual(valid_opts_list, decoded_list) + def test_invalid_critical_options_field(self): - field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', 'unpermit']) + field = _FIELD.CriticalOptionsField([ValueError, "permit-pty", "unpermit"]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - field = _FIELD.CriticalOptionsField('InvalidData') + + self.assertFieldContainsException(field, _EX.InvalidDataException) + + field = _FIELD.CriticalOptionsField("InvalidData") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + + self.assertFieldContainsException(field, _EX.InvalidDataException) + + field = _FIELD.CriticalOptionsField(["no-touch-required", "InvalidOption"]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_extensions_field(self): valid_values = [ - 'no-touch-required', - 'permit-X11-forwarding', - 'permit-agent-forwarding', - 'permit-port-forwarding', - 'permit-pty', - 'permit-user-rc' + "no-touch-required", + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-port-forwarding", + "permit-pty", + "permit-user-rc", ] - + self.assertRandomResponse( _FIELD.ExtensionsField, - values=[valid_values[0:random.randint(1,len(valid_values))] for _ in range(10)] + values=[ + valid_values[0 : random.randint(1, len(valid_values))] + for _ in range(10) + ], ) - + def test_invalid_extensions_field(self): - field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', b'unpermit']) + field = _FIELD.CriticalOptionsField([ValueError, "permit-pty", b"unpermit"]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - field = _FIELD.CriticalOptionsField('InvalidData') + + self.assertFieldContainsException(field, _EX.InvalidDataException) + + field = _FIELD.CriticalOptionsField("InvalidData") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - - field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + + self.assertFieldContainsException(field, _EX.InvalidDataException) + + field = _FIELD.CriticalOptionsField(["no-touch-required", "InvalidOption"]) field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def test_reserved_field(self): - self.assertExpectedResponse( - _FIELD.ReservedField, - '', - b'\x00\x00\x00\x00' - ) + self.assertExpectedResponse(_FIELD.ReservedField, "", b"\x00\x00\x00\x00") def test_invalid_reserved_field(self): - field = _FIELD.ReservedField('InvalidData') + field = _FIELD.ReservedField("InvalidData") field.validate() - - self.assertFieldContainsException( - field, - _EX.InvalidDataException - ) - + + self.assertFieldContainsException(field, _EX.InvalidDataException) + with self.assertRaises(_EX.InvalidDataException): field.encode(ValueError) - + def assertCAPubkeyField(self, type): - key = getattr(self, f'{type}_key').public_key + key = getattr(self, f"{type}_key").public_key field = _FIELD.CAPublicKeyField.from_object(key) encoded = bytes(field) decoded = field.decode(encoded)[0] - - self.assertEqual( - key.raw_bytes(), - decoded.raw_bytes() - ) - - self.assertEqual( - f'{type.upper()} {key.get_fingerprint()}', - str(field) - ) - + + self.assertEqual(key.raw_bytes(), decoded.raw_bytes()) + + self.assertEqual(f"{type.upper()} {key.get_fingerprint()}", str(field)) + + class TestCertificates(unittest.TestCase): def setUp(self): - if not os.path.isdir('tests/certificates'): - os.mkdir('tests/certificates') - + if not os.path.isdir("tests/certificates"): + os.mkdir("tests/certificates") + self.faker = faker.Faker() self.rsa_ca = _KEY.RsaPrivateKey.generate(1024) self.dsa_ca = _KEY.DsaPrivateKey.generate() self.ecdsa_ca = _KEY.EcdsaPrivateKey.generate() self.ed25519_ca = _KEY.Ed25519PrivateKey.generate() - + self.rsa_user = _KEY.RsaPrivateKey.generate(1024).public_key self.dsa_user = _KEY.DsaPrivateKey.generate().public_key self.ecdsa_user = _KEY.EcdsaPrivateKey.generate().public_key self.ed25519_user = _KEY.Ed25519PrivateKey.generate().public_key - - + self.cert_fields = _CERT.CertificateFields( serial=1234567890, cert_type=_FIELD.CERT_TYPE.USER, - key_id='KeyIdentifier', - principals=[ - 'pr_a', - 'pr_b', - 'pr_c' - ], + key_id="KeyIdentifier", + principals=["pr_a", "pr_b", "pr_c"], valid_after=1968491468, valid_before=1968534668, critical_options={ - 'force-command': 'sftp-internal', - 'source-address': '1.2.3.4/8,5.6.7.8/16', - 'verify-required': '' + "force-command": "sftp-internal", + "source-address": "1.2.3.4/8,5.6.7.8/16", + "verify-required": "", }, - extensions=[ - 'permit-agent-forwarding', - 'permit-X11-forwarding' - ] + extensions=["permit-agent-forwarding", "permit-X11-forwarding"], ) - + def tearDown(self): - shutil.rmtree(f'tests/certificates') - os.mkdir(f'tests/certificates') - + shutil.rmtree("tests/certificates") + os.mkdir("tests/certificates") + def test_cert_type_assignment(self): rsa_cert = _CERT.SSHCertificate.create(self.rsa_user) dsa_cert = _CERT.SSHCertificate.create(self.dsa_user) ecdsa_cert = _CERT.SSHCertificate.create(self.ecdsa_user) ed25519_cert = _CERT.SSHCertificate.create(self.ed25519_user) - + self.assertIsInstance(rsa_cert, _CERT.RsaCertificate) self.assertIsInstance(dsa_cert, _CERT.DsaCertificate) self.assertIsInstance(ecdsa_cert, _CERT.EcdsaCertificate) self.assertIsInstance(ed25519_cert, _CERT.Ed25519Certificate) - - + def assertCertificateCreated(self, sub_type, ca_type): - sub_pubkey = getattr(self, f'{sub_type}_user') - ca_privkey = getattr(self, f'{ca_type}_ca') - + sub_pubkey = getattr(self, f"{sub_type}_user") + ca_privkey = getattr(self, f"{ca_type}_ca") + certificate = _CERT.SSHCertificate.create( - subject_pubkey=sub_pubkey, - ca_privkey=ca_privkey, - fields=self.cert_fields + subject_pubkey=sub_pubkey, ca_privkey=ca_privkey, fields=self.cert_fields ) - + self.assertTrue(certificate.can_sign()) certificate.sign() - certificate.to_file(f'tests/certificates/{sub_type}_{ca_type}-cert.pub') - + certificate.to_file(f"tests/certificates/{sub_type}_{ca_type}-cert.pub") + self.assertEqual( 0, - os.system(f'ssh-keygen -Lf tests/certificates/{sub_type}_{ca_type}-cert.pub') + os.system( + f"ssh-keygen -Lf tests/certificates/{sub_type}_{ca_type}-cert.pub" + ), ) - - - reloaded_cert = _CERT.SSHCertificate.from_file(f'tests/certificates/{sub_type}_{ca_type}-cert.pub') - - self.assertEqual( - certificate.get_signable_data(), - reloaded_cert.get_signable_data() + + reloaded_cert = _CERT.SSHCertificate.from_file( + f"tests/certificates/{sub_type}_{ca_type}-cert.pub" ) + self.assertEqual(certificate.get_signable(), reloaded_cert.get_signable()) + def test_certificate_creation(self): for ca_type in CERTIFICATE_TYPES: for user_type in CERTIFICATE_TYPES: - print("Testing certificate creation for {} CA and {} user".format(ca_type, user_type)) - self.assertCertificateCreated( - user_type, - ca_type + print( + "Testing certificate creation for {} CA and {} user".format( + ca_type, user_type + ) ) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + self.assertCertificateCreated(user_type, ca_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_keypairs.py b/tests/test_keypairs.py index 14c1880..31f4267 100644 --- a/tests/test_keypairs.py +++ b/tests/test_keypairs.py @@ -5,32 +5,27 @@ import os import shutil import unittest + +from cryptography.hazmat.primitives.asymmetric import dsa as _DSA +from cryptography.hazmat.primitives.asymmetric import ec as _EC +from cryptography.hazmat.primitives.asymmetric import ed25519 as _ED25519 +from cryptography.hazmat.primitives.asymmetric import rsa as _RSA + import src.sshkey_tools.exceptions as _EX -from base64 import b64encode from src.sshkey_tools.keys import ( - PrivkeyClasses, - PrivateKey, - RsaPrivateKey, DsaPrivateKey, + DsaPublicKey, + EcdsaCurves, EcdsaPrivateKey, + EcdsaPublicKey, Ed25519PrivateKey, - PubkeyClasses, + Ed25519PublicKey, + PrivateKey, PublicKey, + RsaPrivateKey, RsaPublicKey, - DsaPublicKey, - EcdsaPublicKey, - Ed25519PublicKey, - EcdsaCurves ) -from cryptography.hazmat.primitives.asymmetric import ( - rsa as _RSA, - dsa as _DSA, - ec as _EC, - ed25519 as _ED25519 -) - -from cryptography.exceptions import InvalidSignature class KeypairMethods(unittest.TestCase): def generateClasses(self): @@ -42,31 +37,33 @@ def generateClasses(self): def generateFiles(self, folder): self.folder = folder try: - os.mkdir(f'tests/{folder}') + os.mkdir(f"tests/{folder}") except FileExistsError: - shutil.rmtree(f'tests/{folder}') - os.mkdir(f'tests/{folder}') - - os.system(f'ssh-keygen -t rsa -b 2048 -f tests/{folder}/rsa_key_sshkeygen -N "password" > /dev/null 2>&1') - os.system(f'ssh-keygen -t dsa -b 1024 -f tests/{folder}/dsa_key_sshkeygen -N "" > /dev/null 2>&1') - os.system(f'ssh-keygen -t ecdsa -b 256 -f tests/{folder}/ecdsa_key_sshkeygen -N "" > /dev/null 2>&1') - os.system(f'ssh-keygen -t ed25519 -f tests/{folder}/ed25519_key_sshkeygen -N "" > /dev/null 2>&1') + shutil.rmtree(f"tests/{folder}") + os.mkdir(f"tests/{folder}") + os.system( + f'ssh-keygen -t rsa -b 2048 -f tests/{folder}/rsa_key_sshkeygen -N "password" > /dev/null 2>&1' + ) + os.system( + f'ssh-keygen -t dsa -b 1024 -f tests/{folder}/dsa_key_sshkeygen -N "" > /dev/null 2>&1' + ) + os.system( + f'ssh-keygen -t ecdsa -b 256 -f tests/{folder}/ecdsa_key_sshkeygen -N "" > /dev/null 2>&1' + ) + os.system( + f'ssh-keygen -t ed25519 -f tests/{folder}/ed25519_key_sshkeygen -N "" > /dev/null 2>&1' + ) def setUp(self): self.generateClasses() - self.generateFiles('KeypairMethods') + self.generateFiles("KeypairMethods") def tearDown(self): - shutil.rmtree(f'tests/{self.folder}') + shutil.rmtree(f"tests/{self.folder}") def assertEqualPrivateKeys( - self, - priv_class, - pub_class, - a, - b, - privkey_attr = ['private_numbers'] + self, priv_class, pub_class, a, b, privkey_attr=["private_numbers"] ): self.assertIsInstance(a, priv_class) self.assertIsInstance(b, priv_class) @@ -77,103 +74,79 @@ def assertEqualPrivateKeys( except AssertionError: print("Hold") - self.assertEqualPublicKeys( - pub_class, - a.public_key, - b.public_key - ) + self.assertEqualPublicKeys(pub_class, a.public_key, b.public_key) - def assertEqualPublicKeys( - self, - keyclass, - a, - b - ): + def assertEqualPublicKeys(self, keyclass, a, b): self.assertIsInstance(a, keyclass) self.assertIsInstance(b, keyclass) self.assertEqual(a.raw_bytes(), b.raw_bytes()) - def assertEqualKeyFingerprint( - self, - file_a, - file_b - ): - self.assertEqual(0, + def assertEqualKeyFingerprint(self, file_a, file_b): + self.assertEqual( + 0, os.system( - f'''bash -c " + f"""bash -c " diff \ <( ssh-keygen -lf {file_a}) \ <( ssh-keygen -lf {file_b}) \ " - ''' - ) + """ + ), ) + class TestKeypairMethods(KeypairMethods): def test_fail_assertions(self): with self.assertRaises(AssertionError): self.assertEqualPrivateKeys( RsaPrivateKey, RsaPublicKey, - RsaPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password'), - DsaPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + RsaPrivateKey.from_file( + f"tests/{self.folder}/rsa_key_sshkeygen", "password" + ), + DsaPrivateKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen"), ) with self.assertRaises(AssertionError): self.assertEqualPublicKeys( RsaPublicKey, - RsaPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub'), - DsaPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + RsaPublicKey.from_file(f"tests/{self.folder}/rsa_key_sshkeygen.pub"), + DsaPublicKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen.pub"), ) with self.assertRaises(AssertionError): self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_sshkeygen', - f'tests/{self.folder}/dsa_key_sshkeygen' + f"tests/{self.folder}/rsa_key_sshkeygen", + f"tests/{self.folder}/dsa_key_sshkeygen", ) def test_successful_assertions(self): - self.assertTrue( - os.path.isfile( - f'tests/{self.folder}/rsa_key_sshkeygen' - ) - ) + self.assertTrue(os.path.isfile(f"tests/{self.folder}/rsa_key_sshkeygen")) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_sshkeygen', - f'tests/{self.folder}/rsa_key_sshkeygen.pub' + f"tests/{self.folder}/rsa_key_sshkeygen", + f"tests/{self.folder}/rsa_key_sshkeygen.pub", ) - self.assertTrue( - os.path.isfile( - f'tests/{self.folder}/dsa_key_sshkeygen' - ) - ) + self.assertTrue(os.path.isfile(f"tests/{self.folder}/dsa_key_sshkeygen")) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_sshkeygen', - f'tests/{self.folder}/dsa_key_sshkeygen.pub' + f"tests/{self.folder}/dsa_key_sshkeygen", + f"tests/{self.folder}/dsa_key_sshkeygen.pub", ) - self.assertTrue( - os.path.isfile( - f'tests/{self.folder}/ecdsa_key_sshkeygen' - ) - ) + self.assertTrue(os.path.isfile(f"tests/{self.folder}/ecdsa_key_sshkeygen")) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_sshkeygen', - f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + f"tests/{self.folder}/ecdsa_key_sshkeygen", + f"tests/{self.folder}/ecdsa_key_sshkeygen.pub", ) - self.assertTrue( - os.path.isfile( - f'tests/{self.folder}/ed25519_key_sshkeygen' - ) - ) + self.assertTrue(os.path.isfile(f"tests/{self.folder}/ed25519_key_sshkeygen")) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_sshkeygen', - f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + f"tests/{self.folder}/ed25519_key_sshkeygen", + f"tests/{self.folder}/ed25519_key_sshkeygen.pub", ) + class TestKeyGeneration(KeypairMethods): def setUp(self): pass @@ -182,13 +155,7 @@ def tearDown(self): pass def test_rsa(self): - key_bits = [ - 512, - 1024, - 2048, - 4096, - 8192 - ] + key_bits = [512, 1024, 2048, 4096, 8192] for bits in key_bits: key = RsaPrivateKey.generate(bits) @@ -223,16 +190,11 @@ def test_dsa(self): assert isinstance(key.public_key.parameters, _DSA.DSAParameterNumbers) def test_ecdsa(self): - curves = [ - EcdsaCurves.P256, - EcdsaCurves.P384, - EcdsaCurves.P521 - ] + curves = [EcdsaCurves.P256, EcdsaCurves.P384, EcdsaCurves.P521] for curve in curves: key = EcdsaPrivateKey.generate(curve) - assert isinstance(key, EcdsaPrivateKey) assert isinstance(key, PrivateKey) assert isinstance(key.key, _EC.EllipticCurvePrivateKey) @@ -241,12 +203,13 @@ def test_ecdsa(self): assert isinstance(key.public_key, EcdsaPublicKey) assert isinstance(key.public_key, PublicKey) assert isinstance(key.public_key.key, _EC.EllipticCurvePublicKey) - assert isinstance(key.public_key.public_numbers, _EC.EllipticCurvePublicNumbers) - + assert isinstance( + key.public_key.public_numbers, _EC.EllipticCurvePublicNumbers + ) def test_ecdsa_not_a_curve(self): with self.assertRaises(AttributeError): - EcdsaPrivateKey.generate('p256') + EcdsaPrivateKey.generate("p256") def test_ed25519(self): key = Ed25519PrivateKey.generate() @@ -259,329 +222,259 @@ def test_ed25519(self): assert isinstance(key.public_key, PublicKey) assert isinstance(key.public_key.key, _ED25519.Ed25519PublicKey) + class TestToFromFiles(KeypairMethods): def setUp(self): self.generateClasses() - self.generateFiles('TestToFromFiles') - + self.generateFiles("TestToFromFiles") + def test_encoding(self): - with open(f'tests/{self.folder}/rsa_key_sshkeygen', 'r', encoding='utf-8') as file: - from_string = PrivateKey.from_string(file.read(), 'password', 'utf-8') - - with open(f'tests/{self.folder}/rsa_key_sshkeygen.pub', 'r', encoding='utf-8') as file: - from_string_pub = PublicKey.from_string(file.read(), 'utf-8') - - from_file = PrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') - from_file_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') - - self.assertEqualPrivateKeys( - RsaPrivateKey, - RsaPublicKey, - from_string, - from_file + with open( + f"tests/{self.folder}/rsa_key_sshkeygen", "r", encoding="utf-8" + ) as file: + from_string = PrivateKey.from_string(file.read(), "password", "utf-8") + + with open( + f"tests/{self.folder}/rsa_key_sshkeygen.pub", "r", encoding="utf-8" + ) as file: + from_string_pub = PublicKey.from_string(file.read(), "utf-8") + + from_file = PrivateKey.from_file( + f"tests/{self.folder}/rsa_key_sshkeygen", "password" ) - - self.assertEqualPublicKeys( - RsaPublicKey, - from_string_pub, - from_file_pub + from_file_pub = PublicKey.from_file( + f"tests/{self.folder}/rsa_key_sshkeygen.pub" ) - - def test_rsa_files(self): - parent = PrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') - child = RsaPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') + self.assertEqualPrivateKeys(RsaPrivateKey, RsaPublicKey, from_string, from_file) - parent_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') - child_pub = RsaPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') + self.assertEqualPublicKeys(RsaPublicKey, from_string_pub, from_file_pub) - parent.to_file(f'tests/{self.folder}/rsa_key_saved_parent', 'password') - child.to_file(f'tests/{self.folder}/rsa_key_saved_child') + def test_rsa_files(self): + parent = PrivateKey.from_file( + f"tests/{self.folder}/rsa_key_sshkeygen", "password" + ) + child = RsaPrivateKey.from_file( + f"tests/{self.folder}/rsa_key_sshkeygen", "password" + ) - parent_pub.to_file(f'tests/{self.folder}/rsa_key_saved_parent.pub') - child_pub.to_file(f'tests/{self.folder}/rsa_key_saved_child.pub') + parent_pub = PublicKey.from_file(f"tests/{self.folder}/rsa_key_sshkeygen.pub") + child_pub = RsaPublicKey.from_file(f"tests/{self.folder}/rsa_key_sshkeygen.pub") + parent.to_file(f"tests/{self.folder}/rsa_key_saved_parent", "password") + child.to_file(f"tests/{self.folder}/rsa_key_saved_child") - self.assertEqualPrivateKeys( - RsaPrivateKey, - RsaPublicKey, - parent, - child - ) + parent_pub.to_file(f"tests/{self.folder}/rsa_key_saved_parent.pub") + child_pub.to_file(f"tests/{self.folder}/rsa_key_saved_child.pub") - self.assertEqualPublicKeys( - RsaPublicKey, - parent_pub, - child_pub - ) + self.assertEqualPrivateKeys(RsaPrivateKey, RsaPublicKey, parent, child) - self.assertEqualPublicKeys( - RsaPublicKey, - parent.public_key, - child_pub - ) + self.assertEqualPublicKeys(RsaPublicKey, parent_pub, child_pub) + + self.assertEqualPublicKeys(RsaPublicKey, parent.public_key, child_pub) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_sshkeygen', - f'tests/{self.folder}/rsa_key_saved_parent' + f"tests/{self.folder}/rsa_key_sshkeygen", + f"tests/{self.folder}/rsa_key_saved_parent", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_sshkeygen.pub', - f'tests/{self.folder}/rsa_key_saved_parent.pub' + f"tests/{self.folder}/rsa_key_sshkeygen.pub", + f"tests/{self.folder}/rsa_key_saved_parent.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_saved_parent', - f'tests/{self.folder}/rsa_key_saved_child' + f"tests/{self.folder}/rsa_key_saved_parent", + f"tests/{self.folder}/rsa_key_saved_child", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_saved_parent.pub', - f'tests/{self.folder}/rsa_key_sshkeygen.pub' + f"tests/{self.folder}/rsa_key_saved_parent.pub", + f"tests/{self.folder}/rsa_key_sshkeygen.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/rsa_key_saved_parent.pub', - f'tests/{self.folder}/rsa_key_sshkeygen.pub' + f"tests/{self.folder}/rsa_key_saved_parent.pub", + f"tests/{self.folder}/rsa_key_sshkeygen.pub", ) def test_dsa_files(self): - parent = PrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') - child = DsaPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + parent = PrivateKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen") + child = DsaPrivateKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen") - parent_pub = PublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') - child_pub = DsaPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + parent_pub = PublicKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen.pub") + child_pub = DsaPublicKey.from_file(f"tests/{self.folder}/dsa_key_sshkeygen.pub") - parent.to_file(f'tests/{self.folder}/dsa_key_saved_parent') - child.to_file(f'tests/{self.folder}/dsa_key_saved_child') + parent.to_file(f"tests/{self.folder}/dsa_key_saved_parent") + child.to_file(f"tests/{self.folder}/dsa_key_saved_child") - parent_pub.to_file(f'tests/{self.folder}/dsa_key_saved_parent.pub') - child_pub.to_file(f'tests/{self.folder}/dsa_key_saved_child.pub') + parent_pub.to_file(f"tests/{self.folder}/dsa_key_saved_parent.pub") + child_pub.to_file(f"tests/{self.folder}/dsa_key_saved_child.pub") - self.assertEqualPrivateKeys( - DsaPrivateKey, - DsaPublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(DsaPrivateKey, DsaPublicKey, parent, child) - self.assertEqualPublicKeys( - DsaPublicKey, - parent_pub, - child_pub - ) + self.assertEqualPublicKeys(DsaPublicKey, parent_pub, child_pub) - self.assertEqualPublicKeys( - DsaPublicKey, - parent.public_key, - child_pub - ) + self.assertEqualPublicKeys(DsaPublicKey, parent.public_key, child_pub) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_sshkeygen', - f'tests/{self.folder}/dsa_key_saved_parent' + f"tests/{self.folder}/dsa_key_sshkeygen", + f"tests/{self.folder}/dsa_key_saved_parent", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_sshkeygen.pub', - f'tests/{self.folder}/dsa_key_saved_parent.pub' + f"tests/{self.folder}/dsa_key_sshkeygen.pub", + f"tests/{self.folder}/dsa_key_saved_parent.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_saved_parent', - f'tests/{self.folder}/dsa_key_saved_child' + f"tests/{self.folder}/dsa_key_saved_parent", + f"tests/{self.folder}/dsa_key_saved_child", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_saved_parent.pub', - f'tests/{self.folder}/dsa_key_sshkeygen.pub' + f"tests/{self.folder}/dsa_key_saved_parent.pub", + f"tests/{self.folder}/dsa_key_sshkeygen.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/dsa_key_saved_parent.pub', - f'tests/{self.folder}/dsa_key_sshkeygen.pub' + f"tests/{self.folder}/dsa_key_saved_parent.pub", + f"tests/{self.folder}/dsa_key_sshkeygen.pub", ) def test_ecdsa_files(self): - parent = PrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') - child = EcdsaPrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') + parent = PrivateKey.from_file(f"tests/{self.folder}/ecdsa_key_sshkeygen") + child = EcdsaPrivateKey.from_file(f"tests/{self.folder}/ecdsa_key_sshkeygen") - parent_pub = PublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') - child_pub = EcdsaPublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') + parent_pub = PublicKey.from_file(f"tests/{self.folder}/ecdsa_key_sshkeygen.pub") + child_pub = EcdsaPublicKey.from_file( + f"tests/{self.folder}/ecdsa_key_sshkeygen.pub" + ) - parent.to_file(f'tests/{self.folder}/ecdsa_key_saved_parent') - child.to_file(f'tests/{self.folder}/ecdsa_key_saved_child') + parent.to_file(f"tests/{self.folder}/ecdsa_key_saved_parent") + child.to_file(f"tests/{self.folder}/ecdsa_key_saved_child") - parent_pub.to_file(f'tests/{self.folder}/ecdsa_key_saved_parent.pub') - child_pub.to_file(f'tests/{self.folder}/ecdsa_key_saved_child.pub') + parent_pub.to_file(f"tests/{self.folder}/ecdsa_key_saved_parent.pub") + child_pub.to_file(f"tests/{self.folder}/ecdsa_key_saved_child.pub") - self.assertEqualPrivateKeys( - EcdsaPrivateKey, - EcdsaPublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(EcdsaPrivateKey, EcdsaPublicKey, parent, child) - self.assertEqualPublicKeys( - EcdsaPublicKey, - parent_pub, - child_pub - ) + self.assertEqualPublicKeys(EcdsaPublicKey, parent_pub, child_pub) - self.assertEqualPublicKeys( - EcdsaPublicKey, - parent.public_key, - child_pub - ) + self.assertEqualPublicKeys(EcdsaPublicKey, parent.public_key, child_pub) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_sshkeygen', - f'tests/{self.folder}/ecdsa_key_saved_parent' + f"tests/{self.folder}/ecdsa_key_sshkeygen", + f"tests/{self.folder}/ecdsa_key_saved_parent", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_sshkeygen.pub', - f'tests/{self.folder}/ecdsa_key_saved_parent.pub' + f"tests/{self.folder}/ecdsa_key_sshkeygen.pub", + f"tests/{self.folder}/ecdsa_key_saved_parent.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_saved_parent', - f'tests/{self.folder}/ecdsa_key_saved_child' + f"tests/{self.folder}/ecdsa_key_saved_parent", + f"tests/{self.folder}/ecdsa_key_saved_child", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_saved_parent.pub', - f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + f"tests/{self.folder}/ecdsa_key_saved_parent.pub", + f"tests/{self.folder}/ecdsa_key_sshkeygen.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ecdsa_key_saved_parent.pub', - f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + f"tests/{self.folder}/ecdsa_key_saved_parent.pub", + f"tests/{self.folder}/ecdsa_key_sshkeygen.pub", ) def test_ed25519_files(self): - parent = PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') - child = Ed25519PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') + parent = PrivateKey.from_file(f"tests/{self.folder}/ed25519_key_sshkeygen") + child = Ed25519PrivateKey.from_file( + f"tests/{self.folder}/ed25519_key_sshkeygen" + ) - parent_pub = PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') - child_pub = Ed25519PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') + parent_pub = PublicKey.from_file( + f"tests/{self.folder}/ed25519_key_sshkeygen.pub" + ) + child_pub = Ed25519PublicKey.from_file( + f"tests/{self.folder}/ed25519_key_sshkeygen.pub" + ) - parent.to_file(f'tests/{self.folder}/ed25519_key_saved_parent') - child.to_file(f'tests/{self.folder}/ed25519_key_saved_child') + parent.to_file(f"tests/{self.folder}/ed25519_key_saved_parent") + child.to_file(f"tests/{self.folder}/ed25519_key_saved_child") - parent_pub.to_file(f'tests/{self.folder}/ed25519_key_saved_parent.pub') - child_pub.to_file(f'tests/{self.folder}/ed25519_key_saved_child.pub') + parent_pub.to_file(f"tests/{self.folder}/ed25519_key_saved_parent.pub") + child_pub.to_file(f"tests/{self.folder}/ed25519_key_saved_child.pub") - self.assertEqualPrivateKeys( - Ed25519PrivateKey, - Ed25519PublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(Ed25519PrivateKey, Ed25519PublicKey, parent, child) - self.assertEqualPublicKeys( - Ed25519PublicKey, - parent_pub, - child_pub - ) + self.assertEqualPublicKeys(Ed25519PublicKey, parent_pub, child_pub) - self.assertEqualPublicKeys( - Ed25519PublicKey, - parent.public_key, - child_pub - ) + self.assertEqualPublicKeys(Ed25519PublicKey, parent.public_key, child_pub) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_sshkeygen', - f'tests/{self.folder}/ed25519_key_saved_parent' + f"tests/{self.folder}/ed25519_key_sshkeygen", + f"tests/{self.folder}/ed25519_key_saved_parent", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_sshkeygen.pub', - f'tests/{self.folder}/ed25519_key_saved_parent.pub' + f"tests/{self.folder}/ed25519_key_sshkeygen.pub", + f"tests/{self.folder}/ed25519_key_saved_parent.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_saved_parent', - f'tests/{self.folder}/ed25519_key_saved_child' + f"tests/{self.folder}/ed25519_key_saved_parent", + f"tests/{self.folder}/ed25519_key_saved_child", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_saved_parent.pub', - f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + f"tests/{self.folder}/ed25519_key_saved_parent.pub", + f"tests/{self.folder}/ed25519_key_sshkeygen.pub", ) self.assertEqualKeyFingerprint( - f'tests/{self.folder}/ed25519_key_saved_parent.pub', - f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + f"tests/{self.folder}/ed25519_key_saved_parent.pub", + f"tests/{self.folder}/ed25519_key_sshkeygen.pub", ) + class TestFromClass(KeypairMethods): def setUp(self): - self.rsa_key = _RSA.generate_private_key( - public_exponent=65537, - key_size=2048 - ) - self.dsa_key = _DSA.generate_private_key( - key_size=1024 - ) - self.ecdsa_key = _EC.generate_private_key( - curve=_EC.SECP384R1() - ) + self.rsa_key = _RSA.generate_private_key(public_exponent=65537, key_size=2048) + self.dsa_key = _DSA.generate_private_key(key_size=1024) + self.ecdsa_key = _EC.generate_private_key(curve=_EC.SECP384R1()) self.ed25519_key = _ED25519.Ed25519PrivateKey.generate() def tearDown(self): pass - + def test_invalid_key_exception(self): with self.assertRaises(_EX.InvalidKeyException): PublicKey.from_class( - key_class=self.rsa_key, - key_type='invalid-key-type', - comment='Comment' + key_class=self.rsa_key, key_type="invalid-key-type", comment="Comment" ) - def test_rsa_from_class(self): parent = PrivateKey.from_class(self.rsa_key) child = RsaPrivateKey.from_class(self.rsa_key) - self.assertEqualPrivateKeys( - RsaPrivateKey, - RsaPublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(RsaPrivateKey, RsaPublicKey, parent, child) def test_dsa_from_class(self): parent = PrivateKey.from_class(self.dsa_key) child = DsaPrivateKey.from_class(self.dsa_key) - self.assertEqualPrivateKeys( - DsaPrivateKey, - DsaPublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(DsaPrivateKey, DsaPublicKey, parent, child) def test_ecdsa_from_class(self): parent = PrivateKey.from_class(self.ecdsa_key) child = EcdsaPrivateKey.from_class(self.ecdsa_key) - self.assertEqualPrivateKeys( - EcdsaPrivateKey, - EcdsaPublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(EcdsaPrivateKey, EcdsaPublicKey, parent, child) def test_ed25519_from_class(self): parent = PrivateKey.from_class(self.ed25519_key) child = Ed25519PrivateKey.from_class(self.ed25519_key) - self.assertEqualPrivateKeys( - Ed25519PrivateKey, - Ed25519PublicKey, - parent, - child - ) + self.assertEqualPrivateKeys(Ed25519PrivateKey, Ed25519PublicKey, parent, child) + class TestFromComponents(KeypairMethods): def setUp(self): @@ -594,64 +487,54 @@ def test_rsa_from_numbers(self): from_numbers = RsaPrivateKey.from_numbers( n=self.rsa_key.public_key.public_numbers.n, e=self.rsa_key.public_key.public_numbers.e, - d=self.rsa_key.private_numbers.d + d=self.rsa_key.private_numbers.d, ) - + from_numbers_pub = RsaPublicKey.from_numbers( n=self.rsa_key.public_key.public_numbers.n, - e=self.rsa_key.public_key.public_numbers.e + e=self.rsa_key.public_key.public_numbers.e, ) self.assertEqualPublicKeys( - RsaPublicKey, - from_numbers_pub, - from_numbers.public_key + RsaPublicKey, from_numbers_pub, from_numbers.public_key ) self.assertIsInstance(from_numbers, RsaPrivateKey) - + self.assertEqual( self.rsa_key.public_key.public_numbers.n, - from_numbers.public_key.public_numbers.n + from_numbers.public_key.public_numbers.n, ) - + self.assertEqual( self.rsa_key.public_key.public_numbers.e, - from_numbers.public_key.public_numbers.e - ) - - self.assertEqual( - self.rsa_key.private_numbers.d, - from_numbers.private_numbers.d + from_numbers.public_key.public_numbers.e, ) + self.assertEqual(self.rsa_key.private_numbers.d, from_numbers.private_numbers.d) + def test_dsa_from_numbers(self): from_numbers = DsaPrivateKey.from_numbers( p=self.dsa_key.public_key.parameters.p, q=self.dsa_key.public_key.parameters.q, g=self.dsa_key.public_key.parameters.g, y=self.dsa_key.public_key.public_numbers.y, - x=self.dsa_key.private_numbers.x + x=self.dsa_key.private_numbers.x, ) - + from_numbers_pub = DsaPublicKey.from_numbers( p=self.dsa_key.public_key.parameters.p, q=self.dsa_key.public_key.parameters.q, g=self.dsa_key.public_key.parameters.g, - y=self.dsa_key.public_key.public_numbers.y + y=self.dsa_key.public_key.public_numbers.y, ) self.assertEqualPrivateKeys( - DsaPrivateKey, - DsaPublicKey, - self.dsa_key, - from_numbers + DsaPrivateKey, DsaPublicKey, self.dsa_key, from_numbers ) - + self.assertEqualPublicKeys( - DsaPublicKey, - from_numbers_pub, - self.dsa_key.public_key + DsaPublicKey, from_numbers_pub, self.dsa_key.public_key ) def test_ecdsa_from_numbers(self): @@ -659,214 +542,177 @@ def test_ecdsa_from_numbers(self): curve=self.ecdsa_key.public_key.key.curve, x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, - private_value=self.ecdsa_key.private_numbers.private_value + private_value=self.ecdsa_key.private_numbers.private_value, ) - + from_numbers_pub = EcdsaPublicKey.from_numbers( curve=self.ecdsa_key.public_key.key.curve, x=self.ecdsa_key.public_key.public_numbers.x, - y=self.ecdsa_key.public_key.public_numbers.y + y=self.ecdsa_key.public_key.public_numbers.y, ) self.assertEqualPrivateKeys( - EcdsaPrivateKey, - EcdsaPublicKey, - self.ecdsa_key, - from_numbers + EcdsaPrivateKey, EcdsaPublicKey, self.ecdsa_key, from_numbers ) - + self.assertEqualPublicKeys( - EcdsaPublicKey, - from_numbers_pub, - self.ecdsa_key.public_key + EcdsaPublicKey, from_numbers_pub, self.ecdsa_key.public_key ) from_numbers = EcdsaPrivateKey.from_numbers( curve=self.ecdsa_key.public_key.key.curve.name, x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, - private_value=self.ecdsa_key.private_numbers.private_value + private_value=self.ecdsa_key.private_numbers.private_value, ) self.assertEqualPrivateKeys( - EcdsaPrivateKey, - EcdsaPublicKey, - self.ecdsa_key, - from_numbers + EcdsaPrivateKey, EcdsaPublicKey, self.ecdsa_key, from_numbers ) - + def test_ed25519_from_raw_bytes(self): - from_raw = Ed25519PrivateKey.from_raw_bytes( - self.ed25519_key.raw_bytes() - ) + from_raw = Ed25519PrivateKey.from_raw_bytes(self.ed25519_key.raw_bytes()) from_raw_pub = Ed25519PublicKey.from_raw_bytes( self.ed25519_key.public_key.raw_bytes() ) - + self.assertEqualPrivateKeys( - Ed25519PrivateKey, - Ed25519PublicKey, - self.ed25519_key, - from_raw, - [] + Ed25519PrivateKey, Ed25519PublicKey, self.ed25519_key, from_raw, [] ) - + self.assertEqualPublicKeys( - Ed25519PublicKey, - self.ed25519_key.public_key, - from_raw_pub + Ed25519PublicKey, self.ed25519_key.public_key, from_raw_pub ) - + class TestFingerprint(KeypairMethods): - def setUp(self): - self.generateFiles('TestFingerprint') + self.generateFiles("TestFingerprint") def test_rsa_fingerprint(self): key = RsaPrivateKey.from_file( - f'tests/{self.folder}/rsa_key_sshkeygen', - 'password' + f"tests/{self.folder}/rsa_key_sshkeygen", "password" ) - - sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/rsa_key_sshkeygen').read().split(' ')[1] - + + sshkey_fingerprint = ( + os.popen(f"ssh-keygen -lf tests/{self.folder}/rsa_key_sshkeygen") + .read() + .split(" ")[1] + ) + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) - + def test_dsa_fingerprint(self): key = DsaPrivateKey.from_file( - f'tests/{self.folder}/dsa_key_sshkeygen', + f"tests/{self.folder}/dsa_key_sshkeygen", + ) + + sshkey_fingerprint = ( + os.popen(f"ssh-keygen -lf tests/{self.folder}/dsa_key_sshkeygen") + .read() + .split(" ")[1] ) - - sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/dsa_key_sshkeygen').read().split(' ')[1] - + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) def test_ecdsa_fingerprint(self): key = EcdsaPrivateKey.from_file( - f'tests/{self.folder}/ecdsa_key_sshkeygen', + f"tests/{self.folder}/ecdsa_key_sshkeygen", ) - - sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/ecdsa_key_sshkeygen').read().split(' ')[1] - + + sshkey_fingerprint = ( + os.popen(f"ssh-keygen -lf tests/{self.folder}/ecdsa_key_sshkeygen") + .read() + .split(" ")[1] + ) + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) - + def test_ed25519_fingerprint(self): key = Ed25519PrivateKey.from_file( - f'tests/{self.folder}/ed25519_key_sshkeygen', + f"tests/{self.folder}/ed25519_key_sshkeygen", + ) + + sshkey_fingerprint = ( + os.popen(f"ssh-keygen -lf tests/{self.folder}/ed25519_key_sshkeygen") + .read() + .split(" ")[1] ) - - sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/ed25519_key_sshkeygen').read().split(' ')[1] - + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) + class TestSignatures(KeypairMethods): def setUp(self): self.generateClasses() - + def tearDown(self): pass - + def test_rsa_signature(self): - data = b"\x00"+os.urandom(32)+b"\x00" + data = b"\x00" + os.urandom(32) + b"\x00" signature = self.rsa_key.sign(data) - - - self.assertIsNone( - self.rsa_key.public_key.verify( - data, - signature - ) - ) - + + self.assertIsNone(self.rsa_key.public_key.verify(data, signature)) + with self.assertRaises(_EX.InvalidSignatureException): - self.rsa_key.public_key.verify( - data, - signature+b'\x00' - ) - + self.rsa_key.public_key.verify(data, signature + b"\x00") + def test_dsa_signature(self): - data = b"\x00"+os.urandom(32)+b"\x00" + data = b"\x00" + os.urandom(32) + b"\x00" signature = self.dsa_key.sign(data) - - - self.assertIsNone( - self.dsa_key.public_key.verify( - data, - signature - ) - ) - + + self.assertIsNone(self.dsa_key.public_key.verify(data, signature)) + with self.assertRaises(_EX.InvalidSignatureException): - self.dsa_key.public_key.verify( - data, - signature+b'\x00' - ) - + self.dsa_key.public_key.verify(data, signature + b"\x00") + def test_ecdsa_signature(self): - data = b"\x00"+os.urandom(32)+b"\x00" + data = b"\x00" + os.urandom(32) + b"\x00" signature = self.ecdsa_key.sign(data) - - - self.assertIsNone( - self.ecdsa_key.public_key.verify( - data, - signature - ) - ) - + + self.assertIsNone(self.ecdsa_key.public_key.verify(data, signature)) + with self.assertRaises(_EX.InvalidSignatureException): - self.ecdsa_key.public_key.verify( - data, - signature+b'\x00' - ) - + self.ecdsa_key.public_key.verify(data, signature + b"\x00") + def test_ed25519_signature(self): - data = b"\x00"+os.urandom(32)+b"\x00" + data = b"\x00" + os.urandom(32) + b"\x00" signature = self.ed25519_key.sign(data) - - - self.assertIsNone( - self.ed25519_key.public_key.verify( - data, - signature - ) - ) - + + self.assertIsNone(self.ed25519_key.public_key.verify(data, signature)) + with self.assertRaises(_EX.InvalidSignatureException): - self.ed25519_key.public_key.verify( - data, - signature+b'\x00' - ) + self.ed25519_key.public_key.verify(data, signature + b"\x00") + class TestExceptions(KeypairMethods): def setUp(self): self.generateClasses() - + def tearDown(self): pass - + def test_invalid_private_key(self): with self.assertRaises(_EX.InvalidKeyException): - key = PrivateKey.from_class(KeypairMethods) - + _ = PrivateKey.from_class(KeypairMethods) + def test_invalid_ecdsa_curve(self): with self.assertRaises(_EX.InvalidCurveException): - key = EcdsaPublicKey.from_numbers( - 'abc123', + _ = EcdsaPublicKey.from_numbers( + "abc123", x=self.ecdsa_key.public_key.public_numbers.x, - y=self.ecdsa_key.public_key.public_numbers.y + y=self.ecdsa_key.public_key.public_numbers.y, ) - + with self.assertRaises(_EX.InvalidCurveException): - key = EcdsaPrivateKey.from_numbers( - 'abc123', + _ = EcdsaPrivateKey.from_numbers( + "abc123", x=self.ecdsa_key.public_key.public_numbers.x, y=self.ecdsa_key.public_key.public_numbers.y, - private_value=self.ecdsa_key.private_numbers.private_value + private_value=self.ecdsa_key.private_numbers.private_value, ) -if __name__ == '__main__': - unittest.main() - +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index d45bedb..2e9023f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,231 +1,233 @@ import unittest -import faker from random import randint -import src.sshkey_tools.utils as utils +import faker from paramiko.util import deflate_long, inflate_long +import src.sshkey_tools.utils as utils + EXPECTED_LONG_CONVERSIONS = [ - (0, b'\x00'), - (1, b'\x01'), - (11638779394004435200 ,b'\x00\xa1\x85@1\xa6\xc5A\x00'), - (15203631582360839337 ,b'\x00\xd2\xfe&_2\x87$\xa9'), - (15302225898842444598 ,b'\x00\xd4\\ma]IS6'), - (15945599391219780268 ,b'\x00\xddJ&)\xb4W\n\xac'), - (15242635186864927689 ,b'\x00\xd3\x88\xb7\xf1\x89\xf43\xc9'), - (17517368859399259630 ,b'\x00\xf3\x1a2*\xa7\xfb\x85\xee'), - (11464229064000469348 ,b'\x00\x9f\x19\x1f\x97\xf7t\xf9d'), + (0, b"\x00"), + (1, b"\x01"), + (11638779394004435200, b"\x00\xa1\x85@1\xa6\xc5A\x00"), + (15203631582360839337, b"\x00\xd2\xfe&_2\x87$\xa9"), + (15302225898842444598, b"\x00\xd4\\ma]IS6"), + (15945599391219780268, b"\x00\xddJ&)\xb4W\n\xac"), + (15242635186864927689, b"\x00\xd3\x88\xb7\xf1\x89\xf43\xc9"), + (17517368859399259630, b"\x00\xf3\x1a2*\xa7\xfb\x85\xee"), + (11464229064000469348, b"\x00\x9f\x19\x1f\x97\xf7t\xf9d"), ] EXPECTED_HASHES = [ - ( b'HYoNIlxkwde]GX]qBNdtH\^ZYGyRWf', '15:6d:0c:c7:cb:9f:9a:85:dd:5e:f4:ee:b3:2b:9d:3f', 'tchAVFXDgszqtTYzokVmUfdQOHi5UvSsPVQzFsHnfXg', 'mof7eHzmuTZzEkPnjAUbaFABqZl8YUanE3Ips+jy5B9aiRtA1D8fIJgmmfb4V/T0AZz08Gu4AsMOKfybrjOUAA' ), - ( b'sseJmSLI_RgSRciYac\M`BjxkziCFD', 'cc:a4:4e:cc:1a:a4:3b:e6:26:f2:1b:a4:70:b6:0a:5a', 'XYWSEK4LHNhkVSMrMTVBg73r4Nu0ElpFuQ1efJrp5Ks', '/vZr65Oj0vyIIaC+iVCFiXNnmf7Ntg3njkBwJZddOxYaZ3mvs3Ra2OOh/VN1bMDbKaZik9BOkzDTjIXNBMDgUg' ), - ( b'AGGPP]bLy[SKwkNGfgjkwGw_vPa\]h', '25:de:7e:8f:65:da:5c:1d:c2:aa:79:8b:58:72:69:12', 'f6C8TqeIHwAIt0PvvjnR23uBzLcgr6MfTr+150u2+XE', 'cLVxsX2k3HEX6/e+S3NROdAzEIM45gO/stYLyZiJrQ5n14fBGgZX25bwJFcAtB/FNZjrIwmG1m3jRJ+CgXgBBQ' ), - ( b'OTUwhulOPckus_M]EyuxXWz^URykVP', 'df:d8:12:e2:bb:ad:70:0a:0a:e2:ba:b2:2c:82:0e:f7', 'pWW/P8tP5qOFeEYeoDMxzJlTHowbs6VwkTzHeQWrg8g', 'gF5ZmdFWxYVcDMUMyC4hIgFBXtgrxv8VZbw809UiqgTzpxOEyY1mVupma9joUGRHr3IjbDJdr+Uiq4TuZwVTjw' ), - ( b'JPvFmb^HoOw]KPwYAgVlWAhtZt]YSb', '5e:6e:12:09:50:b3:b4:e7:8f:f9:a0:d0:a6:40:e2:4c', 'hsrRKQ8vR7oDIxYAALz48kyg0BZNs61S09KCrgNJMyk', 'AmKjHbklKxWgk1EcUwtJenIYfvGC4bTfYwXnkOGvfFu2nvClgrTJix8MgZ+RtQq/3usweE2CIHgD+d3FH052kQ' ), - ( b'lLzVwDidpzTCNUbAmYgxaCV^ASbdy[', 'c5:e9:35:ea:f9:32:e6:2b:68:c2:18:8a:9f:93:ba:01', 'METh/occInzgqIvcRrJlyQa8E6nr095BXlrn9izySpA', 'G3e2YkowIzM2k5VyKxSk7O8+TJlGzEYE8pL7SPS1Ts4OOg37q06jinNXeRVSDt3ICih3KNtJPXHnFV3M41xDSw' ), - ( b'`NSyrnrn[ZntLugeydWWiaSTlVTNuU', '1a:cb:13:fb:e9:a1:ea:84:55:fd:a1:c4:7a:05:7b:d0', 'yQOxfh54aRBnH9tPxYsvDF/TLTy15nRze0AuftJWRcg', 'u9xyMyKaldgr3EZ9WFUkpGVh3C+lqLci6D3fD9lq3k2vtqK90CzTwHDNA/OHBnsc/ukWPe+kxaGn5CtHROjtLg' ), - ( b'JCUdozAoqggVFeD`dAXo]ElrDOrgVV', '42:5c:7b:4b:44:ac:f4:21:e8:fb:17:fb:b4:46:62:d4', 'uuGOfXTNMvcwfS6nol/cJ5ijVw0DVBA+4rpt18PsQjo', 'j6JOK8j8cjZnddm3IjjjnmHQobIewIelGOeSGMa9WEcKd8nkvTD17coSYdObQ9/X+kbU+9nPDSaRjjA3eQW6QA' ), - ( b']Beog]hAviJSgTZlbTDcytqftqaDof', 'e0:d3:0a:36:75:31:eb:e8:20:73:25:c3:2e:67:aa:54', 'FRreFuNrKbZ1lkJAhtaQSyqCHkrkRoZ6oKc7JNyLTAE', 'MBrk8xn3kkXmV+KefP7Lg0plxd+rI0dZ+QExCL3NlSfC54Y8j6GENtmheUFknHwaLpiwgKSRtwPN6ZP6EfaK0A' ), - ( b'XKHCjGvvtJgPdDkGSGDyPkZDBif^BZ', '7e:e4:29:df:b4:77:2a:6f:5d:eb:9f:25:8e:bd:45:b6', 'BmF8Pt4E/z0M8/rMy6mXkUpVlDG9Zje5+KA3dBVIR7c', 'I9lVFWPx6YkwnsZRMf21TPFvquV59+ng11F3EFFhDLKHrK/l6cdGQu1K0idWQQfRK44d77z00TK/aKmPmgZ2kg' ), + ( + # trunk-ignore(flake8/W605) + b"HYoNIlxkwde]GX]qBNdtH\^ZYGyRWf", + "15:6d:0c:c7:cb:9f:9a:85:dd:5e:f4:ee:b3:2b:9d:3f", + "tchAVFXDgszqtTYzokVmUfdQOHi5UvSsPVQzFsHnfXg", + "mof7eHzmuTZzEkPnjAUbaFABqZl8YUanE3Ips+jy5B9aiRtA1D8fIJgmmfb4V/T0AZz08Gu4AsMOKfybrjOUAA", + ), + ( + # trunk-ignore(flake8/W605) + b"sseJmSLI_RgSRciYac\M`BjxkziCFD", + "cc:a4:4e:cc:1a:a4:3b:e6:26:f2:1b:a4:70:b6:0a:5a", + "XYWSEK4LHNhkVSMrMTVBg73r4Nu0ElpFuQ1efJrp5Ks", + "/vZr65Oj0vyIIaC+iVCFiXNnmf7Ntg3njkBwJZddOxYaZ3mvs3Ra2OOh/VN1bMDbKaZik9BOkzDTjIXNBMDgUg", + ), + ( + # trunk-ignore(flake8/W605) + b"AGGPP]bLy[SKwkNGfgjkwGw_vPa\]h", + "25:de:7e:8f:65:da:5c:1d:c2:aa:79:8b:58:72:69:12", + "f6C8TqeIHwAIt0PvvjnR23uBzLcgr6MfTr+150u2+XE", + "cLVxsX2k3HEX6/e+S3NROdAzEIM45gO/stYLyZiJrQ5n14fBGgZX25bwJFcAtB/FNZjrIwmG1m3jRJ+CgXgBBQ", + ), + ( + b"OTUwhulOPckus_M]EyuxXWz^URykVP", + "df:d8:12:e2:bb:ad:70:0a:0a:e2:ba:b2:2c:82:0e:f7", + "pWW/P8tP5qOFeEYeoDMxzJlTHowbs6VwkTzHeQWrg8g", + "gF5ZmdFWxYVcDMUMyC4hIgFBXtgrxv8VZbw809UiqgTzpxOEyY1mVupma9joUGRHr3IjbDJdr+Uiq4TuZwVTjw", + ), + ( + b"JPvFmb^HoOw]KPwYAgVlWAhtZt]YSb", + "5e:6e:12:09:50:b3:b4:e7:8f:f9:a0:d0:a6:40:e2:4c", + "hsrRKQ8vR7oDIxYAALz48kyg0BZNs61S09KCrgNJMyk", + "AmKjHbklKxWgk1EcUwtJenIYfvGC4bTfYwXnkOGvfFu2nvClgrTJix8MgZ+RtQq/3usweE2CIHgD+d3FH052kQ", + ), + ( + b"lLzVwDidpzTCNUbAmYgxaCV^ASbdy[", + "c5:e9:35:ea:f9:32:e6:2b:68:c2:18:8a:9f:93:ba:01", + "METh/occInzgqIvcRrJlyQa8E6nr095BXlrn9izySpA", + "G3e2YkowIzM2k5VyKxSk7O8+TJlGzEYE8pL7SPS1Ts4OOg37q06jinNXeRVSDt3ICih3KNtJPXHnFV3M41xDSw", + ), + ( + b"`NSyrnrn[ZntLugeydWWiaSTlVTNuU", + "1a:cb:13:fb:e9:a1:ea:84:55:fd:a1:c4:7a:05:7b:d0", + "yQOxfh54aRBnH9tPxYsvDF/TLTy15nRze0AuftJWRcg", + "u9xyMyKaldgr3EZ9WFUkpGVh3C+lqLci6D3fD9lq3k2vtqK90CzTwHDNA/OHBnsc/ukWPe+kxaGn5CtHROjtLg", + ), + ( + b"JCUdozAoqggVFeD`dAXo]ElrDOrgVV", + "42:5c:7b:4b:44:ac:f4:21:e8:fb:17:fb:b4:46:62:d4", + "uuGOfXTNMvcwfS6nol/cJ5ijVw0DVBA+4rpt18PsQjo", + "j6JOK8j8cjZnddm3IjjjnmHQobIewIelGOeSGMa9WEcKd8nkvTD17coSYdObQ9/X+kbU+9nPDSaRjjA3eQW6QA", + ), + ( + b"]Beog]hAviJSgTZlbTDcytqftqaDof", + "e0:d3:0a:36:75:31:eb:e8:20:73:25:c3:2e:67:aa:54", + "FRreFuNrKbZ1lkJAhtaQSyqCHkrkRoZ6oKc7JNyLTAE", + "MBrk8xn3kkXmV+KefP7Lg0plxd+rI0dZ+QExCL3NlSfC54Y8j6GENtmheUFknHwaLpiwgKSRtwPN6ZP6EfaK0A", + ), + ( + b"XKHCjGvvtJgPdDkGSGDyPkZDBif^BZ", + "7e:e4:29:df:b4:77:2a:6f:5d:eb:9f:25:8e:bd:45:b6", + "BmF8Pt4E/z0M8/rMy6mXkUpVlDG9Zje5+KA3dBVIR7c", + "I9lVFWPx6YkwnsZRMf21TPFvquV59+ng11F3EFFhDLKHrK/l6cdGQu1K0idWQQfRK44d77z00TK/aKmPmgZ2kg", + ), ] + class TestStringBytestringConversion(unittest.TestCase): def setUp(self): self.faker = faker.Faker() def test_ensure_string(self): - self.assertEqual( - utils.ensure_string(""), - utils.ensure_string(b"") - ) - - self.assertEqual( - None, - utils.ensure_string(None) - ) + self.assertEqual(utils.ensure_string(""), utils.ensure_string(b"")) + + self.assertEqual(None, utils.ensure_string(None)) for _ in range(100): val = self.faker.pystr() - + self.assertEqual( - utils.ensure_string(val), - utils.ensure_string(val.encode('utf-8')) + utils.ensure_string(val), utils.ensure_string(val.encode("utf-8")) ) - - lst = list([self.faker.pystr() for _ in range(10)]) - lst_byt = list([x.encode('utf-8') if randint(0, 1) == 1 else x for x in lst]) - - self.assertIsInstance( - utils.ensure_string(lst), - list - ) - self.assertEqual( - lst, - utils.ensure_string(lst), - utils.ensure_string(lst_byt) - ) - + lst_byt = list([x.encode("utf-8") if randint(0, 1) == 1 else x for x in lst]) + + self.assertIsInstance(utils.ensure_string(lst), list) + self.assertEqual(lst, utils.ensure_string(lst), utils.ensure_string(lst_byt)) + tpl = tuple(self.faker.pystr() for _ in range(10)) - tpl_byt = tuple(x.encode('utf-8') if randint(0, 1) == 1 else x for x in tpl) - - self.assertIsInstance( - utils.ensure_string(tpl), - list - ) - + tpl_byt = tuple(x.encode("utf-8") if randint(0, 1) == 1 else x for x in tpl) + + self.assertIsInstance(utils.ensure_string(tpl), list) + self.assertEqual( - list(tpl), - utils.ensure_string(tpl), - utils.ensure_string(tpl_byt) + list(tpl), utils.ensure_string(tpl), utils.ensure_string(tpl_byt) ) - + st = set(self.faker.pystr() for _ in range(10)) - st_byt = set(x.encode('utf-8') if randint(0, 1) == 1 else x for x in st) - - self.assertIsInstance( - utils.ensure_string(st), - list - ) - self.assertEqual( - list(st), - utils.ensure_string(st), - utils.ensure_string(st_byt) - ) - + st_byt = set(x.encode("utf-8") if randint(0, 1) == 1 else x for x in st) + + self.assertIsInstance(utils.ensure_string(st), list) + self.assertEqual(list(st), utils.ensure_string(st), utils.ensure_string(st_byt)) + dct = dict({self.faker.pystr(): self.faker.pystr() for _ in range(10)}) - dct_byt = dict({x.encode('utf-8') if randint(0, 1) == 1 else x: y.encode('utf-8') if randint(0, 1) == 1 else y for x, y in dct.items()}) - - self.assertIsInstance( - utils.ensure_string(dct), - dict - ) - self.assertEqual( - dct, - utils.ensure_string(dct), - utils.ensure_string(dct_byt) + dct_byt = dict( + { + x.encode("utf-8") + if randint(0, 1) == 1 + else x: y.encode("utf-8") + if randint(0, 1) == 1 + else y + for x, y in dct.items() + } ) - + + self.assertIsInstance(utils.ensure_string(dct), dict) + self.assertEqual(dct, utils.ensure_string(dct), utils.ensure_string(dct_byt)) + def test_ensure_bytestring(self): - self.assertEqual( - utils.ensure_bytestring(""), - utils.ensure_bytestring(b"") - ) - - self.assertEqual( - None, - utils.ensure_bytestring(None) - ) + self.assertEqual(utils.ensure_bytestring(""), utils.ensure_bytestring(b"")) + + self.assertEqual(None, utils.ensure_bytestring(None)) for _ in range(100): - val = self.faker.pystr().encode('utf-8') - + val = self.faker.pystr().encode("utf-8") + self.assertEqual( utils.ensure_bytestring(val), - utils.ensure_bytestring(val.decode('utf-8')) + utils.ensure_bytestring(val.decode("utf-8")), ) - - - - lst = list([self.faker.pystr().encode('utf-8') for _ in range(10)]) - lst_byt = list([x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst]) - - self.assertIsInstance( - utils.ensure_bytestring(lst), - list - ) + + lst = list([self.faker.pystr().encode("utf-8") for _ in range(10)]) + lst_byt = list([x.decode("utf-8") if randint(0, 1) == 1 else x for x in lst]) + + self.assertIsInstance(utils.ensure_bytestring(lst), list) self.assertEqual( - lst, - utils.ensure_bytestring(lst), - utils.ensure_bytestring(lst_byt) - ) - - tpl = tuple(self.faker.pystr().encode('utf-8') for _ in range(10)) - tpl_byt = tuple(x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst) - - self.assertIsInstance( - utils.ensure_bytestring(tpl), - list + lst, utils.ensure_bytestring(lst), utils.ensure_bytestring(lst_byt) ) - + + tpl = tuple(self.faker.pystr().encode("utf-8") for _ in range(10)) + tpl_byt = tuple(x.decode("utf-8") if randint(0, 1) == 1 else x for x in lst) + + self.assertIsInstance(utils.ensure_bytestring(tpl), list) + self.assertEqual( - list(tpl), - utils.ensure_bytestring(tpl), - utils.ensure_bytestring(tpl_byt) - ) - - st = set(self.faker.pystr().encode('utf-8') for _ in range(10)) - st_byt = set(x.decode('utf-8') if randint(0, 1) == 1 else x for x in lst) - - self.assertIsInstance( - utils.ensure_bytestring(st), - list + list(tpl), utils.ensure_bytestring(tpl), utils.ensure_bytestring(tpl_byt) ) + + st = set(self.faker.pystr().encode("utf-8") for _ in range(10)) + st_byt = set(x.decode("utf-8") if randint(0, 1) == 1 else x for x in lst) + + self.assertIsInstance(utils.ensure_bytestring(st), list) self.assertEqual( - list(st), - utils.ensure_bytestring(st), - utils.ensure_bytestring(st_byt) - ) - - dct = dict({self.faker.pystr().encode('utf-8'): self.faker.pystr().encode('utf-8') for _ in range(10)}) - dct_byt = dict({x.decode('utf-8') if randint(0, 1) == 1 else x: y.decode('utf-8') if randint(0, 1) == 1 else y for x, y in dct.items()}) - - self.assertIsInstance( - utils.ensure_bytestring(dct), - dict + list(st), utils.ensure_bytestring(st), utils.ensure_bytestring(st_byt) ) - self.assertEqual( - dct, - utils.ensure_bytestring(dct), - utils.ensure_bytestring(dct_byt) + + dct = dict( + { + self.faker.pystr().encode("utf-8"): self.faker.pystr().encode("utf-8") + for _ in range(10) + } ) - - def test_concat_to_string(self): - self.assertEqual( - utils.concat_to_string(""), - utils.concat_to_string(b"") + dct_byt = dict( + { + x.decode("utf-8") + if randint(0, 1) == 1 + else x: y.decode("utf-8") + if randint(0, 1) == 1 + else y + for x, y in dct.items() + } ) + + self.assertIsInstance(utils.ensure_bytestring(dct), dict) self.assertEqual( - utils.concat_to_string( - None - ), - "" + dct, utils.ensure_bytestring(dct), utils.ensure_bytestring(dct_byt) ) - + + def test_concat_to_string(self): + self.assertEqual(utils.concat_to_string(""), utils.concat_to_string(b"")) + self.assertEqual(utils.concat_to_string(None), "") + for _ in range(100): strs = [self.faker.pystr() for _ in range(randint(10, 100))] - strs_byt = [x.encode('utf-8') if randint(0, 1) == 1 else x for x in strs] - + strs_byt = [x.encode("utf-8") if randint(0, 1) == 1 else x for x in strs] + self.assertEqual( utils.concat_to_string(*strs), utils.concat_to_string(*strs_byt), - "".join(strs) + "".join(strs), ) - + def test_concat_to_bytestring(self): self.assertEqual( - utils.concat_to_bytestring(b""), - utils.concat_to_bytestring("") + utils.concat_to_bytestring(b""), utils.concat_to_bytestring("") ) - self.assertEqual( - utils.concat_to_bytestring( - None - ), - b"" - ) - + self.assertEqual(utils.concat_to_bytestring(None), b"") + for _ in range(100): - byts = [self.faker.pystr().encode('utf-8') for _ in range(randint(10, 100))] - byts_str = [x.decode('utf-8') if randint(0, 1) == 1 else x for x in byts] - + byts = [self.faker.pystr().encode("utf-8") for _ in range(randint(10, 100))] + byts_str = [x.decode("utf-8") if randint(0, 1) == 1 else x for x in byts] + self.assertEqual( utils.concat_to_bytestring(*byts), utils.concat_to_bytestring(*byts_str), - b"".join(byts) + b"".join(byts), ) + class TestLongConversion(unittest.TestCase): def test_expected_deflation(self): """ @@ -236,16 +238,16 @@ def test_expected_deflation(self): self.assertEqual( utils.long_to_bytes(before), after, - f"Failed to convert {before} to a byte string " + - "(expected: {after}, got: {long_to_bytes(before)})" + f"Failed to convert {before} to a byte string " + + "(expected: {after}, got: {long_to_bytes(before)})", ) self.assertEqual( deflate_long(before), utils.long_to_bytes(before), - f"The comparative function failed to deliver the same result as built-in" + "The comparative function failed to deliver the same result as built-in", ) - + def test_expected_inflation(self): """ Ensure the built-in function handles inflation as expected @@ -255,16 +257,16 @@ def test_expected_inflation(self): self.assertEqual( utils.bytes_to_long(before), after, - f"Failed to convert {before} to a byte string " + - f"(expected: {after}, got: {utils.bytes_to_long(before)})" + f"Failed to convert {before} to a byte string " + + f"(expected: {after}, got: {utils.bytes_to_long(before)})", ) - + self.assertEqual( inflate_long(before), utils.bytes_to_long(before), - f"The comparative function failed to deliver the same result as built-in" + "The comparative function failed to deliver the same result as built-in", ) - + def test_expected_exception(self): """ Ensure appropriate exceptions are thrown when the input is invalid @@ -273,8 +275,8 @@ def test_expected_exception(self): utils.long_to_bytes(-1) with self.assertRaises(TypeError): - utils.long_to_bytes('one') - + utils.long_to_bytes("one") + def test_random_values(self): """ Extend testing with random results, comparing to the established function. @@ -282,36 +284,29 @@ def test_random_values(self): start_length = 16 for _ in range(15): print(start_length) - + for _ in range(10): - value = randint(2**start_length-1, 2**start_length) - + value = randint(2**start_length - 1, 2**start_length) + builtin = utils.long_to_bytes(value) compare = deflate_long(value) - - self.assertEqual( - builtin, - compare - ) - - self.assertEqual( - utils.bytes_to_long(compare), - inflate_long(builtin) - ) - - start_length = start_length*2 - + + self.assertEqual(builtin, compare) + + self.assertEqual(utils.bytes_to_long(compare), inflate_long(builtin)) + + start_length = start_length * 2 + + class TestNonceGeneration(unittest.TestCase): def test_nonce_generation(self): """ Ensure the nonce is generated correctly """ for _ in range(10): - self.assertIsInstance( - utils.generate_secure_nonce(), - str - ) - + self.assertIsInstance(utils.generate_secure_nonce(), str) + + class TestHashGeneration(unittest.TestCase): def test_hashing_functions(self): """ @@ -324,14 +319,15 @@ def test_hashing_functions(self): sha256_3 = utils.sha256_fingerprint(bytestring, True) sha512_2 = utils.sha512_fingerprint(bytestring, False) sha512_3 = utils.sha512_fingerprint(bytestring, True) - + self.assertEqual(md5, md5_2) self.assertEqual(sha256, sha256_2) self.assertEqual(sha512, sha512_2) - + self.assertEqual(f"MD5:{md5}", md5_3) self.assertEqual(f"SHA256:{sha256}", sha256_3) self.assertEqual(f"SHA512:{sha512}", sha512_3) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + + +if __name__ == "__main__": + unittest.main() From 7d4de1faf095883c344b17947080797595e6c165 Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Wed, 20 Jul 2022 15:02:59 +0000 Subject: [PATCH 5/6] Minor bugfixes Updated documentation --- .github/workflows/publish.yml | 1 + .pylintrc | 4 +- README.md | 464 ++++++++++++++++++---------------- src/sshkey_tools/cert.py | 134 ++++++++-- src/sshkey_tools/fields.py | 26 +- src/sshkey_tools/keys.py | 21 +- src/sshkey_tools/utils.py | 33 +-- testcert | 1 - tests/test_keypairs.py | 32 +-- 9 files changed, 427 insertions(+), 289 deletions(-) delete mode 100644 testcert diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e5a1d5e..2677b16 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -31,6 +31,7 @@ jobs: - name: Rebuild documentation run: | + pip3 install pdoc3 pdoc --html src/sshkey_tools/ --force --output-dir docs mv docs/sshkey_tools/* docs/ rm -r docs/sshkey_tools \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index ff33277..3f55642 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,5 @@ [MASTER] -load-plugins=pylint_report +; load-plugins=pylint_report [REPORTS] -output-format=pylint_report.CustomJsonReporter \ No newline at end of file +; output-format=pylint_report.CustomJsonReporter \ No newline at end of file diff --git a/README.md b/README.md index 90a7aee..f5ec8ab 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,34 @@ # sshkey-tools -Python and CLI tools for managing OpenSSH keypairs and certificates +Python package for managing OpenSSH keypairs and certificates ([protocol.CERTKEYS](https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.certkeys)). Supported functionality includes: + +[TOC] + +# Features +### SSH Keys +- Supports RSA, DSA, ECDSA and ED25519 keys +- Import existing keys from file, string, byte data or [pyca/cryptography](https://github.com/pyca/cryptography) class +- Generate new keys +- Get public key from private keys +- Sign bytestrings with private keys +- Export to file, string or bytes +- Generate fingerprint + +### OpenSSH Certificates +- Supports RSA, DSA, ECDSA and ED25519 certificates +- Import existing certificates from file, string or bytes +- Verify certificate signature against internal or separate public key +- Create new certificates from CA private key and subject public key +- Create new certificates using old certificate as template +- Sign certificates +- Export certificates to file, string or bytes + +# Roadmap +- [x] Rewrite certificate field functionality for simpler usage +- [ ] Re-add functionality for changing RSA hash method +- [ ] Add CLI functionality +- [ ] Convert to/from putty format (keys only) + # Installation @@ -21,16 +49,11 @@ pip3 install ./ ``` # Documentation +You can find the full documentation at [scheiblingco.github.io/sshkey-tools/](https://scheiblingco.github.io/sshkey-tools/) -[scheiblingco.github.io/sshkey-tools/](https://scheiblingco.github.io/sshkey-tools/) - -# Basic usage - -## SSH Keypairs - -### Generate keys - +## SSH Keypairs (generating, loading, exporting) ```python +# Import the certificate classes from sshkey_tools.keys import ( RsaPrivateKey, DsaPrivateKey, @@ -38,267 +61,270 @@ from sshkey_tools.keys import ( Ed25519PrivateKey, EcdsaCurves ) +# +## Generating keys +# -# RSA -# By default, RSA is generated with a 4096-bit keysize -rsa_private = RsaPrivateKey.generate() +# For all keys except ED25519, the key size/curve can be manually specified +# Generate RSA (default is 4096 bits) +rsa_priv = RsaPrivateKey.generate() +rsa_priv = RsaPrivateKey.generate(2048) -# You can also specify the key size -rsa_private = RsaPrivateKey.generate(bits) +# Generate DSA keys (since SSH only supports 1024-bit keys, this is the default) +dsa_priv = DsaPrivateKey.generate() -# DSA -# Since OpenSSH only supports 1024-bit keys, this is the default -dsa_private = DsaPrivateKey.generate() +# Generate ECDSA keys (The default curve is P521) +ecdsa_priv = EcdsaPrivateKey.generate() +ecdsa_priv = EcdsaPrivateKey.generate(EcdsaCurves.P256) -# ECDSA -# The default curve is P521 -ecdsa_private = EcdsaPrivateKey.generate() +# Generate ED25519 keys (fixed key size) +ed25519_priv = Ed25519PrivateKey.generate() -# You can also manually specify a curve -ecdsa_private = EcdsaPrivateKey.generate(EcdsaCurves.P256) +# +## Loading keys +# -# ED25519 -# The ED25519 keys are always a fixed size -ed25519_private = Ed25519PrivateKey.generate() +# Keys can be loaded either via the specific class: +rsa_priv = RsaPrivateKey.from_file("/path/to/key", "OptionalSecurePassword") -# Public keys -# The public key for any given private key is in the public_key parameter -rsa_pub = rsa_private.public_key -``` +# or via the general class, in case the type is not known in advance +rsa_priv = PrivateKey.from_file("/path/to/key", "OptionalSecurePassword") -### Load keys +# The import functions are .from_file(), .from_string() and .from_class() and are valid for both PublicKey and PrivateKey-classes +rsa_priv = PrivateKey.from_string("-----BEGIN OPENSSH PRIVATE KEY...........END -----", "OptionalSecurePassword") +rsa_priv = PrivateKey.from_class(pyca_cryptography_class) -You can load keys either directly with the specific key classes (RsaPrivateKey, DsaPrivateKey, etc.) or the general PrivateKey class +# The different keys can also be loaded from their numbers, e.g. RSA Pubkey: +rsa_priv = PublicKey.from_numbers(65537, 123123123....1) -```python -from sshkey_tools.keys import ( - PrivateKey, - PublicKey, - RsaPrivateKey, - RsaPublicKey -) +# +## Key functionality +# -# Load a private key with a specific class -rsa_private = RsaPrivateKey.from_file('path/to/rsa_key') +# The public key for any loaded or generated private key is available in the .public_key attribute +ed25519_pub = ed25519_priv.public_key -# Load a private key with the general class -rsa_private = PrivateKey.from_file('path/to/rsa_key') -print(type(rsa_private)) -"" +# The private keys can be exported using to_bytes, to_string or to_file +rsa_priv.to_bytes("OptionalSecurePassword") +rsa_priv.to_string("OptionalSecurePassword", "utf-8") +rsa_priv.to_file("/path/to/file", "OptionalSecurePassword", "utf-8") -# Public keys can be loaded in the same way -rsa_pub = RsaPublicKey.from_file('path/to/rsa_key.pub') -rsa_pub = PublicKey.from_file('path/to/rsa_key.pub') +# The public keys also have .to_string() and .to_file(), but .to_bytes() is divided into .serialize() and .raw_bytes() +# The comment can be set before export by changing the public_key.comment-attribute +rsa_priv.public_key.comment = "Comment@Comment" -print(type(rsa_private)) -"" +# This will return the serialized public key as found in an OpenSSH keyfile +rsa_priv.public_key.serialize() +b"ssh-rsa AAAA......... Comment@Comment" -# Public key objects are automatically created for any given private key -# negating the need to load them separately -rsa_pub = rsa_private.public_key +# This will return the raw bytes of the key (base64-decoded middle portion) +rsa_priv.public_key.raw_bytes() +b"\0xc\0a\........" +``` -# Load a key from a pyca/cryptography class privkey_pyca/pubkey_pyca -rsa_private = PrivateKey.from_class(privkey_pyca) -rsa_public = PublicKey.from_class(pubkey_pyca) +## SSH Key Signatures +The loaded private key objects can be used to sign bytestrings, and the public keys can be used to verify signatures on those +```python +from sshkey_tools.keys import RsaPrivateKey, RsaPublicKey -# You can also load private and public keys from strings or bytes (file contents) -with open('path/to/rsa_key', 'r', 'utf-8') as file: - rsa_private = PrivateKey.from_string(file.read()) +signable_data = b'This is a message that will be signed' -with open('path/to/rsa_key', 'rb') as file: - rsa_private = PrivateKey.from_bytes(file.read()) +privkey = RsaPrivateKey.generate() +pubkey = RsaPrivateKey.public_key -# RSA, DSA and ECDSA keys can be loaded from the public/private numbers and/or parameters -rsa_public = RsaPublicKey.from_numbers( - e=65537, - n=12.........811 -) +# Sign the data +signature = privkey.sign(signable_data) -rsa_private = RsaPrivateKey.from_numbers( - e=65537, - n=12......811, - d=17......122 -) +# Verify the signature (Throws exception if invalid) +pubkey.verify(signable_data, signature) ``` -## SSH Certificates - -### Attributes - -| Attribute | Type | Key | Example Value | Description | -| ---------------- | ------------------- | ---------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Certificate Type | Integer (1/2) | cert_type | 1 | The type of certificate, 1 for User and 2 for Host. Can also be defined as sshkey_tools.fields.CERT_TYPE.USER or sshkey_tools.fields.CERT_TYPE.HOST | -| Serial | Integer | serial | 11223344 | The serial number for the certificate, a 64-bit integer | -| Key ID | String | key_id | someuser@somehost | The key identifier, can be set to any string, for example username, email or other unique identifier | -| Principals | List | principals | ['zone-webservers', 'server-01'] | The principals for which the certificate is valid, this needs to correspond to the allowed principals on the OpenSSH Server-side. Only valid for User certificates | -| Valid After | Integer | valid_after | datetime.now() | The datetime object or unix timestamp for when the certificate validity starts | -| Valid Before | Integer | valid_before | datetime.now() + timedelta(hours=12) | The datetime object or unix timestamp for when the certificate validity ends | -| Critical Options | Dict | critical_options | {'source-address': '1.2.3.4/8'} | Options set on the certificate that the OpenSSH server cannot choose to ignore (critical). Only valid on user certificates. Valid options are force-command (for limiting the user to a certain shell, e.g. sftp-internal), source-address (to limit the source IPs the user can connect from) and verify-required (to require the user to touch a hardware key before usage) | -| Extensions | Dict/Set/List/Tuple | extensions | {'permit-X11-forwarding', 'permit-port-forwarding'} | Extensions that the certificate holder is allowed to use. Valid options are no-touch-required, permit-X11-forwarding, permit-agent-forwarding, permit-port-forwarding, permit-pty, permit-user-rc | +## OpenSSH Certificates +### Introduction +Certificates are a way to handle access management/PAM for OpenSSH with the ability to dynamically grant access during a specific time, to specific servers and/or with specific attributes. There are a couple of upsides to using certificates instead of public/private keys, mainly: + +- Additional Security: Certificate authentication for OpenSSH is built as an extension of public key authentication, enabling additional features on top of key-based access control. +- Short-term access: The user has to request a certificate for their keypair, which together with the private key grants access to the server. Without the certificate the user can't connect to the server - giving you control over how, when and from where the user can connect. +- Hostkey Verification: Certificiates can be issued for the OpenSSH Server, adding the CA public key to the clients enables you to establish servers as trusted without the hostkey warning. +- RBAC: Control which servers or users (principals) a keypair has access to, and specify the required principals for access to certain functionality on the server side. +- Logging: Key ID and Serial fields for tracking of issued certificates +- CRL: Revoke certificates prematurely if they are compromised + +### Structure +The original OpenSSH certificate format is a block of parameters, encoded and packed to a bytestring. In this package, the fields have been divided into three parts. For a more detailed information about the format, see [PROTOCOL.certkeys](https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.certkeys). + +### Certificate Header +|Attribute|Type(Length)|Key|Example Value|Description| +|---|---|---|---|---| +|Public Key/Certificate type|string(fixed)|pubkey_type|ssh-rsa-sha2-512-cert-v01@openssh.com|The private key (and certificate) type, derived from the public key for which the certificate is created (Automatically set upon creation)| +|Subject public key|bytestring(variable)|public_key|\x00\x00\x00..........|The public key for which the certificate is created (Automatically set upon creation)| +|Nonce|string|nonce(variable, typically 16 or 32 bytes)|abcdefghijklmnopqrstuvwxyz|A random string included to make attacks that depend on inducing collisions in the signature hash infeasible. (Default is automatically set, can be changed with Certificate.header.nonce = "abcdefg..."| + +### Certificate Fields +|Attribute|Type(Length)|Key|Example Value|Description| +|---|---|---|---|---| +|Serial|Integer(64-bit)|serial|1234567890|An optional certificate serial number set by the CA to provide an abbreviated way to refer to certificates from that CA. If a CA does not wish to number its certificates, it must set this field to zero.| +|Certificate type|Integer(1 or 2)|cert_type|1|The type of the certificate, 1 for user certificates, 2 for host certificates| +|Key ID|string(variable)|key_id|someuser@somehost|Free-form text field that is filled in by the CA at the time of signing; the intention is that the contents of this field are used to identify the identity principal in log messages.| +|Valid Principals|List(string(variable))|principals|['some-user', 'some-group', production-webservers']|These principals list the names for which this certificate is valid hostnames for SSH_CERT_TYPE_HOST certificates and usernames for SH_CERT_TYPE_USER certificates. As a special case, a zero-length "valid principals" field means the certificate is valid for any principal of the specified type.| +|Valid After|Timestamp|valid_after|datetime.now()|Timestamp for the start of the validity period for the certificate| +|Valid Before|Timestamp|valid_before|datetime.now()+timedelta(hours=8) or 1658322031|Timestamp for the end of the validity period for the certificate. Needs to be larger than valid_after| +|Critical Options|Dict(string, string)|critical_options|[]|Zero or more of the available critical options (see below)| +|Extensions|Dict(string, string)/List/Tuple/Set|extensions|[]|Zero or more of the available extensions (see below)| + + +#### Critical Options +|Name|Format|Description| +|---|---|---| +|force-command|string|Specifies a command that is executed (replacing any the user specified on the ssh command-line) whenever this key is used for authentication.| +|source-address|string|Comma-separated list of source addresses from which this certificate is accepted for authentication. Addresses are specified in CIDR format (nn.nn.nn.nn/nn or hhhh::hhhh/nn). If this option is not present, then certificates may be presented from any source address.| +|verify-required|empty|Flag indicating that signatures made with this certificate must assert FIDO user verification (e.g. PIN or biometric). This option only makes sense for the U2F/FIDO security key types that support this feature in their signature formats.| + +#### Extensions +|Name|Format|Description| +|---|---|---| +|no-touch-required|empty|Flag indicating that signatures made with this certificate need not assert FIDO user presence. This option only makes sense for the U2F/FIDO security key types that support this feature in their signature formats.| +|permit-X11-forwarding|empty|Flag indicating that X11 forwarding should be permitted. X11 forwarding will be refused if this option is absent.| +|permit-agent-forwarding|empty|Flag indicating that agent forwarding should be allowed. Agent forwarding must not be permitted unless this option is present.| +|permit-port-forwarding|empty|Flag indicating that port-forwarding should be allowed. If this option is not present, then no port forwarding will be allowed.| +|permit-pty|empty|Flag indicating that PTY allocation should be permitted. In the absence of this option PTY allocation will be disabled.| +|permit-user-rc|empty|Flag indicating that execution of ~/.ssh/rc should be permitted. Execution of this script will not be permitted if this option is not present.| + + +### Certificate Body +|Attribute|Type(Length)|Key|Example Value|Description| +|---|---|---|---|---| +|Reserved|string(0)|reserved|""|Reserved for future use, must be empty (automatically set upon signing)| +|CA Public Key|bytestring(variable)|ca_pubkey|\x00\x00\x00..........|The public key of the CA that issued this certificate (automatically set upon signing)| +|Signature|bytestring(variable)|signature|\x00\x00\x00..........|The signature of the certificate, created by the CA (automatically set upon signing)| + +## Creating, signing and verifying certificates +```python +# Every certificate needs two parts, the subject (user or host) public key and the CA Private key +from sshkey_tools.cert import SSHCertificate, CertificateFields, Ed25519Certificate +from sshkey_tools.keys import Ed25519PrivateKey +from datetime import datetime, timedelta -### Certificate creation +subject_pubkey = Ed25519PrivateKey.generate().public_key +ca_privkey = Ed25519PrivateKey.generate() -The basis for a certificate is the public key for the subject (User/Host), and bases the format of the certificate on that. +# There are multiple ways to create a certificate, either by creating the certificate body field object first and then creating the certificate, or creating the certificate and setting the fields one by one -```python -from datetime import datetime, timedelta -from cryptography.hazmat.primitives import ( - serialization as crypto_serialization, - hashes as crypto_hashes -) -from cryptography.hazmat.primitives.asymmetric import padding as crypto_padding -from sshkey_tools.keys import PublicKey, RsaPrivateKey, RsaAlgs -from sshkey_tools.cert import SSHCertificate -from sshkey_tools.exceptions import SignatureNotPossibleException - -user_pubkey = PublicKey.from_file('path/to/user_key.pub') -ca_privkey = PrivateKey.from_file('path/to/ca_key') - -# You can create a certificate with a dict of pre-set options -cert_opts = { - 'cert_type': 1, - 'serial': 12345, - 'key_id': "my.user@mycompany.com", - 'principals': [ - 'webservers-dev', - 'webservers-prod', - 'servername01' +# Create certificate body fields +cert_fields = CertificateFields( + serial=1234567890, + cert_type=1, + key_id="someuser@somehost", + principals=["some-user", "some-group", "production-webservers"], + valid_after=datetime.now(), + valid_before=datetime.now() + timedelta(hours=8), + critical_options=[], + extensions=[ + "permit-pty", + "permit-X11-forwarding", + "permit-agent-forwarding", ], - 'valid_after': datetime.now(), - 'valid_before': datetime.now() + timedelta(hours=12), - 'critical_options': {}, - 'extensions': [ - 'permit-pty', - 'permit-user-rc', - 'permit-port-forwarding' - ] - -} - -# Create a signable certificate from a PublicKey class -certificate = SSHCertificate.from_public_class( - user_pubkey, - ca_privkey, - **cert_opts ) -# You can also create the certificate in steps -certificate = SSHCertificate.from_public_class( - user_pubkey +# Create certificate from existing fields +certificate = SSHCertificate( + subject_pubkey=subject_pubkey, + ca_privkey=ca_privkey, + fields=cert_fields, ) -# Set the CA private key used to sign the certificate -certificate.set_ca(ca_privkey) - -# Set or update the options one-by-one -for key, value in cert_opts.items(): - certificate.set_opt(key, value) - -# Via a dict -certificate.set_opts(**cert_opts) +# Start with a blank certificate by calling the general class +certificate = SSHCertificate.create( + subject_pubkey=subject_pubkey, + ca_privkey=ca_privkey +) -# Or via parameters -certificate.set_opts( - cert_type=1, - serial=12345, - key_id='my.user@mycompany.com', - principals=['zone-webservers'], - valid_after=datetime.now(), - valid_before=datetime.now() + timedelta(hours=12), - critical_options={}, - extensions={} +# You can also call the specialized classes directly, for the general class the .create-function needs to be used +certificate = Ed25519Certificate( + subject_pubkey=subject_pubkey, + ca_privkey=ca_privkey ) +# Manually set the fields +certificate.fields.serial = 1234567890 +certificate.fields.cert_type = 1 +certificate.fields.key_id = "someuser@somehost" +certificate.fields.principals = ["some-user", "some-group", "production-webservers"] +certificate.fields.valid_after = datetime.now() +certificate.fields.valid_before = datetime.now() + timedelta(hours=8) +certificate.fields.critical_options = [] +certificate.fields.extensions = [ + "allow-pty", + "permit-X11-forwarding", + "permit-agent-forwarding", +] + # Check if the certificate is ready to be signed -# Will return True or an exception certificate.can_sign() -# Catch exceptions -try: - certificate.can_sign() -except SignatureNotPossibleException: - ... - # Sign the certificate certificate.sign() -# For certificates signed by an RSA key, you can choose the hashing algorithm -# to be used for creating the hash of the certificate data before signing -certificate.sign( - hash_alg=RsaAlgs.SHA512 -) - -# If you want to verify the signature after creation, -# you can do so with the verify()-method -# -# Please note that a public key should always be provided -# to this function if the certificate was not just created, -# since an attacker very well could have replaced CA public key -# and signature with their own -# -# The method will return None if successful, and InvalidSignatureException -# if the signature does not match the data +# Verify the certificate against the included public key (insecure, but useful for testing) certificate.verify() -certificate.verify(ca_privkey.public_key) - - -# If you prefer to verify manually, you can use the CA public key object -# from sshkey_tools or the key object from pyca/cryptography - -# PublicKey -ca_pubkey = PublicKey.from_file('path/to/ca_pubkey') +# Verify the certificate against a public key that is not included in the certificate +certificate.verify(ca_privkey.public_key) -ca_pubkey.verify( - certificate.get_signable_data(), - certificate.signature.value, - RsaAlgs.SHA256.value[1] -) +# Raise an exception if the certificate is invalid +certificate.verify(ca_privkey.public_key, True) -# pyca/cryptography RsaPrivateKey -with open('path/to/ca_pubkey', 'rb') as file: - crypto_ca_pubkey = crypto_serialization.load_ssh_public_key(file.read()) +# Export the certificate to file/string +certificate.to_file('filename-cert.pub') +cert_str = certificate.to_string() -crypto_ca_pubkey.verify( - certificate.get_signable_data(), - certificate.signature.value, - crypto_padding.PKCS1v15(), - crypto_hashes.SHA256() -) - -# You now have an OpenSSH Certificate -# Export it to file, string or bytes -certificate.to_file('path/to/user_key-cert.pub') -cert_string = certificate.to_string() -cert_bytes = certificate.to_bytes() ``` - -### Load an existing certificate - -Certificates can be loaded from file, a string/bytestring with file contents -or the base64-decoded byte data of the certificate - +## Loading, re-creating and verifying existing certificates ```python +from sshkey_tools.cert import SSHCertificate, CertificateFields, Ed25519Certificate from sshkey_tools.keys import PublicKey, PrivateKey -from sshkey_tools.cert import SSHCertificate, RsaCertificate +from datetime import datetime, timedelta -# Load an existing certificate -certificate = SSHCertificate.from_file('path/to/user_key-cert.pub') +# Load a certificate from file or string +# This will return the correct certificate type based on the contents of the certificate +certificate = SSHCertificate.from_file('filename-cert.pub') +certificate = SSHCertificate.from_string(cert_str) -# or -certificate = RsaCertificate.from_file('path/to/user_key-cert.pub') +type(certificate) # sshkey_tools.cert.Ed25519Certificate -# Verify the certificate with a CA public key -ca_pubkey = PublicKey.from_file('path/to/ca_key.pub') -certificate.verify(ca_pubkey) +# Verify the certificate signature against the included public key (insecure, but useful for testing) +certificate.verify() -# Create a new certificate with duplicate values from existing certificate -# You can use existing or previously issued certificates as templates -# for creating new ones -certificate = SSHCertificate.from_file('path/to/user_key-cert.pub') -ca_privkey = PrivateKey.from_file('path/to/ca_privkey') +# Verify the certificate signature against a public key +pubkey = PublicKey.from_file('filename-pubkey.pub') +certificate.verify(pubkey) -certificate.set_ca(ca_privkey) +# Raise an exception if the certificate is invalid +certificate.verify(pubkey, True) + +# Use the loaded certificate as a template to create a new one +new_ca = PrivateKey.from_file('filename-ca') +certificate.replace_ca(new_ca) certificate.sign() -certificate.to_file('path/to/user_key-cert2.pub') + ``` + +## Changelog +### 0.9 +- Adjustments to certificate field handling for easier usage/syntax autocompletion +- Updated testing +- Removed method for changing RSA hash method (now default SHA512) + +### 0.8.2 +- Fixed bug where an RSA certificate would send the RSA alg to the sign() function of another key type + +### 0.8.1 +- Changed versioning for out-of-github installation/packaging +- Moved documentation to HTML (PDOC3) +- Added verification of certificate signature +- Added option to choose RSA hashing algorithm for signing +- Removed test files +- Added documentation deployment CD for GH pages + +### 0.8 +- Initial public release diff --git a/src/sshkey_tools/cert.py b/src/sshkey_tools/cert.py index 602ab19..b21c4b8 100644 --- a/src/sshkey_tools/cert.py +++ b/src/sshkey_tools/cert.py @@ -120,9 +120,9 @@ def validate(self): for key in self.getattrs(): if not getattr(self, key).validate(): list( - ex.append(f"{type(x)}: {str(x)}") - for x in getattr(self, key).exception - if isinstance(x, Exception) + ex.append(f"{type(x)}: {str(x)}") + for x in getattr(self, key).exception + if isinstance(x, Exception) ) return True if len(ex) == 0 else ex @@ -138,7 +138,7 @@ def decode(cls, data: bytes) -> Tuple["Fieldset", bytes]: cl_instance = cls() for item in cls.DECODE_ORDER: decoded, data = getattr(cl_instance, item).from_decode(data) - setattr(item, decoded) + setattr(cl_instance, item, decoded) return cl_instance, data @@ -146,6 +146,7 @@ def decode(cls, data: bytes) -> Tuple["Fieldset", bytes]: @dataclass class CertificateHeader(Fieldset): """Header fields for the certificate""" + public_key: _FIELD.PublicKeyField = _FIELD.PublicKeyField.factory pubkey_type: _FIELD.PubkeyTypeField = _FIELD.PubkeyTypeField.factory nonce: _FIELD.NonceField = _FIELD.NonceField.factory @@ -161,7 +162,7 @@ def __bytes__(self): def decode(cls, data: bytes) -> Tuple["CertificateHeader", bytes]: cl_instance, data = super().decode(data) - target_class = CERT_TYPES[cl_instance.get('pubkey_type')] + target_class = CERT_TYPES[cl_instance.get("pubkey_type")] public_key, data = getattr(_FIELD, target_class[1]).from_decode(data) cl_instance.public_key = public_key @@ -173,6 +174,7 @@ def decode(cls, data: bytes) -> Tuple["CertificateHeader", bytes]: # pylint: disable=too-many-instance-attributes class CertificateFields(Fieldset): """Information fields for the certificate""" + serial: _FIELD.SerialField = _FIELD.SerialField.factory cert_type: _FIELD.CertificateTypeField = _FIELD.CertificateTypeField.factory key_id: _FIELD.KeyIdField = _FIELD.KeyIdField.factory @@ -209,6 +211,7 @@ def __bytes__(self): @dataclass class CertificateFooter(Fieldset): """Footer fields and signature for the certificate""" + reserved: _FIELD.ReservedField = _FIELD.ReservedField.factory ca_pubkey: _FIELD.CAPublicKeyField = _FIELD.CAPublicKeyField.factory signature: _FIELD.SignatureField = _FIELD.SignatureField.factory @@ -292,12 +295,34 @@ def create( subject_pubkey: PublicKey = None, ca_privkey: PrivateKey = None, fields: CertificateFields = CertificateFields, - ): + header: CertificateHeader = CertificateHeader, + footer: CertificateFooter = CertificateFooter, + ) -> "SSHCertificate": + """ + Creates a new certificate from the given parameters. + + Args: + subject_pubkey (PublicKey, optional): The subject public key. Defaults to None. + ca_privkey (PrivateKey, optional): The CA private key. Defaults to None. + fields (CertificateFields, optional): The CertificateFields object containing the + body fields. Defaults to blank CertificateFields. + header (CertificateHeader, optional): The certificate header. + Defaults to new CertificateHeader. + footer (CertificateFooter, optional): The certificate footer. + Defaults to new CertificateFooter. + + Returns: + SSHCertificate: A SSHCertificate subclass depending on the type of subject_pubkey + """ cert_class = subject_pubkey.__class__.__name__.replace( "PublicKey", "Certificate" ) return globals()[cert_class]( - subject_pubkey=subject_pubkey, ca_privkey=ca_privkey, fields=fields + subject_pubkey=subject_pubkey, + ca_privkey=ca_privkey, + fields=fields, + header=header, + footer=footer, ) @classmethod @@ -362,9 +387,22 @@ def from_file(cls, path: str, encoding: str = "utf-8"): Returns: SSHCertificate: SSHCertificate child class """ - return cls.from_string(open(path, "r", encoding=encoding).read()) + with open(path, "r", encoding=encoding) as file: + return cls.from_string(file.read()) def get(self, field: str): + """ + Fetch a field from any of the sections of the certificate. + + Args: + field (str): The field name to fetch + + Raises: + _EX.InvalidCertificateFieldException: Invalid field name provided + + Returns: + mixed: The certificate field contents + """ if field in ( self.header.getattrs() + self.fields.getattrs() + self.footer.getattrs() ): @@ -376,7 +414,20 @@ def get(self, field: str): raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") - def set(self, field: str, value): + def set(self, field: str, value) -> None: + """ + Set a field in any of the sections of the certificate. + + Args: + field (str): The field name to set + value (mixed): The value to set the field to + + Raises: + _EX.InvalidCertificateFieldException: Invalid field name provided + + Returns: + mixed: The certificate field contents + """ if self.fields.get(field, False): setattr(self.fields, field, value) return @@ -391,7 +442,28 @@ def set(self, field: str, value): raise _EX.InvalidCertificateFieldException(f"Unknown field {field}") + def replace_ca(self, ca_privkey: PrivateKey): + """ + Replace the certificate authority private key with a new one. + + Args: + ca_privkey (PrivateKey): The new CA private key + """ + self.footer.ca_pubkey = ca_privkey.public_key + self.footer.replace_field( + "signature", _FIELD.SignatureField.from_object(ca_privkey) + ) + def can_sign(self) -> bool: + """ + Check if the certificate can be signed in its current state. + + Raises: + _EX.SignatureNotPossibleException: Exception if the certificate cannot be signed + + Returns: + bool: True if the certificate can be signed + """ valid_header = self.header.validate() valid_fields = self.fields.validate() check_keys = ( @@ -404,16 +476,12 @@ def can_sign(self) -> bool: ) if (valid_header, valid_fields, check_keys) != (True, True, True): + exceptions = [] + exceptions += valid_header if not isinstance(valid_header, bool) else [] + exceptions += valid_fields if not isinstance(valid_fields, bool) else [] + exceptions += check_keys if not isinstance(check_keys, bool) else [] raise _EX.SignatureNotPossibleException( - "\n".join( - valid_header - if isinstance(valid_header, Exception) - else [] + valid_fields - if isinstance(valid_fields, Exception) - else [] + check_keys - if isinstance(check_keys, Exception) - else [] - ) + "\n".join([str(e) for e in exceptions]) ) return True @@ -441,6 +509,32 @@ def sign(self) -> bool: return True raise _EX.NotSignedException("There was an error while signing the certificate") + def verify( + self, public_key: PublicKey = None, raise_on_error: bool = False + ) -> bool: + """Verify the signature on the certificate to make sure the data is not corrupted, + and that the signature comes from the given public key or the key included in the + certificate (insecure, useful for testing only) + + Args: + public_key (PublicKey, optional): The public key to use for verification + raise_on_error (bool, default False): Raise an exception if the certificate is invalid + + Raises: + _EX.InvalidSignatureException: The signature is invalid + """ + if not public_key: + public_key = self.get("ca_pubkey") + + try: + public_key.verify(self.get_signable(), self.footer.get("signature")) + except _EX.InvalidSignatureException as exception: + if raise_on_error: + raise exception + return False + + return True + def to_string(self, comment: str = "", encoding: str = "utf-8"): """Export the certificate to a string @@ -473,16 +567,19 @@ def to_file(self, filename: str, encoding: str = "utf-8"): class RsaCertificate(SSHCertificate): """The RSA Certificate class""" + DEFAULT_KEY_TYPE = "rsa-sha2-512-cert-v01@openssh.com" class DsaCertificate(SSHCertificate): """The DSA Certificate class""" + DEFAULT_KEY_TYPE = "ssh-dss-cert-v01@openssh.com" class EcdsaCertificate(SSHCertificate): """The ECDSA certificate class""" + DEFAULT_KEY_TYPE = "ecdsa-sha2-nistp[curve_size]-cert-v01@openssh.com" def __post_init__(self): @@ -494,4 +591,5 @@ def __post_init__(self): class Ed25519Certificate(SSHCertificate): """The ED25519 certificate class""" + DEFAULT_KEY_TYPE = "ssh-ed25519-cert-v01@openssh.com" diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index a23af5e..c4335e8 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -108,6 +108,12 @@ def __bytes__(self) -> bytes: @classmethod def get_name(cls) -> str: + """ + Fetch the name of the field (identifier format) + + Returns: + str: The name/id of the field + """ return "_".join(re.findall("[A-Z][^A-Z]*", cls.__name__)[:-1]).lower() @classmethod @@ -117,7 +123,8 @@ def __validate_type__(cls, value, do_raise: bool = False) -> Union[bool, Excepti """ if not isinstance(value, cls.DATA_TYPE): ex = _EX.InvalidDataException( - f"Invalid data type for {cls.get_name()} (expected {cls.DATA_TYPE}, got {type(value)})" + f"Invalid data type for {cls.get_name()}" + + f"(expected {cls.DATA_TYPE}, got {type(value)})" ) if do_raise: @@ -144,7 +151,7 @@ def __validate_value__(self) -> Union[bool, Exception]: """ return True - # pylint: disable=no-self-use + # pylint: disable=not-callable def validate(self) -> bool: """ Validates all field contents and types @@ -184,6 +191,7 @@ def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: return cls(value), data @classmethod + # pylint: disable=not-callable def factory(cls, blank: bool = False) -> "CertificateField": """ Factory to create field with default value if set, otherwise empty @@ -544,7 +552,7 @@ def encode(cls, value: Union[list, tuple, set]) -> bytes: cls.__validate_type__(value, True) try: - if sum([not isinstance(item, (str, bytes)) for item in value]) > 0: + if sum(not isinstance(item, (str, bytes)) for item in value) > 0: raise TypeError except TypeError: raise _EX.InvalidFieldDataException( @@ -717,9 +725,8 @@ def __validate_value__(self) -> Union[bool, Exception]: if ensure_string(self.value) not in self.ALLOWED_VALUES: return _EX.InvalidFieldDataException( - "Expected one of the following values: {}".format( - NEWLINE.join(self.ALLOWED_VALUES) - ) + "Expected one of the following values: " + + NEWLINE.join(self.ALLOWED_VALUES) ) return True @@ -1045,7 +1052,8 @@ def __validate_value__(self) -> Union[bool, Exception]: if check < datetime.now(): return _EX.InvalidCertificateFieldException( - "The certificate validity period is invalid (expected a future datetime object or timestamp)" + "The certificate validity period is invalid" + + " (expected a future datetime object or timestamp)" ) return True @@ -1414,6 +1422,7 @@ def from_decode(cls, data: bytes) -> Tuple["RsaSignatureField", bytes]: data, ) + # pylint: disable=unused-argument def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512, **kwargs) -> None: """ Signs the provided data with the provided private key @@ -1501,6 +1510,7 @@ def from_decode(cls, data: bytes) -> Tuple["DsaSignatureField", bytes]: return cls(private_key=None, signature=signature), data + # pylint: disable=unused-argument def sign(self, data: bytes, **kwargs) -> None: """ Signs the provided data with the provided private key @@ -1599,6 +1609,7 @@ def from_decode(cls, data: bytes) -> Tuple["EcdsaSignatureField", bytes]: data, ) + # pylint: disable=unused-argument def sign(self, data: bytes, **kwargs) -> None: """ Signs the provided data with the provided private key @@ -1678,6 +1689,7 @@ def from_decode(cls, data: bytes) -> Tuple["Ed25519SignatureField", bytes]: return cls(private_key=None, signature=signature), data + # pylint: disable=unused-argument def sign(self, data: bytes, **kwargs) -> None: """ Signs the provided data with the provided private key diff --git a/src/sshkey_tools/keys.py b/src/sshkey_tools/keys.py index bd521d0..4efbf63 100644 --- a/src/sshkey_tools/keys.py +++ b/src/sshkey_tools/keys.py @@ -173,7 +173,7 @@ def from_string( Returns: PublicKey: Any of the PublicKey child classes """ - split = ensure_bytestring(data).split(b" ") + split = ensure_bytestring(data, encoding).split(b" ") comment = None if len(split) > 2: comment = split[2] @@ -200,7 +200,20 @@ def from_file(cls, path: str) -> "PublicKey": return cls.from_string(data) @classmethod + # pylint: disable=broad-except def from_bytes(cls, data: bytes) -> "PublicKey": + """ + Loads a public key from byte data + + Args: + data (bytes): The bytestring containing the public key + + Raises: + _EX.InvalidKeyException: Invalid data input + + Returns: + PublicKey: PublicKey subclass depending on the key type + """ for key_class in PUBKEY_MAP.values(): try: key = globals()[key_class].from_raw_bytes(data) @@ -252,8 +265,8 @@ def to_string(self, encoding: str = "utf-8") -> str: """ return " ".join( [ - ensure_string(self.serialize()), - ensure_string(getattr(self, "comment", "")), + ensure_string(self.serialize(), encoding), + ensure_string(getattr(self, "comment", ""), encoding), ] ) @@ -462,7 +475,7 @@ def verify( hash_method (HashMethods): The hash method to use Raises: - Raises an sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid + Raises a sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid """ try: return self.key.verify( diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index 5a2803e..c46ba7c 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -24,23 +24,24 @@ def ensure_string( encoding (str, optional): The encoding of the provided strings. Defaults to 'utf-8'. Returns: - Union[str, List[str], Dict[str, str]]: Returns a string, list of strings or dictionary with strings + Union[str, List[str], Dict[str, str]]: Returns a string, list of strings or + dictionary with strings """ if (obj is None and not required) or isinstance(obj, str): return obj - elif isinstance(obj, bytes): + if isinstance(obj, bytes): return obj.decode(encoding) - elif isinstance(obj, (list, tuple, set)): + if isinstance(obj, (list, tuple, set)): return [ensure_string(o, encoding) for o in obj] - elif isinstance(obj, dict): + if isinstance(obj, dict): return { ensure_string(k, encoding): ensure_string(v, encoding) for k, v in obj.items() } - else: - raise TypeError( - f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." - ) + + raise TypeError( + f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." + ) def ensure_bytestring( @@ -55,23 +56,23 @@ def ensure_bytestring( encoding (str, optional): The encoding of the provided bytestrings. Defaults to 'utf-8'. Returns: - Union[str, List[str], Dict[str, str]]: Returns a bytestring, list of bytestrings or dictionary with bytestrings + Union[str, List[str], Dict[str, str]]: Returns a bytestring, list of bytestrings or + dictionary with bytestrings """ if (obj is None and not required) or isinstance(obj, bytes): return obj - elif isinstance(obj, str): + if isinstance(obj, str): return obj.encode(encoding) - elif isinstance(obj, (list, tuple, set)): + if isinstance(obj, (list, tuple, set)): return [ensure_bytestring(o, encoding) for o in obj] - elif isinstance(obj, dict): + if isinstance(obj, dict): return { ensure_bytestring(k, encoding): ensure_bytestring(v, encoding) for k, v in obj.items() } - else: - raise TypeError( - f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." - ) + raise TypeError( + f"Expected one of (str, bytes, list, tuple, dict, set), got {type(obj).__name__}." + ) def concat_to_string(*strs, encoding: str = "utf-8") -> str: diff --git a/testcert b/testcert deleted file mode 100644 index 7ad0607..0000000 --- a/testcert +++ /dev/null @@ -1 +0,0 @@ -ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAJjc0MzU0MDAxNzgzNDU3MzQxNTU4MDE1NTE3Njg3Mjk0ODY0MzcwAAAAIMxLXknwP4nAHkxKEbmw5CjRhWoGeUTZgYtQXB4tI7SLI11oLXZuTzkAAAACAAAAKml2c09iemxXY2VxR3BLS3NtSlRwZWxqb0RuYVpibW9PSXlPUndiYVBWUAAAAqEAAABheUNZYU1zSEFhbVptSXVYeEtRQ01mbmJBRWVkdk9scFFQVGplendMQ0tLRHpLY0JBU01zSGd6a1V4Z0FGeVF3cFZyZVVoYXdZa09kQ0VyR2pvUWZKWkp2SU1NS1laYXF2QgAAAElWT0h3akRMTVhkUE9lRW9obEZ0dlhPUkR1VHJyTEl1TUZOV2plVnhKTHJ0Ymdud3l6QnVGWEZaZU9PZkFBeVRHamlSVHZIUXpKAAAATHp6RU5uZUt0V09aemxGTUtoU2JySkxjbllyWVN0UlR3RGhDQ3hSS3VxbGRSR3FMRGdkR0NoamtYRWlQdmtnektNd3NWRlBOZ2dnbkEAAABlRHJnd1ZTV013eU9qR1RVVmdTZENuU0RDT21DVVlqY0VHV0t1YlVSekVaWUJTVG1RRkhPU3ZQc2RvYUhTeEFsS2N1TXNQdnl3T2d6aGtPamtTbXpvVXVsV1NvekdqTURvcFNHeWsAAABTeGd0ZmN0YVNwekJtWWVaTkZYemh6YW54cHZNU1JWdUR2Q0JUbGFSRU9QT1JEdVpSd0hBUUdOY2xPU1R3RW9kcVRoQlpYTnVqbFV3d1dQWVBEYkUAAAAxaWtYWWpqcmNmSE9xRnhGQXpZZVN1bmlsRlpuWG9PS0JHd2hzSWNra0RMbXhXTEV6bAAAAEVjVU5laU9FaXNDZ0pBbkRuRVNYc1hlaEFvUEtYZUpWZUdxeklvSG9uUEVyU1RyeUxudlhmUkVka0NDTXlFR1VFRFJucEoAAAAwTlNuUm1aQXlxSFlyTkpBZ21ZbU9LckVuR2RjdWtmUWlHRWZNZ0lqTUNjd1NZTkdJAAAAKUVPd2dqcG1NQUZwam5QdUJVaVVqUkpjUmtOUnJoRFFGbVNRUG9oSHlNAAAAAGlpCcwAAAAAe0tVjAAAAGsAAAANZm9yY2UtY29tbWFuZAAAABEAAAANc2Z0cC1pbnRlcm5hbAAAAA5zb3VyY2UtYWRkcmVzcwAAABgAAAAUMS4yLjMuNC84LDUuNi43LjgvMTYAAAAPdmVyaWZ5LXJlcXVpcmVkAAAAAAAAADwAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACDMS15J8D+JwB5MShG5sOQo0YVqBnlE2YGLUFweLSO0iwAAAFMAAAALc3NoLWVkMjU1MTkAAABAkm6FT2GepxZWvdzBGUt5UQbX5O+ppLp17uQeR4JeaW+ddhpUP6HVLnuyuQZ3p3Ck0jWj5Bbk0nnm5Yo8D2AoDA== \ No newline at end of file diff --git a/tests/test_keypairs.py b/tests/test_keypairs.py index 31f4267..6b60fdb 100644 --- a/tests/test_keypairs.py +++ b/tests/test_keypairs.py @@ -594,11 +594,8 @@ def test_rsa_fingerprint(self): f"tests/{self.folder}/rsa_key_sshkeygen", "password" ) - sshkey_fingerprint = ( - os.popen(f"ssh-keygen -lf tests/{self.folder}/rsa_key_sshkeygen") - .read() - .split(" ")[1] - ) + with os.popen(f"ssh-keygen -lf tests/{self.folder}/rsa_key_sshkeygen") as cmd: + sshkey_fingerprint = cmd.read().split(" ")[1] self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) @@ -607,11 +604,8 @@ def test_dsa_fingerprint(self): f"tests/{self.folder}/dsa_key_sshkeygen", ) - sshkey_fingerprint = ( - os.popen(f"ssh-keygen -lf tests/{self.folder}/dsa_key_sshkeygen") - .read() - .split(" ")[1] - ) + with os.popen(f"ssh-keygen -lf tests/{self.folder}/dsa_key_sshkeygen") as cmd: + sshkey_fingerprint = cmd.read().split(" ")[1] self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) @@ -619,12 +613,8 @@ def test_ecdsa_fingerprint(self): key = EcdsaPrivateKey.from_file( f"tests/{self.folder}/ecdsa_key_sshkeygen", ) - - sshkey_fingerprint = ( - os.popen(f"ssh-keygen -lf tests/{self.folder}/ecdsa_key_sshkeygen") - .read() - .split(" ")[1] - ) + with os.popen(f"ssh-keygen -lf tests/{self.folder}/ecdsa_key_sshkeygen") as cmd: + sshkey_fingerprint = cmd.read().split(" ")[1] self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) @@ -632,12 +622,10 @@ def test_ed25519_fingerprint(self): key = Ed25519PrivateKey.from_file( f"tests/{self.folder}/ed25519_key_sshkeygen", ) - - sshkey_fingerprint = ( - os.popen(f"ssh-keygen -lf tests/{self.folder}/ed25519_key_sshkeygen") - .read() - .split(" ")[1] - ) + with os.popen( + f"ssh-keygen -lf tests/{self.folder}/ed25519_key_sshkeygen" + ) as cmd: + sshkey_fingerprint = cmd.read().split(" ")[1] self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) From 9d38ef40c8b72c890566eefcc79c3f477a5a8505 Mon Sep 17 00:00:00 2001 From: Lars Scheibling Date: Wed, 20 Jul 2022 15:06:02 +0000 Subject: [PATCH 6/6] Updated pylintrc --- .pylintrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pylintrc b/.pylintrc index 3f55642..ff33277 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,5 @@ [MASTER] -; load-plugins=pylint_report +load-plugins=pylint_report [REPORTS] -; output-format=pylint_report.CustomJsonReporter \ No newline at end of file +output-format=pylint_report.CustomJsonReporter \ No newline at end of file