diff --git a/src/Clients.c b/src/Clients.c index 585257d2..b22bf834 100644 --- a/src/Clients.c +++ b/src/Clients.c @@ -51,5 +51,5 @@ int clientSocketCompare(void* a, void* b) { Clients* client = (Clients*)a; /*printf("comparing %d with %d\n", (char*)a, (char*)b); */ - return client->net.socket == *(int*)b; + return client->net.socket == *(SOCKET*)b; } diff --git a/src/Clients.h b/src/Clients.h index ae4306e1..de2f5215 100644 --- a/src/Clients.h +++ b/src/Clients.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2021 IBM Corp. and Ian Craggs + * Copyright (c) 2009, 2022 IBM Corp. and Ian Craggs * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -24,15 +24,16 @@ #include #include "MQTTTime.h" -#if defined(OPENSSL) #if defined(_WIN32) || defined(_WIN64) #include #endif +#if defined(OPENSSL) #include #endif #include "MQTTClient.h" #include "LinkedList.h" #include "MQTTClientPersistence.h" +#include "Socket.h" /** * Stored publication data to minimize copying @@ -77,7 +78,7 @@ typedef struct typedef struct { - int socket; + SOCKET socket; START_TIME_TYPE lastSent; START_TIME_TYPE lastReceived; START_TIME_TYPE lastPing; diff --git a/src/Keysight_ws_add b/src/Keysight_ws_add new file mode 100644 index 00000000..f6103546 --- /dev/null +++ b/src/Keysight_ws_add @@ -0,0 +1,22 @@ +#if WINVER <= _WIN32_WINNT_WIN8 +#define HTON(x) hton((uint64_t) (x), sizeof(x)) +uint64_t hton(uint64_t x, size_t n) +{ + uint64_t y = 0; + size_t i = 0; + + for (i=0; i < n; ++i) + { + y = (y << 8) | (x & 0xff); + x = (x >> 8); + } + return y; +} +#define htons(x) (uint16_t) HTON(x) +#define htonl(x) (uint32_t) HTON(x) +#define htonll(x) (uint64_t) HTON(x) + +#define ntohs(x) htons(x) +#define ntohl(x) htonl(x) +#define ntohll(x) htonll(x) +#endif diff --git a/src/MQTTAsync.c b/src/MQTTAsync.c index d6fb2181..0437e28e 100644 --- a/src/MQTTAsync.c +++ b/src/MQTTAsync.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2021 IBM Corp., Ian Craggs and others + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -486,7 +486,7 @@ void MQTTAsync_destroy(MQTTAsync* handle) if (m->c) { - int saved_socket = m->c->net.socket; + SOCKET saved_socket = m->c->net.socket; char* saved_clientid = MQTTStrdup(m->c->clientID); #if !defined(NO_PERSISTENCE) MQTTPersistence_close(m->c); diff --git a/src/MQTTAsyncUtils.c b/src/MQTTAsyncUtils.c index 1a37b561..3c507651 100644 --- a/src/MQTTAsyncUtils.c +++ b/src/MQTTAsyncUtils.c @@ -60,7 +60,7 @@ static int MQTTAsync_deliverMessage(MQTTAsyncs* m, char* topicName, size_t topic static int MQTTAsync_disconnect_internal(MQTTAsync handle, int timeout); static int cmdMessageIDCompare(void* a, void* b); static void MQTTAsync_retry(void); -static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc); +static MQTTPacket* MQTTAsync_cycle(SOCKET* sock, unsigned long timeout, int* rc); static int MQTTAsync_connecting(MQTTAsyncs* m); extern MQTTProtocol state; /* defined in MQTTAsync.c */ @@ -889,11 +889,12 @@ int MQTTAsync_addCommand(MQTTAsync_queuedCommand* command, int command_size) { ListDetach(MQTTAsync_commands, first_publish); - MQTTAsync_freeCommand(first_publish); #if !defined(NO_PERSISTENCE) if (command->client->c->persistence) MQTTAsync_unpersistCommand(first_publish); #endif + + MQTTAsync_freeCommand(first_publish); } } else @@ -976,7 +977,7 @@ void MQTTAsync_checkDisconnect(MQTTAsync handle, MQTTAsync_command* command) /** * Call Socket_noPendingWrites(int socket) with protection by socket_mutex, see https://github.com/eclipse/paho.mqtt.c/issues/385 */ -static int MQTTAsync_Socket_noPendingWrites(int socket) +static int MQTTAsync_Socket_noPendingWrites(SOCKET socket) { int rc; MQTTAsync_lock_mutex(socket_mutex); @@ -1059,7 +1060,7 @@ static void MQTTAsync_freeCommand(MQTTAsync_queuedCommand *command) } -void MQTTAsync_writeComplete(int socket, int rc) +void MQTTAsync_writeComplete(SOCKET socket, int rc) { ListElement* found = NULL; @@ -1976,7 +1977,7 @@ thread_return_type WINAPI MQTTAsync_receiveThread(void* n) while (!MQTTAsync_tostop) { int rc = SOCKET_ERROR; - int sock = -1; + SOCKET sock = -1; MQTTAsyncs* m = NULL; MQTTPacket* pack = NULL; @@ -2879,19 +2880,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) } -static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc) +static MQTTPacket* MQTTAsync_cycle(SOCKET* sock, unsigned long timeout, int* rc) { - struct timeval tp = {0L, 0L}; MQTTPacket* pack = NULL; + int rc1 = 0; FUNC_ENTRY; - if (timeout > 0L) - { - tp.tv_sec = timeout / 1000; - tp.tv_usec = (timeout % 1000) * 1000; /* this field is microseconds! */ - } - - int rc1 = 0; #if defined(OPENSSL) if ((*sock = SSLSocket_getPendingRead()) == -1) { @@ -2899,12 +2893,12 @@ static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc) int should_stop = 0; /* 0 from getReadySocket indicates no work to do, rc -1 == error */ - *sock = Socket_getReadySocket(0, &tp, socket_mutex, &rc1); + *sock = Socket_getReadySocket(0, (int)timeout, socket_mutex, &rc1); *rc = rc1; MQTTAsync_lock_mutex(mqttasync_mutex); should_stop = MQTTAsync_tostop; MQTTAsync_unlock_mutex(mqttasync_mutex); - if (!should_stop && *sock == 0 && (tp.tv_sec > 0L || tp.tv_usec > 0L)) + if (!should_stop && *sock == 0 && (timeout > 0L)) MQTTAsync_sleep(100L); #if defined(OPENSSL) } diff --git a/src/MQTTAsyncUtils.h b/src/MQTTAsyncUtils.h index 42a3145b..f65980e3 100644 --- a/src/MQTTAsyncUtils.h +++ b/src/MQTTAsyncUtils.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. and others + * Copyright (c) 2009, 2022 IBM Corp. and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -169,7 +169,7 @@ void MQTTAsync_closeSession(Clients* client, enum MQTTReasonCodes reasonCode, MQ int MQTTAsync_disconnect1(MQTTAsync handle, const MQTTAsync_disconnectOptions* options, int internal); int MQTTAsync_assignMsgId(MQTTAsyncs* m); int MQTTAsync_getNoBufferedMessages(MQTTAsyncs* m); -void MQTTAsync_writeComplete(int socket, int rc); +void MQTTAsync_writeComplete(SOCKET socket, int rc); void setRetryLoopInterval(int keepalive); #if defined(_WIN32) || defined(_WIN64) diff --git a/src/MQTTClient.c b/src/MQTTClient.c index 8202c467..7cba42b9 100644 --- a/src/MQTTClient.c +++ b/src/MQTTClient.c @@ -361,11 +361,11 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO static int MQTTClient_disconnect1(MQTTClient handle, int timeout, int internal, int stop, enum MQTTReasonCodes, MQTTProperties*); static int MQTTClient_disconnect_internal(MQTTClient handle, int timeout); static void MQTTClient_retry(void); -static MQTTPacket* MQTTClient_cycle(int* sock, ELAPSED_TIME_TYPE timeout, int* rc); +static MQTTPacket* MQTTClient_cycle(SOCKET* sock, ELAPSED_TIME_TYPE timeout, int* rc); static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* rc, int64_t timeout); /*static int pubCompare(void* a, void* b); */ static void MQTTProtocol_checkPendingWrites(void); -static void MQTTClient_writeComplete(int socket, int rc); +static void MQTTClient_writeComplete(SOCKET socket, int rc); int MQTTClient_createWithOptions(MQTTClient* handle, const char* serverURI, const char* clientId, @@ -576,7 +576,7 @@ void MQTTClient_destroy(MQTTClient* handle) if (m->c) { - int saved_socket = m->c->net.socket; + SOCKET saved_socket = m->c->net.socket; char* saved_clientid = MQTTStrdup(m->c->clientID); #if !defined(NO_PERSISTENCE) MQTTPersistence_close(m->c); @@ -807,7 +807,7 @@ static thread_return_type WINAPI MQTTClient_run(void* n) while (!tostop) { int rc = SOCKET_ERROR; - int sock = -1; + SOCKET sock = -1; MQTTClients* m = NULL; MQTTPacket* pack = NULL; @@ -2484,27 +2484,24 @@ static void MQTTClient_retry(void) } -static MQTTPacket* MQTTClient_cycle(int* sock, ELAPSED_TIME_TYPE timeout, int* rc) +static MQTTPacket* MQTTClient_cycle(SOCKET* sock, ELAPSED_TIME_TYPE timeout, int* rc) { - struct timeval tp = {0L, 0L}; static Ack ack; MQTTPacket* pack = NULL; + int rc1 = 0; + START_TIME_TYPE start; FUNC_ENTRY; - if (timeout > 0L) - { - tp.tv_sec = (long)(timeout / 1000); - tp.tv_usec = (long)((timeout % 1000) * 1000); /* this field is microseconds! */ - } - - int rc1 = 0; #if defined(OPENSSL) if ((*sock = SSLSocket_getPendingRead()) == -1) { /* 0 from getReadySocket indicates no work to do, rc -1 == error */ #endif - *sock = Socket_getReadySocket(0, &tp, socket_mutex, rc); + start = MQTTTime_start_clock(); + *sock = Socket_getReadySocket(0, (int)timeout, socket_mutex, rc); *rc = rc1; + if (*sock == 0 && timeout >= 100L && MQTTTime_elapsed(start) < (int64_t)10) + MQTTTime_sleep(100L); #if defined(OPENSSL) } #endif @@ -2618,7 +2615,7 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r *rc = TCPSOCKET_COMPLETE; while (1) { - int sock = -1; + SOCKET sock = -1; pack = MQTTClient_cycle(&sock, 100L, rc); if (sock == m->c->net.socket) { @@ -2678,7 +2675,7 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r } } } - if (MQTTTime_elapsed(start) > (int64_t)timeout) + if (MQTTTime_elapsed(start) > (uint64_t)timeout) { pack = NULL; break; @@ -2723,7 +2720,7 @@ int MQTTClient_receive(MQTTClient handle, char** topicName, int* topicLen, MQTTC elapsed = MQTTTime_elapsed(start); do { - int sock = 0; + SOCKET sock = 0; MQTTClient_cycle(&sock, (timeout > elapsed) ? timeout - elapsed : 0L, &rc); if (rc == SOCKET_ERROR) @@ -2765,7 +2762,7 @@ void MQTTClient_yield(void) elapsed = MQTTTime_elapsed(start); do { - int sock = -1; + SOCKET sock = -1; MQTTClient_cycle(&sock, (timeout > elapsed) ? timeout - elapsed : 0L, &rc); Thread_lock_mutex(mqttclient_mutex); if (rc == SOCKET_ERROR && ListFindItem(handles, &sock, clientSockCompare)) @@ -3012,7 +3009,7 @@ static void MQTTProtocol_checkPendingWrites(void) } -static void MQTTClient_writeComplete(int socket, int rc) +static void MQTTClient_writeComplete(SOCKET socket, int rc) { ListElement* found = NULL; diff --git a/src/MQTTPersistence.c b/src/MQTTPersistence.c index c5634ded..7c156245 100644 --- a/src/MQTTPersistence.c +++ b/src/MQTTPersistence.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -431,7 +431,7 @@ void MQTTPersistence_insertInOrder(List* list, void* content, size_t size) * @param the MQTT version being used (>= MQTTVERSION_5 means properties included) * @return 0 if success, #MQTTCLIENT_PERSISTENCE_ERROR otherwise. */ -int MQTTPersistence_putPacket(int socket, char* buf0, size_t buf0len, int count, +int MQTTPersistence_putPacket(SOCKET socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int htype, int msgId, int scr, int MQTTVersion) { int rc = 0; diff --git a/src/MQTTPersistence.h b/src/MQTTPersistence.h index e4811161..3c62a79c 100644 --- a/src/MQTTPersistence.h +++ b/src/MQTTPersistence.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -63,7 +63,7 @@ int MQTTPersistence_clear(Clients* c); int MQTTPersistence_restorePackets(Clients* c); void* MQTTPersistence_restorePacket(int MQTTVersion, char* buffer, size_t buflen); void MQTTPersistence_insertInOrder(List* list, void* content, size_t size); -int MQTTPersistence_putPacket(int socket, char* buf0, size_t buf0len, int count, +int MQTTPersistence_putPacket(SOCKET socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int htype, int msgId, int scr, int MQTTVersion); int MQTTPersistence_remove(Clients* c, char* type, int qos, int msgId); void MQTTPersistence_wrapMsgID(Clients *c); diff --git a/src/MQTTProtocol.h b/src/MQTTProtocol.h index 52bcd159..3bb816d8 100644 --- a/src/MQTTProtocol.h +++ b/src/MQTTProtocol.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2014 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -27,7 +27,7 @@ typedef struct { - int socket; + SOCKET socket; Publications* p; } pending_write; diff --git a/src/MQTTProtocolClient.c b/src/MQTTProtocolClient.c index b467d391..3e834ecb 100644 --- a/src/MQTTProtocolClient.c +++ b/src/MQTTProtocolClient.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2021 IBM Corp. and Ian Craggs + * Copyright (c) 2009, 2022 IBM Corp. and Ian Craggs * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -314,7 +314,7 @@ void MQTTProtocol_removePublication(Publications* p) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePublishes(void* pack, int sock) +int MQTTProtocol_handlePublishes(void* pack, SOCKET sock) { Publish* publish = (Publish*)pack; Clients* client = NULL; @@ -431,7 +431,7 @@ int MQTTProtocol_handlePublishes(void* pack, int sock) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePubacks(void* pack, int sock) +int MQTTProtocol_handlePubacks(void* pack, SOCKET sock) { Puback* puback = (Puback*)pack; Clients* client = NULL; @@ -477,7 +477,7 @@ int MQTTProtocol_handlePubacks(void* pack, int sock) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePubrecs(void* pack, int sock) +int MQTTProtocol_handlePubrecs(void* pack, SOCKET sock) { Pubrec* pubrec = (Pubrec*)pack; Clients* client = NULL; @@ -546,7 +546,7 @@ int MQTTProtocol_handlePubrecs(void* pack, int sock) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePubrels(void* pack, int sock) +int MQTTProtocol_handlePubrels(void* pack, SOCKET sock) { Pubrel* pubrel = (Pubrel*)pack; Clients* client = NULL; @@ -629,7 +629,7 @@ int MQTTProtocol_handlePubrels(void* pack, int sock) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePubcomps(void* pack, int sock) +int MQTTProtocol_handlePubcomps(void* pack, SOCKET sock) { Pubcomp* pubcomp = (Pubcomp*)pack; Clients* client = NULL; @@ -980,7 +980,7 @@ void MQTTProtocol_freeMessageList(List* msgList) * occur here are ignored. * @param socket the socket that is available for writing */ -void MQTTProtocol_writeAvailable(int socket) +void MQTTProtocol_writeAvailable(SOCKET socket) { Clients* client = NULL; ListElement* current = NULL; diff --git a/src/MQTTProtocolClient.h b/src/MQTTProtocolClient.h index 093711df..b9b22a4e 100644 --- a/src/MQTTProtocolClient.h +++ b/src/MQTTProtocolClient.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -39,11 +39,11 @@ int MQTTProtocol_assignMsgId(Clients* client); void MQTTProtocol_removePublication(Publications* p); void Protocol_processPublication(Publish* publish, Clients* client, int allocatePayload); -int MQTTProtocol_handlePublishes(void* pack, int sock); -int MQTTProtocol_handlePubacks(void* pack, int sock); -int MQTTProtocol_handlePubrecs(void* pack, int sock); -int MQTTProtocol_handlePubrels(void* pack, int sock); -int MQTTProtocol_handlePubcomps(void* pack, int sock); +int MQTTProtocol_handlePublishes(void* pack, SOCKET sock); +int MQTTProtocol_handlePubacks(void* pack, SOCKET sock); +int MQTTProtocol_handlePubrecs(void* pack, SOCKET sock); +int MQTTProtocol_handlePubrels(void* pack, SOCKET sock); +int MQTTProtocol_handlePubcomps(void* pack, SOCKET sock); void MQTTProtocol_closeSession(Clients* c, int sendwill); void MQTTProtocol_keepalive(START_TIME_TYPE); @@ -55,7 +55,7 @@ void MQTTProtocol_freeMessageList(List* msgList); char* MQTTStrncpy(char *dest, const char* src, size_t num); char* MQTTStrdup(const char* src); -void MQTTProtocol_writeAvailable(int socket); +void MQTTProtocol_writeAvailable(SOCKET socket); //#define MQTTStrdup(src) MQTTStrncpy(malloc(strlen(src)+1), src, strlen(src)+1) diff --git a/src/MQTTProtocolOut.c b/src/MQTTProtocolOut.c index 1d6aab83..c588819e 100644 --- a/src/MQTTProtocolOut.c +++ b/src/MQTTProtocolOut.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -333,7 +333,10 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int websocket } if ( websocket ) { - rc = WebSocket_connect( &aClient->net, 0, ip_address ); +#if defined(OPENSSL) + rc = WebSocket_connect(&aClient->net, ssl, ip_address); +#endif + rc = WebSocket_connect(&aClient->net, 0, ip_address); if ( rc == TCPSOCKET_INTERRUPTED ) aClient->connect_state = WEBSOCKET_IN_PROGRESS; /* Websocket connect called - wait for completion */ } @@ -359,7 +362,7 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int websocket * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handlePingresps(void* pack, int sock) +int MQTTProtocol_handlePingresps(void* pack, SOCKET sock) { Clients* client = NULL; int rc = TCPSOCKET_COMPLETE; @@ -400,7 +403,7 @@ int MQTTProtocol_subscribe(Clients* client, List* topics, List* qoss, int msgID, * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handleSubacks(void* pack, int sock) +int MQTTProtocol_handleSubacks(void* pack, SOCKET sock) { Suback* suback = (Suback*)pack; Clients* client = NULL; @@ -438,7 +441,7 @@ int MQTTProtocol_unsubscribe(Clients* client, List* topics, int msgID, MQTTPrope * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handleUnsubacks(void* pack, int sock) +int MQTTProtocol_handleUnsubacks(void* pack, SOCKET sock) { Unsuback* unsuback = (Unsuback*)pack; Clients* client = NULL; @@ -459,7 +462,7 @@ int MQTTProtocol_handleUnsubacks(void* pack, int sock) * @param sock the socket on which the packet was received * @return completion code */ -int MQTTProtocol_handleDisconnects(void* pack, int sock) +int MQTTProtocol_handleDisconnects(void* pack, SOCKET sock) { Ack* disconnect = (Ack*)pack; Clients* client = NULL; diff --git a/src/MQTTProtocolOut.h b/src/MQTTProtocolOut.h index 784a5094..adc95abf 100644 --- a/src/MQTTProtocolOut.h +++ b/src/MQTTProtocolOut.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs, and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -55,12 +55,12 @@ int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int websock MQTTProperties* connectProperties, MQTTProperties* willProperties); #endif #endif -int MQTTProtocol_handlePingresps(void* pack, int sock); +int MQTTProtocol_handlePingresps(void* pack, SOCKET sock); int MQTTProtocol_subscribe(Clients* client, List* topics, List* qoss, int msgID, MQTTSubscribe_options* opts, MQTTProperties* props); -int MQTTProtocol_handleSubacks(void* pack, int sock); +int MQTTProtocol_handleSubacks(void* pack, SOCKET sock); int MQTTProtocol_unsubscribe(Clients* client, List* topics, int msgID, MQTTProperties* props); -int MQTTProtocol_handleUnsubacks(void* pack, int sock); -int MQTTProtocol_handleDisconnects(void* pack, int sock); +int MQTTProtocol_handleUnsubacks(void* pack, SOCKET sock); +int MQTTProtocol_handleDisconnects(void* pack, SOCKET sock); #endif diff --git a/src/SSLSocket.c b/src/SSLSocket.c index a303d370..ade31e11 100644 --- a/src/SSLSocket.c +++ b/src/SSLSocket.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -46,7 +46,7 @@ extern Sockets mod_s; -static int SSLSocket_error(char* aString, SSL* ssl, int sock, int rc, int (*cb)(const char *str, size_t len, void *u), void* u); +static int SSLSocket_error(char* aString, SSL* ssl, SOCKET sock, int rc, int (*cb)(const char *str, size_t len, void *u), void* u); char* SSL_get_verify_result_string(int rc); void SSL_CTX_info_callback(const SSL* ssl, int where, int ret); char* SSLSocket_get_version_string(int version); @@ -69,7 +69,7 @@ extern unsigned long SSLThread_id(void); extern void SSLLocks_callback(int mode, int n, const char *file, int line); int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts); void SSLSocket_destroyContext(networkHandles* net); -void SSLSocket_addPendingRead(int sock); +void SSLSocket_addPendingRead(SOCKET sock); /* 1 ~ we are responsible for initializing openssl; 0 ~ openssl init is done externally */ static int handle_openssl_init = 1; @@ -94,7 +94,7 @@ static int tls_ex_index_ssl_opts; * @param u context to be passed as second argument to ERR_print_errors_cb * @return the specific TCP error code */ -static int SSLSocket_error(char* aString, SSL* ssl, int sock, int rc, int (*cb)(const char *str, size_t len, void *u), void* u) +static int SSLSocket_error(char* aString, SSL* ssl, SOCKET sock, int rc, int (*cb)(const char *str, size_t len, void *u), void* u) { int error; @@ -593,6 +593,8 @@ int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts) } } + SSL_CTX_set_security_level(net->ctx, 1); + if (opts->keyStore) { if ((rc = SSL_CTX_use_certificate_chain_file(net->ctx, opts->keyStore)) != 1) @@ -728,7 +730,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, break; Log(TRACE_PROTOCOL, 1, "SSL cipher available: %d:%s", i, cipher); } - if ((rc = SSL_set_fd(net->ssl, net->socket)) != 1) { + if ((rc = (int)SSL_set_fd(net->ssl, (int)net->socket)) != 1) { if (opts->struct_version >= 3) SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context); else @@ -757,7 +759,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, /* * Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure */ -int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u) +int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u) { int rc = 0; @@ -831,7 +833,7 @@ int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify, int * @param c the character read, returned * @return completion code */ -int SSLSocket_getch(SSL* ssl, int socket, char* c) +int SSLSocket_getch(SSL* ssl, SOCKET socket, char* c) { int rc = SOCKET_ERROR; @@ -871,7 +873,7 @@ int SSLSocket_getch(SSL* ssl, int socket, char* c) * @param actual_len the actual number of bytes read * @return completion code */ -char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len, int* rc) +char *SSLSocket_getdata(SSL* ssl, SOCKET socket, size_t bytes, size_t* actual_len, int* rc) { char* buf; @@ -956,7 +958,7 @@ int SSLSocket_close(networkHandles* net) /* No SSL_writev() provided by OpenSSL. Boo. */ -int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketBuffers bufs) +int SSLSocket_putdatas(SSL* ssl, SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs) { int rc = 0; int i; @@ -996,7 +998,7 @@ int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketB if (sslerror == SSL_ERROR_WANT_WRITE) { - int* sockmem = (int*)malloc(sizeof(int)); + SOCKET* sockmem = (SOCKET*)malloc(sizeof(SOCKET)); int free = 1; if (!sockmem) @@ -1010,7 +1012,7 @@ int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketB SocketBuffer_pendingWrite(socket, ssl, 1, &iovec, &free, iovec.iov_len, 0); *sockmem = socket; ListAppend(mod_s.write_pending, sockmem, sizeof(int)); - FD_SET(socket, &(mod_s.pending_wset)); + //FD_SET(socket, &(mod_s.pending_wset)); rc = TCPSOCKET_INTERRUPTED; } else @@ -1039,12 +1041,12 @@ int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketB } -void SSLSocket_addPendingRead(int sock) +void SSLSocket_addPendingRead(SOCKET sock) { FUNC_ENTRY; if (ListFindItem(&pending_reads, &sock, intcompare) == NULL) /* make sure we don't add the same socket twice */ { - int* psock = (int*)malloc(sizeof(sock)); + SOCKET* psock = (SOCKET*)malloc(sizeof(SOCKET)); if (psock) { *psock = sock; @@ -1058,9 +1060,9 @@ void SSLSocket_addPendingRead(int sock) } -int SSLSocket_getPendingRead(void) +SOCKET SSLSocket_getPendingRead(void) { - int sock = -1; + SOCKET sock = -1; if (pending_reads.count > 0) { diff --git a/src/SSLSocket.h b/src/SSLSocket.h index 86273c8e..7234b964 100644 --- a/src/SSLSocket.h +++ b/src/SSLSocket.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -39,14 +39,14 @@ int SSLSocket_initialize(void); void SSLSocket_terminate(void); int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len); -int SSLSocket_getch(SSL* ssl, int socket, char* c); -char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len, int* rc); +int SSLSocket_getch(SSL* ssl, SOCKET socket, char* c); +char *SSLSocket_getdata(SSL* ssl, SOCKET socket, size_t bytes, size_t* actual_len, int* rc); int SSLSocket_close(networkHandles* net); -int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketBuffers bufs); -int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u); +int SSLSocket_putdatas(SSL* ssl, SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs); +int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u); -int SSLSocket_getPendingRead(void); +SOCKET SSLSocket_getPendingRead(void); int SSLSocket_continueWrite(pending_writes* pw); #endif diff --git a/src/Socket.c b/src/Socket.c index 189a8412..dc4254ca 100644 --- a/src/Socket.c +++ b/src/Socket.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2021 IBM Corp., Ian Craggs and others + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -44,16 +44,16 @@ #include "Heap.h" -int Socket_setnonblocking(int sock); -int Socket_error(char* aString, int sock); -int Socket_addSocket(int newSd); -int isReady(int socket, fd_set* read_set, fd_set* write_set); -int Socket_writev(int socket, iobuf* iovecs, int count, unsigned long* bytes); -int Socket_close_only(int socket); -int Socket_continueWrite(int socket); -int Socket_continueWrites(fd_set* pwset, int* socket, mutex_type mutex); -char* Socket_getaddrname(struct sockaddr* sa, int sock); -int Socket_abortWrite(int socket); +int Socket_setnonblocking(SOCKET sock); +int Socket_error(char* aString, SOCKET sock); +int Socket_addSocket(SOCKET newSd); +int isReady(int index); +int Socket_writev(SOCKET socket, iobuf* iovecs, int count, unsigned long* bytes); +int Socket_close_only(SOCKET socket); +int Socket_continueWrite(SOCKET socket); +int Socket_continueWrites(SOCKET* socket, mutex_type mutex); +char* Socket_getaddrname(struct sockaddr* sa, SOCKET sock); +int Socket_abortWrite(SOCKET socket); #if defined(_WIN32) || defined(_WIN64) #define iov_len len @@ -65,14 +65,13 @@ int Socket_abortWrite(int socket); * Structure to hold all socket data for this module */ Sockets mod_s; -static fd_set wset; /** * Set a socket non-blocking, OS independently * @param sock the socket to set non-blocking * @return TCP call error code */ -int Socket_setnonblocking(int sock) +int Socket_setnonblocking(SOCKET sock) { int rc; #if defined(_WIN32) || defined(_WIN64) @@ -99,7 +98,7 @@ int Socket_setnonblocking(int sock) * @param sock the socket on which the error occurred * @return the specific TCP error code */ -int Socket_error(char* aString, int sock) +int Socket_error(char* aString, SOCKET sock) { int err; @@ -134,14 +133,15 @@ void Socket_outInitialize(void) #endif SocketBuffer_initialize(); - mod_s.clientsds = ListInitialize(); mod_s.connect_pending = ListInitialize(); mod_s.write_pending = ListInitialize(); - mod_s.cur_clientsds = NULL; - FD_ZERO(&(mod_s.rset)); /* Initialize the descriptor set */ - FD_ZERO(&(mod_s.pending_wset)); - mod_s.maxfdp1 = 0; - memcpy((void*)&(mod_s.rset_saved), (void*)&(mod_s.rset), sizeof(mod_s.rset_saved)); + + mod_s.nfds = 0; + mod_s.fds = NULL; + + mod_s.saved.cur_fd = -1; + mod_s.saved.fds = NULL; + mod_s.saved.nfds = 0; FUNC_EXIT; } @@ -154,7 +154,11 @@ void Socket_outTerminate(void) FUNC_ENTRY; ListFree(mod_s.connect_pending); ListFree(mod_s.write_pending); - ListFree(mod_s.clientsds); + //ListFree(mod_s.clientsds); + if (mod_s.fds) + free(mod_s.fds); + if (mod_s.saved.fds) + free(mod_s.saved.fds); SocketBuffer_terminate(); #if defined(_WIN32) || defined(_WIN64) WSACleanup(); @@ -163,49 +167,55 @@ void Socket_outTerminate(void) } +static int cmpfds(const void *p1, const void *p2) +{ + SOCKET key1 = ((struct pollfd*)p1)->fd; + SOCKET key2 = ((struct pollfd*)p2)->fd; + + return (key1 == key2) ? 0 : ((key1 < key2) ? -1 : 1); +} + + +static int cmpsockfds(const void *p1, const void *p2) +{ + int key1 = *(int*)p1; + SOCKET key2 = ((struct pollfd*)p2)->fd; + + return (key1 == key2) ? 0 : ((key1 < key2) ? -1 : 1); +} + + /** * Add a socket to the list of socket to check with select * @param newSd the new socket to add */ -int Socket_addSocket(int newSd) +int Socket_addSocket(SOCKET newSd) { int rc = 0; FUNC_ENTRY; - if (ListFindItem(mod_s.clientsds, &newSd, intcompare) == NULL) /* make sure we don't add the same socket twice */ + mod_s.nfds++; + if (mod_s.fds) + mod_s.fds = realloc(mod_s.fds, mod_s.nfds * sizeof(mod_s.fds[0])); + else + mod_s.fds = malloc(mod_s.nfds * sizeof(mod_s.fds[0])); + if (!mod_s.fds) { - if (mod_s.clientsds->count >= FD_SETSIZE) - { - Log(LOG_ERROR, -1, "addSocket: exceeded FD_SETSIZE %d", FD_SETSIZE); - rc = SOCKET_ERROR; - } - else - { - int* pnewSd = (int*)malloc(sizeof(newSd)); - - if (!pnewSd) - { - rc = PAHO_MEMORY_ERROR; - goto exit; - } - *pnewSd = newSd; - if (!ListAppend(mod_s.clientsds, pnewSd, sizeof(newSd))) - { - free(pnewSd); - rc = PAHO_MEMORY_ERROR; - goto exit; - } - FD_SET(newSd, &(mod_s.rset_saved)); - mod_s.maxfdp1 = max(mod_s.maxfdp1, newSd + 1); - rc = Socket_setnonblocking(newSd); - if (rc == SOCKET_ERROR) - Log(LOG_ERROR, -1, "addSocket: setnonblocking"); - } + rc = PAHO_MEMORY_ERROR; + goto exit; } - else - Log(LOG_ERROR, -1, "addSocket: socket %d already in the list", newSd); -exit: + mod_s.fds[mod_s.nfds - 1].fd = newSd; + mod_s.fds[mod_s.nfds - 1].events = POLLIN | POLLOUT/* | POLLNVAL*/; + + /* sort the poll fds array by socket number */ + qsort(mod_s.fds, (size_t)mod_s.nfds, sizeof(mod_s.fds[0]), cmpfds); + + rc = Socket_setnonblocking(newSd); + if (rc == SOCKET_ERROR) + Log(LOG_ERROR, -1, "addSocket: setnonblocking"); + + exit: FUNC_EXIT_RC(rc); return rc; } @@ -214,20 +224,25 @@ int Socket_addSocket(int newSd) /** * Don't accept work from a client unless it is accepting work back, i.e. its socket is writeable * this seems like a reasonable form of flow control, and practically, seems to work. - * @param socket the socket to check - * @param read_set the socket read set (see select doc) - * @param write_set the socket write set (see select doc) + * @param index the socket index to check * @return boolean - is the socket ready to go? */ -int isReady(int socket, fd_set* read_set, fd_set* write_set) +int isReady(int index) { int rc = 1; + SOCKET* socket = &mod_s.fds[index].fd; FUNC_ENTRY; - if (ListFindItem(mod_s.connect_pending, &socket, intcompare) && FD_ISSET(socket, write_set)) - ListRemoveItem(mod_s.connect_pending, &socket, intcompare); + + if (mod_s.saved.fds[index].revents & POLLHUP || mod_s.saved.fds[index].revents & POLLNVAL) + ; /* signal work to be done if there is an error on the socket */ + else if (ListFindItem(mod_s.connect_pending, socket, intcompare) && + (mod_s.saved.fds[index].revents & POLLOUT)) + ListRemoveItem(mod_s.connect_pending, socket, intcompare); else - rc = FD_ISSET(socket, read_set) && FD_ISSET(socket, write_set) && Socket_noPendingWrites(socket); + rc = (mod_s.saved.fds[index].revents & POLLIN) && + (mod_s.saved.fds[index].revents & POLLOUT) && + Socket_noPendingWrites(*socket); FUNC_EXIT_RC(rc); return rc; } @@ -237,92 +252,84 @@ int isReady(int socket, fd_set* read_set, fd_set* write_set) * Returns the next socket ready for communications as indicated by select * @param more_work flag to indicate more work is waiting, and thus a timeout value of 0 should * be used for the select - * @param tp the timeout to be used for the select, unless overridden + * @param timeout the timeout to be used in ms * @param rc a value other than 0 indicates an error of the returned socket * @return the socket next ready, or 0 if none is ready */ -int Socket_getReadySocket(int more_work, struct timeval *tp, mutex_type mutex, int* rc) +SOCKET Socket_getReadySocket(int more_work, int timeout, mutex_type mutex, int* rc) { - int sock = 0; + SOCKET sock = 0; *rc = 0; - static struct timeval zero = {0L, 0L}; /* 0 seconds */ - static struct timeval one = {1L, 0L}; /* 1 second */ - struct timeval timeout = one; + int timeout_ms = 1000; FUNC_ENTRY; Thread_lock_mutex(mutex); - if (mod_s.clientsds->count == 0) + if (mod_s.nfds == 0) goto exit; if (more_work) - timeout = zero; - else if (tp) - timeout = *tp; + timeout_ms = 0; + else if (timeout >= 0) + timeout_ms = timeout; - while (mod_s.cur_clientsds != NULL) + while (mod_s.saved.cur_fd != -1) { - if (isReady(*((int*)(mod_s.cur_clientsds->content)), &(mod_s.rset), &wset)) + if (isReady(mod_s.saved.cur_fd)) break; - ListNextElement(mod_s.clientsds, &mod_s.cur_clientsds); + mod_s.saved.cur_fd = (mod_s.saved.cur_fd == mod_s.saved.nfds - 1) ? -1 : mod_s.saved.cur_fd + 1; } - if (mod_s.cur_clientsds == NULL) + if (mod_s.saved.cur_fd == -1) { - int rc1, maxfdp1_saved; - fd_set pwset; + if (mod_s.nfds != mod_s.saved.nfds) + { + mod_s.saved.nfds = mod_s.nfds; + if (mod_s.saved.fds) + mod_s.saved.fds = realloc(mod_s.saved.fds, mod_s.nfds * sizeof(struct pollfd)); + else + mod_s.saved.fds = malloc(mod_s.nfds * sizeof(struct pollfd)); + } + memcpy(mod_s.saved.fds, mod_s.fds, mod_s.nfds * sizeof(struct pollfd)); - memcpy((void*)&(mod_s.rset), (void*)&(mod_s.rset_saved), sizeof(mod_s.rset)); - memcpy((void*)&(pwset), (void*)&(mod_s.pending_wset), sizeof(pwset)); /* Prevent performance issue by unlocking the socket_mutex while waiting for a ready socket. */ - maxfdp1_saved = mod_s.maxfdp1; Thread_unlock_mutex(mutex); - *rc = select(maxfdp1_saved, &(mod_s.rset), &pwset, NULL, &timeout); + *rc = poll(mod_s.saved.fds, mod_s.saved.nfds, timeout_ms); Thread_lock_mutex(mutex); if (*rc == SOCKET_ERROR) { - Socket_error("read select", 0); + Socket_error("poll", 0); goto exit; } - Log(TRACE_MAX, -1, "Return code %d from read select", *rc); + Log(TRACE_MAX, -1, "Return code %d from poll", *rc); - if (Socket_continueWrites(&pwset, &sock, mutex) == SOCKET_ERROR) + if (Socket_continueWrites(&sock, mutex) == SOCKET_ERROR) { *rc = SOCKET_ERROR; goto exit; } - memcpy((void*)&wset, (void*)&(mod_s.rset_saved), sizeof(wset)); - if ((rc1 = select(mod_s.maxfdp1, NULL, &(wset), NULL, &zero)) == SOCKET_ERROR) - { - Socket_error("write select", 0); - *rc = rc1; - goto exit; - } - Log(TRACE_MAX, -1, "Return code %d from write select", rc1); - - if (*rc == 0 && rc1 == 0) + if (*rc == 0) { sock = 0; goto exit; /* no work to do */ } - mod_s.cur_clientsds = mod_s.clientsds->first; - while (mod_s.cur_clientsds != NULL) + mod_s.saved.cur_fd = 0; + while (mod_s.saved.cur_fd != -1) { - int cursock = *((int*)(mod_s.cur_clientsds->content)); - if (isReady(cursock, &(mod_s.rset), &wset)) + if (isReady(mod_s.saved.cur_fd)) break; - ListNextElement(mod_s.clientsds, &mod_s.cur_clientsds); + mod_s.saved.cur_fd = (mod_s.saved.cur_fd == mod_s.saved.nfds - 1) ? -1 : mod_s.saved.cur_fd + 1; } } *rc = 0; - if (mod_s.cur_clientsds == NULL) + if (mod_s.saved.cur_fd == -1) sock = 0; else { - sock = *((int*)(mod_s.cur_clientsds->content)); - ListNextElement(mod_s.clientsds, &mod_s.cur_clientsds); + sock = mod_s.saved.fds[mod_s.saved.cur_fd].fd; + mod_s.saved.cur_fd = (mod_s.saved.cur_fd == mod_s.saved.nfds - 1) ? -1 : mod_s.saved.cur_fd + 1; } exit: Thread_unlock_mutex(mutex); @@ -337,7 +344,7 @@ int Socket_getReadySocket(int more_work, struct timeval *tp, mutex_type mutex, i * @param c the character read, returned * @return completion code */ -int Socket_getch(int socket, char* c) +int Socket_getch(SOCKET socket, char* c) { int rc = SOCKET_ERROR; @@ -375,7 +382,7 @@ int Socket_getch(int socket, char* c) * @param actual_len the actual number of bytes read * @return completion code */ -char *Socket_getdata(int socket, size_t bytes, size_t* actual_len, int *rc) +char *Socket_getdata(SOCKET socket, size_t bytes, size_t* actual_len, int *rc) { char* buf; @@ -422,9 +429,9 @@ char *Socket_getdata(int socket, size_t bytes, size_t* actual_len, int *rc) * Indicate whether any data is pending outbound for a socket. * @return boolean - true == data pending. */ -int Socket_noPendingWrites(int socket) +int Socket_noPendingWrites(SOCKET socket) { - int cursock = socket; + SOCKET cursock = socket; return ListFindItem(mod_s.write_pending, &cursock, intcompare) == NULL; } @@ -438,7 +445,7 @@ int Socket_noPendingWrites(int socket) * @param bytes number of bytes actually written returned * @return completion code, especially TCPSOCKET_INTERRUPTED */ -int Socket_writev(int socket, iobuf* iovecs, int count, unsigned long* bytes) +int Socket_writev(SOCKET socket, iobuf* iovecs, int count, unsigned long* bytes) { int rc; @@ -510,7 +517,7 @@ for testing purposes only! * @param buflens an array of corresponding buffer lengths * @return completion code, especially TCPSOCKET_INTERRUPTED */ -int Socket_putdatas(int socket, char* buf0, size_t buf0len, PacketBuffers bufs) +int Socket_putdatas(SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs) { unsigned long bytes = 0L; iobuf iovecs[5]; @@ -545,7 +552,7 @@ int Socket_putdatas(int socket, char* buf0, size_t buf0len, PacketBuffers bufs) rc = TCPSOCKET_COMPLETE; else { - int* sockmem = (int*)malloc(sizeof(int)); + SOCKET* sockmem = (SOCKET*)malloc(sizeof(SOCKET)); if (!sockmem) { @@ -566,7 +573,7 @@ int Socket_putdatas(int socket, char* buf0, size_t buf0len, PacketBuffers bufs) rc = PAHO_MEMORY_ERROR; goto exit; } - FD_SET(socket, &(mod_s.pending_wset)); + //FD_SET(socket, &(mod_s.pending_wset)); rc = TCPSOCKET_INTERRUPTED; } } @@ -582,9 +589,9 @@ int Socket_putdatas(int socket, char* buf0, size_t buf0len, PacketBuffers bufs) * ready to read and write states. * @param socket the socket to add */ -void Socket_addPendingWrite(int socket) +void Socket_addPendingWrite(SOCKET socket) { - FD_SET(socket, &(mod_s.pending_wset)); + //FD_SET(socket, &(mod_s.pending_wset)); } @@ -592,10 +599,10 @@ void Socket_addPendingWrite(int socket) * Clear a socket from the pending write list - if one was added with Socket_addPendingWrite * @param socket the socket to remove */ -void Socket_clearPendingWrite(int socket) +void Socket_clearPendingWrite(SOCKET socket) { - if (FD_ISSET(socket, &(mod_s.pending_wset))) - FD_CLR(socket, &(mod_s.pending_wset)); + /*if (FD_ISSET(socket, &(mod_s.pending_wset))) + FD_CLR(socket, &(mod_s.pending_wset));*/ } @@ -604,7 +611,7 @@ void Socket_clearPendingWrite(int socket) * @param socket the socket to close * @return completion code */ -int Socket_close_only(int socket) +int Socket_close_only(SOCKET socket) { int rc; @@ -632,35 +639,33 @@ int Socket_close_only(int socket) * @param socket the socket to close * @return completion code */ -void Socket_close(int socket) +void Socket_close(SOCKET socket) { + struct pollfd* fd; + struct pollfd* last_fd = &mod_s.fds[mod_s.nfds - 1]; + FUNC_ENTRY; Socket_close_only(socket); - FD_CLR(socket, &(mod_s.rset_saved)); - if (FD_ISSET(socket, &(mod_s.pending_wset))) - FD_CLR(socket, &(mod_s.pending_wset)); - if (mod_s.cur_clientsds != NULL && *(int*)(mod_s.cur_clientsds->content) == socket) - mod_s.cur_clientsds = mod_s.cur_clientsds->next; Socket_abortWrite(socket); SocketBuffer_cleanup(socket); ListRemoveItem(mod_s.connect_pending, &socket, intcompare); ListRemoveItem(mod_s.write_pending, &socket, intcompare); - if (ListRemoveItem(mod_s.clientsds, &socket, intcompare)) + /* find the socket in the fds structure */ + fd = bsearch(&socket, mod_s.fds, (size_t)mod_s.nfds, sizeof(mod_s.fds[0]), cmpsockfds); + if (fd) + { + if (fd != last_fd) + { + /* shift array to remove the socket in question */ + memmove(fd, fd + 1, (mod_s.fds + ((mod_s.nfds - 1) * sizeof(struct pollfd))) - fd); + } + mod_s.nfds--; + mod_s.fds = realloc(mod_s.fds, sizeof(mod_s.fds[0]) * mod_s.nfds); Log(TRACE_MIN, -1, "Removed socket %d", socket); + } else Log(LOG_ERROR, -1, "Failed to remove socket %d", socket); - if (socket + 1 >= mod_s.maxfdp1) - { - /* now we have to reset mod_s.maxfdp1 */ - ListElement* cur_clientsds = NULL; - - mod_s.maxfdp1 = 0; - while (ListNextElement(mod_s.clientsds, &cur_clientsds)) - mod_s.maxfdp1 = max(*((int*)(cur_clientsds->content)), mod_s.maxfdp1); - ++(mod_s.maxfdp1); - Log(TRACE_MAX, -1, "Reset max fdp1 to %d", mod_s.maxfdp1); - } FUNC_EXIT; } @@ -674,9 +679,9 @@ void Socket_close(int socket) * @return completion code */ #if defined(__GNUC__) && defined(__linux__) -int Socket_new(const char* addr, size_t addr_len, int port, int* sock, long timeout) +int Socket_new(const char* addr, size_t addr_len, int port, SOCKET* sock, long timeout) #else -int Socket_new(const char* addr, size_t addr_len, int port, int* sock) +int Socket_new(const char* addr, size_t addr_len, int port, SOCKET* sock) #endif { int type = SOCK_STREAM; @@ -779,7 +784,7 @@ int Socket_new(const char* addr, size_t addr_len, int port, int* sock) Log(LOG_ERROR, -1, "%s is not a valid IP address", addr_mem); else { - *sock = (int)socket(family, type, 0); + *sock = socket(family, type, 0); if (*sock == INVALID_SOCKET) rc = Socket_error("socket", *sock); else @@ -819,7 +824,7 @@ int Socket_new(const char* addr, size_t addr_len, int port, int* sock) rc = Socket_error("connect", *sock); if (rc == EINPROGRESS || rc == EWOULDBLOCK) { - int* pnewSd = (int*)malloc(sizeof(int)); + SOCKET* pnewSd = (SOCKET*)malloc(sizeof(SOCKET)); if (!pnewSd) { @@ -827,7 +832,7 @@ int Socket_new(const char* addr, size_t addr_len, int port, int* sock) goto exit; } *pnewSd = *sock; - if (!ListAppend(mod_s.connect_pending, pnewSd, sizeof(int))) + if (!ListAppend(mod_s.connect_pending, pnewSd, sizeof(SOCKET))) { free(pnewSd); rc = PAHO_MEMORY_ERROR; @@ -874,7 +879,7 @@ void Socket_setWriteAvailableCallback(Socket_writeAvailable* mywriteavailable) * @param socket that socket * @return completion code: 0=incomplete, 1=complete, -1=socket error */ -int Socket_continueWrite(int socket) +int Socket_continueWrite(SOCKET socket) { int rc = 0; pending_writes* pw; @@ -959,7 +964,7 @@ int Socket_continueWrite(int socket) * @param socket that socket * @return completion code: 0=incomplete, 1=complete, -1=socket error */ -int Socket_abortWrite(int socket) +int Socket_abortWrite(SOCKET socket) { int i = -1, rc = 0; pending_writes* pw; @@ -988,12 +993,11 @@ int Socket_abortWrite(int socket) /** - * Continue any outstanding writes for a socket set - * @param pwset the set of sockets + * Continue any outstanding socket writes * @param sock in case of a socket error contains the affected socket * @return completion code, 0 or SOCKET_ERROR */ -int Socket_continueWrites(fd_set* pwset, int* sock, mutex_type mutex) +int Socket_continueWrites(SOCKET* sock, mutex_type mutex) { int rc1 = 0; ListElement* curpending = mod_s.write_pending->first; @@ -1003,12 +1007,15 @@ int Socket_continueWrites(fd_set* pwset, int* sock, mutex_type mutex) { int socket = *(int*)(curpending->content); int rc = 0; + struct pollfd* fd; + + /* find the socket in the fds structure */ + fd = bsearch(&socket, mod_s.saved.fds, (size_t)mod_s.saved.nfds, sizeof(mod_s.saved.fds[0]), cmpsockfds); - if (FD_ISSET(socket, pwset) && ((rc = Socket_continueWrite(socket)) != 0)) + if ((fd->revents & POLLOUT) && ((rc = Socket_continueWrite(socket)) != 0)) { if (!SocketBuffer_writeComplete(socket)) Log(LOG_SEVERE, -1, "Failed to remove pending write from socket buffer list"); - FD_CLR(socket, &(mod_s.pending_wset)); if (!ListRemove(mod_s.write_pending, curpending->content)) { Log(LOG_SEVERE, -1, "Failed to remove pending write from list"); @@ -1029,7 +1036,7 @@ int Socket_continueWrites(fd_set* pwset, int* sock, mutex_type mutex) else ListNextElement(mod_s.write_pending, &curpending); - if(rc == SOCKET_ERROR) + if (rc == SOCKET_ERROR) { *sock = socket; rc1 = SOCKET_ERROR; @@ -1046,7 +1053,7 @@ int Socket_continueWrites(fd_set* pwset, int* sock, mutex_type mutex) * @param sock socket * @return the peer information */ -char* Socket_getaddrname(struct sockaddr* sa, int sock) +char* Socket_getaddrname(struct sockaddr* sa, SOCKET sock) { /** * maximum length of the address string @@ -1084,7 +1091,7 @@ char* Socket_getaddrname(struct sockaddr* sa, int sock) * @param sock the socket to inquire on * @return the peer information */ -char* Socket_getpeer(int sock) +char* Socket_getpeer(SOCKET sock) { struct sockaddr_in6 sa; socklen_t sal = sizeof(sa); diff --git a/src/Socket.h b/src/Socket.h index e9e61e08..390552b2 100644 --- a/src/Socket.h +++ b/src/Socket.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. and others + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -26,6 +26,7 @@ #include #include #define MAXHOSTNAMELEN 256 +#define poll WSAPoll #if !defined(SSLSOCKET_H) #undef EAGAIN #define EAGAIN WSAEWOULDBLOCK @@ -51,6 +52,7 @@ #include #include #include +#include #include #else #include @@ -65,6 +67,7 @@ #include #include #define ULONG size_t +#define SOCKET int #endif #include "mutex_type.h" /* Needed for mutex_type */ @@ -108,41 +111,44 @@ typedef struct */ typedef struct { - fd_set rset, /**< socket read set (see select doc) */ - rset_saved; /**< saved socket read set */ - int maxfdp1; /**< max descriptor used +1 (again see select doc) */ - List* clientsds; /**< list of client socket descriptors */ - ListElement* cur_clientsds; /**< current client socket descriptor (iterator) */ List* connect_pending; /**< list of sockets for which a connect is pending */ List* write_pending; /**< list of sockets for which a write is pending */ - fd_set pending_wset; /**< socket pending write set for select */ + + unsigned int nfds; /**< no of file descriptors for poll */ + struct pollfd* fds; /**< poll read file descriptors */ + + struct { + int cur_fd; /**< index into the fds_saved array */ + unsigned int nfds; /**< number of fds in the fds_saved array */ + struct pollfd* fds; + } saved; } Sockets; void Socket_outInitialize(void); void Socket_outTerminate(void); -int Socket_getReadySocket(int more_work, struct timeval *tp, mutex_type mutex, int* rc); -int Socket_getch(int socket, char* c); -char *Socket_getdata(int socket, size_t bytes, size_t* actual_len, int* rc); -int Socket_putdatas(int socket, char* buf0, size_t buf0len, PacketBuffers bufs); -void Socket_close(int socket); +SOCKET Socket_getReadySocket(int more_work, int timeout, mutex_type mutex, int* rc); +int Socket_getch(SOCKET socket, char* c); +char *Socket_getdata(SOCKET socket, size_t bytes, size_t* actual_len, int* rc); +int Socket_putdatas(SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs); +void Socket_close(SOCKET socket); #if defined(__GNUC__) && defined(__linux__) /* able to use GNU's getaddrinfo_a to make timeouts possible */ -int Socket_new(const char* addr, size_t addr_len, int port, int* socket, long timeout); +int Socket_new(const char* addr, size_t addr_len, int port, SOCKET* socket, long timeout); #else -int Socket_new(const char* addr, size_t addr_len, int port, int* socket); +int Socket_new(const char* addr, size_t addr_len, int port, SOCKET* socket); #endif -int Socket_noPendingWrites(int socket); -char* Socket_getpeer(int sock); +int Socket_noPendingWrites(SOCKET socket); +char* Socket_getpeer(SOCKET sock); -void Socket_addPendingWrite(int socket); -void Socket_clearPendingWrite(int socket); +void Socket_addPendingWrite(SOCKET socket); +void Socket_clearPendingWrite(SOCKET socket); -typedef void Socket_writeComplete(int socket, int rc); +typedef void Socket_writeComplete(SOCKET socket, int rc); void Socket_setWriteCompleteCallback(Socket_writeComplete*); -typedef void Socket_writeAvailable(int socket); +typedef void Socket_writeAvailable(SOCKET socket); void Socket_setWriteAvailableCallback(Socket_writeAvailable*); #endif /* SOCKET_H */ diff --git a/src/SocketBuffer.c b/src/SocketBuffer.c index f6b817af..42eff81f 100644 --- a/src/SocketBuffer.c +++ b/src/SocketBuffer.c @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -148,7 +148,7 @@ void SocketBuffer_terminate(void) * Cleanup any buffers for a specific socket * @param socket the socket to clean up */ -void SocketBuffer_cleanup(int socket) +void SocketBuffer_cleanup(SOCKET socket) { FUNC_ENTRY; SocketBuffer_writeComplete(socket); /* clean up write buffers */ @@ -173,7 +173,7 @@ void SocketBuffer_cleanup(int socket) * @param actual_len the actual length returned * @return the actual data */ -char* SocketBuffer_getQueuedData(int socket, size_t bytes, size_t* actual_len) +char* SocketBuffer_getQueuedData(SOCKET socket, size_t bytes, size_t* actual_len) { socket_queue* queue = NULL; @@ -216,7 +216,7 @@ char* SocketBuffer_getQueuedData(int socket, size_t bytes, size_t* actual_len) * @param c the character returned if any * @return completion code */ -int SocketBuffer_getQueuedChar(int socket, char* c) +int SocketBuffer_getQueuedChar(SOCKET socket, char* c) { int rc = SOCKETBUFFER_INTERRUPTED; @@ -249,7 +249,7 @@ int SocketBuffer_getQueuedChar(int socket, char* c) * @param socket the socket to get queued data for * @param actual_len the actual length of data that was read */ -void SocketBuffer_interrupted(int socket, size_t actual_len) +void SocketBuffer_interrupted(SOCKET socket, size_t actual_len) { socket_queue* queue = NULL; @@ -278,7 +278,7 @@ void SocketBuffer_interrupted(int socket, size_t actual_len) * @param socket the socket for which the operation is now complete * @return pointer to the default queue data */ -char* SocketBuffer_complete(int socket) +char* SocketBuffer_complete(SOCKET socket) { FUNC_ENTRY; if (ListFindItem(queues, &socket, socketcompare)) @@ -300,7 +300,7 @@ char* SocketBuffer_complete(int socket) * @param socket the socket for which to queue char for * @param c the character to queue */ -void SocketBuffer_queueChar(int socket, char c) +void SocketBuffer_queueChar(SOCKET socket, char c) { int error = 0; socket_queue* curq = def_queue; @@ -344,9 +344,9 @@ void SocketBuffer_queueChar(int socket, char c) * @param bytes actual data length that was written */ #if defined(OPENSSL) -int SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes) +int SocketBuffer_pendingWrite(SOCKET socket, SSL* ssl, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes) #else -int SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes) +int SocketBuffer_pendingWrite(SOCKET socket, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes) #endif { int i = 0; @@ -396,7 +396,7 @@ int pending_socketcompare(void* a, void* b) * @param socket the socket to get queued data for * @return pointer to the queued data or NULL */ -pending_writes* SocketBuffer_getWrite(int socket) +pending_writes* SocketBuffer_getWrite(SOCKET socket) { ListElement* le = ListFindItem(&writes, &socket, pending_socketcompare); return (le) ? (pending_writes*)(le->content) : NULL; @@ -408,7 +408,7 @@ pending_writes* SocketBuffer_getWrite(int socket) * @param socket the socket for which the operation is now complete * @return completion code, boolean - was the queue removed? */ -int SocketBuffer_writeComplete(int socket) +int SocketBuffer_writeComplete(SOCKET socket) { return ListRemoveItem(&writes, &socket, pending_socketcompare); } @@ -421,7 +421,7 @@ int SocketBuffer_writeComplete(int socket) * @param payload the payload of the QoS 0 write * @return pointer to the updated queued data structure, or NULL */ -pending_writes* SocketBuffer_updateWrite(int socket, char* topic, char* payload) +pending_writes* SocketBuffer_updateWrite(SOCKET socket, char* topic, char* payload) { pending_writes* pw = NULL; ListElement* le = NULL; diff --git a/src/SocketBuffer.h b/src/SocketBuffer.h index 0fc7d6e7..1b2ab915 100644 --- a/src/SocketBuffer.h +++ b/src/SocketBuffer.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2009, 2020 IBM Corp. + * Copyright (c) 2009, 2022 IBM Corp., Ian Craggs and others * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v2.0 @@ -18,11 +18,7 @@ #if !defined(SOCKETBUFFER_H) #define SOCKETBUFFER_H -#if defined(_WIN32) || defined(_WIN64) -#include -#else -#include -#endif +#include "Socket.h" #if defined(OPENSSL) #include @@ -36,7 +32,7 @@ typedef struct { - int socket; + SOCKET socket; unsigned int index; size_t headerlen; char fixed_header[5]; /**< header plus up to 4 length bytes */ @@ -47,7 +43,8 @@ typedef struct typedef struct { - int socket, count; + SOCKET socket; + int count; size_t total; #if defined(OPENSSL) SSL* ssl; @@ -65,20 +62,20 @@ typedef struct int SocketBuffer_initialize(void); void SocketBuffer_terminate(void); -void SocketBuffer_cleanup(int socket); -char* SocketBuffer_getQueuedData(int socket, size_t bytes, size_t* actual_len); -int SocketBuffer_getQueuedChar(int socket, char* c); -void SocketBuffer_interrupted(int socket, size_t actual_len); -char* SocketBuffer_complete(int socket); -void SocketBuffer_queueChar(int socket, char c); +void SocketBuffer_cleanup(SOCKET socket); +char* SocketBuffer_getQueuedData(SOCKET socket, size_t bytes, size_t* actual_len); +int SocketBuffer_getQueuedChar(SOCKET socket, char* c); +void SocketBuffer_interrupted(SOCKET socket, size_t actual_len); +char* SocketBuffer_complete(SOCKET socket); +void SocketBuffer_queueChar(SOCKET socket, char c); #if defined(OPENSSL) -int SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes); +int SocketBuffer_pendingWrite(SOCKET socket, SSL* ssl, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes); #else -int SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes); +int SocketBuffer_pendingWrite(SOCKET socket, int count, iobuf* iovecs, int* frees, size_t total, size_t bytes); #endif -pending_writes* SocketBuffer_getWrite(int socket); -int SocketBuffer_writeComplete(int socket); -pending_writes* SocketBuffer_updateWrite(int socket, char* topic, char* payload); +pending_writes* SocketBuffer_getWrite(SOCKET socket); +int SocketBuffer_writeComplete(SOCKET socket); +pending_writes* SocketBuffer_updateWrite(SOCKET socket, char* topic, char* payload); #endif diff --git a/test/test1.c b/test/test1.c index 4aa07c6e..8410f745 100644 --- a/test/test1.c +++ b/test/test1.c @@ -805,7 +805,14 @@ int test4_run(int qos) } } - MQTTClient_yield(); /* allow any unfinished protocol exchanges to finish */ + /* call yield a few times until unfinished protocol exchanges are finished */ + count = 0; + do + { + MQTTClient_yield(); + rc = MQTTClient_getPendingDeliveryTokens(c, &tokens); + assert("getPendingDeliveryTokens rc == 0", rc == MQTTCLIENT_SUCCESS, "rc was %d", rc); + } while (tokens != NULL && ++count < 10); rc = MQTTClient_getPendingDeliveryTokens(c, &tokens); assert("getPendingDeliveryTokens rc == 0", rc == MQTTCLIENT_SUCCESS, "rc was %d", rc);