]> sjero.net Git - wget/blobdiff - src/gnutls.c
gnutls: remove deprecated gnutls types.
[wget] / src / gnutls.c
index 47c3f5d1272eb27754f18da49bed4fbdd7561281..2b13875fb87c75966efaf41e92ca9eb6e42133a1 100644 (file)
@@ -1,5 +1,5 @@
 /* SSL support via GnuTLS library.
-   Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010, 2011 Free Software
+   Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012 Free Software
    Foundation, Inc.
 
 This file is part of GNU Wget.
@@ -48,20 +48,44 @@ as that of the covered work.  */
 #include "ptimer.h"
 #include "ssl.h"
 
+#include <sys/fcntl.h>
+
 #ifdef WIN32
 # include "w32sock.h"
 #endif
 
+#include "host.h"
+
+static int
+key_type_to_gnutls_type (enum keyfile_type type)
+{
+  switch (type)
+    {
+    case keyfile_pem:
+      return GNUTLS_X509_FMT_PEM;
+    case keyfile_asn1:
+      return GNUTLS_X509_FMT_DER;
+    default:
+      abort ();
+    }
+}
+
 /* Note: some of the functions private to this file have names that
    begin with "wgnutls_" (e.g. wgnutls_read) so that they wouldn't be
    confused with actual gnutls functions -- such as the gnutls_read
    preprocessor macro.  */
 
-static gnutls_certificate_credentials credentials;
-
+static gnutls_certificate_credentials_t credentials;
 bool
-ssl_init ()
+ssl_init (void)
 {
+  /* Becomes true if GnuTLS is initialized. */
+  static bool ssl_initialized = false;
+
+  /* GnuTLS should be initialized only once. */
+  if (ssl_initialized)
+    return true;
+
   const char *ca_directory;
   DIR *dir;
 
@@ -100,15 +124,48 @@ ssl_init ()
       closedir (dir);
     }
 
+  /* Use the private key from the cert file unless otherwise specified. */
+  if (opt.cert_file && !opt.private_key)
+    {
+      opt.private_key = opt.cert_file;
+      opt.private_key_type = opt.cert_type;
+    }
+  /* Use the cert from the private key file unless otherwise specified. */
+  if (!opt.cert_file && opt.private_key)
+    {
+      opt.cert_file = opt.private_key;
+      opt.cert_type = opt.private_key_type;
+    }
+
+  if (opt.cert_file && opt.private_key)
+    {
+      int type;
+      if (opt.private_key_type != opt.cert_type)
+       {
+         /* GnuTLS can't handle this */
+         logprintf (LOG_NOTQUIET, _("ERROR: GnuTLS requires the key and the \
+cert to be of the same type.\n"));
+       }
+
+      type = key_type_to_gnutls_type (opt.private_key_type);
+
+      gnutls_certificate_set_x509_key_file (credentials, opt.cert_file,
+                                           opt.private_key,
+                                           type);
+    }
+
   if (opt.ca_cert)
     gnutls_certificate_set_x509_trust_file (credentials, opt.ca_cert,
                                             GNUTLS_X509_FMT_PEM);
+
+  ssl_initialized = true;
+
   return true;
 }
 
 struct wgnutls_transport_context
 {
-  gnutls_session session;       /* GnuTLS session handle */
+  gnutls_session_t session;       /* GnuTLS session handle */
   int last_error;               /* last error returned by read/write/... */
 
   /* Since GnuTLS doesn't support the equivalent to recv(...,
@@ -123,8 +180,9 @@ struct wgnutls_transport_context
 # define MIN(i, j) ((i) <= (j) ? (i) : (j))
 #endif
 
+
 static int
-wgnutls_read (int fd, char *buf, int bufsize, void *arg)
+wgnutls_read_timeout (int fd, char *buf, int bufsize, void *arg, double timeout)
 {
 #ifdef F_GETFL
   int flags = 0;
@@ -132,35 +190,14 @@ wgnutls_read (int fd, char *buf, int bufsize, void *arg)
   int ret = 0;
   struct ptimer *timer;
   struct wgnutls_transport_context *ctx = arg;
+  int timed_out = 0;
 
-  if (ctx->peeklen)
-    {
-      /* If we have any peek data, simply return that. */
-      int copysize = MIN (bufsize, ctx->peeklen);
-      memcpy (buf, ctx->peekbuf, copysize);
-      ctx->peeklen -= copysize;
-      if (ctx->peeklen != 0)
-        memmove (ctx->peekbuf, ctx->peekbuf + copysize, ctx->peeklen);
-
-      return copysize;
-    }
-
-  if (opt.read_timeout)
+  if (timeout)
     {
 #ifdef F_GETFL
       flags = fcntl (fd, F_GETFL, 0);
       if (flags < 0)
-        return ret;
-
-      ret = fcntl (fd, F_SETFL, flags | O_NONBLOCK);
-      if (ret < 0)
-        return ret;
-#else
-      /* XXX: Assume it was blocking before.  */
-      const int one = 1;
-      ret = ioctl (fd, FIONBIO, &one);
-      if (ret < 0)
-        return ret;
+        return flags;
 #endif
       timer = ptimer_new ();
       if (timer == 0)
@@ -169,27 +206,78 @@ wgnutls_read (int fd, char *buf, int bufsize, void *arg)
 
   do
     {
-      do
-        ret = gnutls_record_recv (ctx->session, buf, bufsize);
-      while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
-    }
-  while (opt.read_timeout == 0 || ptimer_measure (timer) < opt.read_timeout);
+      double next_timeout;
+      if (timeout > 0.0)
+       {
+         next_timeout = timeout - ptimer_measure (timer);
+         if (next_timeout < 0.0)
+           break;
+       }
+
+      ret = GNUTLS_E_AGAIN;
+      if (timeout == 0 || gnutls_record_check_pending (ctx->session)
+          || select_fd (fd, next_timeout, WAIT_FOR_READ))
+        {
+          if (timeout)
+            {
+#ifdef F_GETFL
+              if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+               break;
+#else
+              /* XXX: Assume it was blocking before.  */
+              const int one = 1;
+              if (ioctl (fd, FIONBIO, &one) < 0)
+               break;
+#endif
+            }
 
-  if (opt.read_timeout)
-    {
-      ptimer_destroy (timer);
+          ret = gnutls_record_recv (ctx->session, buf, bufsize);
+
+          if (timeout)
+            {
 #ifdef F_GETFL
-      ret = fcntl (fd, F_SETFL, flags);
-      if (ret < 0)
-        return ret;
+              if (fcntl (fd, F_SETFL, flags) < 0)
+               break;
 #else
-      const int zero = 0;
-      ret = ioctl (fd, FIONBIO, &zero);
-      if (ret < 0)
-        return ret;
+              const int zero = 0;
+              if (ioctl (fd, FIONBIO, &zero) < 0)
+               break;
 #endif
+            }
+        }
+
+      timed_out = timeout && ptimer_measure (timer) >= timeout;
+    }
+  while (ret == GNUTLS_E_INTERRUPTED || (ret == GNUTLS_E_AGAIN && !timed_out));
+
+  if (timeout)
+    ptimer_destroy (timer);
+
+  if (timeout && timed_out && ret == GNUTLS_E_AGAIN)
+    errno = ETIMEDOUT;
+
+  return ret;
+}
+
+static int
+wgnutls_read (int fd, char *buf, int bufsize, void *arg)
+{
+  int ret = 0;
+  struct wgnutls_transport_context *ctx = arg;
+
+  if (ctx->peeklen)
+    {
+      /* If we have any peek data, simply return that. */
+      int copysize = MIN (bufsize, ctx->peeklen);
+      memcpy (buf, ctx->peekbuf, copysize);
+      ctx->peeklen -= copysize;
+      if (ctx->peeklen != 0)
+        memmove (ctx->peekbuf, ctx->peekbuf + copysize, ctx->peeklen);
+
+      return copysize;
     }
 
+  ret = wgnutls_read_timeout (fd, buf, bufsize, arg, opt.read_timeout);
   if (ret < 0)
     ctx->last_error = ret;
 
@@ -235,9 +323,8 @@ wgnutls_peek (int fd, char *buf, int bufsize, void *arg)
           && select_fd (fd, 0.0, WAIT_FOR_READ) <= 0)
         read = 0;
       else
-        read = gnutls_record_recv (ctx->session, buf + offset,
-                                   bufsize - offset);
-
+        read = wgnutls_read_timeout (fd, buf + offset, bufsize - offset,
+                                     ctx, opt.read_timeout);
       if (read < 0)
         {
           if (offset)
@@ -284,18 +371,26 @@ static struct transport_implementation wgnutls_transport =
 };
 
 bool
-ssl_connect_wget (int fd)
+ssl_connect_wget (int fd, const char *hostname)
 {
   struct wgnutls_transport_context *ctx;
-  gnutls_session session;
+  gnutls_session_t session;
   int err;
   gnutls_init (&session, GNUTLS_CLIENT);
+
+  /* We set the server name but only if it's not an IP address. */
+  if (! is_valid_ip_address (hostname))
+    {
+      gnutls_server_name_set (session, GNUTLS_NAME_DNS, hostname,
+                             strlen (hostname));
+    }
+
   gnutls_set_default_priority (session);
   gnutls_credentials_set (session, GNUTLS_CRD_CERTIFICATE, credentials);
 #ifndef FD_TO_SOCKET
 # define FD_TO_SOCKET(X) (X)
 #endif
-  gnutls_transport_set_ptr (session, (gnutls_transport_ptr) FD_TO_SOCKET (fd));
+  gnutls_transport_set_ptr (session, (gnutls_transport_ptr_t) FD_TO_SOCKET (fd));
 
   err = 0;
 #if HAVE_GNUTLS_PRIORITY_SET_DIRECT
@@ -402,8 +497,8 @@ ssl_check_certificate (int fd, const char *host)
   if (gnutls_certificate_type_get (ctx->session) == GNUTLS_CRT_X509)
     {
       time_t now = time (NULL);
-      gnutls_x509_crt cert;
-      const gnutls_datum *cert_list;
+      gnutls_x509_crt_t cert;
+      const gnutls_datum_t *cert_list;
       unsigned int cert_list_size;
 
       if ((err = gnutls_x509_crt_init (&cert)) < 0)