]> sjero.net Git - wget/blobdiff - src/gnutls.c
Fix timeout option when used with SSL
[wget] / src / gnutls.c
index 2b13875fb87c75966efaf41e92ca9eb6e42133a1..06f90200c514ca303c8dd3739c7c59c7cbcafe52 100644 (file)
@@ -99,7 +99,7 @@ ssl_init (void)
   dir = opendir (ca_directory);
   if (dir == NULL)
     {
-      if (opt.ca_directory)
+      if (opt.ca_directory && *opt.ca_directory)
         logprintf (LOG_NOTQUIET, _("ERROR: Cannot open directory %s.\n"),
                    opt.ca_directory);
     }
@@ -188,7 +188,7 @@ wgnutls_read_timeout (int fd, char *buf, int bufsize, void *arg, double timeout)
   int flags = 0;
 #endif
   int ret = 0;
-  struct ptimer *timer;
+  struct ptimer *timer = NULL;
   struct wgnutls_transport_context *ctx = arg;
   int timed_out = 0;
 
@@ -198,63 +198,56 @@ wgnutls_read_timeout (int fd, char *buf, int bufsize, void *arg, double timeout)
       flags = fcntl (fd, F_GETFL, 0);
       if (flags < 0)
         return flags;
+      if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+        return -1;
+#else
+      /* XXX: Assume it was blocking before.  */
+      const int one = 1;
+      if (ioctl (fd, FIONBIO, &one) < 0)
+        return -1;
 #endif
+
       timer = ptimer_new ();
-      if (timer == 0)
+      if (timer == NULL)
         return -1;
     }
 
   do
     {
-      double next_timeout;
-      if (timeout > 0.0)
-       {
-         next_timeout = timeout - ptimer_measure (timer);
-         if (next_timeout < 0.0)
-           break;
-       }
+      double next_timeout = 0;
+      if (timeout)
+        {
+          next_timeout = timeout - ptimer_measure (timer);
+          if (next_timeout < 0)
+            break;
+        }
 
       ret = GNUTLS_E_AGAIN;
       if (timeout == 0 || gnutls_record_check_pending (ctx->session)
           || select_fd (fd, next_timeout, WAIT_FOR_READ))
         {
-          if (timeout)
-            {
-#ifdef F_GETFL
-              if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
-               break;
-#else
-              /* XXX: Assume it was blocking before.  */
-              const int one = 1;
-              if (ioctl (fd, FIONBIO, &one) < 0)
-               break;
-#endif
-            }
-
           ret = gnutls_record_recv (ctx->session, buf, bufsize);
-
-          if (timeout)
-            {
-#ifdef F_GETFL
-              if (fcntl (fd, F_SETFL, flags) < 0)
-               break;
-#else
-              const int zero = 0;
-              if (ioctl (fd, FIONBIO, &zero) < 0)
-               break;
-#endif
-            }
+          timed_out = timeout && ptimer_measure (timer) >= timeout;
         }
-
-      timed_out = timeout && ptimer_measure (timer) >= timeout;
     }
   while (ret == GNUTLS_E_INTERRUPTED || (ret == GNUTLS_E_AGAIN && !timed_out));
 
   if (timeout)
-    ptimer_destroy (timer);
+    {
+      ptimer_destroy (timer);
 
-  if (timeout && timed_out && ret == GNUTLS_E_AGAIN)
-    errno = ETIMEDOUT;
+#ifdef F_GETFL
+      if (fcntl (fd, F_SETFL, flags) < 0)
+        return -1;
+#else
+      const int zero = 0;
+      if (ioctl (fd, FIONBIO, &zero) < 0)
+        return -1;
+#endif
+
+      if (timed_out && ret == GNUTLS_E_AGAIN)
+        errno = ETIMEDOUT;
+    }
 
   return ret;
 }
@@ -301,8 +294,12 @@ static int
 wgnutls_poll (int fd, double timeout, int wait_for, void *arg)
 {
   struct wgnutls_transport_context *ctx = arg;
-  return ctx->peeklen || gnutls_record_check_pending (ctx->session)
-    || select_fd (fd, timeout, wait_for);
+
+  if (timeout)
+    return ctx->peeklen || gnutls_record_check_pending (ctx->session)
+      || select_fd (fd, timeout, wait_for);
+  else
+    return ctx->peeklen || gnutls_record_check_pending (ctx->session);
 }
 
 static int
@@ -311,15 +308,19 @@ wgnutls_peek (int fd, char *buf, int bufsize, void *arg)
   int read = 0;
   struct wgnutls_transport_context *ctx = arg;
   int offset = MIN (bufsize, ctx->peeklen);
-  if (bufsize > sizeof ctx->peekbuf)
-    bufsize = sizeof ctx->peekbuf;
 
   if (ctx->peeklen)
-    memcpy (buf, ctx->peekbuf, offset);
+    {
+      memcpy (buf, ctx->peekbuf, offset);
+      return offset;
+    }
+
+  if (bufsize > sizeof ctx->peekbuf)
+    bufsize = sizeof ctx->peekbuf;
 
   if (bufsize > offset)
     {
-      if (gnutls_record_check_pending (ctx->session) <= 0
+      if (opt.read_timeout && gnutls_record_check_pending (ctx->session) == 0
           && select_fd (fd, 0.0, WAIT_FOR_READ) <= 0)
         read = 0;
       else
@@ -373,10 +374,14 @@ static struct transport_implementation wgnutls_transport =
 bool
 ssl_connect_wget (int fd, const char *hostname)
 {
+#ifdef F_GETFL
+  int flags = 0;
+#endif
   struct wgnutls_transport_context *ctx;
   gnutls_session_t session;
-  int err;
+  int err,alert;
   gnutls_init (&session, GNUTLS_CLIENT);
+  const char *str;
 
   /* We set the server name but only if it's not an IP address. */
   if (! is_valid_ip_address (hostname))
@@ -400,7 +405,7 @@ ssl_connect_wget (int fd, const char *hostname)
       break;
     case secure_protocol_sslv2:
     case secure_protocol_sslv3:
-      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-TLS-ALL", NULL);
+      err = gnutls_priority_set_direct (session, "NORMAL:-VERS-TLS-ALL:+VERS-SSL3.0", NULL);
       break;
     case secure_protocol_tlsv1:
       err = gnutls_priority_set_direct (session, "NORMAL:-VERS-SSL3.0", NULL);
@@ -439,10 +444,83 @@ ssl_connect_wget (int fd, const char *hostname)
       return false;
     }
 
-  err = gnutls_handshake (session);
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      flags = fcntl (fd, F_GETFL, 0);
+      if (flags < 0)
+        return flags;
+      if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+        return -1;
+#else
+      /* XXX: Assume it was blocking before.  */
+      const int one = 1;
+      if (ioctl (fd, FIONBIO, &one) < 0)
+        return -1;
+#endif
+    }
+
+  /* We don't stop the handshake process for non-fatal errors */
+  do
+    {
+      err = gnutls_handshake (session);
+
+      if (opt.connect_timeout && err == GNUTLS_E_AGAIN)
+        {
+          if (gnutls_record_get_direction (session))
+            {
+              /* wait for writeability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_WRITE);
+            }
+          else
+            {
+              /* wait for readability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_READ);
+            }
+
+          if (err <= 0)
+            {
+              if (err == 0)
+                {
+                  errno = ETIMEDOUT;
+                  err = -1;
+                }
+              break;
+            }
+
+          if (err <= 0)
+            break;
+        }
+      else if (err < 0)
+        {
+          logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
+          if (err == GNUTLS_E_WARNING_ALERT_RECEIVED ||
+              err == GNUTLS_E_FATAL_ALERT_RECEIVED)
+            {
+              alert = gnutls_alert_get (session);
+              str = gnutls_alert_get_name (alert);
+              if (str == NULL)
+                str = "(unknown)";
+              logprintf (LOG_NOTQUIET, "GnuTLS: received alert [%d]: %s\n", alert, str);
+            }
+        }
+    }
+  while (err == GNUTLS_E_WARNING_ALERT_RECEIVED && gnutls_error_is_fatal (err) == 0);
+
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      if (fcntl (fd, F_SETFL, flags) < 0)
+        return -1;
+#else
+      const int zero = 0;
+      if (ioctl (fd, FIONBIO, &zero) < 0)
+        return -1;
+#endif
+    }
+
   if (err < 0)
     {
-      logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
       gnutls_deinit (session);
       return false;
     }
@@ -514,7 +592,7 @@ ssl_check_certificate (int fd, const char *host)
         {
           logprintf (LOG_NOTQUIET, _("No certificate found\n"));
           success = false;
-          goto out;
+          goto crt_deinit;
         }
       err = gnutls_x509_crt_import (cert, cert_list, GNUTLS_X509_FMT_DER);
       if (err < 0)
@@ -522,7 +600,7 @@ ssl_check_certificate (int fd, const char *host)
           logprintf (LOG_NOTQUIET, _("Error parsing certificate: %s\n"),
                      gnutls_strerror (err));
           success = false;
-          goto out;
+          goto crt_deinit;
         }
       if (now < gnutls_x509_crt_get_activation_time (cert))
         {
@@ -541,6 +619,7 @@ ssl_check_certificate (int fd, const char *host)
                      quote (host));
           success = false;
         }
+ crt_deinit:
       gnutls_x509_crt_deinit (cert);
    }