]> sjero.net Git - wget/blobdiff - src/gnutls.c
Fix some other problems with GNU TLS and non blocking sockets.
[wget] / src / gnutls.c
index 0bb49ca533cd794b56ccd4420d74e51f31b15b0c..dfff00cf5bf6e0ffc950fc55f885ddc466b9b726 100644 (file)
@@ -1,6 +1,6 @@
 /* SSL support via GnuTLS library.
-   Copyright (C) 2005, 2006, 2007, 2008, 2009 Free Software Foundation,
-   Inc.
+   Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010, 2011 Free Software
+   Foundation, Inc.
 
 This file is part of GNU Wget.
 
@@ -32,39 +32,81 @@ as that of the covered work.  */
 
 #include <assert.h>
 #include <errno.h>
-#ifdef HAVE_UNISTD_H
-# include <unistd.h>
-#endif
+#include <unistd.h>
 #include <string.h>
 #include <stdio.h>
+#include <dirent.h>
+#include <stdlib.h>
 
 #include <gnutls/gnutls.h>
 #include <gnutls/x509.h>
+#include <sys/ioctl.h>
 
 #include "utils.h"
 #include "connect.h"
 #include "url.h"
+#include "ptimer.h"
 #include "ssl.h"
 
+#ifdef WIN32
+# include "w32sock.h"
+#endif
+
 /* Note: some of the functions private to this file have names that
    begin with "wgnutls_" (e.g. wgnutls_read) so that they wouldn't be
    confused with actual gnutls functions -- such as the gnutls_read
    preprocessor macro.  */
 
 static gnutls_certificate_credentials credentials;
-
 bool
 ssl_init ()
 {
+  const char *ca_directory;
+  DIR *dir;
+
   gnutls_global_init ();
   gnutls_certificate_allocate_credentials (&credentials);
+  gnutls_certificate_set_verify_flags(credentials,
+                                      GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT);
+
+  ca_directory = opt.ca_directory ? opt.ca_directory : "/etc/ssl/certs";
+
+  dir = opendir (ca_directory);
+  if (dir == NULL)
+    {
+      if (opt.ca_directory)
+        logprintf (LOG_NOTQUIET, _("ERROR: Cannot open directory %s.\n"),
+                   opt.ca_directory);
+    }
+  else
+    {
+      struct dirent *dent;
+      while ((dent = readdir (dir)) != NULL)
+        {
+          struct stat st;
+          char *ca_file;
+          asprintf (&ca_file, "%s/%s", ca_directory, dent->d_name);
+
+          stat (ca_file, &st);
+
+          if (S_ISREG (st.st_mode))
+            gnutls_certificate_set_x509_trust_file (credentials, ca_file,
+                                                    GNUTLS_X509_FMT_PEM);
+
+          free (ca_file);
+        }
+
+      closedir (dir);
+    }
+
   if (opt.ca_cert)
     gnutls_certificate_set_x509_trust_file (credentials, opt.ca_cert,
                                             GNUTLS_X509_FMT_PEM);
   return true;
 }
 
-struct wgnutls_transport_context {
+struct wgnutls_transport_context
+{
   gnutls_session session;       /* GnuTLS session handle */
   int last_error;               /* last error returned by read/write/... */
 
@@ -73,38 +115,119 @@ struct wgnutls_transport_context {
      is stored to PEEKBUF, and wgnutls_read checks that buffer before
      actually reading.  */
   char peekbuf[512];
-  int peekstart, peeklen;
+  int peeklen;
 };
 
 #ifndef MIN
 # define MIN(i, j) ((i) <= (j) ? (i) : (j))
 #endif
 
+
+static int
+wgnutls_read_timeout (int fd, char *buf, int bufsize, void *arg, double timeout)
+{
+#ifdef F_GETFL
+  int flags = 0;
+#endif
+  int ret = 0;
+  struct ptimer *timer;
+  struct wgnutls_transport_context *ctx = arg;
+  int timed_out = 0;
+
+  if (timeout)
+    {
+#ifdef F_GETFL
+      flags = fcntl (fd, F_GETFL, 0);
+      if (flags < 0)
+        return flags;
+#endif
+      timer = ptimer_new ();
+      if (timer == 0)
+        return -1;
+    }
+
+  do
+    {
+      double next_timeout = timeout - ptimer_measure (timer);
+      if (timeout && next_timeout < 0)
+        break;
+
+      ret = GNUTLS_E_AGAIN;
+      if (timeout == 0 || gnutls_record_check_pending (ctx->session)
+          || select_fd (fd, next_timeout, WAIT_FOR_READ))
+        {
+          if (timeout)
+            {
+#ifdef F_GETFL
+              ret = fcntl (fd, F_SETFL, flags | O_NONBLOCK);
+              if (ret < 0)
+                return ret;
+#else
+              /* XXX: Assume it was blocking before.  */
+              const int one = 1;
+              ret = ioctl (fd, FIONBIO, &one);
+              if (ret < 0)
+                return ret;
+#endif
+            }
+
+          ret = gnutls_record_recv (ctx->session, buf, bufsize);
+
+          if (timeout)
+            {
+              int status;
+#ifdef F_GETFL
+              status = fcntl (fd, F_SETFL, flags);
+              if (status < 0)
+                return status;
+#else
+              const int zero = 0;
+              status = ioctl (fd, FIONBIO, &zero);
+              if (status < 0)
+                return status;
+#endif
+            }
+        }
+
+      timed_out = timeout && ptimer_measure (timer) >= timeout;
+    }
+  while (ret == GNUTLS_E_INTERRUPTED || (ret == GNUTLS_E_AGAIN && !timed_out));
+
+  if (timeout)
+    ptimer_destroy (timer);
+
+  if (timeout && timed_out && ret == GNUTLS_E_AGAIN)
+    errno = ETIMEDOUT;
+
+  return ret;
+}
+
 static int
 wgnutls_read (int fd, char *buf, int bufsize, void *arg)
 {
-  int ret;
+#ifdef F_GETFL
+  int flags = 0;
+#endif
+  int ret = 0;
+  struct ptimer *timer;
   struct wgnutls_transport_context *ctx = arg;
 
   if (ctx->peeklen)
     {
       /* If we have any peek data, simply return that. */
       int copysize = MIN (bufsize, ctx->peeklen);
-      memcpy (buf, ctx->peekbuf + ctx->peekstart, copysize);
+      memcpy (buf, ctx->peekbuf, copysize);
       ctx->peeklen -= copysize;
       if (ctx->peeklen != 0)
-        ctx->peekstart += copysize;
-      else
-        ctx->peekstart = 0;
+        memmove (ctx->peekbuf, ctx->peekbuf + copysize, ctx->peeklen);
+
       return copysize;
     }
 
-  do
-    ret = gnutls_record_recv (ctx->session, buf, bufsize);
-  while (ret == GNUTLS_E_INTERRUPTED);
-
+  ret = wgnutls_read_timeout (fd, buf, bufsize, arg, opt.read_timeout);
   if (ret < 0)
     ctx->last_error = ret;
+
   return ret;
 }
 
@@ -115,7 +238,7 @@ wgnutls_write (int fd, char *buf, int bufsize, void *arg)
   struct wgnutls_transport_context *ctx = arg;
   do
     ret = gnutls_record_send (ctx->session, buf, bufsize);
-  while (ret == GNUTLS_E_INTERRUPTED);
+  while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
   if (ret < 0)
     ctx->last_error = ret;
   return ret;
@@ -124,31 +247,48 @@ wgnutls_write (int fd, char *buf, int bufsize, void *arg)
 static int
 wgnutls_poll (int fd, double timeout, int wait_for, void *arg)
 {
-  return 1;
+  struct wgnutls_transport_context *ctx = arg;
+  return ctx->peeklen || gnutls_record_check_pending (ctx->session)
+    || select_fd (fd, timeout, wait_for);
 }
 
 static int
 wgnutls_peek (int fd, char *buf, int bufsize, void *arg)
 {
-  int ret;
+  int read = 0;
   struct wgnutls_transport_context *ctx = arg;
-
-  /* We don't support peeks following peeks: the reader must drain all
-     peeked data before the next peek.  */
-  assert (ctx->peeklen == 0);
+  int offset = MIN (bufsize, ctx->peeklen);
   if (bufsize > sizeof ctx->peekbuf)
     bufsize = sizeof ctx->peekbuf;
 
-  do
-    ret = gnutls_record_recv (ctx->session, buf, bufsize);
-  while (ret == GNUTLS_E_INTERRUPTED);
+  if (ctx->peeklen)
+    memcpy (buf, ctx->peekbuf, offset);
 
-  if (ret >= 0)
+  if (bufsize > offset)
     {
-      memcpy (ctx->peekbuf, buf, ret);
-      ctx->peeklen = ret;
+      if (gnutls_record_check_pending (ctx->session) <= 0
+          && select_fd (fd, 0.0, WAIT_FOR_READ) <= 0)
+        read = 0;
+      else
+        read = wgnutls_read_timeout (fd, buf + offset, bufsize - offset,
+                                     ctx, opt.read_timeout);
+      if (read < 0)
+        {
+          if (offset)
+            read = 0;
+          else
+            return read;
+        }
+
+      if (read > 0)
+        {
+          memcpy (ctx->peekbuf + offset, buf + offset,
+                  read);
+          ctx->peeklen += read;
+        }
     }
-  return ret;
+
+  return offset + read;
 }
 
 static const char *
@@ -171,25 +311,73 @@ wgnutls_close (int fd, void *arg)
 /* gnutls_transport is the singleton that describes the SSL transport
    methods provided by this file.  */
 
-static struct transport_implementation wgnutls_transport = {
+static struct transport_implementation wgnutls_transport =
+{
   wgnutls_read, wgnutls_write, wgnutls_poll,
   wgnutls_peek, wgnutls_errstr, wgnutls_close
 };
 
 bool
-ssl_connect (int fd)
+ssl_connect_wget (int fd)
 {
-  static const int cert_type_priority[] = {
-    GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0
-  };
   struct wgnutls_transport_context *ctx;
   gnutls_session session;
   int err;
   gnutls_init (&session, GNUTLS_CLIENT);
   gnutls_set_default_priority (session);
-  gnutls_certificate_type_set_priority (session, cert_type_priority);
   gnutls_credentials_set (session, GNUTLS_CRD_CERTIFICATE, credentials);
-  gnutls_transport_set_ptr (session, (gnutls_transport_ptr) fd);
+#ifndef FD_TO_SOCKET
+# define FD_TO_SOCKET(X) (X)
+#endif
+  gnutls_transport_set_ptr (session, (gnutls_transport_ptr) FD_TO_SOCKET (fd));
+
+  err = 0;
+#if HAVE_GNUTLS_PRIORITY_SET_DIRECT
+  switch (opt.secure_protocol)
+    {
+    case secure_protocol_auto:
+      break;
+    case secure_protocol_sslv2:
+    case secure_protocol_sslv3:
+      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-TLS-ALL", NULL);
+      break;
+    case secure_protocol_tlsv1:
+      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-SSL3.0", NULL);
+      break;
+    default:
+      abort ();
+    }
+#else
+  int allowed_protocols[4] = {0, 0, 0, 0};
+  switch (opt.secure_protocol)
+    {
+    case secure_protocol_auto:
+      break;
+    case secure_protocol_sslv2:
+    case secure_protocol_sslv3:
+      allowed_protocols[0] = GNUTLS_SSL3;
+      err = gnutls_protocol_set_priority (session, allowed_protocols);
+      break;
+
+    case secure_protocol_tlsv1:
+      allowed_protocols[0] = GNUTLS_TLS1_0;
+      allowed_protocols[1] = GNUTLS_TLS1_1;
+      allowed_protocols[2] = GNUTLS_TLS1_2;
+      err = gnutls_protocol_set_priority (session, allowed_protocols);
+      break;
+
+    default:
+      abort ();
+    }
+#endif
+
+  if (err < 0)
+    {
+      logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
+      gnutls_deinit (session);
+      return false;
+    }
+
   err = gnutls_handshake (session);
   if (err < 0)
     {
@@ -197,6 +385,7 @@ ssl_connect (int fd)
       gnutls_deinit (session);
       return false;
     }
+
   ctx = xnew0 (struct wgnutls_transport_context);
   ctx->session = session;
   fd_register_transport (fd, &wgnutls_transport, ctx);