diff --git a/.github/workflows/test_and_lint.yml b/.github/workflows/test_and_lint.yml index d17d2ce..3a35aeb 100644 --- a/.github/workflows/test_and_lint.yml +++ b/.github/workflows/test_and_lint.yml @@ -16,17 +16,17 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.9" - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install pylint-report + pip install pylint pylint-report - name: Run tests run: | - pylint -f json src/sshkey_tools/ | pylint_report.py > report.html + pylint src/sshkey_tools/ | pylint_report.py > report.html - name: Upload report uses: actions/upload-artifact@v1 @@ -40,7 +40,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.6" - "3.7" - "3.8" - "3.9" @@ -60,15 +59,13 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install coverage + pip install -r requirements-test.txt - name: Run tests run: | - python test.py - python -m coverage run test.py - python -m coverage report - python -m coverage html + coverage run -m unittest discover tests/ + coverage report --omit="tests/*.py" + coverage html - name: Upload report uses: actions/upload-artifact@v1 diff --git a/.gitignore b/.gitignore index e39e2f4..5c054b9 100644 --- a/.gitignore +++ b/.gitignore @@ -142,4 +142,7 @@ test_dss* test_ed25519* test_flow.py temptest.py -oldsrc/* \ No newline at end of file +oldsrc/* +tempfolder +report.html +test_certificate \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..ff33277 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,5 @@ +[MASTER] +load-plugins=pylint_report + +[REPORTS] +output-format=pylint_report.CustomJsonReporter \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..b4811ed --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,9 @@ +click +cryptography +bcrypt +enum34 +paramiko +coverage +black +pytest-cov +faker \ No newline at end of file diff --git a/src/sshkey_tools/__version__.py b/src/sshkey_tools/__version__.py index 3596a2c..55dc504 100644 --- a/src/sshkey_tools/__version__.py +++ b/src/sshkey_tools/__version__.py @@ -2,19 +2,19 @@ import os # The title and description of the package -__title__ = 'sshkey-tools' -__description__ = ''' +__title__ = "sshkey-tools" +__description__ = """ A Python module for generating, parsing and handling OpenSSH keys and certificates -''' +""" # The version and build number # Without specifying a unique number, you cannot overwrite packages in the PyPi repo -__version__ = os.getenv('RELEASE_NAME', '0.0.1-dev' + os.getenv('GITHUB_RUN_ID') ) +__version__ = os.getenv("RELEASE_NAME", "0.0.1-dev" + os.getenv("GITHUB_RUN_ID")) # Author and license information -__author__ = 'Lars Scheibling' -__author_email__ = 'lars@scheibling.se' -__license__ = 'GnuPG 3.0' +__author__ = "Lars Scheibling" +__author_email__ = "lars@scheibling.se" +__license__ = "GnuPG 3.0" # URL to the project __url__ = f"https://github.com/scheiblingco/{__title__}" diff --git a/src/sshkey_tools/cert.py b/src/sshkey_tools/cert.py index 9d04d1c..979540e 100644 --- a/src/sshkey_tools/cert.py +++ b/src/sshkey_tools/cert.py @@ -15,56 +15,77 @@ RSAPublicKey, DSAPublicKey, ECDSAPublicKey, - ED25519PublicKey + ED25519PublicKey, ) from . import fields as _FIELD from . import exceptions as _EX from .keys import RsaAlgs +from .utils import join_dicts CERTIFICATE_FIELDS = { - 'serial': _FIELD.SerialField, - 'cert_type': _FIELD.CertificateTypeField, - 'key_id': _FIELD.KeyIDField, - 'principals': _FIELD.PrincipalsField, - 'valid_after': _FIELD.ValidityStartField, - 'valid_before': _FIELD.ValidityEndField, - 'critical_options': _FIELD.CriticalOptionsField, - 'extensions': _FIELD.ExtensionsField + "serial": _FIELD.SerialField, + "cert_type": _FIELD.CertificateTypeField, + "key_id": _FIELD.KeyIDField, + "principals": _FIELD.PrincipalsField, + "valid_after": _FIELD.ValidityStartField, + "valid_before": _FIELD.ValidityEndField, + "critical_options": _FIELD.CriticalOptionsField, + "extensions": _FIELD.ExtensionsField, } CERT_TYPES = { - 'ssh-rsa-cert-v01@openssh.com': ('RSACertificate', '_FIELD.RSAPubkeyField'), - 'rsa-sha2-256-cert-v01@openssh.com': ('RSACertificate', '_FIELD.RSAPubkeyField'), - 'rsa-sha2-512-cert-v01@openssh.com': ('RSACertificate', '_FIELD.RSAPubkeyField'), - 'ssh-dss-cert-v01@openssh.com': ('DSACertificate', '_FIELD.DSAPubkeyField'), - 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ('ECDSACertificate', '_FIELD.ECDSAPubkeyField'), - 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ('ECDSACertificate', '_FIELD.ECDSAPubkeyField'), - 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ('ECDSACertificate', '_FIELD.ECDSAPubkeyField'), - 'ssh-ed25519-cert-v01@openssh.com': ('ED25519Certificate', '_FIELD.ED25519PubkeyField'), + "ssh-rsa-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), + "rsa-sha2-256-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), + "rsa-sha2-512-cert-v01@openssh.com": ("RSACertificate", "_FIELD.RSAPubkeyField"), + "ssh-dss-cert-v01@openssh.com": ("DSACertificate", "_FIELD.DSAPubkeyField"), + "ecdsa-sha2-nistp256-cert-v01@openssh.com": ( + "ECDSACertificate", + "_FIELD.ECDSAPubkeyField", + ), + "ecdsa-sha2-nistp384-cert-v01@openssh.com": ( + "ECDSACertificate", + "_FIELD.ECDSAPubkeyField", + ), + "ecdsa-sha2-nistp521-cert-v01@openssh.com": ( + "ECDSACertificate", + "_FIELD.ECDSAPubkeyField", + ), + "ssh-ed25519-cert-v01@openssh.com": ( + "ED25519Certificate", + "_FIELD.ED25519PubkeyField", + ), } + class SSHCertificate: """ General class for SSH Certificates, used for loading and parsing. To create new certificates, use the respective keytype classes or the from_public_key classmethod """ + def __init__( self, subject_pubkey: PublicKey = None, ca_privkey: PrivateKey = None, decoded: dict = None, - **kwargs + **kwargs, ) -> None: + if self.__class__.__name__ == "SSHCertificate": + raise _EX.InvalidClassCallException( + "You cannot instantiate SSHCertificate directly. Use \n" + + "one of the child classes, or call via decode, \n" + + "or one of the from_-classmethods" + ) if decoded is not None: - self.signature = decoded.pop('signature') - self.signature_pubkey = decoded.pop('ca_pubkey') + self.signature = decoded.pop("signature") + self.signature_pubkey = decoded.pop("ca_pubkey") self.header = { - 'pubkey_type': decoded.pop('pubkey_type'), - 'nonce': decoded.pop('nonce'), - 'public_key': decoded.pop('public_key') + "pubkey_type": decoded.pop("pubkey_type"), + "nonce": decoded.pop("nonce"), + "public_key": decoded.pop("public_key"), } self.fields = decoded @@ -72,54 +93,61 @@ def __init__( return if subject_pubkey is None: - raise _EX.SSHCertificateException( - "The subject public key is required" - ) + raise _EX.SSHCertificateException("The subject public key is required") self.header = { - 'pubkey_type': _FIELD.PubkeyTypeField, - 'nonce': _FIELD.NonceField(), - 'public_key': _FIELD.PublicKeyField.from_object( - subject_pubkey - ) + "pubkey_type": _FIELD.PubkeyTypeField, + "nonce": _FIELD.NonceField(), + "public_key": _FIELD.PublicKeyField.from_object(subject_pubkey), } - self.signature = _FIELD.SignatureField.from_object( - ca_privkey - ) - - self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( - ca_privkey.public_key - ) + if ca_privkey is not None: + self.signature = _FIELD.SignatureField.from_object(ca_privkey) + self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( + ca_privkey.public_key + ) self.fields = dict(CERTIFICATE_FIELDS) self.set_opts(**kwargs) def __str__(self): - principals = '\n' + '\n'.join( - ''.join([' ']*32) + ( - x.decode('utf-8') if isinstance(x, bytes) else x - ) for x in self.fields['principals'].value - ) if len(self.fields['principals'].value) > 0 else 'none' - - critical = '\n' + '\n'.join( - ''.join([' ']*32) + ( - x.decode('utf-8') if isinstance(x, bytes) else x - ) for x in self.fields['critical_options'].value - ) if len(self.fields['critical_options'].value) > 0 else 'none' - - extensions = '\n' + '\n'.join( - ''.join([' ']*32) + ( - x.decode('utf-8') if isinstance(x, bytes) else x - ) for x in self.fields['extensions'].value - ) if len(self.fields['extensions'].value) > 0 else 'none' - - signature_val = b64encode(self.signature.value).decode('utf-8') if isinstance( - self.signature.value, - bytes - ) else "Not signed" - - return f''' + principals = ( + "\n" + + "\n".join( + "".join([" "] * 32) + (x.decode("utf-8") if isinstance(x, bytes) else x) + for x in self.fields["principals"].value + ) + if len(self.fields["principals"].value) > 0 + else "none" + ) + + critical = ( + "\n" + + "\n".join( + "".join([" "] * 32) + (x.decode("utf-8") if isinstance(x, bytes) else x) + for x in self.fields["critical_options"].value + ) + if len(self.fields["critical_options"].value) > 0 + else "none" + ) + + extensions = ( + "\n" + + "\n".join( + "".join([" "] * 32) + (x.decode("utf-8") if isinstance(x, bytes) else x) + for x in self.fields["extensions"].value + ) + if len(self.fields["extensions"].value) > 0 + else "none" + ) + + signature_val = ( + b64encode(self.signature.value).decode("utf-8") + if isinstance(self.signature.value, bytes) + else "Not signed" + ) + + return f""" Certificate: Pubkey Type: {self.header['pubkey_type'].value} Public Key: {str(self.header['public_key'])} @@ -132,10 +160,12 @@ def __str__(self): Critical options: {critical} Extensions: {extensions} Signature: {signature_val} - ''' + """ @staticmethod - def decode(cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None) -> 'SSHCertificate': + def decode( + cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None + ) -> "SSHCertificate": """ Decode an existing certificate and import it into a new object @@ -151,48 +181,53 @@ def decode(cert_bytes: bytes, pubkey_class: _FIELD.PublicKeyField = None) -> 'SS SSHCertificate: SSHCertificate child class """ if pubkey_class is None: - cert_type = _FIELD.StringField.decode(cert_bytes)[0].encode('utf-8') + cert_type = _FIELD.StringField.decode(cert_bytes)[0].encode("utf-8") pubkey_class = CERT_TYPES.get(cert_type, False) if pubkey_class is False: raise _EX.InvalidCertificateFormatException( - "Could not determine certificate type, please use one " + - "of the specific classes or specify the pubkey_class" + "Could not determine certificate type, please use one " + + "of the specific classes or specify the pubkey_class" ) - decode_fields = { - 'pubkey_type': _FIELD.PubkeyTypeField, - 'nonce': _FIELD.NonceField, - 'public_key': pubkey_class, - } | CERTIFICATE_FIELDS | { - 'reserved': _FIELD.ReservedField, - 'ca_pubkey': _FIELD.CAPublicKeyField, - 'signature': _FIELD.SignatureField - } + decode_fields = join_dicts( + { + "pubkey_type": _FIELD.PubkeyTypeField, + "nonce": _FIELD.NonceField, + "public_key": pubkey_class, + }, + CERTIFICATE_FIELDS, + { + "reserved": _FIELD.ReservedField, + "ca_pubkey": _FIELD.CAPublicKeyField, + "signature": _FIELD.SignatureField, + }, + ) cert = {} for item in decode_fields.keys(): cert[item], cert_bytes = decode_fields[item].from_decode(cert_bytes) - if cert_bytes != b'': + if cert_bytes != b"": raise _EX.InvalidCertificateFormatException( "The certificate has additional data after everything has been extracted" ) - pubkey_type = cert['pubkey_type'].value + pubkey_type = cert["pubkey_type"].value if isinstance(pubkey_type, bytes): - pubkey_type = pubkey_type.decode('utf-8') + pubkey_type = pubkey_type.decode("utf-8") cert_type = CERT_TYPES[pubkey_type] - + cert.pop("reserved") return globals()[cert_type[0]]( - subject_pubkey=cert['public_key'].value, - decoded=cert + subject_pubkey=cert["public_key"].value, decoded=cert ) @classmethod - def from_public_class(cls, public_key: PublicKey, **kwargs) -> 'SSHCertificate': + def from_public_class( + cls, public_key: PublicKey, ca_privkey: PrivateKey = None, **kwargs + ) -> "SSHCertificate": """ Creates a new certificate from a supplied public key @@ -203,11 +238,8 @@ def from_public_class(cls, public_key: PublicKey, **kwargs) -> 'SSHCertificate': SSHCertificate: SSHCertificate child class """ return globals()[ - public_key.__class__.__name__.replace('PublicKey', 'Certificate') - ]( - public_key, - **kwargs - ) + public_key.__class__.__name__.replace("PublicKey", "Certificate") + ](public_key, ca_privkey, **kwargs) @classmethod def from_bytes(cls, cert_bytes: bytes): @@ -221,11 +253,11 @@ def from_bytes(cls, cert_bytes: bytes): SSHCertificate: SSHCertificate child class """ cert_type, _ = _FIELD.StringField.decode(cert_bytes) - target_class = CERT_TYPES[cert_type.decode('utf-8')] + target_class = CERT_TYPES[cert_type] return globals()[target_class[0]].decode(cert_bytes) @classmethod - def from_string(cls, cert_str: Union[str, bytes], encoding: str = 'utf-8'): + def from_string(cls, cert_str: Union[str, bytes], encoding: str = "utf-8"): """ Loads an existing certificate from a string in the format [certificate-type] [base64-encoded-certificate] [optional-comment] @@ -240,15 +272,11 @@ def from_string(cls, cert_str: Union[str, bytes], encoding: str = 'utf-8'): if isinstance(cert_str, str): cert_str = cert_str.encode(encoding) - certificate = b64decode( - cert_str.split(b' ')[1] - ) - return cls.from_bytes( - cert_bytes=certificate - ) + certificate = b64decode(cert_str.split(b" ")[1]) + return cls.from_bytes(cert_bytes=certificate) @classmethod - def from_file(cls, path: str, encoding: str = 'utf-8'): + def from_file(cls, path: str, encoding: str = "utf-8"): """ Loads an existing certificate from a file @@ -259,8 +287,18 @@ def from_file(cls, path: str, encoding: str = 'utf-8'): Returns: SSHCertificate: SSHCertificate child class """ - return cls.from_string( - open(path, 'r', encoding=encoding).read() + return cls.from_string(open(path, "r", encoding=encoding).read()) + + def set_ca(self, ca_privkey: PrivateKey): + """ + Set the CA Private Key for signing the certificate + + Args: + ca_privkey (PrivateKey): The CA private key + """ + self.signature = _FIELD.SignatureField.from_object(ca_privkey) + self.signature_pubkey = _FIELD.CAPublicKeyField.from_object( + ca_privkey.public_key ) def set_type(self, pubkey_type: str): @@ -271,10 +309,8 @@ def set_type(self, pubkey_type: str): Args: pubkey_type (str): Public key type, e.g. ssh-rsa-cert-v01@openssh.com """ - if not getattr(self.header['pubkey_type'], 'value', False): - self.header['pubkey_type'] = self.header['pubkey_type']( - pubkey_type - ) + if not getattr(self.header["pubkey_type"], "value", False): + self.header["pubkey_type"] = self.header["pubkey_type"](pubkey_type) def set_opt(self, key: str, value): """ @@ -289,11 +325,11 @@ def set_opt(self, key: str, value): """ if key not in self.fields: raise _EX.InvalidCertificateFieldException( - f'{key} is not a valid certificate field' + f"{key} is not a valid certificate field" ) try: - if self.fields[key].value not in [None, False, '', [], ()]: + if self.fields[key].value not in [None, False, "", [], ()]: self.fields[key].value = value except AttributeError: self.fields[key] = self.fields[key](value) @@ -305,6 +341,7 @@ def set_opts(self, **kwargs): for key, value in kwargs.items(): self.set_opt(key, value) + # pylint: disable=used-before-assignment def can_sign(self) -> bool: """ Determine if the certificate is ready to be signed @@ -316,22 +353,36 @@ def can_sign(self) -> bool: Returns: bool: True/False if the certificate can be signed """ - can_sign = [ - x.validate() for x in self.fields.values() - ] + exceptions = [] + for field in self.fields.values(): + try: + valid = field.validate() + except TypeError: + valid = _EX.SignatureNotPossibleException( + f"The field {field} is missing a value" + ) + finally: + if isinstance(valid, Exception): + exceptions.append(valid) + + if ( + getattr(self, "signature", False) is False + or getattr(self, "signature_pubkey", False) is False + ): + exceptions.append( + _EX.SignatureNotPossibleException("No CA private key is set") + ) - if self.signature.can_sign() is True and can_sign: - return True + if len(exceptions) > 0: + raise _EX.SignatureNotPossibleException(exceptions) - for item in self.fields.values(): - if isinstance(item, Exception): - raise item + if self.signature.can_sign() is True: + return True - raise _EX.NoPrivateKeyException( - "The certificate cannot be signed, the private key is not loaded" + raise _EX.SignatureNotPossibleException( + "The certificate cannot be signed, the CA private key is not loaded" ) - def get_signable_data(self) -> bytes: """ Gets the signable byte string from the certificate fields @@ -339,11 +390,16 @@ def get_signable_data(self) -> bytes: Returns: bytes: The data in the certificate which is signed """ - return b''.join([ - bytes(x) for x in - tuple(self.header.values()) + - tuple(self.fields.values()) - ]) + bytes(_FIELD.ReservedField()) + bytes(self.signature_pubkey) + return ( + b"".join( + [ + bytes(x) + for x in tuple(self.header.values()) + tuple(self.fields.values()) + ] + ) + + bytes(_FIELD.ReservedField()) + + bytes(self.signature_pubkey) + ) def sign(self): """ @@ -353,9 +409,7 @@ def sign(self): SSHCertificate: The signed certificate class """ if self.can_sign(): - self.signature.sign( - data=self.get_signable_data() - ) + self.signature.sign(data=self.get_signable_data()) return self @@ -370,14 +424,13 @@ def to_bytes(self) -> bytes: bytes: The certificate bytes """ if self.signature.is_signed is True: - return ( - self.get_signable_data() + - bytes(self.signature) - ) + return self.get_signable_data() + bytes(self.signature) raise _EX.NotSignedException("The certificate has not been signed") - def to_string(self, comment: Union[str, bytes] = None, encoding: str = 'utf-8') -> str: + def to_string( + self, comment: Union[str, bytes] = None, encoding: str = "utf-8" + ) -> str: """ Export the signed certificate to a string, ready to be written to file @@ -389,16 +442,18 @@ def to_string(self, comment: Union[str, bytes] = None, encoding: str = 'utf-8') str: Certificate string """ return ( - self.header['pubkey_type'].value.encode(encoding) + - b' ' + - b64encode( + self.header["pubkey_type"].value.encode(encoding) + + b" " + + b64encode( self.to_bytes(), - ) + - b' ' + - (comment if comment else b'') - ).decode('utf-8') + ) + + b" " + + (comment if comment else b"") + ).decode("utf-8") - def to_file(self, path: str, comment: Union[str, bytes] = None, encoding: str = 'utf-8'): + def to_file( + self, path: str, comment: Union[str, bytes] = None, encoding: str = "utf-8" + ): """ Saves the certificate to a file @@ -408,15 +463,15 @@ def to_file(self, path: str, comment: Union[str, bytes] = None, encoding: str = Defaults to None. encoding (str, optional): Encoding for the file. Defaults to 'utf-8'. """ - with open(path, 'w', encoding=encoding) as file: - file.write( - self.to_string(comment, encoding) - ) + with open(path, "w", encoding=encoding) as file: + file.write(self.to_string(comment, encoding)) + class RSACertificate(SSHCertificate): """ Specific class for RSA Certificates. Inherits from SSHCertificate """ + def __init__( self, subject_pubkey: RSAPublicKey, @@ -427,11 +482,11 @@ def __init__( super().__init__(subject_pubkey, ca_privkey, **kwargs) self.rsa_alg = rsa_alg - self.set_type(f'{rsa_alg.value[0]}-cert-v01@openssh.com') + self.set_type(f"{rsa_alg.value[0]}-cert-v01@openssh.com") @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> 'SSHCertificate': + def decode(cls, cert_bytes: bytes) -> "SSHCertificate": """ Decode an existing RSA Certificate @@ -441,27 +496,23 @@ def decode(cls, cert_bytes: bytes) -> 'SSHCertificate': Returns: RSACertificate: The decoded certificate """ - return super().decode( - cert_bytes, - _FIELD.RSAPubkeyField - ) + return super().decode(cert_bytes, _FIELD.RSAPubkeyField) + class DSACertificate(SSHCertificate): """ Specific class for DSA/DSS Certificates. Inherits from SSHCertificate """ + def __init__( - self, - subject_pubkey: DSAPublicKey, - ca_privkey: PrivateKey = None, - **kwargs + self, subject_pubkey: DSAPublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type('ssh-dss-cert-v01@openssh.com') + self.set_type("ssh-dss-cert-v01@openssh.com") @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> 'DSACertificate': + def decode(cls, cert_bytes: bytes) -> "DSACertificate": """ Decode an existing DSA Certificate @@ -471,29 +522,25 @@ def decode(cls, cert_bytes: bytes) -> 'DSACertificate': Returns: DSACertificate: The decoded certificate """ - return super().decode( - cert_bytes, - _FIELD.DSAPubkeyField - ) + return super().decode(cert_bytes, _FIELD.DSAPubkeyField) + class ECDSACertificate(SSHCertificate): """ Specific class for ECDSA Certificates. Inherits from SSHCertificate """ + def __init__( - self, - subject_pubkey: ECDSAPublicKey, - ca_privkey: PrivateKey = None, - **kwargs + self, subject_pubkey: ECDSAPublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) self.set_type( - f'ecdsa-sha2-nistp{subject_pubkey.key.curve.key_size}-cert-v01@openssh.com' + f"ecdsa-sha2-nistp{subject_pubkey.key.curve.key_size}-cert-v01@openssh.com" ) @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> 'ECDSACertificate': + def decode(cls, cert_bytes: bytes) -> "ECDSACertificate": """ Decode an existing ECDSA Certificate @@ -503,29 +550,23 @@ def decode(cls, cert_bytes: bytes) -> 'ECDSACertificate': Returns: ECDSACertificate: The decoded certificate """ - return super().decode( - cert_bytes, - _FIELD.ECDSAPubkeyField - ) + return super().decode(cert_bytes, _FIELD.ECDSAPubkeyField) + class ED25519Certificate(SSHCertificate): """ Specific class for ED25519 Certificates. Inherits from SSHCertificate """ + def __init__( - self, - subject_pubkey: ED25519PublicKey, - ca_privkey: PrivateKey = None, - **kwargs + self, subject_pubkey: ED25519PublicKey, ca_privkey: PrivateKey = None, **kwargs ): super().__init__(subject_pubkey, ca_privkey, **kwargs) - self.set_type( - 'ssh-ed25519-cert-v01@openssh.com' - ) + self.set_type("ssh-ed25519-cert-v01@openssh.com") @classmethod # pylint: disable=arguments-differ - def decode(cls, cert_bytes: bytes) -> 'ED25519Certificate': + def decode(cls, cert_bytes: bytes) -> "ED25519Certificate": """ Decode an existing ED25519 Certificate @@ -535,7 +576,4 @@ def decode(cls, cert_bytes: bytes) -> 'ED25519Certificate': Returns: ED25519Certificate: The decoded certificate """ - return super().decode( - cert_bytes, - _FIELD.ED25519PubkeyField - ) + return super().decode(cert_bytes, _FIELD.ED25519PubkeyField) diff --git a/src/sshkey_tools/exceptions.py b/src/sshkey_tools/exceptions.py index c6d8ae9..c71e508 100644 --- a/src/sshkey_tools/exceptions.py +++ b/src/sshkey_tools/exceptions.py @@ -2,39 +2,46 @@ Exceptions thrown by sshkey_tools """ + class InvalidKeyException(ValueError): """ Raised when a key is invalid. """ + class InvalidFieldDataException(ValueError): """ Raised when a field contains invalid data """ + class InvalidCurveException(ValueError): """ Raised when the ECDSA curve is not supported. """ + class InvalidHashException(ValueError): """ Raised when the hash type is not available """ + class InvalidDataException(ValueError): """ Raised when the data passed to a function is invalid """ + class InvalidCertificateFieldException(KeyError): """ Raised when the certificate field is not found/not editable """ + class InsecureNonceException(ValueError): """ Raised when the nonce is too short to be secure. @@ -42,11 +49,13 @@ class InsecureNonceException(ValueError): https://billatnapier.medium.com/ecdsa-weakness-where-nonces-are-reused-2be63856a01a """ + class IntegerOverflowException(ValueError): """ Raised when the integer is too large to be represented """ + class SignatureNotPossibleException(ValueError): """ Raised when the signature of a certificate is not possible, @@ -54,17 +63,20 @@ class SignatureNotPossibleException(ValueError): field is empty. """ + class NotSignedException(ValueError): """ Raised when trying to export a certificate that has not been signed by a private key """ + class InvalidCertificateFormatException(ValueError): """ Raised when the format of the certificate is invalid """ + class InvalidKeyFormatException(ValueError): """ Raised when the format of the chosen key is invalid, @@ -72,12 +84,26 @@ class InvalidKeyFormatException(ValueError): a public key or vice versa """ + class NoPrivateKeyException(ValueError): """ Raised when no private key is present to sign with """ + class SSHCertificateException(ValueError): """ Raised when the SSH Certificate is invalid """ + + +class InvalidSignatureException(ValueError): + """ + Raised when the signature checked is invalid + """ + + +class InvalidClassCallException(ValueError): + """ + Raised when trying to instantiate a parent class + """ diff --git a/src/sshkey_tools/fields.py b/src/sshkey_tools/fields.py index d754316..66f97ea 100644 --- a/src/sshkey_tools/fields.py +++ b/src/sshkey_tools/fields.py @@ -9,7 +9,7 @@ from base64 import b64encode from cryptography.hazmat.primitives.asymmetric.utils import ( decode_dss_signature, - encode_dss_signature + encode_dss_signature, ) from . import exceptions as _EX from .keys import ( @@ -23,60 +23,60 @@ ECDSAPublicKey, ECDSAPrivateKey, ED25519PublicKey, - ED25519PrivateKey + ED25519PrivateKey, ) -from .utils import ( - long_to_bytes, - bytes_to_long, - generate_secure_nonce -) +from .utils import long_to_bytes, bytes_to_long, generate_secure_nonce MAX_INT32 = 2**32 MAX_INT64 = 2**64 ECDSA_CURVE_MAP = { - 'secp256r1': 'nistp256', - 'secp384r1': 'nistp384', - 'secp521r1': 'nistp521' + "secp256r1": "nistp256", + "secp384r1": "nistp384", + "secp521r1": "nistp521", } SUBJECT_PUBKEY_MAP = { - RSAPublicKey: 'RSAPubkeyField', - DSAPublicKey: 'DSAPubkeyField', - ECDSAPublicKey: 'ECDSAPubkeyField', - ED25519PublicKey: 'ED25519PubkeyField' + RSAPublicKey: "RSAPubkeyField", + DSAPublicKey: "DSAPubkeyField", + ECDSAPublicKey: "ECDSAPubkeyField", + ED25519PublicKey: "ED25519PubkeyField", } CA_SIGNATURE_MAP = { - RSAPrivateKey: 'RSASignatureField', - DSAPrivateKey: 'DSASignatureField', - ECDSAPrivateKey: 'ECDSASignatureField', - ED25519PrivateKey: 'ED25519SignatureField' + RSAPrivateKey: "RSASignatureField", + DSAPrivateKey: "DSASignatureField", + ECDSAPrivateKey: "ECDSASignatureField", + ED25519PrivateKey: "ED25519SignatureField", } SIGNATURE_TYPE_MAP = { - b'rsa': 'RSASignatureField', - b'dss': 'DSASignatureField', - b'ecdsa': 'ECDSASignatureField', - b'ed25519': 'ED25519SignatureField' + b"rsa": "RSASignatureField", + b"dss": "DSASignatureField", + b"ecdsa": "ECDSASignatureField", + b"ed25519": "ED25519SignatureField", } + class CERT_TYPE(Enum): """ Certificate types, User certificate/Host certificate """ + USER = 1 HOST = 2 + class CertificateField: """ The base class for certificate fields """ + is_set = None - def __init__(self, value, name = None): + def __init__(self, value, name=None): self.name = name self.value = value self.exception = None @@ -108,7 +108,7 @@ def validate(self) -> Union[bool, Exception]: return True @classmethod - def from_decode(cls, data: bytes) -> Tuple['CertificateField', bytes]: + def from_decode(cls, data: bytes) -> Tuple["CertificateField", bytes]: """ Creates a field class based on encoded bytes @@ -118,10 +118,12 @@ def from_decode(cls, data: bytes) -> Tuple['CertificateField', bytes]: value, data = cls.decode(data) return cls(value), data + class BooleanField(CertificateField): """ Field representing a boolean value (True/False) """ + @staticmethod def encode(value: bool) -> bytes: """ @@ -133,7 +135,12 @@ def encode(value: bool) -> bytes: Returns: bytes: Packed byte representing the boolean """ - return pack('B', 1 if value else 0) + if not isinstance(value, bool): + raise _EX.InvalidFieldDataException( + f"Expected bool, got {value.__class__.__name__}" + ) + + return pack("B", 1 if value else 0) @staticmethod def decode(data: bytes) -> Tuple[bool, bytes]: @@ -143,7 +150,7 @@ def decode(data: bytes) -> Tuple[bool, bytes]: Args: data (bytes): The byte string starting with an encoded boolean """ - return bool(unpack('B', data[:1])[0]), data[1:] + return bool(unpack("B", data[:1])[0]), data[1:] def validate(self) -> Union[bool, Exception]: """ @@ -156,12 +163,14 @@ def validate(self) -> Union[bool, Exception]: return True -class StringField(CertificateField): + +class BytestringField(CertificateField): """ - Field representing a string value + Field representing a bytestring value """ + @staticmethod - def encode(value: Union[str, bytes], encoding: str = 'utf-8') -> bytes: + def encode(value: bytes) -> bytes: """ Encodes a string or bytestring into a packed byte string @@ -172,16 +181,15 @@ def encode(value: Union[str, bytes], encoding: str = 'utf-8') -> bytes: Returns: bytes: Packed byte string containing the source data """ - if isinstance(value, str): - value = value.encode(encoding) - - if isinstance(value, bytes): - return pack('>I', len(value)) + value + if not isinstance(value, bytes): + raise _EX.InvalidFieldDataException( + f"Expected bytes, got {value.__class__.__name__}" + ) - raise _EX.InvalidDataException(f"Expected unicode or bytes, got {type(value).__name__}.") + return pack(">I", len(value)) + value @staticmethod - def decode(data: bytes) -> Tuple[str, bytes]: + def decode(data: bytes) -> Tuple[bytes, bytes]: """ Unpacks the next string from a packed byte string @@ -189,29 +197,80 @@ def decode(data: bytes) -> Tuple[str, bytes]: data (bytes): The packed byte string to unpack Returns: - tuple(str, bytes): The next string from the packed byte + tuple(bytes, bytes): The next block of bytes from the packed byte string and remainder of the data """ - length = unpack('>I', data[:4])[0] + 4 + length = unpack(">I", data[:4])[0] + 4 return data[4:length], data[length:] def validate(self) -> Union[bool, Exception]: """ Validate the field data """ - if not isinstance(self.value, Union[str, bytes]): + if not isinstance(self.value, bytes): return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a string or bytestring" + f"Passed value type ({type(self.value)}) is not a bytestring" ) return True +class StringField(BytestringField): + """ + Field representing a string value + """ + + @staticmethod + def encode(value: str, encoding: str = "utf-8"): + """ + Encodes a string or bytestring into a packed byte string + + Args: + value (Union[str, bytes]): The string/bytestring to encode + encoding (str): The encoding to user for the string + + Returns: + bytes: Packed byte string containing the source data + """ + if not isinstance(value, str): + raise _EX.InvalidFieldDataException( + f"Expected str, got {value.__class__.__name__}" + ) + return BytestringField.encode(value.encode(encoding)) + + @staticmethod + def decode(data: bytes, encoding: str = "utf-8") -> Tuple[str, bytes]: + """ + Unpacks the next string from a packed byte string + + Args: + data (bytes): The packed byte string to unpack + + Returns: + tuple(bytes, bytes): The next block of bytes from the packed byte + string and remainder of the data + """ + value, data = BytestringField.decode(data) + + return value.decode(encoding), data + + def validate(self) -> Union[bool, Exception]: + """ + Validate the field data + """ + if not isinstance(self.value, str): + return _EX.InvalidFieldDataException( + f"Passed value type ({type(self.value)}) is not a string" + ) + + return True + class Integer32Field(CertificateField): """ Certificate field representing a 32-bit integer """ + @staticmethod def encode(value: int) -> bytes: """Encodes a 32-bit integer value to a packed byte string @@ -223,9 +282,11 @@ def encode(value: int) -> bytes: bytes: Packed byte string containing integer """ if not isinstance(value, int): - raise _EX.InvalidDataException(f"Expected integer, got {type(value).__name__}.") + raise _EX.InvalidFieldDataException( + f"Expected int, got {value.__class__.__name__}" + ) - return pack('>I', value) + return pack(">I", value) @staticmethod def decode(data: bytes) -> Tuple[int, bytes]: @@ -237,7 +298,7 @@ def decode(data: bytes) -> Tuple[int, bytes]: Returns: tuple: Tuple with integer and remainder of data """ - return int(unpack('>I', data[:4])[0]), data[4:] + return int(unpack(">I", data[:4])[0]), data[4:] def validate(self) -> Union[bool, Exception]: """ @@ -255,10 +316,12 @@ def validate(self) -> Union[bool, Exception]: return True + class Integer64Field(CertificateField): """ Certificate field representing a 64-bit integer """ + @staticmethod def encode(value: int) -> bytes: """Encodes a 64-bit integer value to a packed byte string @@ -270,9 +333,11 @@ def encode(value: int) -> bytes: bytes: Packed byte string containing integer """ if not isinstance(value, int): - raise _EX.InvalidDataException(f"Expected integer, got {type(value).__name__}.") + raise _EX.InvalidFieldDataException( + f"Expected int, got {value.__class__.__name__}" + ) - return pack('>Q', value) + return pack(">Q", value) @staticmethod def decode(data: bytes) -> Tuple[int, bytes]: @@ -284,7 +349,7 @@ def decode(data: bytes) -> Tuple[int, bytes]: Returns: tuple: Tuple with integer and remainder of data """ - return int(unpack('>Q', data[:8])[0]), data[8:] + return int(unpack(">Q", data[:8])[0]), data[8:] def validate(self) -> Union[bool, Exception]: """ @@ -302,39 +367,65 @@ def validate(self) -> Union[bool, Exception]: return True + class DateTimeField(Integer64Field): """ Certificate field representing a datetime value. The value is saved as a 64-bit integer (unix timestamp) """ + @staticmethod - def encode(value: datetime) -> bytes: - return Integer64Field.encode(int(value.timestamp())) + def encode(value: Union[datetime, int]) -> bytes: + """Encodes a datetime object to a byte string + + Args: + value (datetime): Datetime object + + Returns: + bytes: Packed byte string containing datetime timestamp + """ + if not isinstance(value, (datetime, int)): + raise _EX.InvalidFieldDataException( + f"Expected datetime, got {value.__class__.__name__}" + ) + + if isinstance(value, datetime): + value = int(value.timestamp()) + + return Integer64Field.encode(value) @staticmethod def decode(data: bytes) -> datetime: + """Decodes a datetime object from a block of bytes + + Args: + data (bytes): Block of bytes containing a datetime object + + Returns: + tuple: Tuple with datetime and remainder of data + """ timestamp, data = Integer64Field.decode(data) - return datetime.fromtimestamp( - timestamp - ), data + return datetime.fromtimestamp(timestamp), data def validate(self) -> Union[bool, Exception]: """ Validate the field data """ - if not isinstance(self.value, datetime): + if not isinstance(self.value, (datetime, int)): return _EX.InvalidFieldDataException( f"Passed value type ({type(self.value)}) is not a datetime object" ) return True -class MpIntegerField(StringField): + +class MpIntegerField(BytestringField): """ Certificate field representing a multiple precision integer, an integer too large to fit in 64 bits. """ + @staticmethod # pylint: disable=arguments-differ def encode(value: int) -> bytes: @@ -348,9 +439,12 @@ def encode(value: int) -> bytes: Returns: bytes: Packed byte string containing integer """ - return StringField.encode( - long_to_bytes(value) - ) + if not isinstance(value, int): + raise _EX.InvalidFieldDataException( + f"Expected int, got {value.__class__.__name__}" + ) + + return BytestringField.encode(long_to_bytes(value)) @staticmethod def decode(data: bytes) -> Tuple[int, bytes]: @@ -362,7 +456,7 @@ def decode(data: bytes) -> Tuple[int, bytes]: Returns: tuple: Tuple with integer and remainder of data """ - mpint, data = StringField.decode(data) + mpint, data = BytestringField.decode(data) return bytes_to_long(mpint), data def validate(self) -> Union[bool, Exception]: @@ -376,12 +470,14 @@ def validate(self) -> Union[bool, Exception]: return True -class StandardListField(CertificateField): + +class ListField(CertificateField): """ Certificate field representing a list or tuple of strings """ + @staticmethod - def encode(value: Union[list, tuple]) -> bytes: + def encode(value: Union[list, tuple, set]) -> bytes: """Encodes a list or tuple to a byte string Args: @@ -391,16 +487,20 @@ def encode(value: Union[list, tuple]) -> bytes: Returns: bytes: Packed byte string containing the source data """ - if sum([ not isinstance(item, Union[str, bytes]) for item in value]) > 0: - raise TypeError("Expected list or tuple containing strings or bytes") - - return StringField.encode( - b''.join( - [ - StringField.encode(x) for x in value - ] + if not isinstance(value, (list, tuple, set)): + raise _EX.InvalidFieldDataException( + f"Expected (list, tuple, set), got {value.__class__.__name__}" ) - ) + + try: + if sum([not isinstance(item, (str, bytes)) for item in value]) > 0: + raise TypeError + except TypeError: + raise _EX.InvalidFieldDataException( + "Expected list or tuple containing strings or bytes" + ) from TypeError + + return BytestringField.encode(b"".join([StringField.encode(x) for x in value])) @staticmethod def decode(data: bytes) -> Tuple[list, bytes]: @@ -411,7 +511,7 @@ def decode(data: bytes) -> Tuple[list, bytes]: Returns: tuple: _description_ """ - list_bytes, data = StringField.decode(data) + list_bytes, data = BytestringField.decode(data) decoded = [] while len(list_bytes) > 0: @@ -424,46 +524,65 @@ def validate(self) -> Union[bool, Exception]: """ Validate the field data """ - if not isinstance(self.value, Union[list, tuple]): + if not isinstance(self.value, (list, tuple)): return _EX.InvalidFieldDataException( f"Passed value type ({type(self.value)}) is not a list/tuple" ) + if sum([not isinstance(item, (str, bytes)) for item in self.value]) > 0: + return _EX.InvalidFieldDataException( + "Expected list or tuple containing strings or bytes" + ) + return True -class SeparatedListField(CertificateField): + +class KeyValueField(CertificateField): """ Certificate field representing a list or integer in python, separated in byte-form by null-bytes. """ @staticmethod - def encode(value: Union[list, tuple]) -> bytes: + def encode(value: Union[list, tuple, dict, set]) -> bytes: """ - Encodes a list or tuple to a byte string separated by a null byte + Encodes a dict, set, list or tuple into a key-value byte string. + If a set, list or tuple is provided, the items are considered keys + and added with empty values. Args: - source_list (list): list of strings + source_list (dict, set, list, tuple): list of strings Returns: bytes: Packed byte string containing the source data """ - if sum([ not isinstance(item, Union[str, bytes]) for item in value ]) > 0: - raise TypeError("Expected list or tuple containing strings or bytes") + if not isinstance(value, (list, tuple, dict, set)): + raise _EX.InvalidFieldDataException( + f"Expected (list, tuple), got {value.__class__.__name__}" + ) - if len(value) < 1: - return StandardListField.encode(value) + if not isinstance(value, dict): + value = {item: "" for item in value} - null_byte = StringField.encode('') + list_data = b"" - return StringField.encode( - null_byte.join( - StringField.encode(item) for item in value - ) + null_byte - ) + for key, item in value.items(): + list_data += StringField.encode(key) + + item = ( + StringField.encode(item) + if item in ["", b""] + else ListField.encode( + [item] if isinstance(item, (str, bytes)) else item + ) + ) + + list_data += item + + return BytestringField.encode(list_data) @staticmethod - def decode(data: bytes) -> Tuple[list, bytes]: + def decode(data: bytes) -> Tuple[dict, bytes]: """Decodes a list of strings from a block of bytes Args: @@ -471,14 +590,20 @@ def decode(data: bytes) -> Tuple[list, bytes]: Returns: tuple: _description_ """ - list_bytes, data = StringField.decode(data) + list_bytes, data = BytestringField.decode(data) - decoded = [] + decoded = {} while len(list_bytes) > 0: - elem, list_bytes = StringField.decode(list_bytes) + key, list_bytes = StringField.decode(list_bytes) + value, list_bytes = BytestringField.decode(list_bytes) + + if value != b"": + value = StringField.decode(value)[0] - if elem != b'': - decoded.append(elem) + decoded[key] = "" if value == b"" else value + + if "".join(decoded.values()) == "": + return list(decoded.keys()), data return decoded, data @@ -486,23 +611,42 @@ def validate(self) -> Union[bool, Exception]: """ Validate the field data """ - if not isinstance(self.value, Union[list, tuple]): + if not isinstance(self.value, (list, tuple, dict, set)): return _EX.InvalidFieldDataException( - f"Passed value type ({type(self.value)}) is not a list/tuple" + f"Passed value type ({type(self.value)}) is not a list/tuple/dict/set" + ) + + if isinstance(self.value, (dict)): + if ( + sum([not isinstance(item, (str, bytes)) for item in self.value.keys()]) + > 0 + ): + return _EX.InvalidFieldDataException( + "Expected a dict with string or byte keys" + ) + + if ( + isinstance(self.value, (list, tuple, set)) + and sum([not isinstance(item, (str, bytes)) for item in self.value]) > 0 + ): + return _EX.InvalidFieldDataException( + "Expected list, tuple or set containing strings or bytes" ) return True + class PubkeyTypeField(StringField): """ Contains the certificate type, which is based on the public key type the certificate is created for, e.g. 'ssh-ed25519-cert-v01@openssh.com' for an ED25519 key """ + def __init__(self, value: str): super().__init__( value=value, - name='pubkey_type', + name="pubkey_type", ) def validate(self) -> Union[bool, Exception]: @@ -510,29 +654,30 @@ def validate(self) -> Union[bool, Exception]: Validate the field data """ if self.value not in ( - 'ssh-rsa-cert-v01@openssh.com', - 'rsa-sha2-256-cert-v01@openssh.com', - 'rsa-sha2-512-cert-v01@openssh.com', - 'ssh-dss-cert-v01@openssh.com', - 'ecdsa-sha2-nistp256-cert-v01@openssh.com', - 'ecdsa-sha2-nistp384-cert-v01@openssh.com', - 'ecdsa-sha2-nistp521-cert-v01@openssh.com', - 'ssh-ed25519-cert-v01@openssh.com' + "ssh-rsa-cert-v01@openssh.com", + "rsa-sha2-256-cert-v01@openssh.com", + "rsa-sha2-512-cert-v01@openssh.com", + "ssh-dss-cert-v01@openssh.com", + "ecdsa-sha2-nistp256-cert-v01@openssh.com", + "ecdsa-sha2-nistp384-cert-v01@openssh.com", + "ecdsa-sha2-nistp521-cert-v01@openssh.com", + "ssh-ed25519-cert-v01@openssh.com", ): return _EX.InvalidDataException(f"Invalid pubkey type: {self.value}") return True + class NonceField(StringField): """ Contains the nonce for the certificate, randomly generated this protects the integrity of the private key, especially for ecdsa. """ + def __init__(self, value: str = None): super().__init__( - value=value if value is not None else generate_secure_nonce(), - name='nonce' + value=value if value is not None else generate_secure_nonce(), name="nonce" ) def validate(self) -> Union[bool, Exception]: @@ -541,31 +686,33 @@ def validate(self) -> Union[bool, Exception]: """ if len(self.value) < 32: self.exception = _EX.InsecureNonceException( - "Nonce must be at least 32 bytes long to be secure" + "Nonce should be at least 32 bytes long to be secure. " + + "This is especially important for ECDSA" ) return False return True + class PublicKeyField(CertificateField): """ Contains the subject (User or Host) public key for whom/which the certificate is created. """ + def __init__(self, value: PublicKey): - super().__init__( - value=value, - name='public_key' - ) + super().__init__(value=value, name="public_key") def __str__(self) -> str: - return ' '.join([ - self.__class__.__name__.replace('PubkeyField', ''), - self.value.get_fingerprint() - ]) + return " ".join( + [ + self.__class__.__name__.replace("PubkeyField", ""), + self.value.get_fingerprint(), + ] + ) @staticmethod - def encode(value: RSAPublicKey) -> bytes: + def encode(value: PublicKey) -> bytes: """ Encode the certificate field to a byte string @@ -575,10 +722,12 @@ def encode(value: RSAPublicKey) -> bytes: Returns: bytes: A byte string with the encoded public key """ + if not isinstance(value, PublicKey): + raise _EX.InvalidFieldDataException( + f"Expected PublicKey, got {value.__class__.__name__}" + ) - return StringField.decode( - value.raw_bytes() - )[1] + return BytestringField.decode(value.raw_bytes())[1] @staticmethod def from_object(public_key: PublicKey): @@ -598,13 +747,10 @@ class or childclass to the chosen public key """ try: - return globals()[SUBJECT_PUBKEY_MAP[public_key.__class__]]( - value=public_key - ) + return globals()[SUBJECT_PUBKEY_MAP[public_key.__class__]](value=public_key) except KeyError: - raise _EX.InvalidKeyException( - "The public key is invalid" - ) from KeyError + raise _EX.InvalidKeyException("The public key is invalid") from KeyError + class RSAPubkeyField(PublicKeyField): """ @@ -626,10 +772,7 @@ def decode(data: bytes) -> Tuple[RSAPublicKey, bytes]: e, data = MpIntegerField.decode(data) n, data = MpIntegerField.decode(data) - return RSAPublicKey.from_numbers( - e=e, - n=n - ), data + return RSAPublicKey.from_numbers(e=e, n=n), data def validate(self) -> Union[bool, Exception]: """ @@ -642,6 +785,7 @@ def validate(self) -> Union[bool, Exception]: return True + class DSAPubkeyField(PublicKeyField): """ Holds the DSA Public Key for DSA Certificates @@ -664,9 +808,7 @@ def decode(data: bytes) -> Tuple[DSAPublicKey, bytes]: g, data = MpIntegerField.decode(data) y, data = MpIntegerField.decode(data) - return DSAPublicKey.from_numbers( - p=p, q=q, g=g, y=y - ), data + return DSAPublicKey.from_numbers(p=p, q=q, g=g, y=y), data def validate(self) -> Union[bool, Exception]: """ @@ -679,6 +821,7 @@ def validate(self) -> Union[bool, Exception]: return True + class ECDSAPubkeyField(PublicKeyField): """ Holds the ECDSA Public Key for ECDSA Certificates @@ -697,18 +840,22 @@ def decode(data: bytes) -> Tuple[ECDSAPublicKey, bytes]: Tuple[ECPublicKey, bytes]: The PublicKey field and remainder of the data """ curve, data = StringField.decode(data) - key, data = StringField.decode(data) - - key_type = b'ecdsa-sha2-' + curve - - return ECDSAPublicKey.from_string( - key_type + b' ' + - b64encode( - StringField.encode(key_type) + - StringField.encode(curve) + - StringField.encode(key) - ) - ), data + key, data = BytestringField.decode(data) + + key_type = "ecdsa-sha2-" + curve + + return ( + ECDSAPublicKey.from_string( + key_type + + " " + + b64encode( + StringField.encode(key_type) + + StringField.encode(curve) + + BytestringField.encode(key) + ).decode("utf-8") + ), + data, + ) def validate(self) -> Union[bool, Exception]: """ @@ -721,6 +868,7 @@ def validate(self) -> Union[bool, Exception]: return True + class ED25519PubkeyField(PublicKeyField): """ Holds the ED25519 Public Key for ED25519 Certificates @@ -738,11 +886,9 @@ def decode(data: bytes) -> Tuple[ED25519PublicKey, bytes]: Returns: Tuple[ED25519PublicKey, bytes]: The PublicKey field and remainder of the data """ - pubkey, data = StringField.decode(data) + pubkey, data = BytestringField.decode(data) - return ED25519PublicKey.from_raw_bytes( - pubkey - ), data + return ED25519PublicKey.from_raw_bytes(pubkey), data def validate(self) -> Union[bool, Exception]: """ @@ -755,16 +901,16 @@ def validate(self) -> Union[bool, Exception]: return True + class SerialField(Integer64Field): """ Contains the numeric serial number of the certificate, maximum is (2**64)-1 """ + def __init__(self, value: int): - super().__init__( - value=value, - name='serial' - ) + super().__init__(value=value, name="serial") + class CertificateTypeField(Integer32Field): """ @@ -772,80 +918,100 @@ class CertificateTypeField(Integer32Field): User certificate: CERT_TYPE.USER/1 Host certificate: CERT_TYPE.HOST/2 """ + def __init__(self, value: Union[CERT_TYPE, int]): super().__init__( - value=value.value if isinstance(value, CERT_TYPE) else value, - name='type' + value=value.value if isinstance(value, CERT_TYPE) else value, name="type" ) + @staticmethod + def encode(value: Union[CERT_TYPE, int]) -> bytes: + """ + Encode the certificate type field to a byte string + + Args: + value (Union[CERT_TYPE, int]): The type of the certificate + + Returns: + bytes: A byte string with the encoded public key + """ + if not isinstance(value, (CERT_TYPE, int)): + raise _EX.InvalidFieldDataException( + f"Expected (CERT_TYPE, int), got {value.__class__.__name__}" + ) + + if isinstance(value, CERT_TYPE): + value = value.value + + return Integer32Field.encode(value) + def validate(self) -> Union[bool, Exception]: """ Validates that the field contains a valid type """ - if 0 > self.value > 3: - self.exception = _EX.InvalidDataException( + if not isinstance(self.value, (CERT_TYPE, int)): + return _EX.InvalidFieldDataException( + f"Passed value type ({type(self.value)}) is not an integer" + ) + + if not isinstance(self.value, CERT_TYPE) and (self.value > 2 or self.value < 1): + return _EX.InvalidDataException( "The certificate type is invalid (1: User, 2: Host)" ) - return False return True + class KeyIDField(StringField): """ Contains the key identifier (subject) of the certificate, alphanumeric string """ + def __init__(self, value: str): - super().__init__( - value=value, - name='key_id' - ) + super().__init__(value=value, name="key_id") def validate(self) -> Union[bool, Exception]: """ Validates that the field is set and not empty """ - if self.value in [None, False, '', ' ']: - return _EX.InvalidDataException( - "You need to provide a Key ID" - ) + if self.value in [None, False, "", " "]: + return _EX.InvalidDataException("You need to provide a Key ID") + + return super().validate() - return True -class PrincipalsField(StandardListField): +class PrincipalsField(ListField): """ Contains a list of principals for the certificate, e.g. SERVERHOSTNAME01 or all-web-servers """ + def __init__(self, value: Union[list, tuple]): - super().__init__( - value=list(value), - name='principals' - ) + super().__init__(value=list(value), name="principals") + class ValidityStartField(DateTimeField): """ Contains the start of the validity period for the certificate, represented by a datetime object """ + def __init__(self, value: datetime): - super().__init__( - value=value, - name='valid_after' - ) + super().__init__(value=value, name="valid_after") + class ValidityEndField(DateTimeField): """ Contains the end of the validity period for the certificate, represented by a datetime object """ + def __init__(self, value: datetime): - super().__init__( - value=value, - name='valid_before' - ) + super().__init__(value=value, name="valid_before") + -class CriticalOptionsField(SeparatedListField): +class CriticalOptionsField(KeyValueField): """ Contains the critical options part of the certificate (optional). This should be a list of strings with one of the following @@ -860,55 +1026,36 @@ class CriticalOptionsField(SeparatedListField): verify_required= If set to true, the user must verify their identity if using a hardware token - - Additionally, the following flags are also supported (no value): - flags: - no-touch-required - The user doesn't need to touch the - physical key to authenticate. - - permit-X11-forwarding - Permits the user to use X11 Forwarding - - permit-agent-forwarding - Permits the user to use agent forwarding - - permit-port-forwarding - Permits the user to forward ports - - permit-pty - Permits the user to use a pseudo-terminal - - permit-user-rc - Permits the user to use the user rc file - """ - def __init__(self, value: Union[list, tuple]): - super().__init__( - value=value, - name='critical_options' - ) + + def __init__(self, value: Union[list, tuple, dict]): + super().__init__(value=value, name="critical_options") def validate(self) -> Union[bool, Exception]: """ Validate that the field contains a valid list of options """ - valid_opts = ( - 'force-command', - 'source-address', - 'verify-required' - ) + valid_opts = ("force-command", "source-address", "verify-required") - for item in self.value: - split = item.split('=') - if split[0] not in valid_opts: - return _EX.InvalidFieldDataException( - f"The option {item} is invalid" - ) + if not isinstance(self.value, (list, tuple, dict, set)): + return _EX.InvalidFieldDataException( + "You need to provide a list, tuple or set of strings or a dict" + ) + + if not all( + elem in valid_opts + for elem in ( + self.value.keys() if isinstance(self.value, dict) else self.value + ) + ): + return _EX.InvalidFieldDataException( + "You have provided invalid data to the critical options field" + ) return True -class ExtensionsField(SeparatedListField): + +class ExtensionsField(KeyValueField): """ Contains a list of extensions for the certificate, set to give the user limitations and/or additional @@ -935,87 +1082,78 @@ class ExtensionsField(SeparatedListField): Permits the user to use the user rc file """ + def __init__(self, value: Union[list, tuple]): - super().__init__( - value=value, - name='extensions' - ) + super().__init__(value=value, name="extensions") def validate(self) -> Union[bool, Exception]: """ Validates that the options provided are valid """ valid_opts = ( - 'no-touch-required', - 'permit-X11-forwarding', - 'permit-agent-forwarding', - 'permit-port-forwarding', - 'permit-pty', - 'permit-user-rc' + "no-touch-required", + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-port-forwarding", + "permit-pty", + "permit-user-rc", ) for item in self.value: if item not in valid_opts: - self.exception = _EX.InvalidDataException( - f"The extension '{item}' is invalid" - ) - return False + return _EX.InvalidDataException(f"The extension '{item}' is invalid") return True + class ReservedField(StringField): """ This field is reserved for future use, and doesn't contain any actual data, just an empty string. """ - def __init__(self, value: str = ''): - super().__init__( - value=value, - name='reserved' - ) + + def __init__(self, value: str = ""): + super().__init__(value=value, name="reserved") def validate(self) -> Union[bool, Exception]: """ Validate that the field only contains an empty string """ - if self.value == '': + if self.value == "": return True return _EX.InvalidDataException( "The reserved field needs to be an empty string" ) -class CAPublicKeyField(StringField): + +class CAPublicKeyField(BytestringField): """ Contains the public key of the certificate authority that is used to sign the certificate. """ + def __init__(self, value: PublicKey): - super().__init__( - value=value, - name='ca_public_key' - ) + super().__init__(value=value, name="ca_public_key") def __str__(self) -> str: - return ' '.join([ - ( - self. - value. - __class__. - __name__.replace('PublicKey', ''). - replace('EllipticCurve', 'ECDSA') - ), - self.value.get_fingerprint() - ]) + return " ".join( + [ + ( + self.value.__class__.__name__.replace("PublicKey", "").replace( + "EllipticCurve", "ECDSA" + ) + ), + self.value.get_fingerprint(), + ] + ) def validate(self) -> Union[bool, Exception]: """ Validates the contents of the field """ - if self.value in [None, False, '', ' ']: - return _EX.InvalidFieldDataException( - "You need to provide a CA public key" - ) + if self.value in [None, False, "", " "]: + return _EX.InvalidFieldDataException("You need to provide a CA public key") if not isinstance(self.value, PublicKey): return _EX.InvalidFieldDataException( @@ -1036,39 +1174,33 @@ def decode(data) -> Tuple[PublicKey, bytes]: Returns: Tuple[PublicKey, bytes]: The PublicKey field and remainder of the data """ - pubkey, data = StringField.decode(data) + pubkey, data = BytestringField.decode(data) pubkey_type = StringField.decode(pubkey)[0] - return PublicKey.from_string( - f"{pubkey_type.decode('utf-8')} {b64encode(pubkey).decode('utf-8')}" - ), data + return ( + PublicKey.from_string(f"{pubkey_type} {b64encode(pubkey).decode('utf-8')}"), + data, + ) def __bytes__(self) -> bytes: - return self.encode( - self.value.raw_bytes() - ) + return self.encode(self.value.raw_bytes()) @classmethod - def from_object(cls, public_key: PublicKey) -> 'CAPublicKeyField': + def from_object(cls, public_key: PublicKey) -> "CAPublicKeyField": """ Creates a new CAPublicKeyField from a PublicKey object """ - return cls( - value=public_key - ) + return cls(value=public_key) class SignatureField(CertificateField): """ Creates and contains the signature of the certificate """ + # pylint: disable=super-init-not-called - def __init__( - self, - private_key: PrivateKey = None, - signature: bytes = None - ): - self.name = 'signature' + def __init__(self, private_key: PrivateKey = None, signature: bytes = None): + self.name = "signature" self.private_key = private_key self.is_signed = False self.value = signature @@ -1094,11 +1226,11 @@ def from_object(private_key: PrivateKey): ) except KeyError: raise _EX.InvalidKeyException( - 'The private key provided is invalid or not supported' + "The private key provided is invalid or not supported" ) from KeyError @staticmethod - def from_decode(data: bytes) -> Tuple['SignatureField', bytes]: + def from_decode(data: bytes) -> Tuple["SignatureField", bytes]: """ Generates a SignatureField child class from the encoded signature @@ -1111,18 +1243,14 @@ def from_decode(data: bytes) -> Tuple['SignatureField', bytes]: Returns: SignatureField: child of SignatureField """ - signature, _ = StringField.decode(data) - signature_type = StringField.decode(signature)[0] + signature, _ = BytestringField.decode(data) + signature_type = BytestringField.decode(signature)[0] for key, value in SIGNATURE_TYPE_MAP.items(): if key in signature_type: - return globals()[value].from_decode( - data - ) + return globals()[value].from_decode(data) - raise _EX.InvalidDataException( - "No matching signature type found" - ) + raise _EX.InvalidDataException("No matching signature type found") def can_sign(self): """ @@ -1137,25 +1265,25 @@ def sign(self, data: bytes) -> None: """ def __bytes__(self) -> None: - return self.encode( - self.value - ) + return self.encode(self.value) + class RSASignatureField(SignatureField): """ Creates and contains the RSA signature from an RSA Private Key """ + def __init__( self, private_key: RSAPrivateKey = None, hash_alg: RsaAlgs = RsaAlgs.SHA512, - signature: bytes = None + signature: bytes = None, ): super().__init__(private_key, signature) self.hash_alg = hash_alg @staticmethod - #pylint: disable=arguments-renamed + # pylint: disable=arguments-renamed def encode(signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: """ Encodes the signature to a byte string @@ -1168,19 +1296,17 @@ def encode(signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> bytes: Returns: bytes: The encoded byte string """ - return StringField.encode( - StringField.encode(hash_alg.value[0]) + - StringField.encode(signature) + if not isinstance(signature, bytes): + raise _EX.InvalidFieldDataException( + f"Expected bytes, got {signature.__class__.__name__}" + ) + + return BytestringField.encode( + StringField.encode(hash_alg.value[0]) + BytestringField.encode(signature) ) @staticmethod - def decode(data: bytes) -> Tuple[ - Tuple[ - bytes, - bytes - ], - bytes - ]: + def decode(data: bytes) -> Tuple[Tuple[bytes, bytes], bytes]: """ Decodes a bytestring containing a signature @@ -1190,15 +1316,15 @@ def decode(data: bytes) -> Tuple[ Returns: Tuple[ Tuple[ bytes, bytes ], bytes ]: (signature_type, signature), remainder of data """ - signature, data = StringField.decode(data) + signature, data = BytestringField.decode(data) sig_type, signature = StringField.decode(signature) - signature, _ = StringField.decode(signature) + signature, _ = BytestringField.decode(signature) return (sig_type, signature), data @classmethod - def from_decode(cls, data: bytes) -> Tuple['RSASignatureField', bytes]: + def from_decode(cls, data: bytes) -> Tuple["RSASignatureField", bytes]: """ Generates an RSASignatureField class from the encoded signature @@ -1213,17 +1339,16 @@ def from_decode(cls, data: bytes) -> Tuple['RSASignatureField', bytes]: """ signature, data = cls.decode(data) - return cls( - private_key=None, - hash_alg=[alg for alg in RsaAlgs if alg.value[0] == signature[0].decode('utf-8')][0], - signature=signature[1] - ), data + return ( + cls( + private_key=None, + hash_alg=[alg for alg in RsaAlgs if alg.value[0] == signature[0]][0], + signature=signature[1], + ), + data, + ) - def sign( - self, - data: bytes, - hash_alg: RsaAlgs = RsaAlgs.SHA256 - ) -> None: + def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA256) -> None: """ Signs the provided data with the provided private key @@ -1232,29 +1357,22 @@ def sign( hash_alg (RsaAlgs, optional): The RSA algorithm to use for hashing. Defaults to RsaAlgs.SHA256. """ - self.value = self.private_key.sign( - data, - hash_alg - ) + self.value = self.private_key.sign(data, hash_alg) self.hash_alg = hash_alg self.is_signed = True def __bytes__(self): - return self.encode( - self.value, - self.hash_alg - ) + return self.encode(self.value, self.hash_alg) class DSASignatureField(SignatureField): """ Creates and contains the DSA signature from an DSA Private Key """ + def __init__( - self, - private_key: DSAPrivateKey = None, - signature: bytes = None + self, private_key: DSAPrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -1270,14 +1388,16 @@ def encode(signature: bytes): Returns: bytes: The encoded byte string """ + if not isinstance(signature, bytes): + raise _EX.InvalidFieldDataException( + f"Expected bytes, got {signature.__class__.__name__}" + ) + r, s = decode_dss_signature(signature) - return StringField.encode( - StringField.encode('ssh-dss') + - StringField.encode( - long_to_bytes(r, 20) + - long_to_bytes(s, 20) - ) + return BytestringField.encode( + StringField.encode("ssh-dss") + + BytestringField.encode(long_to_bytes(r, 20) + long_to_bytes(s, 20)) ) @staticmethod @@ -1291,11 +1411,9 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: Returns: Tuple[ bytes, bytes ]: signature, remainder of the data """ - signature, data = StringField.decode(data) + signature, data = BytestringField.decode(data) - signature = StringField.decode( - StringField.decode(signature)[1] - )[0] + signature = BytestringField.decode(BytestringField.decode(signature)[1])[0] r = bytes_to_long(signature[:20]) s = bytes_to_long(signature[20:]) @@ -1304,7 +1422,7 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: return signature, data @classmethod - def from_decode(cls, data: bytes) -> Tuple['DSASignatureField', bytes]: + def from_decode(cls, data: bytes) -> Tuple["DSASignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1316,10 +1434,7 @@ def from_decode(cls, data: bytes) -> Tuple['DSASignatureField', bytes]: """ signature, data = cls.decode(data) - return cls( - private_key=None, - signature=signature - ), data + return cls(private_key=None, signature=signature), data def sign(self, data: bytes) -> None: """ @@ -1328,31 +1443,31 @@ def sign(self, data: bytes) -> None: Args: data (bytes): The data to be signed """ - self.value = self.private_key.sign( - data - ) + self.value = self.private_key.sign(data) self.is_signed = True + class ECDSASignatureField(SignatureField): """ Creates and contains the ECDSA signature from an ECDSA Private Key """ + def __init__( self, private_key: ECDSAPrivateKey = None, signature: bytes = None, - curve_name: str = None + curve_name: str = None, ) -> None: super().__init__(private_key, signature) if curve_name is None: curve_size = self.private_key.public_key.key.curve.key_size - curve_name = f'ecdsa-sha2-nistp{curve_size}' + curve_name = f"ecdsa-sha2-nistp{curve_size}" self.curve = curve_name @staticmethod - #pylint: disable=arguments-renamed + # pylint: disable=arguments-renamed def encode(signature: bytes, curve_name: str = None) -> bytes: """ Encodes the signature to a byte string @@ -1365,20 +1480,22 @@ def encode(signature: bytes, curve_name: str = None) -> bytes: Returns: bytes: The encoded byte string """ + if not isinstance(signature, bytes): + raise _EX.InvalidFieldDataException( + f"Expected bytes, got {signature.__class__.__name__}" + ) + r, s = decode_dss_signature(signature) - return StringField.encode( - StringField.encode( - curve_name - ) + - StringField.encode( - MpIntegerField.encode(r) + - MpIntegerField.encode(s) + return BytestringField.encode( + StringField.encode(curve_name) + + BytestringField.encode( + MpIntegerField.encode(r) + MpIntegerField.encode(s) ) ) @staticmethod - def decode(data: bytes) -> Tuple[ Tuple[ bytes, bytes ], bytes]: + def decode(data: bytes) -> Tuple[Tuple[bytes, bytes], bytes]: """ Decodes a bytestring containing a signature @@ -1388,10 +1505,10 @@ def decode(data: bytes) -> Tuple[ Tuple[ bytes, bytes ], bytes]: Returns: Tuple[ Tuple[ bytes, bytes ], bytes]: (curve, signature), remainder of the data """ - signature, data = StringField.decode(data) + signature, data = BytestringField.decode(data) curve, signature = StringField.decode(signature) - signature, _ = StringField.decode(signature) + signature, _ = BytestringField.decode(signature) r, signature = MpIntegerField.decode(signature) s, _ = MpIntegerField.decode(signature) @@ -1401,7 +1518,7 @@ def decode(data: bytes) -> Tuple[ Tuple[ bytes, bytes ], bytes]: return (curve, signature), data @classmethod - def from_decode(cls, data: bytes) -> Tuple['ECDSASignatureField', bytes]: + def from_decode(cls, data: bytes) -> Tuple["ECDSASignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1413,11 +1530,10 @@ def from_decode(cls, data: bytes) -> Tuple['ECDSASignatureField', bytes]: """ signature, data = cls.decode(data) - return cls( - private_key=None, - signature = signature[1], - curve_name = signature[0] - ), data + return ( + cls(private_key=None, signature=signature[1], curve_name=signature[0]), + data, + ) def sign(self, data: bytes) -> None: """ @@ -1426,25 +1542,20 @@ def sign(self, data: bytes) -> None: Args: data (bytes): The data to be signed """ - self.value = self.private_key.sign( - data - ) + self.value = self.private_key.sign(data) self.is_signed = True def __bytes__(self): - return self.encode( - self.value, - self.curve - ) + return self.encode(self.value, self.curve) + class ED25519SignatureField(SignatureField): """ Creates and contains the ED25519 signature from an ED25519 Private Key """ + def __init__( - self, - private_key: ED25519PrivateKey = None, - signature: bytes = None + self, private_key: ED25519PrivateKey = None, signature: bytes = None ) -> None: super().__init__(private_key, signature) @@ -1460,9 +1571,13 @@ def encode(signature: bytes) -> None: Returns: bytes: The encoded byte string """ - return StringField.encode( - StringField.encode('ssh-ed25519') + - StringField.encode(signature) + if not isinstance(signature, bytes): + raise _EX.InvalidFieldDataException( + f"Expected bytes, got {signature.__class__.__name__}" + ) + + return BytestringField.encode( + StringField.encode("ssh-ed25519") + BytestringField.encode(signature) ) @staticmethod @@ -1476,16 +1591,14 @@ def decode(data: bytes) -> Tuple[bytes, bytes]: Returns: Tuple[ bytes, bytes ]: signature, remainder of the data """ - signature, data = StringField.decode(data) + signature, data = BytestringField.decode(data) - signature = StringField.decode( - StringField.decode(signature)[1] - )[0] + signature = BytestringField.decode(BytestringField.decode(signature)[1])[0] return signature, data @classmethod - def from_decode(cls, data: bytes) -> Tuple['ED25519SignatureField', bytes]: + def from_decode(cls, data: bytes) -> Tuple["ED25519SignatureField", bytes]: """ Creates a signature field class from the encoded signature @@ -1497,10 +1610,7 @@ def from_decode(cls, data: bytes) -> Tuple['ED25519SignatureField', bytes]: """ signature, data = cls.decode(data) - return cls( - private_key=None, - signature=signature - ), data + return cls(private_key=None, signature=signature), data def sign(self, data: bytes) -> None: """ @@ -1511,7 +1621,5 @@ def sign(self, data: bytes) -> None: hash_alg (RsaAlgs, optional): The RSA algorithm to use for hashing. Defaults to RsaAlgs.SHA256. """ - self.value = self.private_key.sign( - data - ) + self.value = self.private_key.sign(data) self.is_signed = True diff --git a/src/sshkey_tools/keys.py b/src/sshkey_tools/keys.py index d637578..f2b7ed1 100644 --- a/src/sshkey_tools/keys.py +++ b/src/sshkey_tools/keys.py @@ -4,128 +4,129 @@ from typing import Union from enum import Enum from base64 import b64decode +from struct import unpack from cryptography.hazmat.backends.openssl.rsa import _RSAPublicKey, _RSAPrivateKey from cryptography.hazmat.backends.openssl.dsa import _DSAPublicKey, _DSAPrivateKey -from cryptography.hazmat.backends.openssl.ed25519 import _Ed25519PublicKey, _Ed25519PrivateKey +from cryptography.hazmat.backends.openssl.ed25519 import ( + _Ed25519PublicKey, + _Ed25519PrivateKey, +) from cryptography.hazmat.backends.openssl.ec import ( _EllipticCurvePublicKey, - _EllipticCurvePrivateKey + _EllipticCurvePrivateKey, ) from cryptography.hazmat.primitives import ( serialization as _SERIALIZATION, - hashes as _HASHES + hashes as _HASHES, ) from cryptography.hazmat.primitives.asymmetric import ( rsa as _RSA, dsa as _DSA, ec as _ECDSA, ed25519 as _ED25519, - padding as _PADDING + padding as _PADDING, ) +from cryptography.exceptions import InvalidSignature + from . import exceptions as _EX from .utils import ( md5_fingerprint as _FP_MD5, sha256_fingerprint as _FP_SHA256, - sha512_fingerprint as _FP_SHA512 + sha512_fingerprint as _FP_SHA512, ) PUBKEY_MAP = { _RSAPublicKey: "RSAPublicKey", _DSAPublicKey: "DSAPublicKey", _EllipticCurvePublicKey: "ECDSAPublicKey", - _Ed25519PublicKey: "ED25519PublicKey" + _Ed25519PublicKey: "ED25519PublicKey", } PRIVKEY_MAP = { _RSAPrivateKey: "RSAPrivateKey", _DSAPrivateKey: "DSAPrivateKey", _EllipticCurvePrivateKey: "ECDSAPrivateKey", - _Ed25519PrivateKey: "ED25519PrivateKey" + _Ed25519PrivateKey: "ED25519PrivateKey", } ECDSA_HASHES = { - 'secp256r1': _HASHES.SHA256, - 'secp384r1': _HASHES.SHA384, - 'secp521r1': _HASHES.SHA512, + "secp256r1": _HASHES.SHA256, + "secp384r1": _HASHES.SHA384, + "secp521r1": _HASHES.SHA512, } PubkeyClasses = Union[ _RSA.RSAPublicKey, _DSA.DSAPublicKey, _ECDSA.EllipticCurvePublicKey, - _ED25519.Ed25519PublicKey + _ED25519.Ed25519PublicKey, ] PrivkeyClasses = Union[ _RSA.RSAPrivateKey, _DSA.DSAPrivateKey, _ECDSA.EllipticCurvePrivateKey, - _ED25519.Ed25519PrivateKey + _ED25519.Ed25519PrivateKey, ] + class RsaAlgs(Enum): """ RSA Algorithms """ - SHA1 = ( - 'ssh-rsa', - _HASHES.SHA1 - ) - SHA256 = ( - 'rsa-sha2-256', - _HASHES.SHA256 - ) - SHA512 = ( - 'rsa-sha2-512', - _HASHES.SHA512 - ) + + SHA1 = ("ssh-rsa", _HASHES.SHA1) + SHA256 = ("rsa-sha2-256", _HASHES.SHA256) + SHA512 = ("rsa-sha2-512", _HASHES.SHA512) + class EcdsaCurves(Enum): """ ECDSA Curves """ + P256 = _ECDSA.SECP256R1 P384 = _ECDSA.SECP384R1 P521 = _ECDSA.SECP521R1 + class FingerprintHashes(Enum): """ Fingerprint hashes """ + MD5 = _FP_MD5 SHA256 = _FP_SHA256 SHA512 = _FP_SHA512 + class PublicKey: """ Class for handling SSH public keys """ + def __init__( - self, - key: PrivkeyClasses = None, - comment: Union[str, bytes] = None, - **kwargs + self, key: PrivkeyClasses = None, comment: Union[str, bytes] = None, **kwargs ) -> None: self.key = key self.comment = comment - self.public_numbers = kwargs.get('public_numbers', None) - self.key_type = kwargs.get('key_type', None) - self.serialized = kwargs.get('serialized', None) + self.public_numbers = kwargs.get("public_numbers", None) + self.key_type = kwargs.get("key_type", None) + self.serialized = kwargs.get("serialized", None) self.export_opts = [ _SERIALIZATION.Encoding.OpenSSH, _SERIALIZATION.PublicFormat.OpenSSH, ] - @classmethod def from_class( cls, key_class: PubkeyClasses, comment: Union[str, bytes] = None, - key_type: Union[str, bytes] = None - ) -> 'PublicKey': + key_type: Union[str, bytes] = None, + ) -> "PublicKey": """ Creates a new SSH Public key from a cryptography class @@ -142,18 +143,16 @@ def from_class( """ try: return globals()[PUBKEY_MAP[key_class.__class__]]( - key_class, - comment, - key_type + key_class, comment, key_type ) except KeyError: - raise _EX.InvalidKeyException( - "Invalid public key" - ) from KeyError + raise _EX.InvalidKeyException("Invalid public key") from KeyError @classmethod - def from_string(cls, data: Union[str, bytes]) -> 'PublicKey': + def from_string( + cls, data: Union[str, bytes], encoding: str = "utf-8" + ) -> "PublicKey": """ Loads an SSH public key from a string containing the data in OpenSSH format (SubjectPublickeyInfo) @@ -165,23 +164,20 @@ def from_string(cls, data: Union[str, bytes]) -> 'PublicKey': PublicKey: Any of the PublicKey child classes """ if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode(encoding) - split = data.split(b' ') + split = data.split(b" ") comment = None if len(split) > 2: comment = split[2] return cls.from_class( - key_class=_SERIALIZATION.load_ssh_public_key( - b' '.join(split[:2]) - ), - comment=comment + key_class=_SERIALIZATION.load_ssh_public_key(b" ".join(split[:2])), + comment=comment, ) - @classmethod - def from_file(cls, path: str) -> 'PublicKey': + def from_file(cls, path: str) -> "PublicKey": """ Loads an SSH Public key from a file @@ -191,14 +187,13 @@ def from_file(cls, path: str) -> 'PublicKey': Returns: PublicKey: Any of the PublicKey child classes """ - with open(path, 'rb') as file: + with open(path, "rb") as file: data = file.read() return cls.from_string(data) def get_fingerprint( - self, - hash_method: FingerprintHashes = FingerprintHashes.SHA256 + self, hash_method: FingerprintHashes = FingerprintHashes.SHA256 ) -> str: """ Generates a fingerprint of the public key @@ -218,9 +213,7 @@ def serialize(self) -> bytes: Returns: bytes: The serialized key in OpenSSH format """ - return self.key.public_bytes( - *self.export_opts - ) + return self.key.public_bytes(*self.export_opts) def raw_bytes(self) -> bytes: """ @@ -229,9 +222,9 @@ def raw_bytes(self) -> bytes: Returns: bytes: The raw certificate bytes """ - return b64decode(self.serialize().split(b' ')[1]) + return b64decode(self.serialize().split(b" ")[1]) - def to_string(self, encoding: str = 'utf-8') -> str: + def to_string(self, encoding: str = "utf-8") -> str: """ Export the public key as a string @@ -242,11 +235,11 @@ def to_string(self, encoding: str = 'utf-8') -> str: public_bytes = self.serialize() if self.comment is not None: - public_bytes += b' ' + self.comment + public_bytes += b" " + self.comment return public_bytes.decode(encoding) - def to_file(self, path: str, encoding: str = 'utf-8') -> None: + def to_file(self, path: str, encoding: str = "utf-8") -> None: """ Export the public key to a file @@ -254,23 +247,20 @@ def to_file(self, path: str, encoding: str = 'utf-8') -> None: path (str): The path of the file encoding(str, optional): The encoding of the file. Defaults to 'utf-8'. """ - with open(path, 'w', encoding=encoding) as pubkey_file: + with open(path, "w", encoding=encoding) as pubkey_file: pubkey_file.write(self.to_string()) + class PrivateKey: """ Class for handling SSH Private keys """ - def __init__( - self, - key: PrivkeyClasses, - public_key: PublicKey, - **kwargs - ) -> None: + + def __init__(self, key: PrivkeyClasses, public_key: PublicKey, **kwargs) -> None: self.key = key self.public_key = public_key - self.private_numbers = kwargs.get('private_numbers', None) + self.private_numbers = kwargs.get("private_numbers", None) self.export_opts = { "encoding": _SERIALIZATION.Encoding.PEM, "format": _SERIALIZATION.PrivateFormat.OpenSSH, @@ -278,7 +268,7 @@ def __init__( } @classmethod - def from_class(cls, key_class: PrivkeyClasses) -> 'PrivateKey': + def from_class(cls, key_class: PrivkeyClasses) -> "PrivateKey": """ Import an SSH Private key from a cryptography key class @@ -301,8 +291,8 @@ def from_string( cls, key_data: Union[str, bytes], password: Union[str, bytes] = None, - encoding: str = 'utf-8' - ) -> 'PrivateKey': + encoding: str = "utf-8", + ) -> "PrivateKey": """ Loads an SSH private key from a string containing the key data @@ -320,10 +310,7 @@ def from_string( if isinstance(password, str): password = password.encode(encoding) - private_key = _SERIALIZATION.load_ssh_private_key( - key_data, - password=password - ) + private_key = _SERIALIZATION.load_ssh_private_key(key_data, password=password) return cls.from_class(private_key) @@ -332,8 +319,7 @@ def from_file( cls, path: str, password: Union[str, bytes] = None, - encoding: str = 'utf-8' - ) -> 'PrivateKey': + ) -> "PrivateKey": """ Loads an SSH private key from a file @@ -345,9 +331,23 @@ def from_file( Returns: PrivateKey: Any of the PrivateKey child classes """ - with open(path, 'rb', encoding=encoding) as key_file: + with open(path, "rb") as key_file: return cls.from_string(key_file.read(), password) + def get_fingerprint( + self, hash_method: FingerprintHashes = FingerprintHashes.SHA256 + ) -> str: + """ + Generates a fingerprint of the private key + + Args: + hash_method (FingerprintHashes, optional): Type of hash. Defaults to SHA256. + + Returns: + str: The hash of the private key + """ + return self.public_key.get_fingerprint(hash_method) + def to_bytes(self, password: Union[str, bytes] = None) -> bytes: """ Exports the private key to a byte string @@ -360,19 +360,19 @@ def to_bytes(self, password: Union[str, bytes] = None) -> bytes: bytes: The private key in PEM format """ if isinstance(password, str): - password = password.encode('utf-8') + password = password.encode("utf-8") encryption = _SERIALIZATION.NoEncryption() if password is not None: - encryption = self.export_opts['encryption'](password) + encryption = self.export_opts["encryption"](password) return self.key.private_bytes( - self.export_opts['encoding'], - self.export_opts['format'], - encryption + self.export_opts["encoding"], self.export_opts["format"], encryption ) - def to_string(self, password: Union[str, bytes] = None, encoding: str = 'utf-8') -> str: + def to_string( + self, password: Union[str, bytes] = None, encoding: str = "utf-8" + ) -> str: """ Exports the private key to a string @@ -387,10 +387,7 @@ def to_string(self, password: Union[str, bytes] = None, encoding: str = 'utf-8') return self.to_bytes(password).decode(encoding) def to_file( - self, - path: str, - password: Union[str, bytes] = None, - encoding: str = 'utf-8' + self, path: str, password: Union[str, bytes] = None, encoding: str = "utf-8" ) -> None: """ Exports the private key to a file @@ -402,36 +399,33 @@ def to_file( Returns: bytes: The private key in PEM format """ - with open(path, 'w', encoding=encoding) as key_file: - key_file.write( - self.to_string( - password, - encoding - ) - ) + with open(path, "w", encoding=encoding) as key_file: + key_file.write(self.to_string(password, encoding)) + class RSAPublicKey(PublicKey): """ Class for holding RSA public keys """ + def __init__( self, key: _RSA.RSAPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, - serialized: bytes = None + serialized: bytes = None, ): super().__init__( key=key, comment=comment, key_type=key_type, public_numbers=key.public_numbers(), - serialized=serialized + serialized=serialized, ) @classmethod # pylint: disable=invalid-name - def from_numbers(cls, e: int, n: int) -> 'RSAPublicKey': + def from_numbers(cls, e: int, n: int) -> "RSAPublicKey": """ Loads an RSA Public Key from the public numbers e and n @@ -442,21 +436,42 @@ def from_numbers(cls, e: int, n: int) -> 'RSAPublicKey': Returns: RSAPublicKey: _description_ """ - return cls( - key=_RSA.RSAPublicNumbers(e, n).public_key() - ) + return cls(key=_RSA.RSAPublicNumbers(e, n).public_key()) + + def verify( + self, data: bytes, signature: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512 + ) -> None: + """ + Verifies a signature + + Args: + data (bytes): The data to verify + signature (bytes): The signature to verify + hash_method (HashMethods): The hash method to use + + Raises: + Raises an sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid + """ + try: + return self.key.verify( + signature, data, _PADDING.PKCS1v15(), hash_alg.value[1]() + ) + except InvalidSignature: + raise _EX.InvalidSignatureException( + "The signature is invalid for the given data" + ) from InvalidSignature + class RSAPrivateKey(PrivateKey): """ Class for holding RSA private keys """ + def __init__(self, key: _RSA.RSAPrivateKey): super().__init__( key=key, - public_key=RSAPublicKey( - key.public_key() - ), - private_numbers=key.private_numbers() + public_key=RSAPublicKey(key.public_key()), + private_numbers=key.private_numbers(), ) @classmethod @@ -470,8 +485,8 @@ def from_numbers( q: int = None, dmp1: int = None, dmq1: int = None, - iqmp: int = None - ) -> 'RSAPrivateKey': + iqmp: int = None, + ) -> "RSAPrivateKey": """ Load an RSA private key from numbers @@ -511,16 +526,14 @@ def from_numbers( d=d, dmp1=_RSA.rsa_crt_dmp1(d, p), dmq1=_RSA.rsa_crt_dmq1(d, q), - iqmp=_RSA.rsa_crt_iqmp(p, q) + iqmp=_RSA.rsa_crt_iqmp(p, q), ).private_key() ) @classmethod def generate( - cls, - key_size: int = 4096, - public_exponent: int = 65537 - ) -> 'RSAPrivateKey': + cls, key_size: int = 4096, public_exponent: int = 65537 + ) -> "RSAPrivateKey": """ Generates a new RSA private key @@ -533,8 +546,7 @@ def generate( """ return cls.from_class( _RSA.generate_private_key( - public_exponent=public_exponent, - key_size=key_size + public_exponent=public_exponent, key_size=key_size ) ) @@ -550,41 +562,33 @@ def sign(self, data: bytes, hash_alg: RsaAlgs = RsaAlgs.SHA512) -> bytes: Returns: bytes: The signature bytes """ - return self.key.sign( - data, - _PADDING.PKCS1v15(), - hash_alg.value[1]() - ) + return self.key.sign(data, _PADDING.PKCS1v15(), hash_alg.value[1]()) + class DSAPublicKey(PublicKey): """ Class for holding DSA public keys """ + def __init__( self, key: _DSA.DSAPublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, - serialized: bytes = None + serialized: bytes = None, ): super().__init__( key=key, comment=comment, key_type=key_type, public_numbers=key.public_numbers(), - serialized=serialized + serialized=serialized, ) self.parameters = key.parameters().parameter_numbers() @classmethod # pylint: disable=invalid-name - def from_numbers( - cls, - p: int, - q: int, - g: int, - y: int - ) -> 'DSAPublicKey': + def from_numbers(cls, p: int, q: int, g: int, y: int) -> "DSAPublicKey": """ Create a DSA public key from public numbers and parameters @@ -599,38 +603,44 @@ def from_numbers( """ return cls( key=_DSA.DSAPublicNumbers( - y=y, - parameter_numbers=_DSA.DSAParameterNumbers( - p=p, - q=q, - g=g - ) + y=y, parameter_numbers=_DSA.DSAParameterNumbers(p=p, q=q, g=g) ).public_key() ) + def verify(self, data: bytes, signature: bytes) -> None: + """ + Verifies a signature + + Args: + data (bytes): The data to verify + signature (bytes): The signature to verify + + Raises: + Raises an sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid + """ + try: + return self.key.verify(signature, data, _HASHES.SHA1()) + except InvalidSignature: + raise _EX.InvalidSignatureException( + "The signature is invalid for the given data" + ) from InvalidSignature + + class DSAPrivateKey(PrivateKey): """ Class for holding DSA private keys """ + def __init__(self, key: _DSA.DSAPrivateKey): super().__init__( key=key, - public_key=DSAPublicKey( - key.public_key() - ), - private_numbers=key.private_numbers() + public_key=DSAPublicKey(key.public_key()), + private_numbers=key.private_numbers(), ) @classmethod # pylint: disable=invalid-name,too-many-arguments - def from_numbers( - cls, - p: int, - q: int, - g: int, - y: int, - x: int - ) -> 'DSAPrivateKey': + def from_numbers(cls, p: int, q: int, g: int, y: int, x: int) -> "DSAPrivateKey": """ Creates a new DSAPrivateKey object from parameters and public/private numbers @@ -647,33 +657,22 @@ def from_numbers( return cls( key=_DSA.DSAPrivateNumbers( public_numbers=_DSA.DSAPublicNumbers( - y=y, - parameter_numbers=_DSA.DSAParameterNumbers( - p=p, - q=q, - g=g - ) + y=y, parameter_numbers=_DSA.DSAParameterNumbers(p=p, q=q, g=g) ), - x=x + x=x, ).private_key() ) @classmethod - def generate(cls, key_size: int = 4096) -> 'DSAPrivateKey': + def generate(cls) -> "DSAPrivateKey": """ Generate a new DSA private key - - Args: - key_size (int, optional): Number of key bytes. Defaults to 4096. + Key size is fixed since OpenSSH only supports 1024-bit DSA keys Returns: DSAPrivateKey: An instance of DSAPrivateKey """ - return cls.from_class( - _DSA.generate_private_key( - key_size=key_size - ) - ) + return cls.from_class(_DSA.generate_private_key(key_size=1024)) def sign(self, data: bytes): """ @@ -685,38 +684,34 @@ def sign(self, data: bytes): Returns: bytes: The signature bytes """ - return self.key.sign( - data, - _HASHES.SHA1() - ) + return self.key.sign(data, _HASHES.SHA1()) + class ECDSAPublicKey(PublicKey): """ Class for holding ECDSA public keys """ + def __init__( self, key: _ECDSA.EllipticCurvePublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, - serialized: bytes = None + serialized: bytes = None, ): super().__init__( key=key, comment=comment, key_type=key_type, public_numbers=key.public_numbers(), - serialized=serialized + serialized=serialized, ) @classmethod - #pylint: disable=invalid-name + # pylint: disable=invalid-name def from_numbers( - cls, - curve: Union[str, _ECDSA.EllipticCurve], - x: int, - y: int - ) -> 'ECDSAPublicKey': + cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int + ) -> "ECDSAPublicKey": """ Create an ECDSA public key from public numbers and parameters @@ -738,33 +733,48 @@ def from_numbers( return cls( key=_ECDSA.EllipticCurvePublicNumbers( - curve=ECDSA_HASHES[curve]() if isinstance(curve, str) else curve, + curve=getattr(_ECDSA, curve.upper())(), x=x, - y=y + y=y, ).public_key() ) + def verify(self, data: bytes, signature: bytes) -> None: + """ + Verifies a signature + + Args: + data (bytes): The data to verify + signature (bytes): The signature to verify + + Raises: + Raises an sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid + """ + try: + curve_hash = ECDSA_HASHES[self.key.curve.name]() + return self.key.verify(signature, data, _ECDSA.ECDSA(curve_hash)) + except InvalidSignature: + raise _EX.InvalidSignatureException( + "The signature is invalid for the given data" + ) from InvalidSignature + + class ECDSAPrivateKey(PrivateKey): """ Class for holding ECDSA private keys """ + def __init__(self, key: _ECDSA.EllipticCurvePrivateKey): super().__init__( key=key, - public_key=ECDSAPublicKey( - key.public_key() - ), - private_numbers=key.private_numbers() + public_key=ECDSAPublicKey(key.public_key()), + private_numbers=key.private_numbers(), ) @classmethod - #pylint: disable=invalid-name + # pylint: disable=invalid-name def from_numbers( - cls, - curve: Union[str, _ECDSA.EllipticCurve], - x: int, - y: int, - private_value: int + cls, curve: Union[str, _ECDSA.EllipticCurve], x: int, y: int, private_value: int ): """ Creates a new ECDSAPrivateKey object from parameters and public/private numbers @@ -789,11 +799,11 @@ def from_numbers( return cls( key=_ECDSA.EllipticCurvePrivateNumbers( public_numbers=_ECDSA.EllipticCurvePublicNumbers( - curve=ECDSA_HASHES[curve]() if isinstance(curve, str) else curve, + curve=getattr(_ECDSA, curve.upper())(), x=x, - y=y + y=y, ), - private_value=private_value + private_value=private_value, ).private_key() ) @@ -808,11 +818,7 @@ def generate(cls, curve: EcdsaCurves = EcdsaCurves.P521): Returns: ECDSAPrivateKey: An instance of ECDSAPrivateKey """ - return cls.from_class( - _ECDSA.generate_private_key( - curve=curve.value - ) - ) + return cls.from_class(_ECDSA.generate_private_key(curve=curve.value)) def sign(self, data: bytes): """ @@ -824,32 +830,28 @@ def sign(self, data: bytes): Returns: bytes: The signature bytes """ - curve = ECDSA_HASHES[self.key.curve.name]() - return self.key.sign( - data, - _ECDSA.ECDSA(curve) - ) + curve_hash = ECDSA_HASHES[self.key.curve.name]() + return self.key.sign(data, _ECDSA.ECDSA(curve_hash)) + class ED25519PublicKey(PublicKey): """ Class for holding ED25519 public keys """ + def __init__( self, key: _ED25519.Ed25519PublicKey, comment: Union[str, bytes] = None, key_type: Union[str, bytes] = None, - serialized: bytes = None + serialized: bytes = None, ): super().__init__( - key=key, - comment=comment, - key_type=key_type, - serialized=serialized + key=key, comment=comment, key_type=key_type, serialized=serialized ) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> 'ED25519PublicKey': + def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PublicKey": """ Load an ED25519 public key from raw bytes @@ -859,26 +861,43 @@ def from_raw_bytes(cls, raw_bytes: bytes) -> 'ED25519PublicKey': Returns: ED25519PublicKey: Instance of ED25519PublicKey """ + if b"ssh-ed25519" in raw_bytes: + id_length = unpack(">I", raw_bytes[:4])[0] + 8 + raw_bytes = raw_bytes[id_length:] + return cls.from_class( - _ED25519.Ed25519PublicKey.from_public_bytes( - data=raw_bytes - ) + _ED25519.Ed25519PublicKey.from_public_bytes(data=raw_bytes) ) + def verify(self, data: bytes, signature: bytes) -> None: + """ + Verifies a signature + + Args: + data (bytes): The data to verify + signature (bytes): The signature to verify + + Raises: + Raises an sshkey_tools.exceptions.InvalidSignatureException if the signature is invalid + """ + try: + return self.key.verify(signature, data) + except InvalidSignature: + raise _EX.InvalidSignatureException( + "The signature is invalid for the given data" + ) from InvalidSignature + + class ED25519PrivateKey(PrivateKey): """ Class for holding ED25519 private keys """ + def __init__(self, key: _ED25519.Ed25519PrivateKey): - super().__init__( - key=key, - public_key=ED25519PublicKey( - key.public_key() - ) - ) + super().__init__(key=key, public_key=ED25519PublicKey(key.public_key())) @classmethod - def from_raw_bytes(cls, raw_bytes: bytes) -> 'ED25519PrivateKey': + def from_raw_bytes(cls, raw_bytes: bytes) -> "ED25519PrivateKey": """ Load an ED25519 private key from raw bytes @@ -889,22 +908,18 @@ def from_raw_bytes(cls, raw_bytes: bytes) -> 'ED25519PrivateKey': ED25519PrivateKey: Instance of ED25519PrivateKey """ return cls.from_class( - _ED25519.Ed25519PrivateKey.from_private_bytes( - data=raw_bytes - ) + _ED25519.Ed25519PrivateKey.from_private_bytes(data=raw_bytes) ) @classmethod - def generate(cls) -> 'ED25519PrivateKey': + def generate(cls) -> "ED25519PrivateKey": """ Generates a new ED25519 Private Key Returns: ED25519PrivateKey: Instance of ED25519PrivateKey """ - return cls.from_class( - _ED25519.Ed25519PrivateKey.generate() - ) + return cls.from_class(_ED25519.Ed25519PrivateKey.generate()) def raw_bytes(self) -> bytes: """ @@ -916,7 +931,7 @@ def raw_bytes(self) -> bytes: return self.key.private_bytes( encoding=_SERIALIZATION.Encoding.Raw, format=_SERIALIZATION.PrivateFormat.Raw, - encryption_algorithm=_SERIALIZATION.NoEncryption() + encryption_algorithm=_SERIALIZATION.NoEncryption(), ) def sign(self, data: bytes): diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index dfff6ea..43ccc72 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -1,13 +1,16 @@ """ Utilities for handling keys and certificates """ +import sys from secrets import randbits from base64 import b64encode import hashlib as hl -def long_to_bytes(source_int: int, force_length: int = None, byteorder: str = 'big') -> bytes: - """ Converts a positive integer to a byte string conforming with the certificate format. +def long_to_bytes( + source_int: int, force_length: int = None, byteorder: str = "big" +) -> bytes: + """Converts a positive integer to a byte string conforming with the certificate format. Equivalent to paramiko.util.deflate_long() Args: source_int (int): Integer to convert @@ -18,7 +21,9 @@ def long_to_bytes(source_int: int, force_length: int = None, byteorder: str = 'b str: Byte string representing the chosen long integer """ if source_int < 0: - raise ValueError("You can only convert positive long integers to bytes with this method") + raise ValueError( + "You can only convert positive long integers to bytes with this method" + ) if not isinstance(source_int, int): raise TypeError(f"Expected integer, got {type(source_int).__name__}.") @@ -26,7 +31,8 @@ def long_to_bytes(source_int: int, force_length: int = None, byteorder: str = 'b length = (source_int.bit_length() // 8 + 1) if not force_length else force_length return source_int.to_bytes(length, byteorder) -def bytes_to_long(source_bytes: bytes, byteorder: str = 'big') -> int: + +def bytes_to_long(source_bytes: bytes, byteorder: str = "big") -> int: """The opposite of long_to_bytes, converts a byte string to a long integer Equivalent to paramiko.util.inflate_long() Args: @@ -41,9 +47,12 @@ def bytes_to_long(source_bytes: bytes, byteorder: str = 'big') -> int: return int.from_bytes(source_bytes, byteorder) -def generate_secure_nonce(length: int = 64): - """ Generates a secure random nonce of the specified length. + +def generate_secure_nonce(length: int = 128): + """Generates a secure random nonce of the specified length. Mainly important for ECDSA keys, but is used with all key/certificate types + https://blog.trailofbits.com/2020/06/11/ecdsa-handle-with-care/ + https://datatracker.ietf.org/doc/html/rfc6979 Args: length (int, optional): Length of the nonce. Defaults to 64. @@ -52,6 +61,7 @@ def generate_secure_nonce(length: int = 64): """ return str(randbits(length)) + def md5_fingerprint(data: bytes, prefix: bool = True) -> str: """ Returns an MD5 fingerprint of the given data. @@ -64,7 +74,10 @@ def md5_fingerprint(data: bytes, prefix: bool = True) -> str: str: The fingerprint (OpenSSH style MD5:xx:xx:xx...) """ digest = hl.md5(data).hexdigest() - return ("MD5:" if prefix else "") + ':'.join(a + b for a, b in zip(digest[::2], digest[1::2])) + return ("MD5:" if prefix else "") + ":".join( + a + b for a, b in zip(digest[::2], digest[1::2]) + ) + def sha256_fingerprint(data: bytes, prefix: bool = True) -> str: """ @@ -78,7 +91,10 @@ def sha256_fingerprint(data: bytes, prefix: bool = True) -> str: str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...) """ digest = hl.sha256(data).digest() - return ("SHA256:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode('utf-8') + return ("SHA256:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode( + "utf-8" + ) + def sha512_fingerprint(data: bytes, prefix: bool = True) -> str: """ @@ -92,4 +108,30 @@ def sha512_fingerprint(data: bytes, prefix: bool = True) -> str: str: The fingerprint (OpenSSH style SHA256:xx:xx:xx...) """ digest = hl.sha512(data).digest() - return ("SHA512:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode('utf-8') + return ("SHA512:" if prefix else "") + b64encode(digest).replace(b"=", b"").decode( + "utf-8" + ) + + +def join_dicts(*dicts) -> dict: + """ + Joins two or more dictionaries together. + In case of duplicate keys, the latest one wins. + + Returns: + dict: Joined dictionary + """ + py_version = sys.version_info[0:2] + return_dict = {} + + if py_version[0] == 3 and py_version[1] > 9: + + for add_dict in dicts: + return_dict = return_dict | add_dict + + return return_dict + + for add_dict in dicts: + return_dict = {**return_dict, **add_dict} + + return return_dict diff --git a/test.py b/test.py index 2cff923..dea405a 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,211 @@ -print("Done") +import os +import unittest +import random +import datetime +import faker +import src.sshkey_tools.keys as _KEY +import src.sshkey_tools.fields as _FIELD +import src.sshkey_tools.cert as _CERT +import src.sshkey_tools.exceptions as _EX +ll = [ + 'permit-x11-forwarding', + 'permit-pty' +] +test = _FIELD.CriticalOptionsField(ll) +test.validate() + + +# print(_FIELD.StandardListField.encode(ll)) +# print(_FIELD.KeyValueField.encode(ll)) +# print(_FIELD.SeparatedListField.encode(ll)) + +# cert_opts = { +# 'serial': 1234567890, +# 'cert_type': _FIELD.CERT_TYPE.USER, +# 'key_id': 'KeyIdentifier', +# 'principals': [ +# 'pr_a', +# 'pr_b', +# 'pr_c' +# ], +# 'valid_after': 1968491468, +# 'valid_before': 1968534668, +# 'critical_options': { +# 'force-command': 'sftp-internal', +# 'source-address': '1.2.3.4/8,5.6.7.8/16', +# 'verify-required': '' +# }, +# 'extensions': [ +# 'permit-agent-forwarding', +# 'permit-X11-forwarding' +# ] +# } + +# user_pub = _KEY.RSAPrivateKey.generate(1024).public_key +# ca_priv = _KEY.RSAPrivateKey.generate(1024) + +# cert = _CERT.SSHCertificate.from_public_class(user_pub, ca_priv, **cert_opts) + +# cert.sign() +# cert.to_file('test_certificate') + +# cert2 = _CERT.SSHCertificate.from_file('test_certificate') + +# assert cert.get_signable_data() == cert2.get_signable_data() + +# print("Hold") + + +# cert = _CERT.RSACertificate( +# user_pub, +# ca_priv, +# **cert_opts +# ) +# cert.sign() + +# cert.to_file('test_certificate') +# cert2 = _CERT.SSHCertificate.from_file('test_certificate') + +# print(cert2) + +# # print(cert2.fields['critical_options'].value) +# # print(cert2.fields['extensions'].value) + + + +# os.system('ssh-keygen -Lf test_certificate') + + + +# print((datetime.now() + timedelta(weeks=52*10)).timestamp()) +# print((datetime.now() + timedelta(weeks=52*10, hours=12)).timestamp()) + +# _FLD.BooleanField.encode('Hello') + +# test = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] +# test2 = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] + +# field = _FLD.SeparatedListField(test) +# field2 = _FLD.SeparatedListField(test2) + +# by = bytes(field) +# by2 = bytes(field) + +# fieldout = _FLD.SeparatedListField.decode(by)[0] +# fieldout2 = _FLD.SeparatedListField.decode(by2)[0] + +# print("Hold") + +# allowed_values = ( +# "ssh-rsa-cert-v01@openssh.com", +# "rsa-sha2-256-cert-v01@openssh.com", +# "rsa-sha2-512-cert-v01@openssh.com", +# "ssh-dss-cert-v01@openssh.com", +# "ecdsa-sha2-nistp256-cert-v01@openssh.com", +# "ecdsa-sha2-nistp384-cert-v01@openssh.com", +# "ecdsa-sha2-nistp521-cert-v01@openssh.com", +# "ssh-ed25519-cert-v01@openssh.com", +# ) + +# for value in allowed_values: +# print(f''' ('{value}', {_FLD.PubkeyTypeField.encode(value)}) ''') + +# randomized = _FLD.NonceField() +# randomized.value + +# print(_FLD.BooleanField.encode(False)) +# print(_FLD.BooleanField.encode(True)) + +# test = True +# test = not test +# print(test) + +# rsa = RSAPrivateKey.generate() +# dsa = DSAPrivateKey.generate() +# ecdsa = ECDSAPrivateKey.generate() +# ed25519 = ED25519PrivateKey.generate() + + + +# data = b'Hello World' + +# rsa = PrivateKey.from_string('''-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcnNhAAAA\nAwEAAQAAAgEAsbdWQ1yxhB5UdbU7cGtra5DDSzg9pHKlo1Zq7XOspt0krpv2PK9Gt3QSsrrD1Yr0\nNqRMC9XqsCAoNaOKFImlwFzeM8u5H409iAoaOqsT2gowXx8K0vJz5KhPoufgT3Ez9yuGs8/OH3S+\n2A9VQceIVf7ThSOWcdA5vXudC7KQOMq5wddfKcduwl5ouUbyOK1qF4yRGJytaWyP+oY/gAUHREON\nZvnKl0GYcvU41zhMvujPlb3ZcEqu/F8HWc492TlbJrB4FWMJd5ZtIxEQR+BgkXY2HrnTEG4+tSi7\nR7c7086kR+KJXrBhLGNHQwt3KrhTpt5QjFZstPaNzqM0daRVcv0nddP36uPd+4SUIFpfqGZNv4x8\nSCSB+3x8gWDj4YMnYMJ2gBsUPc3UowaXIld35rQtfeYNc4tC+vTiwKeSueYE3ec77I2xg1Eqbhmj\nOgabtS9cmGENvG59In/GvgrooXd8L1K+bl7ysLIH03f+RZE8BX/o0eAxeko83WaqVGlxsmQ8aFif\nLEgVfJRxGFudvrPbArFGW1wGba5yH3Gz3d4X65XHklWibUoq2nQjHpYccWmXa+VPKqjRldPG78wq\nLz+Jcxwc1+JjkfYLchHqrYeoQ678qayVg3nqoiI00qh83MNJIRMTrRvvgBSeXQfgHH2N5TtspdOq\n2AvqvNlCcQsAAAc4QkfASkJHwEoAAAAHc3NoLXJzYQAAAgEAsbdWQ1yxhB5UdbU7cGtra5DDSzg9\npHKlo1Zq7XOspt0krpv2PK9Gt3QSsrrD1Yr0NqRMC9XqsCAoNaOKFImlwFzeM8u5H409iAoaOqsT\n2gowXx8K0vJz5KhPoufgT3Ez9yuGs8/OH3S+2A9VQceIVf7ThSOWcdA5vXudC7KQOMq5wddfKcdu\nwl5ouUbyOK1qF4yRGJytaWyP+oY/gAUHREONZvnKl0GYcvU41zhMvujPlb3ZcEqu/F8HWc492Tlb\nJrB4FWMJd5ZtIxEQR+BgkXY2HrnTEG4+tSi7R7c7086kR+KJXrBhLGNHQwt3KrhTpt5QjFZstPaN\nzqM0daRVcv0nddP36uPd+4SUIFpfqGZNv4x8SCSB+3x8gWDj4YMnYMJ2gBsUPc3UowaXIld35rQt\nfeYNc4tC+vTiwKeSueYE3ec77I2xg1EqbhmjOgabtS9cmGENvG59In/GvgrooXd8L1K+bl7ysLIH\n03f+RZE8BX/o0eAxeko83WaqVGlxsmQ8aFifLEgVfJRxGFudvrPbArFGW1wGba5yH3Gz3d4X65XH\nklWibUoq2nQjHpYccWmXa+VPKqjRldPG78wqLz+Jcxwc1+JjkfYLchHqrYeoQ678qayVg3nqoiI0\n0qh83MNJIRMTrRvvgBSeXQfgHH2N5TtspdOq2AvqvNlCcQsAAAADAQABAAACAAMCF3O3HfTJOU9v\nbJIlP1boHGYpjYw7Dz1fORrL2nWjSKZWqCm0Iyj3zgPjJW137KpVcvQVquNQUrNAZsCc6TFYYRUq\nCE2Aa9+MTDqx//lbiCC+uxrW/8nfD3oHyBmQJlEIwOmfmt2YHE2L9OV9eyakKZsXVHSYvGF4ti/R\n1fR1egTN9c5p5yC4eGKqe28k1ablujmwbT8GQhRQ3Bej/iYpqTsU/1jlbgSEIhzX1x9kJsoMwfbP\nTNGjdNG7AVEBUjRVcwg++j9hTHeg0lBlJpKlGEVs39BnYqhZCCfZR39OVXmMsXE+Nbw1R1TbMikx\noDjdin+ATFbD1aKpyzmH4+pced64oO8cROyXT/p9gTUyjFrKjFTQrW6tSx5Z+dBU8GotlMsgoIc1\nNn2jQC0GzbIcc8pj9ipxtKkKWYVSutrxIxAX8hnWVfA3Sgm85eJT+M0netpfWEkzJav/w7GE/7AY\nMsfdvSMBaVU1I0G+m9qWqy3zGV9DzIooedvm6rx2TGEDFP8ap6N1oJ9B0nBnV+iBSJndclI5uhAe\n8JZq0YT9JeRj0qKVuyUmgccS0EckKDMjtaGZVi5k5Vr1CpzBsy3mlIY7L53NhykzFjER9TZt172j\niOh2S/Tqn5Fcs9AgTRSO4uc0t8qRCqAi/MoYgYeW0wD9xq2KLX6MXXA3vFgBAAABAHInUt76VvFr\nuE6FMaif6gUp1y+/74qKakW6pDIEMQoQOZWntlPrsR0PYObiivbzkFtGHLG4D7YcDBCxbAzh+SQn\n7KPhMWj53i6dOMulheFznUzhhOLGL+SXswXrxnHyQGVnU4giu0//pK6Af5krokw1YoWS8GK4hgm+\nmGX85cA6rxUYPVcvuh5vuehuU8P2m30GXDp9GqbnMnAED4KwerfDsgBKtP5yHLq3Rs6mLsVJLxB3\nnf+DsyDQ4mO7iMoQMHtGXrLHQGOGTh6PXzit+84NPufD44MNZSbvejZBp2BJG+tx44FHwZhz30VW\nzXt7mCWiIOgcQSwVcU6FDrByhuUAAAEBAMzHkdgqRJG5egqBbXTRueCy9SSV9WGDzRfTHK5zbsIU\nnPiGkwUDl0G+J91Q5XCr73mPhjt3J2gOYulcOTIlrcGQQ5Op5W+bGS21uNA75vgAbrDcKpMzWNd2\nGOUuq5q9UceNnGz4GPmBeLWnOSAsXik/+RbaZhEJ5MvrqZHdnvg0bDtKB7BRUrXYJ0A6NuTMNzOW\nmb8lmtfLmSKI/Y6m4tR2VRKyf/KQOQJhAsdYpsg/KjpVnalikzGRd8vPlzwSL/GwOK5xMc80Wdi2\nc/ay1Pur6ABsrPt1p3bDjg3ElEtUzH3HeYZnUtzCpmEiyQCBPTKvOD82ivth85EK9TLkh10AAAEB\nAN4q2FR/jlnTsR1Tg3o9TuLI8W2kAjigyq/wjV4X6m01bhjxm9XWOiU7+iTbGNRShh4tnn4mqlIu\na/c9/T87H4dgoGbBnck3XT1ls6+lCqD2F6j2J8Ippd/nPAqqPL1t1AVsGO53ZIAiJxunoUwRAyG6\n+5OspRPQq7h3k6b1xm8y+L7x34HTNUdzlzn6DzTmfMO0THkpcW2zongODMsQt6cexYA6+k1w7IO6\nZp2OUxahv4NjCseqzsF2ThfpZH2loL6FCn7IV/b2SDYWfq/sEHUGBCgg6S7fu7mAnsll5geltiAX\nohk2XSGrWQVsQFhhT9OAkH9dVvSB9KqRzJwVW4cAAAAAAQID\n-----END OPENSSH PRIVATE KEY-----\n''') +# dsa = PrivateKey.from_string('''-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABsQAAAAdzc2gtZHNzAAAA\ngQD8R6aUzO+egB6wdIe/mclkiZ3fvIiPOseBHuradx5GSujuNQNqZuy9CIPfDmNjjyyLLKleaIn8\nLMb8DGVSDgKOoNGTWuTwQ6kjA50jDJzMgL/fzlYPyArvr6sIH287FZkp506rRERHUyHAxY407sqw\n5zJ/adhHAFOYSWmFvJOc2wAAABUApA6ZZMbc6qS4CBCAWWekCcYFLdcAAACAfuEuvf7h1rx0kK+D\nOU3POGAdWP4IVQfcwFOQoi6M4etM7CpfBgIl1j1ZwSi0E56uiB+gST1rY/P0xVUbFtUd0VbpJrkQ\n1AWb/Jb+oElwgaEYUi11exMvyRzwCGpOX4fPmqiQZXXtdF4Ba5KwGoxmmN4eGgoqhx9EoD7fxWqQ\n3hEAAACAS7cdXTTEw5hpkNr757fV2M4zH0/CMjvKCvAUbZZgeuZpQ0frFlaFAneG3BeMMlYqbtEE\n4mBOPGNe58VJovb9ANAE3kVkZUbnZF8ofCKyam7vp0jRMqd/QQvRrSEVo/yb4d9QHoQ15Y1ZxbRK\nxvaiKEt0pnC4/9GwMM+SfhLatdoAAAHYd4p1QHeKdUAAAAAHc3NoLWRzcwAAAIEA/EemlMzvnoAe\nsHSHv5nJZImd37yIjzrHgR7q2nceRkro7jUDambsvQiD3w5jY48siyypXmiJ/CzG/AxlUg4CjqDR\nk1rk8EOpIwOdIwyczIC/385WD8gK76+rCB9vOxWZKedOq0RER1MhwMWONO7KsOcyf2nYRwBTmElp\nhbyTnNsAAAAVAKQOmWTG3OqkuAgQgFlnpAnGBS3XAAAAgH7hLr3+4da8dJCvgzlNzzhgHVj+CFUH\n3MBTkKIujOHrTOwqXwYCJdY9WcEotBOerogfoEk9a2Pz9MVVGxbVHdFW6Sa5ENQFm/yW/qBJcIGh\nGFItdXsTL8kc8AhqTl+Hz5qokGV17XReAWuSsBqMZpjeHhoKKocfRKA+38VqkN4RAAAAgEu3HV00\nxMOYaZDa++e31djOMx9PwjI7ygrwFG2WYHrmaUNH6xZWhQJ3htwXjDJWKm7RBOJgTjxjXufFSaL2\n/QDQBN5FZGVG52RfKHwismpu76dI0TKnf0EL0a0hFaP8m+HfUB6ENeWNWcW0Ssb2oihLdKZwuP/R\nsDDPkn4S2rXaAAAAFCLiyWA/bpyBew4ZTgniE3LzEdmTAAAAAAECAw==\n-----END OPENSSH PRIVATE KEY-----\n''') +# ecdsa = PrivateKey.from_string('''-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS1zaGEy\nLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBmnxyVUU1ZALvV6pVwIOr0E6qUAXh4+7fpmFV71YP\nzhoeim7u+AtnaNYSEBvnEcogK9IQXJ3bjkWbYQuQJhWVG+0B5AEgUnZAnEmklN+MlxV+Iam15vJV\n4dQSfdSWnOu6iz04pWpSBTTKQyd/PpoxrgNoQ1ZY4FwFptYHrtm+xHGEkmcAAAEADxJkPw8SZD8A\nAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQAAAIUEAZp8clVFNWQC71eqVcCDq9BO\nqlAF4ePu36ZhVe9WD84aHopu7vgLZ2jWEhAb5xHKICvSEFyd245Fm2ELkCYVlRvtAeQBIFJ2QJxJ\npJTfjJcVfiGptebyVeHUEn3Ulpzruos9OKVqUgU0ykMnfz6aMa4DaENWWOBcBabWB67ZvsRxhJJn\nAAAAQgG80oMoSfdQpLDNHJLmJqGt29TF+t5961uU//nJqVRgOwNfR52urpQf0shljtvWNhdNkab9\nn55bYQhTPSIjqj47jwAAAAABAg==\n-----END OPENSSH PRIVATE KEY-----\n''') +# ed25519 = PrivateKey.from_string('''-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZWQyNTUx\nOQAAACDogQjCZHK2fnTSnuoaI8ZlkrwntznXOrr628xfZPHUJAAAAIgX9FhFF/RYRQAAAAtzc2gt\nZWQyNTUxOQAAACDogQjCZHK2fnTSnuoaI8ZlkrwntznXOrr628xfZPHUJAAAAED0ZjfmuHF9M5kh\n0U5tCgKgVNvKZem6HPkpyY0DLTkjZeiBCMJkcrZ+dNKe6hojxmWSvCe3Odc6uvrbzF9k8dQkAAAA\nAAECAwQF\n-----END OPENSSH PRIVATE KEY-----\n''') + +# signatures = [ +# { +# 'data': b'@OW=KT:J?KI32=;V3^`=', +# 'rsa': '', +# 'dsa': '', +# 'ecdsa': '', +# 'ed25519': '' +# }, +# { +# 'data': b'>_TVGM:bM77NG8O=Lab7', +# 'rsa': '', +# 'dsa': '', +# 'ecdsa': '', +# 'ed25519': '' +# }, +# { +# 'data': b'>bH[LNFG7=cNcEYJ;TEN', +# 'rsa': '', +# 'dsa': '', +# 'ecdsa': '', +# 'ed25519': '' +# }, +# { +# 'data': b'^Z`KVX@:XL:6?`@TYcOX', +# 'rsa': '', +# 'dsa': '', +# 'ecdsa': '', +# 'ed25519': '' +# }, +# { +# 'data': b'N[U4\=>:5;bPD', +# 'rsa': '', +# 'dsa': '', +# 'ecdsa': '', +# 'ed25519': '' +# } +# ] + +# from base64 import b64encode + +# for item in signatures: +# rsa_sig = b64encode(rsa.sign(item['data'])) +# dsa_sig = b64encode(dsa.sign(item['data'])) +# ecdsa_sig = b64encode(ecdsa.sign(item['data'])) +# ed25519_sig = b64encode(ed25519.sign(item['data'])) + +# print(f''' +# 'data': {item['data']}, +# 'rsa': {rsa_sig}, +# 'dsa': {dsa_sig}, +# 'ecdsa': {ecdsa_sig}, +# 'ed25519': {ed25519_sig} +# ''') + + + +# a = privatekey.sign(data) +# for _ in range(10): +# b = privatekey.sign(data) +# assert a == b + +# from src.sshkey_tools.cert import main # import os diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_certificates.py b/tests/test_certificates.py new file mode 100644 index 0000000..5ad7205 --- /dev/null +++ b/tests/test_certificates.py @@ -0,0 +1,733 @@ +# Test all certificate combinations (rsa-rsa, rsa-dsa, dsa-rsa, etc.) +# Random values for fields, push specifications +# All exceptions +# to/from file and string +# Generated by ssh-keygen and decoded by script, created by script and verified by ssh-keygen + +import os +import shutil +import unittest +import random +import datetime +import faker + +import src.sshkey_tools.keys as _KEY +import src.sshkey_tools.fields as _FIELD +import src.sshkey_tools.cert as _CERT +import src.sshkey_tools.exceptions as _EX + +CERTIFICATE_TYPES = ['rsa', 'dsa', 'ecdsa', 'ed25519'] + +class TestCertificateFields(unittest.TestCase): + def setUp(self): + self.faker = faker.Faker() + self.rsa_key = _KEY.RSAPrivateKey.generate(1024) + self.dsa_key = _KEY.DSAPrivateKey.generate() + self.ecdsa_key = _KEY.ECDSAPrivateKey.generate() + self.ed25519_key = _KEY.ED25519PrivateKey.generate() + + def assertRandomResponse(self, field_class, values = None, random_function = None): + if values is None: + values = [random_function() for _ in range(100)] + + fields = [] + + bytestring = b'' + for value in values: + bytestring += field_class.encode(value) + + field = field_class(value) + self.assertTrue(field.validate()) + fields.append(field) + + self.assertEqual( + bytestring, + b''.join(bytes(x) for x in fields) + ) + + decoded = [] + while bytestring != b'': + decode, bytestring = field_class.decode(bytestring) + decoded.append(decode) + + self.assertEqual( + decoded, + values + ) + + def assertExpectedResponse(self, field_class, input, expected_output): + self.assertEqual( + field_class.encode(input), + expected_output + ) + + def test_boolean_field(self): + self.assertRandomResponse( + _FIELD.BooleanField, + random_function=lambda : self.faker.pybool() + ) + + def test_invalid_boolean_field(self): + field = _FIELD.BooleanField("SomeInvalidData") + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_bytestring_field(self): + self.assertRandomResponse( + _FIELD.BytestringField, + random_function=lambda : self.faker.pystr(1, 100).encode('utf-8') + ) + + def test_invalid_bytestring_field(self): + field = _FIELD.BytestringField('Hello') + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode('String') + + def test_string_field(self): + self.assertRandomResponse( + _FIELD.StringField, + random_function=lambda : self.faker.pystr(1, 100) + ) + + def test_invalid_string_field(self): + field = _FIELD.StringField(b'Hello') + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(b'String') + + def test_integer32_field(self): + self.assertRandomResponse( + _FIELD.Integer32Field, + random_function=lambda : random.randint(2**2, 2**32) + ) + + def test_invalid_integer32_field(self): + field = _FIELD.Integer32Field(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.Integer32Field(_FIELD.MAX_INT32 + 1) + + self.assertIsInstance( + field.validate(), + _EX.IntegerOverflowException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_integer64_field(self): + self.assertRandomResponse( + _FIELD.Integer64Field, + random_function=lambda : random.randint(2**32, 2**64) + ) + + def test_invalid_integer64_field(self): + field = _FIELD.Integer64Field(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.Integer64Field(_FIELD.MAX_INT64 + 1) + + self.assertIsInstance( + field.validate(), + _EX.IntegerOverflowException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_datetime_field(self): + self.assertRandomResponse( + _FIELD.DateTimeField, + random_function=lambda : self.faker.date_time() + ) + + def test_invalid_datetime_field(self): + field = _FIELD.DateTimeField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_mp_integer_field(self): + self.assertRandomResponse( + _FIELD.MpIntegerField, + random_function=lambda : random.randint(2**128, 2**512) + ) + + def test_invalid_mp_integer_field(self): + field = _FIELD.MpIntegerField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.MpIntegerField('InvalidData') + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_list_field(self): + self.assertRandomResponse( + _FIELD.ListField, + random_function=lambda : [self.faker.pystr(0, 100) for _ in range(10)] + ) + + + def test_invalid_list_field(self): + field = _FIELD.ListField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.ListField([ValueError, ValueError, ValueError]) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_key_value_field(self): + self.assertRandomResponse( + _FIELD.KeyValueField, + random_function=lambda : { + self.faker.pystr(1, 10): self.faker.pystr(1, 100) + for _ in range(10) + } + ) + + def test_invalid_key_value_field(self): + field = _FIELD.KeyValueField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.KeyValueField([ValueError, ValueError, ValueError]) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_pubkey_type_field(self): + allowed_values = ( + ('ssh-rsa-cert-v01@openssh.com', b'\x00\x00\x00\x1cssh-rsa-cert-v01@openssh.com'), + ('rsa-sha2-256-cert-v01@openssh.com', b'\x00\x00\x00!rsa-sha2-256-cert-v01@openssh.com'), + ('rsa-sha2-512-cert-v01@openssh.com', b'\x00\x00\x00!rsa-sha2-512-cert-v01@openssh.com'), + ('ssh-dss-cert-v01@openssh.com', b'\x00\x00\x00\x1cssh-dss-cert-v01@openssh.com'), + ('ecdsa-sha2-nistp256-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp256-cert-v01@openssh.com'), + ('ecdsa-sha2-nistp384-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp384-cert-v01@openssh.com'), + ('ecdsa-sha2-nistp521-cert-v01@openssh.com', b'\x00\x00\x00(ecdsa-sha2-nistp521-cert-v01@openssh.com'), + ('ssh-ed25519-cert-v01@openssh.com', b'\x00\x00\x00 ssh-ed25519-cert-v01@openssh.com') + ) + + for value in allowed_values: + self.assertExpectedResponse( + _FIELD.PubkeyTypeField, + value[0], + value[1] + ) + + def test_invalid_pubkey_type_field(self): + field = _FIELD.PubkeyTypeField('HelloWorld') + + self.assertIsInstance( + field.validate(), + _EX.InvalidDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_nonce_field(self): + randomized = _FIELD.NonceField() + + self.assertTrue(randomized.validate()) + + specific = ( + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz', + b'\x00\x00\x004abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz' + ) + + self.assertExpectedResponse( + _FIELD.NonceField, + specific[0], + specific[1] + ) + + def test_pubkey_class_assignment(self): + rsa_field = _FIELD.PublicKeyField.from_object(self.rsa_key.public_key) + dsa_field = _FIELD.PublicKeyField.from_object(self.dsa_key.public_key) + ecdsa_field = _FIELD.PublicKeyField.from_object(self.ecdsa_key.public_key) + ed25519_field = _FIELD.PublicKeyField.from_object(self.ed25519_key.public_key) + + self.assertIsInstance(rsa_field, _FIELD.RSAPubkeyField) + self.assertIsInstance(dsa_field, _FIELD.DSAPubkeyField) + self.assertIsInstance(ecdsa_field, _FIELD.ECDSAPubkeyField) + self.assertIsInstance(ed25519_field, _FIELD.ED25519PubkeyField) + + self.assertTrue(rsa_field.validate()) + self.assertTrue(dsa_field.validate()) + self.assertTrue(ecdsa_field.validate()) + self.assertTrue(ed25519_field.validate()) + + def assertPubkeyOutput(self, key_class, *opts): + key = getattr(self, key_class.__name__.replace('PrivateKey', '').lower() + '_key').public_key + raw_start = key.raw_bytes() + field = _FIELD.PublicKeyField.from_object(key) + byte_data = bytes(field) + decoded = field.decode(byte_data)[0] + + self.assertTrue(field.validate()) + + self.assertEqual( + raw_start, + decoded.raw_bytes() + ) + + self.assertEqual( + f'{key.__class__.__name__.replace("PublicKey", "")} {key.get_fingerprint()}', + str(field) + ) + + + def test_rsa_pubkey_output(self): + self.assertPubkeyOutput( + _KEY.RSAPrivateKey, + 1024 + ) + + def test_dsa_pubkey_output(self): + self.assertPubkeyOutput( + _KEY.DSAPrivateKey + ) + + def test_ecdsa_pubkey_output(self): + self.assertPubkeyOutput( + _KEY.ECDSAPrivateKey + ) + + def test_ed25519_pubkey_output(self): + self.assertPubkeyOutput( + _KEY.ED25519PrivateKey + ) + + def test_serial_field(self): + self.assertRandomResponse( + _FIELD.SerialField, + random_function=lambda : random.randint(0, 2**64 - 1) + ) + + def test_invalid_serial_field(self): + field = _FIELD.SerialField('abcdefg') + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.SerialField(random.randint(2**65, 2**66)) + + self.assertIsInstance( + field.validate(), + _EX.IntegerOverflowException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_certificate_type_field(self): + self.assertExpectedResponse( + _FIELD.CertificateTypeField, + 1, + b'\x00\x00\x00\x01' + ) + + self.assertExpectedResponse( + _FIELD.CertificateTypeField, + 2, + b'\x00\x00\x00\x02' + ) + + self.assertExpectedResponse( + _FIELD.CertificateTypeField, + _FIELD.CERT_TYPE.USER, + b'\x00\x00\x00\x01' + ) + + self.assertExpectedResponse( + _FIELD.CertificateTypeField, + _FIELD.CERT_TYPE.HOST, + b'\x00\x00\x00\x02' + ) + + def test_invalid_certificate_field(self): + field = _FIELD.CertificateTypeField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.CertificateTypeField(3) + + self.assertIsInstance( + field.validate(), + _EX.InvalidDataException + ) + + field = _FIELD.CertificateTypeField(0) + + self.assertIsInstance( + field.validate(), + _EX.InvalidDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_key_id_field(self): + self.assertRandomResponse( + _FIELD.KeyIDField, + random_function=lambda : self.faker.pystr(8, 128) + ) + + def test_invalid_key_id_field(self): + field = _FIELD.KeyIDField('') + + self.assertIsInstance( + field.validate(), + _EX.InvalidDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_principals_field(self): + self.assertRandomResponse( + _FIELD.PrincipalsField, + random_function=lambda : [self.faker.pystr(8, 128) for _ in range(20)] + ) + + comparison = [self.faker.pystr(8, 128) for _ in range(20)] + + self.assertExpectedResponse( + _FIELD.PrincipalsField, + comparison, + _FIELD.ListField.encode(comparison) + ) + + def test_invalid_principals_field(self): + field = _FIELD.PrincipalsField([ValueError, ValueError, ValueError]) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_validity_start_field(self): + self.assertRandomResponse( + _FIELD.ValidityStartField, + random_function=lambda : self.faker.date_time() + ) + + def test_invalid_validity_start_field(self): + field = _FIELD.ValidityStartField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_validity_end_field(self): + self.assertRandomResponse( + _FIELD.ValidityEndField, + random_function=lambda : self.faker.date_time() + ) + + def test_invalid_validity_end_field(self): + field = _FIELD.ValidityEndField(ValueError) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_critical_options_field(self): + valid_opts_dict = { + 'force-command': 'sftp-internal', + 'source-address': '1.2.3.4/8,5.6.7.8/16', + 'verify-required': '' + } + + valid_opts_list = [ + 'force-command', + 'source-address', + 'verify-required' + ] + + verify_dict = _FIELD.CriticalOptionsField(valid_opts_dict) + verify_list = _FIELD.CriticalOptionsField(valid_opts_list) + + encoded_dict = _FIELD.CriticalOptionsField.encode(valid_opts_dict) + encoded_list = _FIELD.CriticalOptionsField.encode(valid_opts_list) + + decoded_dict = _FIELD.CriticalOptionsField.decode(encoded_dict)[0] + decoded_list = _FIELD.CriticalOptionsField.decode(encoded_list)[0] + + self.assertTrue(verify_dict.validate()) + self.assertTrue(verify_list.validate()) + + self.assertEqual( + valid_opts_dict, + decoded_dict + ) + + self.assertEqual( + valid_opts_list, + decoded_list + ) + + def test_invalid_critical_options_field(self): + field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', 'unpermit']) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.CriticalOptionsField('InvalidData') + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_extensions_field(self): + valid_values = [ + 'no-touch-required', + 'permit-X11-forwarding', + 'permit-agent-forwarding', + 'permit-port-forwarding', + 'permit-pty', + 'permit-user-rc' + ] + + self.assertRandomResponse( + _FIELD.ExtensionsField, + values=[valid_values[0:random.randint(1,len(valid_values))] for _ in range(10)] + ) + + def test_invalid_extensions_field(self): + field = _FIELD.CriticalOptionsField([ValueError, 'permit-pty', b'unpermit']) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.CriticalOptionsField('InvalidData') + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + field = _FIELD.CriticalOptionsField(['no-touch-required', 'InvalidOption']) + + self.assertIsInstance( + field.validate(), + _EX.InvalidFieldDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def test_reserved_field(self): + self.assertExpectedResponse( + _FIELD.ReservedField, + '', + b'\x00\x00\x00\x00' + ) + + def test_invalid_reserved_field(self): + field = _FIELD.ReservedField('InvalidData') + + self.assertIsInstance( + field.validate(), + _EX.InvalidDataException + ) + + with self.assertRaises(_EX.InvalidFieldDataException): + field.encode(ValueError) + + def assertCAPubkeyField(self, type): + key = getattr(self, f'{type}_key').public_key + field = _FIELD.CAPublicKeyField.from_object(key) + encoded = bytes(field) + decoded = field.decode(encoded)[0] + + self.assertEqual( + key.raw_bytes(), + decoded.raw_bytes() + ) + + self.assertEqual( + f'{type.upper()} {key.get_fingerprint()}', + str(field) + ) + +class TestCertificates(unittest.TestCase): + def setUp(self): + if not os.path.isdir('tests/certificates'): + os.mkdir('tests/certificates') + + self.faker = faker.Faker() + + self.rsa_ca = _KEY.RSAPrivateKey.generate(1024) + self.dsa_ca = _KEY.DSAPrivateKey.generate() + self.ecdsa_ca = _KEY.ECDSAPrivateKey.generate() + self.ed25519_ca = _KEY.ED25519PrivateKey.generate() + + self.rsa_user = _KEY.RSAPrivateKey.generate(1024).public_key + self.dsa_user = _KEY.DSAPrivateKey.generate().public_key + self.ecdsa_user = _KEY.ECDSAPrivateKey.generate().public_key + self.ed25519_user = _KEY.ED25519PrivateKey.generate().public_key + + + self.cert_opts = { + 'serial': 1234567890, + 'cert_type': _FIELD.CERT_TYPE.USER, + 'key_id': 'KeyIdentifier', + 'principals': [ + 'pr_a', + 'pr_b', + 'pr_c' + ], + 'valid_after': 1968491468, + 'valid_before': 1968534668, + 'critical_options': { + 'force-command': 'sftp-internal', + 'source-address': '1.2.3.4/8,5.6.7.8/16', + 'verify-required': '' + }, + 'extensions': [ + 'permit-agent-forwarding', + 'permit-X11-forwarding' + ] + } + + def tearDown(self): + shutil.rmtree(f'tests/certificates') + os.mkdir(f'tests/certificates') + + def test_cert_type_assignment(self): + rsa_cert = _CERT.SSHCertificate.from_public_class(self.rsa_user) + dsa_cert = _CERT.SSHCertificate.from_public_class(self.dsa_user) + ecdsa_cert = _CERT.SSHCertificate.from_public_class(self.ecdsa_user) + ed25519_cert = _CERT.SSHCertificate.from_public_class(self.ed25519_user) + + self.assertIsInstance(rsa_cert, _CERT.RSACertificate) + self.assertIsInstance(dsa_cert, _CERT.DSACertificate) + self.assertIsInstance(ecdsa_cert, _CERT.ECDSACertificate) + self.assertIsInstance(ed25519_cert, _CERT.ED25519Certificate) + + + def assertCertificateCreated(self, sub_type, ca_type): + sub_pubkey = getattr(self, f'{sub_type}_user') + ca_privkey = getattr(self, f'{ca_type}_ca') + + certificate = _CERT.SSHCertificate.from_public_class( + public_key=sub_pubkey, + ca_privkey=ca_privkey, + **self.cert_opts + ) + + self.assertTrue(certificate.can_sign()) + certificate.sign() + certificate.to_file(f'tests/certificates/{sub_type}_{ca_type}-cert.pub') + + self.assertEqual( + 0, + os.system(f'ssh-keygen -Lf tests/certificates/{sub_type}_{ca_type}-cert.pub') + ) + + reloaded_cert = _CERT.SSHCertificate.from_file(f'tests/certificates/{sub_type}_{ca_type}-cert.pub') + + self.assertEqual( + certificate.get_signable_data(), + reloaded_cert.get_signable_data() + ) + + def test_certificate_creation(self): + for ca_type in CERTIFICATE_TYPES: + for user_type in CERTIFICATE_TYPES: + print("Testing certificate creation for {} CA and {} user".format(ca_type, user_type)) + self.assertCertificateCreated( + user_type, + ca_type + ) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_keypairs.py b/tests/test_keypairs.py new file mode 100644 index 0000000..90393aa --- /dev/null +++ b/tests/test_keypairs.py @@ -0,0 +1,872 @@ +# Test privkey generation with valid and invalid parameters +# Test privkey import +# Test import to right class +# Test pubkey generation from priv +import os +import shutil +import unittest +import src.sshkey_tools.exceptions as _EX +from base64 import b64encode +from src.sshkey_tools.keys import ( + PrivkeyClasses, + PrivateKey, + RSAPrivateKey, + DSAPrivateKey, + ECDSAPrivateKey, + ED25519PrivateKey, + PubkeyClasses, + PublicKey, + RSAPublicKey, + DSAPublicKey, + ECDSAPublicKey, + ED25519PublicKey, + EcdsaCurves +) + +from cryptography.hazmat.primitives.asymmetric import ( + rsa as _RSA, + dsa as _DSA, + ec as _EC, + ed25519 as _ED25519 +) + +from cryptography.exceptions import InvalidSignature + +class KeypairMethods(unittest.TestCase): + def generateClasses(self): + self.rsa_key = RSAPrivateKey.generate(2048) + self.dsa_key = DSAPrivateKey.generate() + self.ecdsa_key = ECDSAPrivateKey.generate(EcdsaCurves.P256) + self.ed25519_key = ED25519PrivateKey.generate() + + def generateFiles(self, folder): + self.folder = folder + try: + os.mkdir(f'tests/{folder}') + except FileExistsError: + shutil.rmtree(f'tests/{folder}') + os.mkdir(f'tests/{folder}') + + os.system(f'ssh-keygen -t rsa -b 2048 -f tests/{folder}/rsa_key_sshkeygen -N "password" > /dev/null 2>&1') + os.system(f'ssh-keygen -t dsa -b 1024 -f tests/{folder}/dsa_key_sshkeygen -N "" > /dev/null 2>&1') + os.system(f'ssh-keygen -t ecdsa -b 256 -f tests/{folder}/ecdsa_key_sshkeygen -N "" > /dev/null 2>&1') + os.system(f'ssh-keygen -t ed25519 -f tests/{folder}/ed25519_key_sshkeygen -N "" > /dev/null 2>&1') + + + def setUp(self): + self.generateClasses() + self.generateFiles('KeypairMethods') + + def tearDown(self): + shutil.rmtree(f'tests/{self.folder}') + + def assertEqualPrivateKeys( + self, + priv_class, + pub_class, + a, + b, + privkey_attr = ['private_numbers'] + ): + self.assertIsInstance(a, priv_class) + self.assertIsInstance(b, priv_class) + + for att in privkey_attr: + try: + self.assertEqual(getattr(a, att), getattr(b, att)) + except AssertionError: + print("Hold") + + self.assertEqualPublicKeys( + pub_class, + a.public_key, + b.public_key + ) + + def assertEqualPublicKeys( + self, + keyclass, + a, + b + ): + self.assertIsInstance(a, keyclass) + self.assertIsInstance(b, keyclass) + + self.assertEqual(a.raw_bytes(), b.raw_bytes()) + + def assertEqualKeyFingerprint( + self, + file_a, + file_b + ): + self.assertEqual(0, + os.system( + f'''bash -c " + diff \ + <( ssh-keygen -lf {file_a}) \ + <( ssh-keygen -lf {file_b}) \ + " + ''' + ) + ) + +class TestKeypairMethods(KeypairMethods): + def test_fail_assertions(self): + with self.assertRaises(AssertionError): + self.assertEqualPrivateKeys( + RSAPrivateKey, + RSAPublicKey, + RSAPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password'), + DSAPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + ) + + with self.assertRaises(AssertionError): + self.assertEqualPublicKeys( + RSAPublicKey, + RSAPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub'), + DSAPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + ) + + with self.assertRaises(AssertionError): + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_sshkeygen', + f'tests/{self.folder}/dsa_key_sshkeygen' + ) + + def test_successful_assertions(self): + self.assertTrue( + os.path.isfile( + f'tests/{self.folder}/rsa_key_sshkeygen' + ) + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_sshkeygen', + f'tests/{self.folder}/rsa_key_sshkeygen.pub' + ) + + self.assertTrue( + os.path.isfile( + f'tests/{self.folder}/dsa_key_sshkeygen' + ) + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_sshkeygen', + f'tests/{self.folder}/dsa_key_sshkeygen.pub' + ) + + self.assertTrue( + os.path.isfile( + f'tests/{self.folder}/ecdsa_key_sshkeygen' + ) + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_sshkeygen', + f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + ) + + self.assertTrue( + os.path.isfile( + f'tests/{self.folder}/ed25519_key_sshkeygen' + ) + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_sshkeygen', + f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + ) + +class TestKeyGeneration(KeypairMethods): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_rsa(self): + key_bits = [ + 512, + 1024, + 2048, + 4096, + 8192 + ] + + for bits in key_bits: + key = RSAPrivateKey.generate(bits) + + assert isinstance(key, RSAPrivateKey) + assert isinstance(key, PrivateKey) + assert isinstance(key.key, _RSA.RSAPrivateKey) + assert isinstance(key.private_numbers, _RSA.RSAPrivateNumbers) + + assert isinstance(key.public_key, RSAPublicKey) + assert isinstance(key.public_key, PublicKey) + assert isinstance(key.public_key.key, _RSA.RSAPublicKey) + assert isinstance(key.public_key.public_numbers, _RSA.RSAPublicNumbers) + + def test_rsa_incorrect_keysize(self): + with self.assertRaises(ValueError): + RSAPrivateKey.generate(256) + + def test_dsa(self): + + key = DSAPrivateKey.generate() + + assert isinstance(key, DSAPrivateKey) + assert isinstance(key, PrivateKey) + assert isinstance(key.key, _DSA.DSAPrivateKey) + assert isinstance(key.private_numbers, _DSA.DSAPrivateNumbers) + + assert isinstance(key.public_key, DSAPublicKey) + assert isinstance(key.public_key, PublicKey) + assert isinstance(key.public_key.key, _DSA.DSAPublicKey) + assert isinstance(key.public_key.public_numbers, _DSA.DSAPublicNumbers) + assert isinstance(key.public_key.parameters, _DSA.DSAParameterNumbers) + + def test_ecdsa(self): + curves = [ + EcdsaCurves.P256, + EcdsaCurves.P384, + EcdsaCurves.P521 + ] + + for curve in curves: + key = ECDSAPrivateKey.generate(curve) + + + assert isinstance(key, ECDSAPrivateKey) + assert isinstance(key, PrivateKey) + assert isinstance(key.key, _EC.EllipticCurvePrivateKey) + assert isinstance(key.private_numbers, _EC.EllipticCurvePrivateNumbers) + + assert isinstance(key.public_key, ECDSAPublicKey) + assert isinstance(key.public_key, PublicKey) + assert isinstance(key.public_key.key, _EC.EllipticCurvePublicKey) + assert isinstance(key.public_key.public_numbers, _EC.EllipticCurvePublicNumbers) + + + def test_ecdsa_not_a_curve(self): + with self.assertRaises(AttributeError): + ECDSAPrivateKey.generate('p256') + + def test_ed25519(self): + key = ED25519PrivateKey.generate() + + assert isinstance(key, ED25519PrivateKey) + assert isinstance(key, PrivateKey) + assert isinstance(key.key, _ED25519.Ed25519PrivateKey) + + assert isinstance(key.public_key, ED25519PublicKey) + assert isinstance(key.public_key, PublicKey) + assert isinstance(key.public_key.key, _ED25519.Ed25519PublicKey) + +class TestToFromFiles(KeypairMethods): + def setUp(self): + self.generateClasses() + self.generateFiles('TestToFromFiles') + + def test_encoding(self): + with open(f'tests/{self.folder}/rsa_key_sshkeygen', 'r', encoding='utf-8') as file: + from_string = PrivateKey.from_string(file.read(), 'password', 'utf-8') + + with open(f'tests/{self.folder}/rsa_key_sshkeygen.pub', 'r', encoding='utf-8') as file: + from_string_pub = PublicKey.from_string(file.read(), 'utf-8') + + from_file = PrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') + from_file_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') + + self.assertEqualPrivateKeys( + RSAPrivateKey, + RSAPublicKey, + from_string, + from_file + ) + + self.assertEqualPublicKeys( + RSAPublicKey, + from_string_pub, + from_file_pub + ) + + + def test_rsa_files(self): + parent = PrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') + child = RSAPrivateKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen', 'password') + + parent_pub = PublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') + child_pub = RSAPublicKey.from_file(f'tests/{self.folder}/rsa_key_sshkeygen.pub') + + parent.to_file(f'tests/{self.folder}/rsa_key_saved_parent', 'password') + child.to_file(f'tests/{self.folder}/rsa_key_saved_child') + + parent_pub.to_file(f'tests/{self.folder}/rsa_key_saved_parent.pub') + child_pub.to_file(f'tests/{self.folder}/rsa_key_saved_child.pub') + + + self.assertEqualPrivateKeys( + RSAPrivateKey, + RSAPublicKey, + parent, + child + ) + + self.assertEqualPublicKeys( + RSAPublicKey, + parent_pub, + child_pub + ) + + self.assertEqualPublicKeys( + RSAPublicKey, + parent.public_key, + child_pub + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_sshkeygen', + f'tests/{self.folder}/rsa_key_saved_parent' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_sshkeygen.pub', + f'tests/{self.folder}/rsa_key_saved_parent.pub' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_saved_parent', + f'tests/{self.folder}/rsa_key_saved_child' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_saved_parent.pub', + f'tests/{self.folder}/rsa_key_sshkeygen.pub' + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/rsa_key_saved_parent.pub', + f'tests/{self.folder}/rsa_key_sshkeygen.pub' + ) + + def test_dsa_files(self): + parent = PrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + child = DSAPrivateKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen') + + parent_pub = PublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + child_pub = DSAPublicKey.from_file(f'tests/{self.folder}/dsa_key_sshkeygen.pub') + + parent.to_file(f'tests/{self.folder}/dsa_key_saved_parent') + child.to_file(f'tests/{self.folder}/dsa_key_saved_child') + + parent_pub.to_file(f'tests/{self.folder}/dsa_key_saved_parent.pub') + child_pub.to_file(f'tests/{self.folder}/dsa_key_saved_child.pub') + + self.assertEqualPrivateKeys( + DSAPrivateKey, + DSAPublicKey, + parent, + child + ) + + self.assertEqualPublicKeys( + DSAPublicKey, + parent_pub, + child_pub + ) + + self.assertEqualPublicKeys( + DSAPublicKey, + parent.public_key, + child_pub + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_sshkeygen', + f'tests/{self.folder}/dsa_key_saved_parent' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_sshkeygen.pub', + f'tests/{self.folder}/dsa_key_saved_parent.pub' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_saved_parent', + f'tests/{self.folder}/dsa_key_saved_child' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_saved_parent.pub', + f'tests/{self.folder}/dsa_key_sshkeygen.pub' + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/dsa_key_saved_parent.pub', + f'tests/{self.folder}/dsa_key_sshkeygen.pub' + ) + + def test_ecdsa_files(self): + parent = PrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') + child = ECDSAPrivateKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen') + + parent_pub = PublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') + child_pub = ECDSAPublicKey.from_file(f'tests/{self.folder}/ecdsa_key_sshkeygen.pub') + + parent.to_file(f'tests/{self.folder}/ecdsa_key_saved_parent') + child.to_file(f'tests/{self.folder}/ecdsa_key_saved_child') + + parent_pub.to_file(f'tests/{self.folder}/ecdsa_key_saved_parent.pub') + child_pub.to_file(f'tests/{self.folder}/ecdsa_key_saved_child.pub') + + self.assertEqualPrivateKeys( + ECDSAPrivateKey, + ECDSAPublicKey, + parent, + child + ) + + self.assertEqualPublicKeys( + ECDSAPublicKey, + parent_pub, + child_pub + ) + + self.assertEqualPublicKeys( + ECDSAPublicKey, + parent.public_key, + child_pub + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_sshkeygen', + f'tests/{self.folder}/ecdsa_key_saved_parent' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_sshkeygen.pub', + f'tests/{self.folder}/ecdsa_key_saved_parent.pub' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_saved_parent', + f'tests/{self.folder}/ecdsa_key_saved_child' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_saved_parent.pub', + f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ecdsa_key_saved_parent.pub', + f'tests/{self.folder}/ecdsa_key_sshkeygen.pub' + ) + + def test_ed25519_files(self): + parent = PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') + child = ED25519PrivateKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen') + + parent_pub = PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') + child_pub = ED25519PublicKey.from_file(f'tests/{self.folder}/ed25519_key_sshkeygen.pub') + + parent.to_file(f'tests/{self.folder}/ed25519_key_saved_parent') + child.to_file(f'tests/{self.folder}/ed25519_key_saved_child') + + parent_pub.to_file(f'tests/{self.folder}/ed25519_key_saved_parent.pub') + child_pub.to_file(f'tests/{self.folder}/ed25519_key_saved_child.pub') + + self.assertEqualPrivateKeys( + ED25519PrivateKey, + ED25519PublicKey, + parent, + child + ) + + self.assertEqualPublicKeys( + ED25519PublicKey, + parent_pub, + child_pub + ) + + self.assertEqualPublicKeys( + ED25519PublicKey, + parent.public_key, + child_pub + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_sshkeygen', + f'tests/{self.folder}/ed25519_key_saved_parent' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_sshkeygen.pub', + f'tests/{self.folder}/ed25519_key_saved_parent.pub' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_saved_parent', + f'tests/{self.folder}/ed25519_key_saved_child' + ) + + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_saved_parent.pub', + f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + ) + self.assertEqualKeyFingerprint( + f'tests/{self.folder}/ed25519_key_saved_parent.pub', + f'tests/{self.folder}/ed25519_key_sshkeygen.pub' + ) + +class TestFromClass(KeypairMethods): + def setUp(self): + self.rsa_key = _RSA.generate_private_key( + public_exponent=65537, + key_size=2048 + ) + self.dsa_key = _DSA.generate_private_key( + key_size=1024 + ) + self.ecdsa_key = _EC.generate_private_key( + curve=_EC.SECP384R1() + ) + self.ed25519_key = _ED25519.Ed25519PrivateKey.generate() + + def tearDown(self): + pass + + def test_invalid_key_exception(self): + with self.assertRaises(_EX.InvalidKeyException): + PublicKey.from_class( + key_class=self.rsa_key, + key_type='invalid-key-type', + comment='Comment' + ) + + + def test_rsa_from_class(self): + parent = PrivateKey.from_class(self.rsa_key) + child = RSAPrivateKey.from_class(self.rsa_key) + + self.assertEqualPrivateKeys( + RSAPrivateKey, + RSAPublicKey, + parent, + child + ) + + def test_dsa_from_class(self): + parent = PrivateKey.from_class(self.dsa_key) + child = DSAPrivateKey.from_class(self.dsa_key) + + self.assertEqualPrivateKeys( + DSAPrivateKey, + DSAPublicKey, + parent, + child + ) + + def test_ecdsa_from_class(self): + parent = PrivateKey.from_class(self.ecdsa_key) + child = ECDSAPrivateKey.from_class(self.ecdsa_key) + + self.assertEqualPrivateKeys( + ECDSAPrivateKey, + ECDSAPublicKey, + parent, + child + ) + + def test_ed25519_from_class(self): + parent = PrivateKey.from_class(self.ed25519_key) + child = ED25519PrivateKey.from_class(self.ed25519_key) + + self.assertEqualPrivateKeys( + ED25519PrivateKey, + ED25519PublicKey, + parent, + child + ) + +class TestFromComponents(KeypairMethods): + def setUp(self): + self.generateClasses() + + def tearDown(self): + pass + + def test_rsa_from_numbers(self): + from_numbers = RSAPrivateKey.from_numbers( + n=self.rsa_key.public_key.public_numbers.n, + e=self.rsa_key.public_key.public_numbers.e, + d=self.rsa_key.private_numbers.d + ) + + from_numbers_pub = RSAPublicKey.from_numbers( + n=self.rsa_key.public_key.public_numbers.n, + e=self.rsa_key.public_key.public_numbers.e + ) + + self.assertEqualPublicKeys( + RSAPublicKey, + from_numbers_pub, + from_numbers.public_key + ) + + self.assertIsInstance(from_numbers, RSAPrivateKey) + + self.assertEqual( + self.rsa_key.public_key.public_numbers.n, + from_numbers.public_key.public_numbers.n + ) + + self.assertEqual( + self.rsa_key.public_key.public_numbers.e, + from_numbers.public_key.public_numbers.e + ) + + self.assertEqual( + self.rsa_key.private_numbers.d, + from_numbers.private_numbers.d + ) + + def test_dsa_from_numbers(self): + from_numbers = DSAPrivateKey.from_numbers( + p=self.dsa_key.public_key.parameters.p, + q=self.dsa_key.public_key.parameters.q, + g=self.dsa_key.public_key.parameters.g, + y=self.dsa_key.public_key.public_numbers.y, + x=self.dsa_key.private_numbers.x + ) + + from_numbers_pub = DSAPublicKey.from_numbers( + p=self.dsa_key.public_key.parameters.p, + q=self.dsa_key.public_key.parameters.q, + g=self.dsa_key.public_key.parameters.g, + y=self.dsa_key.public_key.public_numbers.y + ) + + self.assertEqualPrivateKeys( + DSAPrivateKey, + DSAPublicKey, + self.dsa_key, + from_numbers + ) + + self.assertEqualPublicKeys( + DSAPublicKey, + from_numbers_pub, + self.dsa_key.public_key + ) + + def test_ecdsa_from_numbers(self): + from_numbers = ECDSAPrivateKey.from_numbers( + curve=self.ecdsa_key.public_key.key.curve, + x=self.ecdsa_key.public_key.public_numbers.x, + y=self.ecdsa_key.public_key.public_numbers.y, + private_value=self.ecdsa_key.private_numbers.private_value + ) + + from_numbers_pub = ECDSAPublicKey.from_numbers( + curve=self.ecdsa_key.public_key.key.curve, + x=self.ecdsa_key.public_key.public_numbers.x, + y=self.ecdsa_key.public_key.public_numbers.y + ) + + self.assertEqualPrivateKeys( + ECDSAPrivateKey, + ECDSAPublicKey, + self.ecdsa_key, + from_numbers + ) + + self.assertEqualPublicKeys( + ECDSAPublicKey, + from_numbers_pub, + self.ecdsa_key.public_key + ) + + from_numbers = ECDSAPrivateKey.from_numbers( + curve=self.ecdsa_key.public_key.key.curve.name, + x=self.ecdsa_key.public_key.public_numbers.x, + y=self.ecdsa_key.public_key.public_numbers.y, + private_value=self.ecdsa_key.private_numbers.private_value + ) + + self.assertEqualPrivateKeys( + ECDSAPrivateKey, + ECDSAPublicKey, + self.ecdsa_key, + from_numbers + ) + + def test_ed25519_from_raw_bytes(self): + from_raw = ED25519PrivateKey.from_raw_bytes( + self.ed25519_key.raw_bytes() + ) + from_raw_pub = ED25519PublicKey.from_raw_bytes( + self.ed25519_key.public_key.raw_bytes() + ) + + self.assertEqualPrivateKeys( + ED25519PrivateKey, + ED25519PublicKey, + self.ed25519_key, + from_raw, + [] + ) + + self.assertEqualPublicKeys( + ED25519PublicKey, + self.ed25519_key.public_key, + from_raw_pub + ) + + +class TestFingerprint(KeypairMethods): + + def setUp(self): + self.generateFiles('TestFingerprint') + + def test_rsa_fingerprint(self): + key = RSAPrivateKey.from_file( + f'tests/{self.folder}/rsa_key_sshkeygen', + 'password' + ) + + sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/rsa_key_sshkeygen').read().split(' ')[1] + + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) + + def test_dsa_fingerprint(self): + key = DSAPrivateKey.from_file( + f'tests/{self.folder}/dsa_key_sshkeygen', + ) + + sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/dsa_key_sshkeygen').read().split(' ')[1] + + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) + + def test_ecdsa_fingerprint(self): + key = ECDSAPrivateKey.from_file( + f'tests/{self.folder}/ecdsa_key_sshkeygen', + ) + + sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/ecdsa_key_sshkeygen').read().split(' ')[1] + + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) + + def test_ed25519_fingerprint(self): + key = ED25519PrivateKey.from_file( + f'tests/{self.folder}/ed25519_key_sshkeygen', + ) + + sshkey_fingerprint = os.popen(f'ssh-keygen -lf tests/{self.folder}/ed25519_key_sshkeygen').read().split(' ')[1] + + self.assertEqual(key.get_fingerprint(), sshkey_fingerprint) + +class TestSignatures(KeypairMethods): + def setUp(self): + self.generateClasses() + + def tearDown(self): + pass + + def test_rsa_signature(self): + data = b"\x00"+os.urandom(32)+b"\x00" + signature = self.rsa_key.sign(data) + + + self.assertIsNone( + self.rsa_key.public_key.verify( + data, + signature + ) + ) + + with self.assertRaises(_EX.InvalidSignatureException): + self.rsa_key.public_key.verify( + data, + signature+b'\x00' + ) + + def test_dsa_signature(self): + data = b"\x00"+os.urandom(32)+b"\x00" + signature = self.dsa_key.sign(data) + + + self.assertIsNone( + self.dsa_key.public_key.verify( + data, + signature + ) + ) + + with self.assertRaises(_EX.InvalidSignatureException): + self.dsa_key.public_key.verify( + data, + signature+b'\x00' + ) + + def test_ecdsa_signature(self): + data = b"\x00"+os.urandom(32)+b"\x00" + signature = self.ecdsa_key.sign(data) + + + self.assertIsNone( + self.ecdsa_key.public_key.verify( + data, + signature + ) + ) + + with self.assertRaises(_EX.InvalidSignatureException): + self.ecdsa_key.public_key.verify( + data, + signature+b'\x00' + ) + + def test_ed25519_signature(self): + data = b"\x00"+os.urandom(32)+b"\x00" + signature = self.ed25519_key.sign(data) + + + self.assertIsNone( + self.ed25519_key.public_key.verify( + data, + signature + ) + ) + + with self.assertRaises(_EX.InvalidSignatureException): + self.ed25519_key.public_key.verify( + data, + signature+b'\x00' + ) + +class TestExceptions(KeypairMethods): + def setUp(self): + self.generateClasses() + + def tearDown(self): + pass + + def test_invalid_private_key(self): + with self.assertRaises(_EX.InvalidKeyException): + key = PrivateKey.from_class(KeypairMethods) + + def test_invalid_ecdsa_curve(self): + with self.assertRaises(_EX.InvalidCurveException): + key = ECDSAPublicKey.from_numbers( + 'abc123', + x=self.ecdsa_key.public_key.public_numbers.x, + y=self.ecdsa_key.public_key.public_numbers.y + ) + + with self.assertRaises(_EX.InvalidCurveException): + key = ECDSAPrivateKey.from_numbers( + 'abc123', + x=self.ecdsa_key.public_key.public_numbers.x, + y=self.ecdsa_key.public_key.public_numbers.y, + private_value=self.ecdsa_key.private_numbers.private_value + ) + +if __name__ == '__main__': + unittest.main() + + diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6421ef2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,140 @@ +import unittest +from random import randint +import src.sshkey_tools.utils as utils + +from paramiko.util import deflate_long, inflate_long + +EXPECTED_LONG_CONVERSIONS = [ + (0, b'\x00'), + (1, b'\x01'), + (11638779394004435200 ,b'\x00\xa1\x85@1\xa6\xc5A\x00'), + (15203631582360839337 ,b'\x00\xd2\xfe&_2\x87$\xa9'), + (15302225898842444598 ,b'\x00\xd4\\ma]IS6'), + (15945599391219780268 ,b'\x00\xddJ&)\xb4W\n\xac'), + (15242635186864927689 ,b'\x00\xd3\x88\xb7\xf1\x89\xf43\xc9'), + (17517368859399259630 ,b'\x00\xf3\x1a2*\xa7\xfb\x85\xee'), + (11464229064000469348 ,b'\x00\x9f\x19\x1f\x97\xf7t\xf9d'), +] + +EXPECTED_HASHES = [ + ( b'HYoNIlxkwde]GX]qBNdtH\^ZYGyRWf', '15:6d:0c:c7:cb:9f:9a:85:dd:5e:f4:ee:b3:2b:9d:3f', 'tchAVFXDgszqtTYzokVmUfdQOHi5UvSsPVQzFsHnfXg', 'mof7eHzmuTZzEkPnjAUbaFABqZl8YUanE3Ips+jy5B9aiRtA1D8fIJgmmfb4V/T0AZz08Gu4AsMOKfybrjOUAA' ), + ( b'sseJmSLI_RgSRciYac\M`BjxkziCFD', 'cc:a4:4e:cc:1a:a4:3b:e6:26:f2:1b:a4:70:b6:0a:5a', 'XYWSEK4LHNhkVSMrMTVBg73r4Nu0ElpFuQ1efJrp5Ks', '/vZr65Oj0vyIIaC+iVCFiXNnmf7Ntg3njkBwJZddOxYaZ3mvs3Ra2OOh/VN1bMDbKaZik9BOkzDTjIXNBMDgUg' ), + ( b'AGGPP]bLy[SKwkNGfgjkwGw_vPa\]h', '25:de:7e:8f:65:da:5c:1d:c2:aa:79:8b:58:72:69:12', 'f6C8TqeIHwAIt0PvvjnR23uBzLcgr6MfTr+150u2+XE', 'cLVxsX2k3HEX6/e+S3NROdAzEIM45gO/stYLyZiJrQ5n14fBGgZX25bwJFcAtB/FNZjrIwmG1m3jRJ+CgXgBBQ' ), + ( b'OTUwhulOPckus_M]EyuxXWz^URykVP', 'df:d8:12:e2:bb:ad:70:0a:0a:e2:ba:b2:2c:82:0e:f7', 'pWW/P8tP5qOFeEYeoDMxzJlTHowbs6VwkTzHeQWrg8g', 'gF5ZmdFWxYVcDMUMyC4hIgFBXtgrxv8VZbw809UiqgTzpxOEyY1mVupma9joUGRHr3IjbDJdr+Uiq4TuZwVTjw' ), + ( b'JPvFmb^HoOw]KPwYAgVlWAhtZt]YSb', '5e:6e:12:09:50:b3:b4:e7:8f:f9:a0:d0:a6:40:e2:4c', 'hsrRKQ8vR7oDIxYAALz48kyg0BZNs61S09KCrgNJMyk', 'AmKjHbklKxWgk1EcUwtJenIYfvGC4bTfYwXnkOGvfFu2nvClgrTJix8MgZ+RtQq/3usweE2CIHgD+d3FH052kQ' ), + ( b'lLzVwDidpzTCNUbAmYgxaCV^ASbdy[', 'c5:e9:35:ea:f9:32:e6:2b:68:c2:18:8a:9f:93:ba:01', 'METh/occInzgqIvcRrJlyQa8E6nr095BXlrn9izySpA', 'G3e2YkowIzM2k5VyKxSk7O8+TJlGzEYE8pL7SPS1Ts4OOg37q06jinNXeRVSDt3ICih3KNtJPXHnFV3M41xDSw' ), + ( b'`NSyrnrn[ZntLugeydWWiaSTlVTNuU', '1a:cb:13:fb:e9:a1:ea:84:55:fd:a1:c4:7a:05:7b:d0', 'yQOxfh54aRBnH9tPxYsvDF/TLTy15nRze0AuftJWRcg', 'u9xyMyKaldgr3EZ9WFUkpGVh3C+lqLci6D3fD9lq3k2vtqK90CzTwHDNA/OHBnsc/ukWPe+kxaGn5CtHROjtLg' ), + ( b'JCUdozAoqggVFeD`dAXo]ElrDOrgVV', '42:5c:7b:4b:44:ac:f4:21:e8:fb:17:fb:b4:46:62:d4', 'uuGOfXTNMvcwfS6nol/cJ5ijVw0DVBA+4rpt18PsQjo', 'j6JOK8j8cjZnddm3IjjjnmHQobIewIelGOeSGMa9WEcKd8nkvTD17coSYdObQ9/X+kbU+9nPDSaRjjA3eQW6QA' ), + ( b']Beog]hAviJSgTZlbTDcytqftqaDof', 'e0:d3:0a:36:75:31:eb:e8:20:73:25:c3:2e:67:aa:54', 'FRreFuNrKbZ1lkJAhtaQSyqCHkrkRoZ6oKc7JNyLTAE', 'MBrk8xn3kkXmV+KefP7Lg0plxd+rI0dZ+QExCL3NlSfC54Y8j6GENtmheUFknHwaLpiwgKSRtwPN6ZP6EfaK0A' ), + ( b'XKHCjGvvtJgPdDkGSGDyPkZDBif^BZ', '7e:e4:29:df:b4:77:2a:6f:5d:eb:9f:25:8e:bd:45:b6', 'BmF8Pt4E/z0M8/rMy6mXkUpVlDG9Zje5+KA3dBVIR7c', 'I9lVFWPx6YkwnsZRMf21TPFvquV59+ng11F3EFFhDLKHrK/l6cdGQu1K0idWQQfRK44d77z00TK/aKmPmgZ2kg' ), +] + +class TestLongConversion(unittest.TestCase): + def test_expected_deflation(self): + """ + Ensure the built-in function handles deflation as expected + compared to the established function + """ + for before, after in EXPECTED_LONG_CONVERSIONS: + self.assertEqual( + utils.long_to_bytes(before), + after, + f"Failed to convert {before} to a byte string " + + "(expected: {after}, got: {long_to_bytes(before)})" + ) + + self.assertEqual( + deflate_long(before), + utils.long_to_bytes(before), + f"The comparative function failed to deliver the same result as built-in" + ) + + def test_expected_inflation(self): + """ + Ensure the built-in function handles inflation as expected + compared to the established function + """ + for after, before in EXPECTED_LONG_CONVERSIONS: + self.assertEqual( + utils.bytes_to_long(before), + after, + f"Failed to convert {before} to a byte string " + + f"(expected: {after}, got: {utils.bytes_to_long(before)})" + ) + + self.assertEqual( + inflate_long(before), + utils.bytes_to_long(before), + f"The comparative function failed to deliver the same result as built-in" + ) + + def test_expected_exception(self): + """ + Ensure appropriate exceptions are thrown when the input is invalid + """ + with self.assertRaises(ValueError): + utils.long_to_bytes(-1) + + with self.assertRaises(TypeError): + utils.long_to_bytes('one') + + def test_random_values(self): + """ + Extend testing with random results, comparing to the established function. + """ + start_length = 16 + for _ in range(15): + print(start_length) + + for _ in range(10): + value = randint(2**start_length-1, 2**start_length) + + builtin = utils.long_to_bytes(value) + compare = deflate_long(value) + + self.assertEqual( + builtin, + compare + ) + + self.assertEqual( + utils.bytes_to_long(compare), + inflate_long(builtin) + ) + + start_length = start_length*2 + +class TestNonceGeneration(unittest.TestCase): + def test_nonce_generation(self): + """ + Ensure the nonce is generated correctly + """ + for _ in range(10): + self.assertIsInstance( + utils.generate_secure_nonce(), + str + ) + +class TestHashGeneration(unittest.TestCase): + def test_hashing_functions(self): + """ + Test the md5 hash function + """ + for bytestring, md5, sha256, sha512 in EXPECTED_HASHES: + md5_2 = utils.md5_fingerprint(bytestring, False) + md5_3 = utils.md5_fingerprint(bytestring, True) + sha256_2 = utils.sha256_fingerprint(bytestring, False) + sha256_3 = utils.sha256_fingerprint(bytestring, True) + sha512_2 = utils.sha512_fingerprint(bytestring, False) + sha512_3 = utils.sha512_fingerprint(bytestring, True) + + self.assertEqual(md5, md5_2) + self.assertEqual(sha256, sha256_2) + self.assertEqual(sha512, sha512_2) + + self.assertEqual(f"MD5:{md5}", md5_3) + self.assertEqual(f"SHA256:{sha256}", sha256_3) + self.assertEqual(f"SHA512:{sha512}", sha512_3) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file