forked from awslabs/aws-crt-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mqtt_test.py
141 lines (118 loc) · 4.93 KB
/
mqtt_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.
import argparse
from awscrt import io, mqtt
from awscrt.io import LogLevel
import threading
import uuid
TIMEOUT = 5 # seconds given to each step of the test before giving up
UNIQUE_ID = str(uuid.uuid4()) # prevent simultaneously-running tests from interfering with each other
CLIENT_ID = 'test_pubsub_' + UNIQUE_ID
TOPIC = 'test/pubsub/' + UNIQUE_ID
MESSAGE = 'test message ' + UNIQUE_ID
parser = argparse.ArgumentParser()
parser.add_argument('--endpoint', required=True, help="Connect to this endpoint (aka host-name)")
parser.add_argument('--port', type=int, help="Override default connection port")
parser.add_argument('--cert', help="File path to your client certificate, in PEM format")
parser.add_argument('--key', help="File path to your private key, in PEM format")
parser.add_argument('--root-ca', help="File path to root certificate authority, in PEM format")
io.init_logging(LogLevel.Trace, 'stderr')
def on_connection_interrupted(connection, error, **kwargs):
print("Connection has been interrupted with error", error)
def on_connection_resumed(connection, return_code, session_present, **kwargs):
print("Connection has been resumed with return code", return_code, "and session present:", session_present)
if not session_present:
print("Resubscribing to existing topics")
resubscribe_future, packet_id = connection.resubscribe_existing_topics()
def on_resubscribe_complete(resubscribe_future):
try:
resubscribe_results = resubscribe_future.result()
print("Resubscribe results:", resubscribe_results)
assert(resubscribe_results['packet_id'] == packet_id)
for (topic, qos) in resubscribe_results['topics']:
assert(qos is not None)
except Exception as e:
print("Resubscribe failure:", e)
exit(-1)
resubscribe_future.add_done_callback(on_resubscribe_complete)
receive_results = {}
receive_event = threading.Event()
def on_receive_message(topic, payload, dup, qos, retain, **kwargs):
receive_results['topic'] = topic
receive_results['payload'] = payload
receive_results['dup'] = dup
receive_results['qos'] = qos
receive_results['retain'] = retain
receive_event.set()
# Run
args = parser.parse_args()
event_loop_group = io.EventLoopGroup(1)
host_resolver = io.DefaultHostResolver(event_loop_group)
client_bootstrap = io.ClientBootstrap(event_loop_group, host_resolver)
tls_options = None
if args.cert or args.key or args.root_ca:
if args.cert:
assert args.key
tls_options = io.TlsContextOptions.create_client_with_mtls_from_path(args.cert, args.key)
else:
tls_options = io.TlsContextOptions()
if args.root_ca:
with open(args.root_ca, mode='rb') as ca:
rootca = ca.read()
tls_options.override_default_trust_store(rootca)
if args.port:
port = args.port
elif io.is_alpn_available():
port = 443
if tls_options:
tls_options.alpn_list = ['x-amzn-mqtt-ca']
else:
port = 8883
tls_context = io.ClientTlsContext(tls_options) if tls_options else None
mqtt_client = mqtt.Client(client_bootstrap, tls_context)
# Connect
print("Connecting to {}:{} with client-id:{}".format(args.endpoint, port, CLIENT_ID))
mqtt_connection = mqtt.Connection(
client=mqtt_client,
host_name=args.endpoint,
port=port,
client_id=CLIENT_ID,
on_connection_interrupted=on_connection_interrupted,
on_connection_resumed=on_connection_resumed)
connect_results = mqtt_connection.connect().result(TIMEOUT)
assert(connect_results['session_present'] == False)
# Subscribe
print("Subscribing to:", TOPIC)
qos = mqtt.QoS.AT_LEAST_ONCE
subscribe_future, subscribe_packet_id = mqtt_connection.subscribe(
topic=TOPIC,
qos=qos,
callback=on_receive_message)
subscribe_results = subscribe_future.result(TIMEOUT)
assert(subscribe_results['packet_id'] == subscribe_packet_id)
assert(subscribe_results['topic'] == TOPIC)
print(subscribe_results)
assert(subscribe_results['qos'] == qos)
# Publish
print("Publishing to '{}': {}".format(TOPIC, MESSAGE))
publish_future, publish_packet_id = mqtt_connection.publish(
topic=TOPIC,
payload=MESSAGE,
qos=mqtt.QoS.AT_LEAST_ONCE)
publish_results = publish_future.result(TIMEOUT)
assert(publish_results['packet_id'] == publish_packet_id)
# Receive Message
print("Waiting to receive messsage")
assert(receive_event.wait(TIMEOUT))
assert(receive_results['topic'] == TOPIC)
assert(receive_results['payload'].decode() == MESSAGE)
# Unsubscribe
print("Unsubscribing from topic")
unsubscribe_future, unsubscribe_packet_id = mqtt_connection.unsubscribe(TOPIC)
unsubscribe_results = unsubscribe_future.result(TIMEOUT)
assert(unsubscribe_results['packet_id'] == unsubscribe_packet_id)
# Disconnect
print("Disconnecting")
mqtt_connection.disconnect().result(TIMEOUT)
# Done
print("Test Success")