=================================================================== RCS file: /cvsrepo/anoncvs/cvs/src/usr.bin/ssh/dispatch.c,v retrieving revision 1.22 retrieving revision 1.23 diff -u -r1.22 -r1.23 --- src/usr.bin/ssh/dispatch.c 2008/10/31 15:05:34 1.22 +++ src/usr.bin/ssh/dispatch.c 2015/01/19 20:07:45 1.23 @@ -1,4 +1,4 @@ -/* $OpenBSD: dispatch.c,v 1.22 2008/10/31 15:05:34 stevesk Exp $ */ +/* $OpenBSD: dispatch.c,v 1.23 2015/01/19 20:07:45 markus Exp $ */ /* * Copyright (c) 2000 Markus Friedl. All rights reserved. * @@ -34,69 +34,107 @@ #include "dispatch.h" #include "packet.h" #include "compat.h" +#include "ssherr.h" -#define DISPATCH_MAX 255 - -dispatch_fn *dispatch[DISPATCH_MAX]; - -void -dispatch_protocol_error(int type, u_int32_t seq, void *ctxt) +int +dispatch_protocol_error(int type, u_int32_t seq, void *ctx) { + struct ssh *ssh = active_state; /* XXX */ + int r; + logit("dispatch_protocol_error: type %d seq %u", type, seq); if (!compat20) fatal("protocol error"); - packet_start(SSH2_MSG_UNIMPLEMENTED); - packet_put_int(seq); - packet_send(); - packet_write_wait(); + if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 || + (r = sshpkt_put_u32(ssh, seq)) != 0 || + (r = sshpkt_send(ssh)) != 0) + fatal("%s: %s", __func__, ssh_err(r)); + ssh_packet_write_wait(ssh); + return 0; } -void -dispatch_protocol_ignore(int type, u_int32_t seq, void *ctxt) + +int +dispatch_protocol_ignore(int type, u_int32_t seq, void *ssh) { logit("dispatch_protocol_ignore: type %d seq %u", type, seq); + return 0; } + void -dispatch_init(dispatch_fn *dflt) +ssh_dispatch_init(struct ssh *ssh, dispatch_fn *dflt) { u_int i; for (i = 0; i < DISPATCH_MAX; i++) - dispatch[i] = dflt; + ssh->dispatch[i] = dflt; } + void -dispatch_range(u_int from, u_int to, dispatch_fn *fn) +ssh_dispatch_range(struct ssh *ssh, u_int from, u_int to, dispatch_fn *fn) { u_int i; for (i = from; i <= to; i++) { if (i >= DISPATCH_MAX) break; - dispatch[i] = fn; + ssh->dispatch[i] = fn; } } + void -dispatch_set(int type, dispatch_fn *fn) +ssh_dispatch_set(struct ssh *ssh, int type, dispatch_fn *fn) { - dispatch[type] = fn; + ssh->dispatch[type] = fn; } -void -dispatch_run(int mode, volatile sig_atomic_t *done, void *ctxt) + +int +ssh_dispatch_run(struct ssh *ssh, int mode, volatile sig_atomic_t *done, + void *ctxt) { - for (;;) { - int type; - u_int32_t seqnr; + int r; + u_char type; + u_int32_t seqnr; + for (;;) { if (mode == DISPATCH_BLOCK) { - type = packet_read_seqnr(&seqnr); + r = ssh_packet_read_seqnr(ssh, &type, &seqnr); + if (r != 0) + return r; } else { - type = packet_read_poll_seqnr(&seqnr); + r = ssh_packet_read_poll_seqnr(ssh, &type, &seqnr); + if (r != 0) + return r; if (type == SSH_MSG_NONE) - return; + return 0; } - if (type > 0 && type < DISPATCH_MAX && dispatch[type] != NULL) - (*dispatch[type])(type, seqnr, ctxt); - else - packet_disconnect("protocol error: rcvd type %d", type); + if (type > 0 && type < DISPATCH_MAX && + ssh->dispatch[type] != NULL) { + if (ssh->dispatch_skip_packets) { + debug2("skipped packet (type %u)", type); + ssh->dispatch_skip_packets--; + continue; + } + /* XXX 'ssh' will replace 'ctxt' later */ + r = (*ssh->dispatch[type])(type, seqnr, ctxt); + if (r != 0) + return r; + } else { + r = sshpkt_disconnect(ssh, + "protocol error: rcvd type %d", type); + if (r != 0) + return r; + return SSH_ERR_DISCONNECTED; + } if (done != NULL && *done) - return; + return 0; } +} + +void +ssh_dispatch_run_fatal(struct ssh *ssh, int mode, volatile sig_atomic_t *done, + void *ctxt) +{ + int r; + + if ((r = ssh_dispatch_run(ssh, mode, done, ctxt)) != 0) + fatal("%s: %s", __func__, ssh_err(r)); }