diff --git a/src/sshkey_tools/keys.py b/src/sshkey_tools/keys.py index 9143cbf..46293cb 100644 --- a/src/sshkey_tools/keys.py +++ b/src/sshkey_tools/keys.py @@ -26,7 +26,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa as _RSA from . import exceptions as _EX -from .utils import ensure_bytestring, ensure_string +from .utils import ensure_bytestring, ensure_string, nullsafe_getattr from .utils import md5_fingerprint as _FP_MD5 from .utils import sha256_fingerprint as _FP_SHA256 from .utils import sha512_fingerprint as _FP_SHA512 @@ -129,12 +129,15 @@ def __init__( _SERIALIZATION.Encoding.OpenSSH, _SERIALIZATION.PublicFormat.OpenSSH, ] + + # Ensure comment is not None + self.comment = nullsafe_getattr(self, "comment", "") @classmethod def from_class( cls, key_class: PubkeyClasses, - comment: Union[str, bytes] = None, + comment: Union[str, bytes] = "", key_type: Union[str, bytes] = None, ) -> "PublicKey": """ @@ -266,7 +269,7 @@ def to_string(self, encoding: str = "utf-8") -> str: return " ".join( [ ensure_string(self.serialize(), encoding), - ensure_string(getattr(self, "comment", ""), encoding), + ensure_string(nullsafe_getattr(self, "comment", ""), encoding), ] ) diff --git a/src/sshkey_tools/utils.py b/src/sshkey_tools/utils.py index 0d98234..4b21af2 100644 --- a/src/sshkey_tools/utils.py +++ b/src/sshkey_tools/utils.py @@ -233,6 +233,22 @@ def sha512_fingerprint(data: bytes, prefix: bool = True) -> str: ) +def nullsafe_getattr(obj, attr: str, default): + """ + Null-safe getattr, ensuring the result is not None. + If the result is None, the default value is returned instead. + + Args: + obj: The object + attr: The attribute to get + default: The default value + """ + att = getattr(obj, attr, default) + if att is None: + att = default + + return att + def join_dicts(*dicts) -> dict: """ Joins two or more dictionaries together.