Pārlūkot izejas kodu

Add Support TLS Mode
1. modbus_new_tls
2. modbus_tls_listen
3. modbus_tls_accept

Derek Tsai 10 mēneši atpakaļ
vecāks
revīzija
aee7f1d2fd
4 mainītis faili ar 564 papildinājumiem un 0 dzēšanām
  1. 12 0
      configure.ac
  2. 14 0
      src/modbus-tcp-private.h
  3. 534 0
      src/modbus-tcp.c
  4. 4 0
      src/modbus-tcp.h

+ 12 - 0
configure.ac

@@ -153,6 +153,17 @@ AC_ARG_ENABLE(tests,
 	[enable_tests=yes])
 AM_CONDITIONAL(BUILD_TESTS, [test $enable_tests != no])
 
+AC_ARG_ENABLE(tls,
+	AS_HELP_STRING([--enable-tls],
+	[Build with SSL/TLS support (default: no)]),
+	[enable_tls=yes],[enable_tls=no])
+AM_CONDITIONAL(BUILD_TLS, [test $enable_tls != no])
+
+AS_IF([test $enable_tls != no],
+	[AC_DEFINE([USE_TLS], [1], [Define if SSL/TLS feature is enabled])]
+	[AC_CHECK_LIB([ssl], [SSL_new])]
+	[AC_CHECK_LIB([crypto], [CRYPTO_new_ex_data])],)
+  
 AC_CONFIG_HEADERS([config.h tests/unit-test.h])
 AC_CONFIG_FILES([
         Makefile
@@ -177,5 +188,6 @@ AC_MSG_RESULT([
         cflags:                 ${CFLAGS}
         ldflags:                ${LDFLAGS}
 
+        tls:                    ${enable_tls}
         tests:                  ${enable_tests}
 ])

+ 14 - 0
src/modbus-tcp-private.h

@@ -38,4 +38,18 @@ typedef struct _modbus_tcp_pi {
     char *service;
 } modbus_tcp_pi_t;
 
+#if defined(USE_TLS)
+typedef struct _modbus_tls {
+    /* Transaction ID */
+    uint16_t t_id;
+    /* TCP port */
+    int port;
+    /* IP address */
+    char ip[16];
+    /* TLS context and connection */
+    SSL_CTX *ctx;
+    SSL *ssl;
+} modbus_tls_t;
+#endif
+
 #endif /* MODBUS_TCP_PRIVATE_H */

+ 534 - 0
src/modbus-tcp.c

@@ -56,6 +56,12 @@
 #endif
 // clang-format on
 
+#if defined(USE_TLS)
+#include <openssl/crypto.h>
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+#endif
+
 #include "modbus-private.h"
 
 #include "modbus-tcp-private.h"
@@ -176,6 +182,52 @@ static ssize_t _modbus_tcp_send(modbus_t *ctx, const uint8_t *req, int req_lengt
     return send(ctx->s, (const char *) req, req_length, MSG_NOSIGNAL);
 }
 
+#if defined(USE_TLS)
+static ssize_t _modbus_tls_send(modbus_t *ctx, const uint8_t *req, int req_length)
+{
+    int ret;
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+    fd_set fds;
+
+    do {
+        ret = SSL_write(ctx_tls->ssl, req, req_length);
+
+        if(ret < 0) {
+            ret = SSL_get_error(ctx_tls->ssl, ret);
+
+            FD_ZERO(&fds);
+            FD_SET(ctx->s, &fds);
+
+            switch(ret) {
+                case SSL_ERROR_WANT_READ:
+                    ret = select(ctx->s+1, &fds, NULL, NULL, NULL);    /*TODO: send timeout*/
+                    if(ret <= 0) {
+                        /*TODO: handle timeout*/
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                case SSL_ERROR_WANT_WRITE:
+                    select(ctx->s+1, NULL, &fds, NULL, NULL);    /*TODO send timeout*/
+                    if(ret <= 0) {
+                        /*TODO: handle timeout*/
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                default:
+                    return -1;
+            }
+        }
+    } while(ret < 0);
+
+    return ret;
+}
+#endif
+
+
 static int _modbus_tcp_receive(modbus_t *ctx, uint8_t *req)
 {
     return _modbus_receive_msg(ctx, req, MSG_INDICATION);
@@ -186,6 +238,52 @@ static ssize_t _modbus_tcp_recv(modbus_t *ctx, uint8_t *rsp, int rsp_length)
     return recv(ctx->s, (char *) rsp, rsp_length, 0);
 }
 
+#if defined(USE_TLS)
+static ssize_t _modbus_tls_recv(modbus_t *ctx, uint8_t *rsp, int rsp_length)
+{
+    int ret;
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+    fd_set fds;
+
+    do {
+        ret = SSL_read(ctx_tls->ssl, rsp, rsp_length);
+
+        if(ret < 0) {
+            ret = SSL_get_error(ctx_tls->ssl, ret);
+
+            FD_ZERO(&fds);
+            FD_SET(ctx->s, &fds);
+
+            switch(ret) {
+                case SSL_ERROR_WANT_READ:
+                    ret = select(ctx->s+1, &fds, NULL, NULL, NULL);    /*TODO: receive timeout*/
+                    if(ret <= 0) {
+                        /*TODO: handle timeout*/
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                case SSL_ERROR_WANT_WRITE:
+                    select(ctx->s+1, NULL, &fds, NULL, NULL);    /*TODO receive timeout*/
+                    if(ret <= 0) {
+                        /*TODO: handle timeout*/
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                default:
+                    return -1;
+            }
+        }
+    } while(ret < 0);
+
+    return ret;
+}
+#endif
+
+
 static int _modbus_tcp_check_integrity(modbus_t *ctx, uint8_t *msg, const int msg_length)
 {
     return msg_length;
@@ -449,6 +547,110 @@ static int _modbus_tcp_pi_connect(modbus_t *ctx)
     return 0;
 }
 
+#if defined(USE_TLS)
+static int _modbus_tls_connect(modbus_t *ctx)
+{
+    int ret;
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+    X509 *cert;
+    fd_set fds, errfds;
+    struct timeval timeout;
+
+    ret = _modbus_tcp_connect(ctx);
+    if(ret < 0) {
+        return ret;
+    }
+
+    ctx_tls->ssl = SSL_new(ctx_tls->ctx);
+    if(ctx_tls->ssl == NULL) {
+        close(ctx->s);
+        ctx->s = -1;
+
+        if(ctx->debug) {
+            ERR_print_errors_fp(stderr);
+        }
+
+        return -1;
+    }
+
+    SSL_set_fd(ctx_tls->ssl, ctx->s);
+
+    do {
+        ret = SSL_connect(ctx_tls->ssl);
+        if(ret == -1) {
+            ret = SSL_get_error(ctx_tls->ssl, ret);
+
+            FD_ZERO(&fds);
+            FD_ZERO(&errfds);
+            FD_SET(ctx->s, &fds);
+            FD_SET(ctx->s, &errfds);
+            timeout.tv_sec = ctx->response_timeout.tv_sec;
+            timeout.tv_usec = ctx->response_timeout.tv_usec;
+
+            switch(ret) {
+                case SSL_ERROR_WANT_READ:
+                    ret = select(ctx->s+1, &fds, NULL, &errfds, &timeout);
+                    if(ret <= 0 || FD_ISSET(ctx->s, &errfds)) {
+                        SSL_free(ctx_tls->ssl);
+                        close(ctx->s);
+                        ctx->s = -1;
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                case SSL_ERROR_WANT_WRITE:
+                    ret = select(ctx->s+1, NULL, &fds, &errfds, &timeout);
+                    if(ret <= 0 || FD_ISSET(ctx->s, &errfds)) {
+                        SSL_free(ctx_tls->ssl);
+                        close(ctx->s);
+                        ctx->s = -1;
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                default:
+                    SSL_free(ctx_tls->ssl);
+                    close(ctx->s);
+                    ctx->s = -1;
+                    return -1;
+            }
+        }
+    } while(ret <= 0);
+
+    if(ret == -1) {
+        if(ctx->debug) {
+            ERR_print_errors_fp(stderr);
+        }
+        close(ctx->s);
+        ctx->s = -1;
+        SSL_free(ctx_tls->ssl);
+        return -1;
+    }
+
+    cert = SSL_get_peer_certificate(ctx_tls->ssl);
+    if(cert == NULL) {
+        close(ctx->s);
+        ctx->s = -1;
+        SSL_free(ctx_tls->ssl);
+        return -1;
+    }
+
+    if(SSL_get_verify_result(ctx_tls->ssl) != X509_V_OK) {
+        X509_free(cert);
+        SSL_free(ctx_tls->ssl);
+        close(ctx->s);
+        ctx->s = -1;
+        return -1;
+    }
+
+    X509_free(cert);
+
+    return 0;
+}
+#endif
+
 static unsigned int _modbus_tcp_is_connected(modbus_t *ctx)
 {
     return ctx->s >= 0;
@@ -464,6 +666,21 @@ static void _modbus_tcp_close(modbus_t *ctx)
     }
 }
 
+#if defined(USE_TLS)
+static void _modbus_tls_close(modbus_t *ctx)
+{
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+
+    if(ctx->s != -1) {
+        SSL_shutdown(ctx_tls->ssl);
+        shutdown(ctx->s, SHUT_RDWR);
+        close(ctx->s);
+        ctx->s = -1;
+        SSL_free(ctx_tls->ssl);
+    }
+}
+#endif
+
 static int _modbus_tcp_flush(modbus_t *ctx)
 {
     int rc;
@@ -501,6 +718,46 @@ static int _modbus_tcp_flush(modbus_t *ctx)
     return rc_sum;
 }
 
+#if defined(USE_TLS)
+static int _modbus_tls_flush(modbus_t *ctx)
+{
+    int rc, rc_sum = 0;
+    char devnull[MODBUS_TCP_MAX_ADU_LENGTH];
+    fd_set rset;
+    struct timeval tv = {.tv_sec = 0, .tv_usec = 0};
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+
+    if((rc = SSL_pending(ctx_tls->ssl)) > 0) {
+        rc_sum = SSL_read(ctx_tls->ssl, devnull, rc);
+    }
+
+    FD_ZERO(&rset);
+    FD_SET(ctx->s, &rset);
+
+    do {
+        rc = select(ctx->s+1, &rset, NULL, NULL, &tv);
+
+        switch(rc) {
+            case -1:
+                if(errno == EINTR) {
+                    continue;
+                }
+                return -1;
+
+            case 1:
+                rc = SSL_read(ctx_tls->ssl, devnull, MODBUS_TCP_MAX_ADU_LENGTH);
+                if(rc > 0) {
+                    rc_sum += rc;
+                }
+
+                rc = -1;
+        }
+    } while(rc < 0);
+
+    return rc_sum;
+}
+#endif
+
 /* Listens for any request from one or many modbus masters in TCP */
 int modbus_tcp_listen(modbus_t *ctx, int nb_connection)
 {
@@ -690,6 +947,13 @@ int modbus_tcp_pi_listen(modbus_t *ctx, int nb_connection)
     return new_s;
 }
 
+#if defined(USE_TLS)
+int modbus_tls_listen(modbus_t *ctx, int nb_connection)
+{
+    return modbus_tcp_listen(ctx, nb_connection);
+}
+#endif
+
 int modbus_tcp_accept(modbus_t *ctx, int *s)
 {
     struct sockaddr_in addr;
@@ -758,6 +1022,102 @@ int modbus_tcp_pi_accept(modbus_t *ctx, int *s)
     return ctx->s;
 }
 
+#if defined(USE_TLS)
+int modbus_tls_accept(modbus_t *ctx, int *s)
+{
+    int ret;
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+    fd_set fds;
+    X509 *cert;
+
+    ret = modbus_tcp_accept(ctx, s);
+
+    if(ret < 0) {
+        return ret;
+    }
+
+    ctx_tls->ssl = SSL_new(ctx_tls->ctx);
+    if(ctx_tls->ssl == NULL) {
+        close(ctx->s);
+        ctx->s = -1;
+
+        if(ctx->debug) {
+            ERR_print_errors_fp(stderr);
+        }
+
+        return -1;
+    }
+
+    SSL_set_fd(ctx_tls->ssl, ctx->s);
+
+    do {
+        ret = SSL_accept(ctx_tls->ssl);
+        if(ret == -1) {
+            ret = SSL_get_error(ctx_tls->ssl, ret);
+
+            FD_ZERO(&fds);
+            FD_SET(ctx->s, &fds);
+
+            switch(ret) {
+                case SSL_ERROR_WANT_READ:
+                    ret = select(ctx->s+1, &fds, NULL, NULL, NULL);
+                    if(ret <= 0) {
+                        if(errno == EINTR) {
+                            continue;
+                        }
+                        SSL_free(ctx_tls->ssl);
+                        close(ctx->s);
+                        ctx->s = -1;
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                case SSL_ERROR_WANT_WRITE:
+                    select(ctx->s+1, NULL, &fds, NULL, NULL);
+                    if(ret <= 0) {
+                        if(errno == EINTR) {
+                            continue;
+                        }
+                        SSL_free(ctx_tls->ssl);
+                        close(ctx->s);
+                        ctx->s = -1;
+                        return -1;
+                    }
+
+                    ret = -1;
+                    break;
+                default:
+                    SSL_free(ctx_tls->ssl);
+                    close(ctx->s);
+                    ctx->s = -1;
+                    return -1;
+            }
+        }
+    } while(ret < 0);
+
+    cert = SSL_get_peer_certificate(ctx_tls->ssl);
+    if(cert == NULL) {
+        SSL_free(ctx_tls->ssl);
+        close(ctx->s);
+        ctx->s = -1;
+        return -1;
+    }
+
+    if(SSL_get_verify_result(ctx_tls->ssl) != X509_V_OK) {
+        X509_free(cert);
+        SSL_free(ctx_tls->ssl);
+        close(ctx->s);
+        ctx->s = -1;
+        return -1;
+    }
+
+    X509_free(cert);
+
+    return ctx->s;
+}
+#endif
+
 static int
 _modbus_tcp_select(modbus_t *ctx, fd_set *rset, struct timeval *tv, int length_to_read)
 {
@@ -783,6 +1143,21 @@ _modbus_tcp_select(modbus_t *ctx, fd_set *rset, struct timeval *tv, int length_t
     return s_rc;
 }
 
+#if defined(USE_TLS)
+static int _modbus_tls_select(modbus_t *ctx, fd_set *rset, struct timeval *tv, int length_to_read)
+{
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+
+    /*Check SSL buffer for pending data*/
+    if(SSL_pending(ctx_tls->ssl) > 0) {
+        return 1;
+    }
+
+    /*Run normal TCP select*/
+    return _modbus_tcp_select(ctx, rset, tv, length_to_read);
+}
+#endif
+
 static void _modbus_tcp_free(modbus_t *ctx)
 {
     if (ctx->backend_data) {
@@ -803,6 +1178,14 @@ static void _modbus_tcp_pi_free(modbus_t *ctx)
     free(ctx);
 }
 
+#if defined(USE_TLS)
+static void _modbus_tls_free(modbus_t *ctx) {
+    modbus_tls_t *ctx_tls = (modbus_tls_t *)ctx->backend_data;
+    SSL_CTX_free(ctx_tls->ctx);
+    _modbus_tcp_free(ctx);
+}
+#endif
+
 // clang-format off
 const modbus_backend_t _modbus_tcp_backend = {
     _MODBUS_BACKEND_TYPE_TCP,
@@ -850,6 +1233,31 @@ const modbus_backend_t _modbus_tcp_pi_backend = {
     _modbus_tcp_pi_free
 };
 
+#if defined(USE_TLS)
+const modbus_backend_t _modbus_tls_backend = {
+    _MODBUS_BACKEND_TYPE_TCP,
+    _MODBUS_TCP_HEADER_LENGTH,
+    _MODBUS_TCP_CHECKSUM_LENGTH,
+    MODBUS_TCP_MAX_ADU_LENGTH,
+    _modbus_set_slave,
+    _modbus_tcp_build_request_basis,
+    _modbus_tcp_build_response_basis,
+    _modbus_tcp_prepare_response_tid,
+    _modbus_tcp_send_msg_pre,
+    _modbus_tls_send,
+    _modbus_tcp_receive,
+    _modbus_tls_recv,
+    _modbus_tcp_check_integrity,
+    _modbus_tcp_pre_check_confirmation,
+    _modbus_tls_connect,
+    _modbus_tcp_is_connected,
+    _modbus_tls_close,
+    _modbus_tls_flush,
+    _modbus_tls_select,
+    _modbus_tls_free
+};
+#endif
+
 // clang-format on
 
 modbus_t *modbus_new_tcp(const char *ip, int port)
@@ -972,3 +1380,129 @@ modbus_t *modbus_new_tcp_pi(const char *node, const char *service)
 
     return ctx;
 }
+
+#if defined(USE_TLS)
+modbus_t* modbus_new_tls(const char *ip, int port, const char *cert, const char *key, const char *ca)
+{
+    int ret;
+    modbus_t *ctx;
+    modbus_tls_t *ctx_tls;
+    size_t dest_size;
+    size_t ret_size;
+    struct sigaction sa;
+
+    /* There is no MSG_NOSIGNAL equivalent for SSL_write, so we
+     * install the ignore handler for SIGPIPE */
+    sa.sa_handler = SIG_IGN;
+    if (sigaction(SIGPIPE, &sa, NULL) < 0) {
+        fprintf(stderr, "Could not install SIGPIPE handler.\n");
+        return NULL;
+    }
+
+    ctx = (modbus_t *)malloc(sizeof(modbus_t));
+    if (ctx == NULL) {
+        return NULL;
+    }
+    _modbus_init_common(ctx);
+
+    /* Could be changed after to reach a remote serial Modbus device */
+    ctx->slave = MODBUS_TCP_SLAVE;
+
+    ctx->backend = &_modbus_tls_backend;
+
+    ctx->backend_data = (modbus_tls_t *)malloc(sizeof(modbus_tls_t));
+    if (ctx->backend_data == NULL) {
+        modbus_free(ctx);
+        errno = ENOMEM;
+        return NULL;
+    }
+    ctx_tls = (modbus_tls_t *)ctx->backend_data;
+
+    if (ip != NULL) {
+        dest_size = sizeof(char) * 16;
+        ret_size = strlcpy(ctx_tls->ip, ip, dest_size);
+        if (ret_size == 0) {
+            fprintf(stderr, "The IP string is empty\n");
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+
+        if (ret_size >= dest_size) {
+            fprintf(stderr, "The IP string has been truncated\n");
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+    } else {
+        ctx_tls->ip[0] = '0';
+    }
+    ctx_tls->port = port;
+    ctx_tls->t_id = 0;
+
+    SSL_load_error_strings();
+    SSL_library_init();
+
+    ctx_tls->ctx = SSL_CTX_new(TLS_method());
+    SSL_CTX_set_min_proto_version(ctx_tls->ctx, TLS1_2_VERSION);
+
+    if(!ctx_tls->ctx) {
+        fprintf(stderr, "Cannot create SSL context\n");
+        ERR_print_errors_fp(stderr);
+        modbus_free(ctx);
+        errno = EPROTO;
+        return NULL;
+    }
+
+    if(cert) {
+        ret = SSL_CTX_use_certificate_file(ctx_tls->ctx, cert, SSL_FILETYPE_PEM);
+
+        if(ret <= 0) {
+            fprintf(stderr, "Cannot parse certificate file %s:\n", cert);
+            ERR_print_errors_fp(stderr);
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+    }
+
+    if(key) {
+        ret = SSL_CTX_use_PrivateKey_file(ctx_tls->ctx, key, SSL_FILETYPE_PEM);
+
+        if(ret <= 0) {
+            fprintf(stderr, "Cannot parse key file %s:\n", key);
+            ERR_print_errors_fp(stderr);
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+
+        ret = SSL_CTX_check_private_key(ctx_tls->ctx);
+        if(ret != 1) {
+            fprintf(stderr, "Invalid parse key %s for certificate %s :\n", key, cert);
+            ERR_print_errors_fp(stderr);
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+    }
+
+    if(ca) {
+        ret = SSL_CTX_load_verify_locations(ctx_tls->ctx, ca, NULL);
+
+        if(ret != 1) {
+            fprintf(stderr, "Cannot parse CA file %s:\n", ca);
+            ERR_print_errors_fp(stderr);
+            modbus_free(ctx);
+            errno = EINVAL;
+            return NULL;
+        }
+    }
+
+    SSL_CTX_set_verify(ctx_tls->ctx, SSL_VERIFY_PEER, NULL);
+
+    SSL_CTX_set_verify_depth(ctx_tls->ctx, 1);
+
+    return ctx;
+}
+#endif

+ 4 - 0
src/modbus-tcp.h

@@ -47,6 +47,10 @@ MODBUS_API modbus_t *modbus_new_tcp_pi(const char *node, const char *service);
 MODBUS_API int modbus_tcp_pi_listen(modbus_t *ctx, int nb_connection);
 MODBUS_API int modbus_tcp_pi_accept(modbus_t *ctx, int *s);
 
+MODBUS_API modbus_t* modbus_new_tls(const char *ip_address, int port, const char *cert, const char *key, const char *ca);
+MODBUS_API int modbus_tls_listen(modbus_t *ctx, int nb_connection);
+MODBUS_API int modbus_tls_accept(modbus_t *ctx, int *s);
+
 MODBUS_END_DECLS
 
 #endif /* MODBUS_TCP_H */