]> sjero.net Git - wget/blobdiff - src/gnutls.c
Fix build when libpsl is not available
[wget] / src / gnutls.c
index a1054a4d89c9804dd3f31b794e22924c2b7a022b..4f0fa962537b304c7766947c12c29a0c13e77795 100644 (file)
@@ -1,5 +1,5 @@
 /* SSL support via GnuTLS library.
-   Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010, 2011 Free Software
+   Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012 Free Software
    Foundation, Inc.
 
 This file is part of GNU Wget.
@@ -40,75 +40,160 @@ as that of the covered work.  */
 
 #include <gnutls/gnutls.h>
 #include <gnutls/x509.h>
-#include <fcntl.h>
 #include <sys/ioctl.h>
 
 #include "utils.h"
 #include "connect.h"
 #include "url.h"
+#include "ptimer.h"
+#include "hash.h"
 #include "ssl.h"
 
+#include <sys/fcntl.h>
+
 #ifdef WIN32
 # include "w32sock.h"
 #endif
 
+#include "host.h"
+
+static int
+key_type_to_gnutls_type (enum keyfile_type type)
+{
+  switch (type)
+    {
+    case keyfile_pem:
+      return GNUTLS_X509_FMT_PEM;
+    case keyfile_asn1:
+      return GNUTLS_X509_FMT_DER;
+    default:
+      abort ();
+    }
+}
+
 /* Note: some of the functions private to this file have names that
    begin with "wgnutls_" (e.g. wgnutls_read) so that they wouldn't be
    confused with actual gnutls functions -- such as the gnutls_read
    preprocessor macro.  */
 
-static gnutls_certificate_credentials credentials;
-
+static gnutls_certificate_credentials_t credentials;
 bool
-ssl_init ()
+ssl_init (void)
 {
+  /* Becomes true if GnuTLS is initialized. */
+  static bool ssl_initialized = false;
   const char *ca_directory;
   DIR *dir;
+  int ncerts = -1;
+
+  /* GnuTLS should be initialized only once. */
+  if (ssl_initialized)
+    return true;
 
   gnutls_global_init ();
   gnutls_certificate_allocate_credentials (&credentials);
-  gnutls_certificate_set_verify_flags(credentials,
-                                      GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT);
+  gnutls_certificate_set_verify_flags (credentials,
+                                       GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT);
 
-  ca_directory = opt.ca_directory ? opt.ca_directory : "/etc/ssl/certs";
+#if GNUTLS_VERSION_MAJOR >= 3
+  if (!opt.ca_directory)
+    ncerts = gnutls_certificate_set_x509_system_trust (credentials);
+#endif
 
-  dir = opendir (ca_directory);
-  if (dir == NULL)
-    {
-      if (opt.ca_directory)
-        logprintf (LOG_NOTQUIET, _("ERROR: Cannot open directory %s.\n"),
-                   opt.ca_directory);
-    }
-  else
+  /* If GnuTLS version is too old or CA loading failed, fallback to old behaviour.
+   * Also use old behaviour if the CA directory is user-provided.  */
+  if (ncerts <= 0)
     {
-      struct dirent *dent;
-      while ((dent = readdir (dir)) != NULL)
+      ca_directory = opt.ca_directory ? opt.ca_directory : "/etc/ssl/certs";
+      if ((dir = opendir (ca_directory)) == NULL)
         {
-          struct stat st;
-          char *ca_file;
-          asprintf (&ca_file, "%s/%s", ca_directory, dent->d_name);
+          if (opt.ca_directory && *opt.ca_directory)
+            logprintf (LOG_NOTQUIET, _("ERROR: Cannot open directory %s.\n"),
+                       opt.ca_directory);
+        }
+      else
+        {
+          struct hash_table *inode_map = hash_table_new (196, NULL, NULL);
+          struct dirent *dent;
+          size_t dirlen = strlen(ca_directory);
+          int rc;
+
+          ncerts = 0;
+
+          while ((dent = readdir (dir)) != NULL)
+            {
+              struct stat st;
+              char ca_file[dirlen + strlen(dent->d_name) + 2];
+
+              snprintf (ca_file, sizeof(ca_file), "%s/%s", ca_directory, dent->d_name);
+              if (stat (ca_file, &st) != 0)
+                continue;
+
+              if (! S_ISREG (st.st_mode))
+                continue;
+
+              /* avoid loading the same file twice by checking the inode.  */
+              if (hash_table_contains (inode_map, (void *)(intptr_t) st.st_ino))
+                continue;
+
+              hash_table_put (inode_map, (void *)(intptr_t) st.st_ino, NULL);
+              if ((rc = gnutls_certificate_set_x509_trust_file (credentials, ca_file,
+                                                                GNUTLS_X509_FMT_PEM)) <= 0)
+                logprintf (LOG_NOTQUIET, _("ERROR: Failed to open cert %s: (%d).\n"),
+                           ca_file, rc);
+              else
+                ncerts += rc;
+            }
+
+          hash_table_destroy (inode_map);
+          closedir (dir);
+        }
+    }
 
-          stat (ca_file, &st);
+  DEBUGP (("Certificates loaded: %d\n", ncerts));
 
-          if (S_ISREG (st.st_mode))
-            gnutls_certificate_set_x509_trust_file (credentials, ca_file,
-                                                    GNUTLS_X509_FMT_PEM);
+  /* Use the private key from the cert file unless otherwise specified. */
+  if (opt.cert_file && !opt.private_key)
+    {
+      opt.private_key = opt.cert_file;
+      opt.private_key_type = opt.cert_type;
+    }
+  /* Use the cert from the private key file unless otherwise specified. */
+  if (!opt.cert_file && opt.private_key)
+    {
+      opt.cert_file = opt.private_key;
+      opt.cert_type = opt.private_key_type;
+    }
 
-          free (ca_file);
+  if (opt.cert_file && opt.private_key)
+    {
+      int type;
+      if (opt.private_key_type != opt.cert_type)
+        {
+          /* GnuTLS can't handle this */
+          logprintf (LOG_NOTQUIET, _("ERROR: GnuTLS requires the key and the \
+cert to be of the same type.\n"));
         }
 
-      closedir (dir);
+      type = key_type_to_gnutls_type (opt.private_key_type);
+
+      gnutls_certificate_set_x509_key_file (credentials, opt.cert_file,
+                                            opt.private_key,
+                                            type);
     }
 
   if (opt.ca_cert)
     gnutls_certificate_set_x509_trust_file (credentials, opt.ca_cert,
                                             GNUTLS_X509_FMT_PEM);
+
+  ssl_initialized = true;
+
   return true;
 }
 
 struct wgnutls_transport_context
 {
-  gnutls_session session;       /* GnuTLS session handle */
+  gnutls_session_t session;       /* GnuTLS session handle */
   int last_error;               /* last error returned by read/write/... */
 
   /* Since GnuTLS doesn't support the equivalent to recv(...,
@@ -123,6 +208,78 @@ struct wgnutls_transport_context
 # define MIN(i, j) ((i) <= (j) ? (i) : (j))
 #endif
 
+
+static int
+wgnutls_read_timeout (int fd, char *buf, int bufsize, void *arg, double timeout)
+{
+#ifdef F_GETFL
+  int flags = 0;
+#endif
+  int ret = 0;
+  struct ptimer *timer = NULL;
+  struct wgnutls_transport_context *ctx = arg;
+  int timed_out = 0;
+
+  if (timeout)
+    {
+#ifdef F_GETFL
+      flags = fcntl (fd, F_GETFL, 0);
+      if (flags < 0)
+        return flags;
+      if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+        return -1;
+#else
+      /* XXX: Assume it was blocking before.  */
+      const int one = 1;
+      if (ioctl (fd, FIONBIO, &one) < 0)
+        return -1;
+#endif
+
+      timer = ptimer_new ();
+      if (timer == NULL)
+        return -1;
+    }
+
+  do
+    {
+      double next_timeout = 0;
+      if (timeout)
+        {
+          next_timeout = timeout - ptimer_measure (timer);
+          if (next_timeout < 0)
+            break;
+        }
+
+      ret = GNUTLS_E_AGAIN;
+      if (timeout == 0 || gnutls_record_check_pending (ctx->session)
+          || select_fd (fd, next_timeout, WAIT_FOR_READ))
+        {
+          ret = gnutls_record_recv (ctx->session, buf, bufsize);
+          timed_out = timeout && ptimer_measure (timer) >= timeout;
+        }
+    }
+  while (ret == GNUTLS_E_INTERRUPTED || (ret == GNUTLS_E_AGAIN && !timed_out));
+
+  if (timeout)
+    {
+      ptimer_destroy (timer);
+
+#ifdef F_GETFL
+      if (fcntl (fd, F_SETFL, flags) < 0)
+        return -1;
+#else
+      const int zero = 0;
+      if (ioctl (fd, FIONBIO, &zero) < 0)
+        return -1;
+#endif
+
+      if (timed_out && ret == GNUTLS_E_AGAIN)
+        errno = ETIMEDOUT;
+    }
+
+  return ret;
+}
+
 static int
 wgnutls_read (int fd, char *buf, int bufsize, void *arg)
 {
@@ -141,10 +298,7 @@ wgnutls_read (int fd, char *buf, int bufsize, void *arg)
       return copysize;
     }
 
-  do
-    ret = gnutls_record_recv (ctx->session, buf, bufsize);
-  while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
-
+  ret = wgnutls_read_timeout (fd, buf, bufsize, arg, opt.read_timeout);
   if (ret < 0)
     ctx->last_error = ret;
 
@@ -152,7 +306,7 @@ wgnutls_read (int fd, char *buf, int bufsize, void *arg)
 }
 
 static int
-wgnutls_write (int fd, char *buf, int bufsize, void *arg)
+wgnutls_write (int fd _GL_UNUSED, char *buf, int bufsize, void *arg)
 {
   int ret;
   struct wgnutls_transport_context *ctx = arg;
@@ -168,42 +322,38 @@ static int
 wgnutls_poll (int fd, double timeout, int wait_for, void *arg)
 {
   struct wgnutls_transport_context *ctx = arg;
-  return ctx->peeklen || gnutls_record_check_pending (ctx->session)
-    || select_fd (fd, timeout, wait_for);
+
+  if (timeout)
+    return ctx->peeklen || gnutls_record_check_pending (ctx->session)
+      || select_fd (fd, timeout, wait_for);
+  else
+    return ctx->peeklen || gnutls_record_check_pending (ctx->session);
 }
 
 static int
 wgnutls_peek (int fd, char *buf, int bufsize, void *arg)
 {
-  int ret = 0, read = 0;
+  int read = 0;
   struct wgnutls_transport_context *ctx = arg;
   int offset = MIN (bufsize, ctx->peeklen);
-  if (bufsize > sizeof ctx->peekbuf)
-    bufsize = sizeof ctx->peekbuf;
 
   if (ctx->peeklen)
-    memcpy (buf, ctx->peekbuf, offset);
+    {
+      memcpy (buf, ctx->peekbuf, offset);
+      return offset;
+    }
+
+  if (bufsize > (int) sizeof ctx->peekbuf)
+    bufsize = sizeof ctx->peekbuf;
 
   if (bufsize > offset)
     {
-#ifdef F_GETFL
-      int flags;
-      flags = fcntl (fd, F_GETFL, 0);
-      if (flags < 0)
-        return ret;
-
-      ret = fcntl (fd, F_SETFL, flags | O_NONBLOCK);
-      if (ret < 0)
-        return ret;
-#else
-      /* XXX: Assume it was blocking before.  */
-      const int one = 1;
-      ret = ioctl (fd, FIONBIO, &one);
-      if (ret < 0)
-        return ret;
-#endif
-      read = gnutls_record_recv (ctx->session, buf + offset,
-                                 bufsize - offset);
+      if (opt.read_timeout && gnutls_record_check_pending (ctx->session) == 0
+          && select_fd (fd, 0.0, WAIT_FOR_READ) <= 0)
+        read = 0;
+      else
+        read = wgnutls_read_timeout (fd, buf + offset, bufsize - offset,
+                                     ctx, opt.read_timeout);
       if (read < 0)
         {
           if (offset)
@@ -218,24 +368,13 @@ wgnutls_peek (int fd, char *buf, int bufsize, void *arg)
                   read);
           ctx->peeklen += read;
         }
-
-#ifdef F_GETFL
-      ret = fcntl (fd, F_SETFL, flags);
-      if (ret < 0)
-        return ret;
-#else
-      const int zero = 0;
-      ret = ioctl (fd, FIONBIO, &zero);
-      if (ret < 0)
-        return ret;
-#endif
     }
 
   return offset + read;
 }
 
 static const char *
-wgnutls_errstr (int fd, void *arg)
+wgnutls_errstr (int fd _GL_UNUSED, void *arg)
 {
   struct wgnutls_transport_context *ctx = arg;
   return gnutls_strerror (ctx->last_error);
@@ -261,25 +400,59 @@ static struct transport_implementation wgnutls_transport =
 };
 
 bool
-ssl_connect_wget (int fd)
+ssl_connect_wget (int fd, const char *hostname)
 {
-  static const int cert_type_priority[] = {
-    GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0
-  };
+#ifdef F_GETFL
+  int flags = 0;
+#endif
   struct wgnutls_transport_context *ctx;
-  gnutls_session session;
-  int err;
-  int allowed_protocols[4] = {0, 0, 0, 0};
+  gnutls_session_t session;
+  int err,alert;
   gnutls_init (&session, GNUTLS_CLIENT);
+  const char *str;
+
+  /* We set the server name but only if it's not an IP address. */
+  if (! is_valid_ip_address (hostname))
+    {
+      gnutls_server_name_set (session, GNUTLS_NAME_DNS, hostname,
+                              strlen (hostname));
+    }
+
   gnutls_set_default_priority (session);
-  gnutls_certificate_type_set_priority (session, cert_type_priority);
   gnutls_credentials_set (session, GNUTLS_CRD_CERTIFICATE, credentials);
 #ifndef FD_TO_SOCKET
 # define FD_TO_SOCKET(X) (X)
 #endif
-  gnutls_transport_set_ptr (session, (gnutls_transport_ptr) FD_TO_SOCKET (fd));
+#ifdef HAVE_INTPTR_T
+  gnutls_transport_set_ptr (session, (gnutls_transport_ptr_t) (intptr_t) FD_TO_SOCKET (fd));
+#else
+  gnutls_transport_set_ptr (session, (gnutls_transport_ptr_t) FD_TO_SOCKET (fd));
+#endif
 
   err = 0;
+#if HAVE_GNUTLS_PRIORITY_SET_DIRECT
+  switch (opt.secure_protocol)
+    {
+    case secure_protocol_auto:
+      break;
+    case secure_protocol_sslv2:
+    case secure_protocol_sslv3:
+      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-TLS-ALL:+VERS-SSL3.0", NULL);
+      break;
+    case secure_protocol_tlsv1:
+      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-SSL3.0", NULL);
+      break;
+    case secure_protocol_pfs:
+      err = gnutls_priority_set_direct (session, "PFS", NULL);
+      if (err != GNUTLS_E_SUCCESS)
+        /* fallback if PFS is not available */
+        err = gnutls_priority_set_direct (session, "NORMAL:-RSA", NULL);
+      break;
+    default:
+      abort ();
+    }
+#else
+  int allowed_protocols[4] = {0, 0, 0, 0};
   switch (opt.secure_protocol)
     {
     case secure_protocol_auto:
@@ -289,15 +462,19 @@ ssl_connect_wget (int fd)
       allowed_protocols[0] = GNUTLS_SSL3;
       err = gnutls_protocol_set_priority (session, allowed_protocols);
       break;
+
     case secure_protocol_tlsv1:
       allowed_protocols[0] = GNUTLS_TLS1_0;
       allowed_protocols[1] = GNUTLS_TLS1_1;
       allowed_protocols[2] = GNUTLS_TLS1_2;
       err = gnutls_protocol_set_priority (session, allowed_protocols);
       break;
+
     default:
       abort ();
     }
+#endif
+
   if (err < 0)
     {
       logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
@@ -305,10 +482,82 @@ ssl_connect_wget (int fd)
       return false;
     }
 
-  err = gnutls_handshake (session);
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      flags = fcntl (fd, F_GETFL, 0);
+      if (flags < 0)
+        return flags;
+      if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+        return -1;
+#else
+      /* XXX: Assume it was blocking before.  */
+      const int one = 1;
+      if (ioctl (fd, FIONBIO, &one) < 0)
+        return -1;
+#endif
+    }
+
+  /* We don't stop the handshake process for non-fatal errors */
+  do
+    {
+      err = gnutls_handshake (session);
+
+      if (opt.connect_timeout && err == GNUTLS_E_AGAIN)
+        {
+          if (gnutls_record_get_direction (session))
+            {
+              /* wait for writeability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_WRITE);
+            }
+          else
+            {
+              /* wait for readability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_READ);
+            }
+
+          if (err <= 0)
+            {
+              if (err == 0)
+                {
+                  errno = ETIMEDOUT;
+                  err = -1;
+                }
+              break;
+            }
+
+           err = GNUTLS_E_AGAIN;
+        }
+      else if (err < 0)
+        {
+          logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
+          if (err == GNUTLS_E_WARNING_ALERT_RECEIVED ||
+              err == GNUTLS_E_FATAL_ALERT_RECEIVED)
+            {
+              alert = gnutls_alert_get (session);
+              str = gnutls_alert_get_name (alert);
+              if (str == NULL)
+                str = "(unknown)";
+              logprintf (LOG_NOTQUIET, "GnuTLS: received alert [%d]: %s\n", alert, str);
+            }
+        }
+    }
+  while (err && gnutls_error_is_fatal (err) == 0);
+
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      if (fcntl (fd, F_SETFL, flags) < 0)
+        return -1;
+#else
+      const int zero = 0;
+      if (ioctl (fd, FIONBIO, &zero) < 0)
+        return -1;
+#endif
+    }
+
   if (err < 0)
     {
-      logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
       gnutls_deinit (session);
       return false;
     }
@@ -319,6 +568,14 @@ ssl_connect_wget (int fd)
   return true;
 }
 
+#define _CHECK_CERT(flag,msg) \
+  if (status & (flag))\
+    {\
+      logprintf (LOG_NOTQUIET, (msg),\
+                 severity, quote (host));\
+      success = false;\
+    }
+
 bool
 ssl_check_certificate (int fd, const char *host)
 {
@@ -341,30 +598,19 @@ ssl_check_certificate (int fd, const char *host)
       goto out;
     }
 
-  if (status & GNUTLS_CERT_INVALID)
-    {
-      logprintf (LOG_NOTQUIET, _("%s: The certificate of %s is not trusted.\n"),
-                 severity, quote (host));
-      success = false;
-    }
-  if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
-    {
-      logprintf (LOG_NOTQUIET, _("%s: The certificate of %s hasn't got a known issuer.\n"),
-                 severity, quote (host));
-      success = false;
-    }
-  if (status & GNUTLS_CERT_REVOKED)
-    {
-      logprintf (LOG_NOTQUIET, _("%s: The certificate of %s has been revoked.\n"),
-                 severity, quote (host));
-      success = false;
-    }
+  _CHECK_CERT (GNUTLS_CERT_INVALID, _("%s: The certificate of %s is not trusted.\n"));
+  _CHECK_CERT (GNUTLS_CERT_SIGNER_NOT_FOUND, _("%s: The certificate of %s hasn't got a known issuer.\n"));
+  _CHECK_CERT (GNUTLS_CERT_REVOKED, _("%s: The certificate of %s has been revoked.\n"));
+  _CHECK_CERT (GNUTLS_CERT_SIGNER_NOT_CA, _("%s: The certificate signer of %s was not a CA.\n"));
+  _CHECK_CERT (GNUTLS_CERT_INSECURE_ALGORITHM, _("%s: The certificate of %s was signed using an insecure algorithm.\n"));
+  _CHECK_CERT (GNUTLS_CERT_NOT_ACTIVATED, _("%s: The certificate of %s is not yet activated.\n"));
+  _CHECK_CERT (GNUTLS_CERT_EXPIRED, _("%s: The certificate of %s has expired.\n"));
 
   if (gnutls_certificate_type_get (ctx->session) == GNUTLS_CRT_X509)
     {
       time_t now = time (NULL);
-      gnutls_x509_crt cert;
-      const gnutls_datum *cert_list;
+      gnutls_x509_crt_t cert;
+      const gnutls_datum_t *cert_list;
       unsigned int cert_list_size;
 
       if ((err = gnutls_x509_crt_init (&cert)) < 0)
@@ -380,7 +626,7 @@ ssl_check_certificate (int fd, const char *host)
         {
           logprintf (LOG_NOTQUIET, _("No certificate found\n"));
           success = false;
-          goto out;
+          goto crt_deinit;
         }
       err = gnutls_x509_crt_import (cert, cert_list, GNUTLS_X509_FMT_DER);
       if (err < 0)
@@ -388,7 +634,7 @@ ssl_check_certificate (int fd, const char *host)
           logprintf (LOG_NOTQUIET, _("Error parsing certificate: %s\n"),
                      gnutls_strerror (err));
           success = false;
-          goto out;
+          goto crt_deinit;
         }
       if (now < gnutls_x509_crt_get_activation_time (cert))
         {
@@ -407,8 +653,14 @@ ssl_check_certificate (int fd, const char *host)
                      quote (host));
           success = false;
         }
+ crt_deinit:
       gnutls_x509_crt_deinit (cert);
-   }
+    }
+  else
+    {
+      logprintf (LOG_NOTQUIET, _("Certificate must be X.509\n"));
+      success = false;
+    }
 
  out:
   return opt.check_cert ? success : true;