=================================================================== RCS file: /cvsrepo/anoncvs/cvs/src/usr.bin/ssh/packet.c,v retrieving revision 1.204 retrieving revision 1.205 diff -u -r1.204 -r1.205 --- src/usr.bin/ssh/packet.c 2015/01/28 21:15:47 1.204 +++ src/usr.bin/ssh/packet.c 2015/01/30 01:13:33 1.205 @@ -1,4 +1,4 @@ -/* $OpenBSD: packet.c,v 1.204 2015/01/28 21:15:47 djm Exp $ */ +/* $OpenBSD: packet.c,v 1.205 2015/01/30 01:13:33 djm Exp $ */ /* * Author: Tatu Ylonen * Copyright (c) 1995 Tatu Ylonen , Espoo, Finland @@ -266,20 +266,26 @@ const struct sshcipher *none = cipher_by_name("none"); int r; - if (none == NULL) - fatal("%s: cannot load cipher 'none'", __func__); + if (none == NULL) { + error("%s: cannot load cipher 'none'", __func__); + return NULL; + } if (ssh == NULL) ssh = ssh_alloc_session_state(); - if (ssh == NULL) - fatal("%s: cound not allocate state", __func__); + if (ssh == NULL) { + error("%s: cound not allocate state", __func__); + return NULL; + } state = ssh->state; state->connection_in = fd_in; state->connection_out = fd_out; if ((r = cipher_init(&state->send_context, none, (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 || (r = cipher_init(&state->receive_context, none, - (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) - fatal("%s: cipher_init failed: %s", __func__, ssh_err(r)); + (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) { + error("%s: cipher_init failed: %s", __func__, ssh_err(r)); + return NULL; + } state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL; deattack_init(&state->deattack); return ssh; @@ -882,8 +888,8 @@ /* * Note that the packet is now only buffered in output. It won't be - * actually sent until packet_write_wait or packet_write_poll is - * called. + * actually sent until ssh_packet_write_wait or ssh_packet_write_poll + * is called. */ r = 0; out: @@ -1252,8 +1258,12 @@ if (setp == NULL) return SSH_ERR_ALLOC_FAIL; - /* Since we are blocking, ensure that all written packets have been sent. */ - ssh_packet_write_wait(ssh); + /* + * Since we are blocking, ensure that all written packets have + * been sent. + */ + if ((r = ssh_packet_write_wait(ssh)) != 0) + return r; /* Stay in the loop until we have received a complete packet. */ for (;;) { @@ -1339,16 +1349,22 @@ * that given, and gives a fatal error and exits if there is a mismatch. */ -void -ssh_packet_read_expect(struct ssh *ssh, int expected_type) +int +ssh_packet_read_expect(struct ssh *ssh, u_int expected_type) { - int type; + int r; + u_char type; - type = ssh_packet_read(ssh); - if (type != expected_type) - ssh_packet_disconnect(ssh, + if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0) + return r; + if (type != expected_type) { + if ((r = sshpkt_disconnect(ssh, "Protocol error: expected packet type %d, got %d", - expected_type, type); + expected_type, type)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } + return 0; } /* Checks if a full packet is available in the data received so far via @@ -1365,6 +1381,7 @@ { struct session_state *state = ssh->state; u_int len, padded_len; + const char *emsg; const u_char *cp; u_char *p; u_int checksum, stored_checksum; @@ -1377,9 +1394,12 @@ return 0; /* Get length of incoming packet. */ len = PEEK_U32(sshbuf_ptr(state->input)); - if (len < 1 + 2 + 2 || len > 256 * 1024) - ssh_packet_disconnect(ssh, "Bad packet length %u.", - len); + if (len < 1 + 2 + 2 || len > 256 * 1024) { + if ((r = sshpkt_disconnect(ssh, "Bad packet length %u", + len)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } padded_len = (len + 8) & ~7; /* Check if the packet has been entirely received. */ @@ -1398,20 +1418,28 @@ * Ariel Futoransky(futo@core-sdi.com) */ if (!state->receive_context.plaintext) { + emsg = NULL; switch (detect_attack(&state->deattack, sshbuf_ptr(state->input), padded_len)) { case DEATTACK_OK: break; case DEATTACK_DETECTED: - ssh_packet_disconnect(ssh, - "crc32 compensation attack: network attack detected" - ); + emsg = "crc32 compensation attack detected"; + break; case DEATTACK_DOS_DETECTED: - ssh_packet_disconnect(ssh, - "deattack denial of service detected"); + emsg = "deattack denial of service detected"; + break; default: - ssh_packet_disconnect(ssh, "deattack error"); + emsg = "deattack error"; + break; } + if (emsg != NULL) { + error("%s", emsg); + if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } } /* Decrypt data to incoming_packet. */ @@ -1439,16 +1467,24 @@ goto out; /* Test check bytes. */ - if (len != sshbuf_len(state->incoming_packet)) - ssh_packet_disconnect(ssh, - "packet_read_poll1: len %d != sshbuf_len %zd.", + if (len != sshbuf_len(state->incoming_packet)) { + error("%s: len %d != sshbuf_len %zd", __func__, len, sshbuf_len(state->incoming_packet)); + if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } cp = sshbuf_ptr(state->incoming_packet) + len - 4; stored_checksum = PEEK_U32(cp); - if (checksum != stored_checksum) - ssh_packet_disconnect(ssh, - "Corrupted check bytes on input."); + if (checksum != stored_checksum) { + error("Corrupted check bytes on input"); + if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0) goto out; @@ -1466,9 +1502,13 @@ state->p_read.bytes += padded_len + 4; if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) goto out; - if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) - ssh_packet_disconnect(ssh, - "Invalid ssh1 packet type: %d", *typep); + if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) { + error("Invalid ssh1 packet type: %d", *typep); + if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } r = 0; out: return r; @@ -1622,7 +1662,6 @@ if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0) goto out; } - /* XXX now it's safe to use fatal/packet_disconnect */ if (seqnr_p != NULL) *seqnr_p = state->p_read.seqnr; if (++state->p_read.seqnr == 0) @@ -1636,9 +1675,13 @@ /* get padlen */ padlen = sshbuf_ptr(state->incoming_packet)[4]; DBG(debug("input: padlen %d", padlen)); - if (padlen < 4) - ssh_packet_disconnect(ssh, - "Corrupted padlen %d on input.", padlen); + if (padlen < 4) { + if ((r = sshpkt_disconnect(ssh, + "Corrupted padlen %d on input.", padlen)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } /* skip packet size + padlen, discard padding */ if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 || @@ -1665,9 +1708,13 @@ */ if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) goto out; - if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) - ssh_packet_disconnect(ssh, - "Invalid ssh2 packet type: %d", *typep); + if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) { + if ((r = sshpkt_disconnect(ssh, + "Invalid ssh2 packet type: %d", *typep)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } if (*typep == SSH2_MSG_NEWKEYS) r = ssh_set_newkeys(ssh, MODE_IN); else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side) @@ -1804,9 +1851,8 @@ * message is printed immediately, but only if the client is being executed * in verbose mode. These messages are primarily intended to ease debugging * authentication problems. The length of the formatted message must not - * exceed 1024 bytes. This will automatically call packet_write_wait. + * exceed 1024 bytes. This will automatically call ssh_packet_write_wait. */ - void ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...) { @@ -1834,16 +1880,37 @@ (r = sshpkt_send(ssh)) != 0) fatal("%s: %s", __func__, ssh_err(r)); } - ssh_packet_write_wait(ssh); + if ((r = ssh_packet_write_wait(ssh)) != 0) + fatal("%s: %s", __func__, ssh_err(r)); } /* + * Pretty-print connection-terminating errors and exit. + */ +void +sshpkt_fatal(struct ssh *ssh, const char *tag, int r) +{ + switch (r) { + case SSH_ERR_CONN_CLOSED: + logit("Connection closed by %.200s", ssh_remote_ipaddr(ssh)); + cleanup_exit(255); + case SSH_ERR_CONN_TIMEOUT: + logit("Connection to %.200s timed out while " + "waiting to write", ssh_remote_ipaddr(ssh)); + cleanup_exit(255); + default: + fatal("%s%sConnection to %.200s: %s", + tag != NULL ? tag : "", tag != NULL ? ": " : "", + ssh_remote_ipaddr(ssh), ssh_err(r)); + } +} + +/* * Logs the error plus constructs and sends a disconnect packet, closes the * connection, and exits. This function never returns. The error message * should not contain a newline. The length of the formatted message must * not exceed 1024 bytes. */ - void ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...) { @@ -1867,30 +1934,26 @@ /* Display the error locally */ logit("Disconnecting: %.100s", buf); - /* Send the disconnect message to the other side, and wait for it to get sent. */ - if (compat20) { - if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 || - (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 || - (r = sshpkt_put_cstring(ssh, buf)) != 0 || - (r = sshpkt_put_cstring(ssh, "")) != 0 || - (r = sshpkt_send(ssh)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); - } else { - if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 || - (r = sshpkt_put_cstring(ssh, buf)) != 0 || - (r = sshpkt_send(ssh)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); - } - ssh_packet_write_wait(ssh); + /* + * Send the disconnect message to the other side, and wait + * for it to get sent. + */ + if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0) + sshpkt_fatal(ssh, __func__, r); + if ((r = ssh_packet_write_wait(ssh)) != 0) + sshpkt_fatal(ssh, __func__, r); + /* Close the connection. */ ssh_packet_close(ssh); cleanup_exit(255); } -/* Checks if there is any buffered output, and tries to write some of the output. */ - -void +/* + * Checks if there is any buffered output, and tries to write some of + * the output. + */ +int ssh_packet_write_poll(struct ssh *ssh) { struct session_state *state = ssh->state; @@ -1903,33 +1966,33 @@ sshbuf_ptr(state->output), len, &cont); if (len == -1) { if (errno == EINTR || errno == EAGAIN) - return; - fatal("Write failed: %.100s", strerror(errno)); + return 0; + return SSH_ERR_SYSTEM_ERROR; } if (len == 0 && !cont) - fatal("Write connection closed"); + return SSH_ERR_CONN_CLOSED; if ((r = sshbuf_consume(state->output, len)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); + return r; } + return 0; } /* * Calls packet_write_poll repeatedly until all pending output data has been * written. */ - -void +int ssh_packet_write_wait(struct ssh *ssh) { fd_set *setp; - int ret, ms_remain = 0; + int ret, r, ms_remain = 0; struct timeval start, timeout, *timeoutp = NULL; struct session_state *state = ssh->state; setp = (fd_set *)calloc(howmany(state->connection_out + 1, NFDBITS), sizeof(fd_mask)); if (setp == NULL) - fatal("%s: calloc failed", __func__); + return SSH_ERR_ALLOC_FAIL; ssh_packet_write_poll(ssh); while (ssh_packet_have_data_to_write(ssh)) { memset(setp, 0, howmany(state->connection_out + 1, @@ -1959,13 +2022,16 @@ } } if (ret == 0) { - logit("Connection to %.200s timed out while " - "waiting to write", ssh_remote_ipaddr(ssh)); - cleanup_exit(255); + free(setp); + return SSH_ERR_CONN_TIMEOUT; } - ssh_packet_write_poll(ssh); + if ((r = ssh_packet_write_poll(ssh)) != 0) { + free(setp); + return r; + } } free(setp); + return 0; } /* Returns true if there is buffered data to write to the connection. */