=================================================================== RCS file: /cvsrepo/anoncvs/cvs/src/usr.bin/ssh/sshconnect.c,v retrieving revision 1.201 retrieving revision 1.202 diff -u -r1.201 -r1.202 --- src/usr.bin/ssh/sshconnect.c 2007/08/23 03:23:26 1.201 +++ src/usr.bin/ssh/sshconnect.c 2007/09/04 11:15:55 1.202 @@ -1,4 +1,4 @@ -/* $OpenBSD: sshconnect.c,v 1.201 2007/08/23 03:23:26 djm Exp $ */ +/* $OpenBSD: sshconnect.c,v 1.202 2007/09/04 11:15:55 djm Exp $ */ /* * Author: Tatu Ylonen * Copyright (c) 1995 Tatu Ylonen , Espoo, Finland @@ -64,6 +64,23 @@ static int show_other_keys(const char *, Key *); static void warn_changed_key(Key *); +static void +ms_subtract_diff(struct timeval *start, int *ms) +{ + struct timeval diff, finish; + + gettimeofday(&finish, NULL); + timersub(&finish, start, &diff); + *ms -= (diff.tv_sec * 1000) + (diff.tv_usec / 1000); +} + +static void +ms_to_timeval(struct timeval *tv, int ms) +{ + tv->tv_sec = ms / 1000; + tv->tv_usec = (ms % 1000) * 1000; +} + /* * Connect to the given ssh server using a proxy command. */ @@ -210,30 +227,36 @@ static int timeout_connect(int sockfd, const struct sockaddr *serv_addr, - socklen_t addrlen, int timeout) + socklen_t addrlen, int *timeoutp) { fd_set *fdset; - struct timeval tv; + struct timeval tv, t_start; socklen_t optlen; int optval, rc, result = -1; - if (timeout <= 0) - return (connect(sockfd, serv_addr, addrlen)); + gettimeofday(&t_start, NULL); + if (*timeoutp <= 0) { + result = connect(sockfd, serv_addr, addrlen); + goto done; + } + set_nonblock(sockfd); rc = connect(sockfd, serv_addr, addrlen); if (rc == 0) { unset_nonblock(sockfd); - return (0); + result = 0; + goto done; } - if (errno != EINPROGRESS) - return (-1); + if (errno != EINPROGRESS) { + result = -1; + goto done; + } fdset = (fd_set *)xcalloc(howmany(sockfd + 1, NFDBITS), sizeof(fd_mask)); FD_SET(sockfd, fdset); - tv.tv_sec = timeout; - tv.tv_usec = 0; + ms_to_timeval(&tv, *timeoutp); for (;;) { rc = select(sockfd + 1, NULL, fdset, NULL, &tv); @@ -272,6 +295,16 @@ } xfree(fdset); + + done: + if (result == 0 && *timeoutp > 0) { + ms_subtract_diff(&t_start, timeoutp); + if (*timeoutp <= 0) { + errno = ETIMEDOUT; + result = -1; + } + } + return (result); } @@ -288,8 +321,8 @@ */ int ssh_connect(const char *host, struct sockaddr_storage * hostaddr, - u_short port, int family, int connection_attempts, - int needpriv, const char *proxy_command) + u_short port, int family, int connection_attempts, int *timeout_ms, + int want_keepalive, int needpriv, const char *proxy_command) { int gaierr; int on = 1; @@ -342,7 +375,7 @@ continue; if (timeout_connect(sock, ai->ai_addr, ai->ai_addrlen, - options.connection_timeout) >= 0) { + timeout_ms) >= 0) { /* Successful connection. */ memcpy(hostaddr, ai->ai_addr, ai->ai_addrlen); break; @@ -369,7 +402,7 @@ debug("Connection established."); /* Set SO_KEEPALIVE if requested. */ - if (options.tcp_keep_alive && + if (want_keepalive && setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (void *)&on, sizeof(on)) < 0) error("setsockopt SO_KEEPALIVE: %.100s", strerror(errno)); @@ -385,7 +418,7 @@ * identification string. */ static void -ssh_exchange_identification(void) +ssh_exchange_identification(int timeout_ms) { char buf[256], remote_version[256]; /* must be same size! */ int remote_major, remote_minor, mismatch; @@ -393,16 +426,44 @@ int connection_out = packet_get_connection_out(); int minor1 = PROTOCOL_MINOR_1; u_int i, n; + size_t len; + int fdsetsz, remaining, rc; + struct timeval t_start, t_remaining; + fd_set *fdset; + fdsetsz = howmany(connection_in + 1, NFDBITS) * sizeof(fd_mask); + fdset = xcalloc(1, fdsetsz); + /* Read other side's version identification. */ + remaining = timeout_ms; for (n = 0;;) { for (i = 0; i < sizeof(buf) - 1; i++) { - size_t len = atomicio(read, connection_in, &buf[i], 1); + if (timeout_ms > 0) { + gettimeofday(&t_start, NULL); + ms_to_timeval(&t_remaining, remaining); + FD_SET(connection_in, fdset); + rc = select(connection_in + 1, fdset, NULL, + fdset, &t_remaining); + ms_subtract_diff(&t_start, &remaining); + if (rc == 0 || remaining <= 0) + fatal("Connection timed out during " + "banner exchange"); + if (rc == -1) { + if (errno == EINTR) + continue; + fatal("ssh_exchange_identification: " + "select: %s", strerror(errno)); + } + } + len = atomicio(read, connection_in, &buf[i], 1); + if (len != 1 && errno == EPIPE) - fatal("ssh_exchange_identification: Connection closed by remote host"); + fatal("ssh_exchange_identification: " + "Connection closed by remote host"); else if (len != 1) - fatal("ssh_exchange_identification: read: %.100s", strerror(errno)); + fatal("ssh_exchange_identification: " + "read: %.100s", strerror(errno)); if (buf[i] == '\r') { buf[i] = '\n'; buf[i + 1] = 0; @@ -413,7 +474,8 @@ break; } if (++n > 65536) - fatal("ssh_exchange_identification: No banner received"); + fatal("ssh_exchange_identification: " + "No banner received"); } buf[sizeof(buf) - 1] = 0; if (strncmp(buf, "SSH-", 4) == 0) @@ -421,6 +483,7 @@ debug("ssh_exchange_identification: %s", buf); } server_version_string = xstrdup(buf); + xfree(fdset); /* * Check that the versions match. In future this might accept @@ -929,7 +992,7 @@ */ void ssh_login(Sensitive *sensitive, const char *orighost, - struct sockaddr *hostaddr, struct passwd *pw) + struct sockaddr *hostaddr, struct passwd *pw, int timeout_ms) { char *host, *cp; char *server_user, *local_user; @@ -944,7 +1007,7 @@ *cp = (char)tolower(*cp); /* Exchange protocol version identification strings with the server. */ - ssh_exchange_identification(); + ssh_exchange_identification(timeout_ms); /* Put the connection into non-blocking mode. */ packet_set_nonblocking();