]> sjero.net Git - wget/blobdiff - src/openssl.c
Fix timeout option when used with SSL
[wget] / src / openssl.c
index 1823f5935655a15a4b8d92279a9cb5129acb63c2..e2eec4f7c2932b704fd5cd63e0ed48a735f15813 100644 (file)
@@ -1,6 +1,6 @@
 /* SSL support via OpenSSL library.
    Copyright (C) 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
-   2009 Free Software Foundation, Inc.
+   2009, 2010, 2011, 2012 Free Software Foundation, Inc.
    Originally contributed by Christian Fraenkel.
 
 This file is part of GNU Wget.
@@ -33,9 +33,7 @@ as that of the covered work.  */
 
 #include <assert.h>
 #include <errno.h>
-#ifdef HAVE_UNISTD_H
-# include <unistd.h>
-#endif
+#include <unistd.h>
 #include <string.h>
 
 #include <openssl/ssl.h>
@@ -48,6 +46,10 @@ as that of the covered work.  */
 #include "url.h"
 #include "ssl.h"
 
+#ifdef WINDOWS
+# include <w32sock.h>
+#endif
+
 /* Application-wide SSL context.  This is common to all SSL
    connections.  */
 static SSL_CTX *ssl_ctx;
@@ -157,9 +159,9 @@ key_type_to_ssl_type (enum keyfile_type type)
    Returns true on success, false otherwise.  */
 
 bool
-ssl_init ()
+ssl_init (void)
 {
-  SSL_METHOD *meth;
+  SSL_METHOD const *meth;
 
   if (ssl_ctx)
     /* The SSL has already been initialized. */
@@ -184,9 +186,11 @@ ssl_init ()
     case secure_protocol_auto:
       meth = SSLv23_client_method ();
       break;
+#ifndef OPENSSL_NO_SSL2
     case secure_protocol_sslv2:
       meth = SSLv2_client_method ();
       break;
+#endif
     case secure_protocol_sslv3:
       meth = SSLv3_client_method ();
       break;
@@ -197,7 +201,9 @@ ssl_init ()
       abort ();
     }
 
-  ssl_ctx = SSL_CTX_new (meth);
+  /* The type cast below accommodates older OpenSSL versions (0.9.8)
+     where SSL_CTX_new() is declared without a "const" argument. */
+  ssl_ctx = SSL_CTX_new ((SSL_METHOD *)meth);
   if (!ssl_ctx)
     goto error;
 
@@ -245,23 +251,50 @@ ssl_init ()
   return false;
 }
 
-struct openssl_transport_context {
+struct openssl_transport_context
+{
   SSL *conn;                    /* SSL connection handle */
   char *last_error;             /* last error printed with openssl_errstr */
 };
 
-static int
-openssl_read (int fd, char *buf, int bufsize, void *arg)
+struct openssl_read_args
 {
-  int ret;
-  struct openssl_transport_context *ctx = arg;
+  int fd;
+  struct openssl_transport_context *ctx;
+  char *buf;
+  int bufsize;
+  int retval;
+};
+
+static void openssl_read_callback(void *arg)
+{
+  struct openssl_read_args *args = (struct openssl_read_args *) arg;
+  struct openssl_transport_context *ctx = args->ctx;
   SSL *conn = ctx->conn;
+  char *buf = args->buf;
+  int bufsize = args->bufsize;
+  int ret;
+
   do
     ret = SSL_read (conn, buf, bufsize);
-  while (ret == -1
-         && SSL_get_error (conn, ret) == SSL_ERROR_SYSCALL
+  while (ret == -1 && SSL_get_error (conn, ret) == SSL_ERROR_SYSCALL
          && errno == EINTR);
-  return ret;
+  args->retval = ret;
+}
+
+static int
+openssl_read (int fd, char *buf, int bufsize, void *arg)
+{
+  struct openssl_read_args args;
+  args.fd = fd;
+  args.buf = buf;
+  args.bufsize = bufsize;
+  args.ctx = (struct openssl_transport_context*) arg;
+
+  if (run_with_timeout(opt.read_timeout, openssl_read_callback, &args)) {
+    return -1;
+  }
+  return args.retval;
 }
 
 static int
@@ -283,10 +316,10 @@ openssl_poll (int fd, double timeout, int wait_for, void *arg)
 {
   struct openssl_transport_context *ctx = arg;
   SSL *conn = ctx->conn;
-  if (timeout == 0)
-    return 1;
   if (SSL_pending (conn))
     return 1;
+  if (timeout == 0)
+    return 1;
   return select_fd (fd, timeout, wait_for);
 }
 
@@ -296,6 +329,8 @@ openssl_peek (int fd, char *buf, int bufsize, void *arg)
   int ret;
   struct openssl_transport_context *ctx = arg;
   SSL *conn = ctx->conn;
+  if (! openssl_poll (fd, 0.0, WAIT_FOR_READ, arg))
+    return 0;
   do
     ret = SSL_peek (conn, buf, bufsize);
   while (ret == -1
@@ -364,11 +399,7 @@ openssl_close (int fd, void *arg)
   xfree_null (ctx->last_error);
   xfree (ctx);
 
-#if defined(WINDOWS) || defined(USE_WATT32)
-  closesocket (fd);
-#else
   close (fd);
-#endif
 
   DEBUGP (("Closed %d/SSL 0x%0*lx\n", fd, PTR_FORMAT (conn)));
 }
@@ -381,6 +412,19 @@ static struct transport_implementation openssl_transport = {
   openssl_peek, openssl_errstr, openssl_close
 };
 
+struct scwt_context
+{
+  SSL *ssl;
+  int result;
+};
+
+static void
+ssl_connect_with_timeout_callback(void *arg)
+{
+  struct scwt_context *ctx = (struct scwt_context *)arg;
+  ctx->result = SSL_connect(ctx->ssl);
+}
+
 /* Perform the SSL handshake on file descriptor FD, which is assumed
    to be connected to an SSL server.  The SSL handle provided by
    OpenSSL is registered with the file descriptor FD using
@@ -390,9 +434,10 @@ static struct transport_implementation openssl_transport = {
    Returns true on success, false on failure.  */
 
 bool
-ssl_connect_wget (int fd)
+ssl_connect_wget (int fd, const char *hostname)
 {
   SSL *conn;
+  struct scwt_context scwt_ctx;
   struct openssl_transport_context *ctx;
 
   DEBUGP (("Initiating SSL handshake.\n"));
@@ -401,10 +446,33 @@ ssl_connect_wget (int fd)
   conn = SSL_new (ssl_ctx);
   if (!conn)
     goto error;
-  if (!SSL_set_fd (conn, fd))
+#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
+  /* If the SSL library was build with support for ServerNameIndication
+     then use it whenever we have a hostname.  If not, don't, ever. */
+  if (! is_valid_ip_address (hostname))
+    {
+      if (! SSL_set_tlsext_host_name (conn, hostname))
+       {
+       DEBUGP (("Failed to set TLS server-name indication."));
+       goto error;
+       }
+    }
+#endif
+
+#ifndef FD_TO_SOCKET
+# define FD_TO_SOCKET(X) (X)
+#endif
+  if (!SSL_set_fd (conn, FD_TO_SOCKET (fd)))
     goto error;
   SSL_set_connect_state (conn);
-  if (SSL_connect (conn) <= 0 || conn->state != SSL_ST_OK)
+
+  scwt_ctx.ssl = conn;
+  if (run_with_timeout(opt.read_timeout, ssl_connect_with_timeout_callback,
+                       &scwt_ctx)) {
+    DEBUGP (("SSL handshake timed out.\n"));
+    goto timeout;
+  }
+  if (scwt_ctx.result <= 0 || conn->state != SSL_ST_OK)
     goto error;
 
   ctx = xnew0 (struct openssl_transport_context);
@@ -420,6 +488,7 @@ ssl_connect_wget (int fd)
  error:
   DEBUGP (("SSL handshake failed.\n"));
   print_errors ();
+ timeout:
   if (conn)
     SSL_free (conn);
   return false;