]> sjero.net Git - wget/blobdiff - src/gnutls.c
Fix timeout option when used with SSL
[wget] / src / gnutls.c
index 769b0059d16b7707e906fb12cb778a2161657726..06f90200c514ca303c8dd3739c7c59c7cbcafe52 100644 (file)
@@ -374,10 +374,14 @@ static struct transport_implementation wgnutls_transport =
 bool
 ssl_connect_wget (int fd, const char *hostname)
 {
+#ifdef F_GETFL
+  int flags = 0;
+#endif
   struct wgnutls_transport_context *ctx;
   gnutls_session_t session;
-  int err;
+  int err,alert;
   gnutls_init (&session, GNUTLS_CLIENT);
+  const char *str;
 
   /* We set the server name but only if it's not an IP address. */
   if (! is_valid_ip_address (hostname))
@@ -440,10 +444,83 @@ ssl_connect_wget (int fd, const char *hostname)
       return false;
     }
 
-  err = gnutls_handshake (session);
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      flags = fcntl (fd, F_GETFL, 0);
+      if (flags < 0)
+        return flags;
+      if (fcntl (fd, F_SETFL, flags | O_NONBLOCK))
+        return -1;
+#else
+      /* XXX: Assume it was blocking before.  */
+      const int one = 1;
+      if (ioctl (fd, FIONBIO, &one) < 0)
+        return -1;
+#endif
+    }
+
+  /* We don't stop the handshake process for non-fatal errors */
+  do
+    {
+      err = gnutls_handshake (session);
+
+      if (opt.connect_timeout && err == GNUTLS_E_AGAIN)
+        {
+          if (gnutls_record_get_direction (session))
+            {
+              /* wait for writeability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_WRITE);
+            }
+          else
+            {
+              /* wait for readability */
+              err = select_fd (fd, opt.connect_timeout, WAIT_FOR_READ);
+            }
+
+          if (err <= 0)
+            {
+              if (err == 0)
+                {
+                  errno = ETIMEDOUT;
+                  err = -1;
+                }
+              break;
+            }
+
+          if (err <= 0)
+            break;
+        }
+      else if (err < 0)
+        {
+          logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
+          if (err == GNUTLS_E_WARNING_ALERT_RECEIVED ||
+              err == GNUTLS_E_FATAL_ALERT_RECEIVED)
+            {
+              alert = gnutls_alert_get (session);
+              str = gnutls_alert_get_name (alert);
+              if (str == NULL)
+                str = "(unknown)";
+              logprintf (LOG_NOTQUIET, "GnuTLS: received alert [%d]: %s\n", alert, str);
+            }
+        }
+    }
+  while (err == GNUTLS_E_WARNING_ALERT_RECEIVED && gnutls_error_is_fatal (err) == 0);
+
+  if (opt.connect_timeout)
+    {
+#ifdef F_GETFL
+      if (fcntl (fd, F_SETFL, flags) < 0)
+        return -1;
+#else
+      const int zero = 0;
+      if (ioctl (fd, FIONBIO, &zero) < 0)
+        return -1;
+#endif
+    }
+
   if (err < 0)
     {
-      logprintf (LOG_NOTQUIET, "GnuTLS: %s\n", gnutls_strerror (err));
       gnutls_deinit (session);
       return false;
     }