Skip to content

Commit

Permalink
feat: add sm2 encrypt and decrypt
Browse files Browse the repository at this point in the history
  • Loading branch information
j-z10 committed May 29, 2023
1 parent 7230a54 commit edabff6
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 20 deletions.
104 changes: 97 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
A Python ctypes GmSSL implementation
=======

## 1 INSTALL
### 1.1 install GmSSL
## INSTALL
### install GmSSL
```bash
git clone https://github.com/guanzhi/GmSSL.git
cd GmSSL && mkdir build && cd build && cmake ..
Expand All @@ -14,14 +14,104 @@ sudo ldconfig
gmssl version
```

### 1.2 install pygmssl
### install pygmssl
```bash
python -m pip install pygmssl
```

## 2 Usage
## USAGE

### SM3
```python3
import pygmssl
print(pygmssl.get_gmssl_version_str(), pygmssl.get_gmssl_version_num())
```
from pygmssl.sm3 import SM3

# sm3 hash all data
data = b'hello, world'
assert SM3(data).hexdigest() == '02df30dff15f2ccb72bffdcb44e68d4d09974036dc7a6927e556fbef421c7f34'

# sm3 hash data by part
s3 = SM3()
for part_data in [b'hel', b'lo', b', world']:
s3.update(part_data)
assert s3.hexdigest() == '02df30dff15f2ccb72bffdcb44e68d4d09974036dc7a6927e556fbef421c7f34'

# sm3 hash with sm2 public key and id extra
# if not id, id will be sm2 default id, which is b'1234567812345678'
sm2_pub_key = b'\xe8G\x0be\xc3P\x12\x10\xa9+\xe6n%\x9dc\xe9\xed\xaeBEf\xab\xd0\x12t\x01RQ\xb8\xceJ\xb0\x9b;\x17\xbb.\xf7i\x00\x18Nq~\xa3\xf4n\xf8\xd7\xdd%m-@\xa3\xc3tv\xe4\xe2\xf7\x81\x83\xe0'
assert SM3.hash_with_sm2(data, sm2_pub_key).hexdigest() == 'cad9730d3d178bf4c234ab7d2b1fc39569af314faecda258f30ee92456f53d2f'
assert SM3.hash_with_sm2(data, sm2_pub_key, id=b'1234567812345678').hexdigest() == 'cad9730d3d178bf4c234ab7d2b1fc39569af314faecda258f30ee92456f53d2f'
assert SM3.hash_with_sm2(data, sm2_pub_key, id=b'123').hexdigest() == 'd5ba879b0197c1a528283ff9a2b25f347474749b27ab5fd7c8a55648fff1f861'

# sm3 hash with sm2 public key by part
s3 = SM3.hash_with_sm2(b'', sm2_pub_key)
for part_data in [b'hel', b'lo', b', world']:
s3.update(part_data)
assert s3.hexdigest() == 'cad9730d3d178bf4c234ab7d2b1fc39569af314faecda258f30ee92456f53d2f'
```
### SM3-HMAC

```python3
from pygmssl.sm3 import SM3HMAC

# sm3 hmac all data
data = b'hello, world'
assert SM3HMAC(key=b'123', data=data).hexdigest() == '4410e0fef1ae0a641c7c4f1a7f6c7cef5b992f80607d5275f669d8942a77cc08'

# sm3 hmac data by part
s3 = SM3HMAC(key=b'123')
for part_data in [b'hel', b'lo', b', world']:
s3.update(part_data)
assert s3.hexdigest() == '4410e0fef1ae0a641c7c4f1a7f6c7cef5b992f80607d5275f669d8942a77cc08'
```

### SM4
```python3
from pygmssl.sm4 import SM4, MOD

# CBC, must 16 bytes key and 16 bytes iv
key = b'F\x7f\x8e7\x05\xc8\x14\x92\xa8P\x8feGx\xf6\xfc'
iv = b'W\xd3,A\x97L\x0e\xfd\xbe\xb5@\xa9\xb0\xe2L\xdf'
cipher = SM4(key, mode=MOD.CBC, iv=iv)
data = b'hello, world'
assert cipher.decrypt(cipher.encrypt(data)) == data

```

### SM2
```python3
from pygmssl.sm2 import SM2

# generate sm2 private key and public key
s2 = SM2.generate_new_pair()
print(s2.pub_key) # 64 byte public key
print(s2.pri_key) # 32 byte private key

# 64 byte public_key or 65 byte public key(which is b'\x04' + 64 byte)
test_pub_key = b'\xe8G\x0be\xc3P\x12\x10\xa9+\xe6n%\x9dc\xe9\xed\xaeBEf' \
b'\xab\xd0\x12t\x01RQ\xb8\xceJ\xb0\x9b;\x17\xbb.\xf7i\x00' \
b'\x18Nq~\xa3\xf4n\xf8\xd7\xdd%m-@\xa3\xc3tv\xe4\xe2\xf7\x81\x83\xe0'
test_pri_key = b'\x87\x95\x84V\xcej\x8cq\xd1\x10\x94\xa7\xb7\x8d\xc1\x9a' \
b'\x98\xcf\xe7\x84\x90\x9d\x8d\xd2\xff\xb4\xaeo2\xb8j\x1b'

# SM2 sign and verify with default id
signer = SM2(pub_key=test_pub_key, pri_key=test_pri_key)
data = b'hello, world'
sig = signer.sign(data) # if not id, id will be sm2.SM2_DEFAULT_ID
assert signer.verify(data, sig) == True

# SM2 sign and verify with id
signer2 = SM2(pub_key=test_pub_key, pri_key=test_pri_key)
data = b'hello, world'
sig = signer2.sign(data, id=b'123') # if not id, id will be sm2.SM2_DEFAULT_ID
assert signer2.verify(data, sig, id=b'123') == True
assert signer2.verify(data + b'\x00', sig, id=b'123') == False # libgmssl will print some fail info

# SM2 encrypt and decrypt, data's length <= sm2.SM2_MAX_PLAINTEXT_SIZE
en = SM2(pub_key=test_pub_key)
data = b'hello, world'
s_data = en.encrypt(data)

de = SM2(pri_key=test_pri_key)
d_data = de.decrypt(s_data)
assert d_data == data
```
2 changes: 1 addition & 1 deletion pygmssl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def get_gmssl_version_num() -> str:
return _gm.gmssl_version_num()


VERSION = __version__ = '0.0.4'
VERSION = __version__ = '0.0.5'
24 changes: 24 additions & 0 deletions pygmssl/sm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
SM2_DEFAULT_ID = b'1234567812345678'
SM2_MIN_SIGNATURE_SIZE = 8
SM2_MAX_SIGNATURE_SIZE = 72
SM2_MIN_PLAINTEXT_SIZE = 1
SM2_MAX_PLAINTEXT_SIZE = 255
SM2_MIN_CIPHERTEXT_SIZE = 45
SM2_MAX_CIPHERTEXT_SIZE = 366


class _SM2_POINT(Structure):
Expand Down Expand Up @@ -87,3 +91,23 @@ def verify(self, data: bytes, sig: bytes, id: bytes = SM2_DEFAULT_ID) -> bool:
_gm.sm2_verify_update(byref(_verify_ctx), byref(buff), len(chunk))
ret = _gm.sm2_verify_finish(byref(_verify_ctx), c_char_p(sig), len(sig))
return ret == 1

def encrypt(self, data:bytes) -> bytes:
if len(data) > SM2_MAX_PLAINTEXT_SIZE:
raise ValueError('to encrypt data\'s length must <= sm2.SM2_MIN_PLAINTEXT_SIZE')
buff = (c_uint8 * SM2_MAX_PLAINTEXT_SIZE)()
buff[:len(data)] = data
out = (c_uint8 * SM2_MAX_CIPHERTEXT_SIZE)()
length = c_size_t()
_gm.sm2_encrypt(byref(self._sm2_key), byref(buff), len(data), byref(out), byref(length))
return bytes(out[:length.value])

def decrypt(self, data:bytes) -> bytes:
if len(data) > SM2_MAX_CIPHERTEXT_SIZE:
raise ValueError('to decrypt data\'s length must <= sm2.SM2_MAX_CIPHERTEXT_SIZE')
buff = (c_uint8 * SM2_MAX_CIPHERTEXT_SIZE)()
buff[:len(data)] = data
out = (c_uint8 * SM2_MAX_PLAINTEXT_SIZE)()
length = c_size_t()
_gm.sm2_decrypt(byref(self._sm2_key), byref(buff), len(data), byref(out), byref(length))
return bytes(out[:length.value])
16 changes: 11 additions & 5 deletions pygmssl/sm4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import byref, c_uint8, c_size_t, c_uint32, Structure, c_char_p
from enum import Enum

from ._gm import _gm

Expand Down Expand Up @@ -29,6 +30,7 @@ def init(self, key: bytes, iv: bytes, encrypt: bool):
else:
_gm.sm4_cbc_decrypt_init(byref(self), c_char_p(key), c_char_p(iv))
self._result: list[bytes] = []
self._encrypt = encrypt

def encrypt_update(self, data: bytes):
outbuf = (c_uint8 * 4196)()
Expand Down Expand Up @@ -65,16 +67,20 @@ def decrypt_get(self) -> bytes:
return b''.join(self._result)


MOD_CTX_DICT = {
_MOD_CTX_DICT = {
'CBC': _SM4_CBC_CTX
}


class MOD(str, Enum):
CBC = 'CBC'


class SM4:
def __init__(self, key: bytes, *, mode: str, iv: bytes):
if mode.upper() not in MOD_CTX_DICT:
raise ValueError(u'Only support sm4 mod: %s' % MOD_CTX_DICT.keys())
self._ctx = MOD_CTX_DICT[mode.upper()]()
def __init__(self, key: bytes, *, mode: MOD, iv: bytes):
if mode.value.upper() not in _MOD_CTX_DICT:
raise ValueError(u'Only support sm4 mod: %s' % _MOD_CTX_DICT.keys())
self._ctx = _MOD_CTX_DICT[mode.value.upper()]()
self.key = key
self.iv = iv

Expand Down
11 changes: 11 additions & 0 deletions tests/test_sm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ def test_006_sm2_sign_with_id(self):
self.assertFalse(self.k.verify(data, sig))
self.assertTrue(self.k.verify(data, sig=sig, id=b'123'))
self.assertFalse(self.k.verify(b'\x00' + data, sig=sig, id=b'123'))

def test_007_sm2_encrypt_and_decrypt(self):
data = b'hello, world'
self.assertEqual(self.k.decrypt(self.k.encrypt(data)), data)

def test_008_sm2_encrypt_and_decrypt_check(self):
data = b'1' * 1024
with self.assertRaises(ValueError):
self.k.encrypt(data)
with self.assertRaises(ValueError):
self.k.decrypt(data)
14 changes: 7 additions & 7 deletions tests/test_sm4.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from unittest import TestCase

from pygmssl.sm4 import SM4
from pygmssl.sm4 import SM4, MOD


class TestSM2(TestCase):
class TestSM4(TestCase):
def setUp(self) -> None:
self._16_key = b"\x97\xfc\xa3\xd7\t;\xf3\xd1'\x9c\x8c\x03\x92\x1c\xf5\xd2"
self._16_iv = b'p\xce\xb9\x8d\t$x\x9c\x0f]\xea\x92\xae\xa1\x96\x9d'
Expand All @@ -13,17 +13,17 @@ def setUp(self) -> None:

def test_000_valid_mod(self):
with self.assertRaises(ValueError):
SM4(self._16_key, mode='JBC', iv=b'123')
SM4(self._16_key, mode=MOD('CDC'), iv=b'123')

def test_001_cbc_encrypt(self):
k = SM4(self._16_key, mode='CBC', iv=self._16_iv)
k2 = SM4(self._16_key, mode='CBC', iv=self._16_iv)
k = SM4(self._16_key, mode=MOD.CBC, iv=self._16_iv)
k2 = SM4(self._16_key, mode=MOD.CBC, iv=self._16_iv)
e_data = k.encrypt(b'hello, world')
self.assertEqual(b'W\x855 su+\x95\xd9@\x0fGL\xacKk', e_data)
self.assertEqual(b'hello, world', k2.decrypt(e_data))

def test_002_cbc_bulk_encrypt(self):
k = SM4(self._16_key, mode='CBC', iv=self._16_iv)
k2 = SM4(self._16_key, mode='CBC', iv=self._16_iv)
k = SM4(self._16_key, mode=MOD.CBC, iv=self._16_iv)
k2 = SM4(self._16_key, mode=MOD.CBC, iv=self._16_iv)
bulk_data = os.urandom(512) * 2099
self.assertEqual(bulk_data, k2.decrypt(k.encrypt(bulk_data)))

0 comments on commit edabff6

Please sign in to comment.