=================================================================== RCS file: /cvsrepo/anoncvs/cvs/src/usr.bin/nc/netcat.c,v retrieving revision 1.124 retrieving revision 1.125 diff -c -r1.124 -r1.125 *** src/usr.bin/nc/netcat.c 2014/10/26 13:59:30 1.124 --- src/usr.bin/nc/netcat.c 2014/10/30 16:06:07 1.125 *************** *** 1,4 **** ! /* $OpenBSD: netcat.c,v 1.124 2014/10/26 13:59:30 millert Exp $ */ /* * Copyright (c) 2001 Eric Jackson * --- 1,4 ---- ! /* $OpenBSD: netcat.c,v 1.125 2014/10/30 16:06:07 tedu Exp $ */ /* * Copyright (c) 2001 Eric Jackson * *************** *** 64,69 **** --- 64,75 ---- #define PORT_MAX_LEN 6 #define UNIX_DG_TMP_SOCKET_SIZE 19 + #define POLL_STDIN 0 + #define POLL_NETOUT 1 + #define POLL_NETIN 2 + #define POLL_STDOUT 3 + #define BUFSIZE 2048 + /* Command Line Options */ int dflag; /* detached, no stdin */ int Fflag; /* fdpass sock to stdout */ *************** *** 111,116 **** --- 117,124 ---- int map_tos(char *, int *); void report_connect(const struct sockaddr *, socklen_t); void usage(int); + ssize_t drainbuf(int, unsigned char *, size_t *); + ssize_t fillbuf(int, unsigned char *, size_t *); int main(int argc, char *argv[]) *************** *** 390,396 **** &len); if (connfd == -1) { /* For now, all errnos are fatal */ ! err(1, "accept"); } if (vflag) report_connect((struct sockaddr *)&cliaddr, len); --- 398,404 ---- &len); if (connfd == -1) { /* For now, all errnos are fatal */ ! err(1, "accept"); } if (vflag) report_connect((struct sockaddr *)&cliaddr, len); *************** *** 729,794 **** * Loop that polls on the network file descriptor and stdin. */ void ! readwrite(int nfd) { ! struct pollfd pfd[2]; ! unsigned char buf[16 * 1024]; ! int n, wfd = fileno(stdin); ! int lfd = fileno(stdout); ! int plen; ! plen = sizeof(buf); ! /* Setup Network FD */ ! pfd[0].fd = nfd; ! pfd[0].events = POLLIN; ! /* Set up STDIN FD. */ ! pfd[1].fd = wfd; ! pfd[1].events = POLLIN; ! while (pfd[0].fd != -1) { if (iflag) sleep(iflag); ! if ((n = poll(pfd, 2 - dflag, timeout)) < 0) { ! int saved_errno = errno; ! close(nfd); ! errc(1, saved_errno, "Polling Error"); } ! if (n == 0) return; ! if (pfd[0].revents & (POLLIN|POLLHUP)) { ! if ((n = read(nfd, buf, plen)) < 0) ! return; ! else if (n == 0) { ! shutdown(nfd, SHUT_RD); ! pfd[0].fd = -1; ! pfd[0].events = 0; ! } else { ! if (tflag) ! atelnet(nfd, buf, n); ! if (atomicio(vwrite, lfd, buf, n) != n) ! return; } } ! if (!dflag && pfd[1].revents & (POLLIN|POLLHUP)) { ! if ((n = read(wfd, buf, plen)) < 0) ! return; ! else if (n == 0) { ! if (Nflag) ! shutdown(nfd, SHUT_WR); ! pfd[1].fd = -1; ! pfd[1].events = 0; ! } else { ! if (atomicio(vwrite, nfd, buf, n) != n) ! return; } } } } /* --- 737,958 ---- * Loop that polls on the network file descriptor and stdin. */ void ! readwrite(int net_fd) { ! struct pollfd pfd[4]; ! int stdin_fd = STDIN_FILENO; ! int stdout_fd = STDOUT_FILENO; ! unsigned char netinbuf[BUFSIZE]; ! size_t netinbufpos = 0; ! unsigned char stdinbuf[BUFSIZE]; ! size_t stdinbufpos = 0; ! int n, num_fds, flags; ! ssize_t ret; ! /* don't read from stdin if requested */ ! if (dflag) ! stdin_fd = -1; ! /* stdin */ ! pfd[POLL_STDIN].fd = stdin_fd; ! pfd[POLL_STDIN].events = POLLIN; ! /* network out */ ! pfd[POLL_NETOUT].fd = net_fd; ! pfd[POLL_NETOUT].events = 0; ! /* network in */ ! pfd[POLL_NETIN].fd = net_fd; ! pfd[POLL_NETIN].events = POLLIN; ! ! /* stdout */ ! pfd[POLL_STDOUT].fd = stdout_fd; ! pfd[POLL_STDOUT].events = 0; ! ! while (1) { ! /* both inputs are gone, buffers are empty, we are done */ ! if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1 ! && stdinbufpos == 0 && netinbufpos == 0) { ! close(net_fd); ! return; ! } ! /* both outputs are gone, we can't continue */ ! if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1) { ! close(net_fd); ! return; ! } ! /* listen and net in gone, queues empty, done */ ! if (lflag && pfd[POLL_NETIN].fd == -1 ! && stdinbufpos == 0 && netinbufpos == 0) { ! close(net_fd); ! return; ! } ! ! /* help says -i is for "wait between lines sent". We read and ! * write arbitrary amounts of data, and we don't want to start ! * scanning for newlines, so this is as good as it gets */ if (iflag) sleep(iflag); ! /* poll */ ! num_fds = poll(pfd, 4, timeout); ! ! /* treat poll errors */ ! if (num_fds == -1) { ! close(net_fd); ! err(1, "polling error"); } ! /* timeout happened */ ! if (num_fds == 0) return; ! /* treat socket error conditions */ ! for (n = 0; n < 4; n++) { ! if (pfd[n].revents & (POLLERR|POLLNVAL)) { ! pfd[n].fd = -1; } } + /* reading is possible after HUP */ + if (pfd[POLL_STDIN].events & POLLIN && + pfd[POLL_STDIN].revents & POLLHUP && + ! (pfd[POLL_STDIN].revents & POLLIN)) + pfd[POLL_STDIN].fd = -1; ! if (pfd[POLL_NETIN].events & POLLIN && ! pfd[POLL_NETIN].revents & POLLHUP && ! ! (pfd[POLL_NETIN].revents & POLLIN)) ! pfd[POLL_NETIN].fd = -1; ! ! if (pfd[POLL_NETOUT].revents & POLLHUP) { ! if (Nflag) ! shutdown(pfd[POLL_NETOUT].fd, SHUT_WR); ! pfd[POLL_NETOUT].fd = -1; ! } ! /* if HUP, stop watching stdout */ ! if (pfd[POLL_STDOUT].revents & POLLHUP) ! pfd[POLL_STDOUT].fd = -1; ! /* if no net out, stop watching stdin */ ! if (pfd[POLL_NETOUT].fd == -1) ! pfd[POLL_STDIN].fd = -1; ! /* if no stdout, stop watching net in */ ! if (pfd[POLL_STDOUT].fd == -1) { ! if (pfd[POLL_NETIN].fd != -1) ! shutdown(pfd[POLL_NETIN].fd, SHUT_RD); ! pfd[POLL_NETIN].fd = -1; ! } ! ! /* try to read from stdin */ ! if (pfd[POLL_STDIN].revents & POLLIN && stdinbufpos < BUFSIZE) { ! ret = fillbuf(pfd[POLL_STDIN].fd, stdinbuf, ! &stdinbufpos); ! /* error or eof on stdin - remove from pfd */ ! if (ret == 0 || ret == -1) ! pfd[POLL_STDIN].fd = -1; ! /* read something - poll net out */ ! if (stdinbufpos > 0) ! pfd[POLL_NETOUT].events = POLLOUT; ! /* filled buffer - remove self from polling */ ! if (stdinbufpos == BUFSIZE) ! pfd[POLL_STDIN].events = 0; ! } ! /* try to write to network */ ! if (pfd[POLL_NETOUT].revents & POLLOUT && stdinbufpos > 0) { ! ret = drainbuf(pfd[POLL_NETOUT].fd, stdinbuf, ! &stdinbufpos); ! if (ret == -1) ! pfd[POLL_NETOUT].fd = -1; ! /* buffer empty - remove self from polling */ ! if (stdinbufpos == 0) ! pfd[POLL_NETOUT].events = 0; ! /* buffer no longer full - poll stdin again */ ! if (stdinbufpos < BUFSIZE) ! pfd[POLL_STDIN].events = POLLIN; ! } ! /* try to read from network */ ! if (pfd[POLL_NETIN].revents & POLLIN && netinbufpos < BUFSIZE) { ! ret = fillbuf(pfd[POLL_NETIN].fd, netinbuf, ! &netinbufpos); ! if (ret == -1) ! pfd[POLL_NETIN].fd = -1; ! /* eof on net in - remove from pfd */ ! if (ret == 0) { ! shutdown(pfd[POLL_NETIN].fd, SHUT_RD); ! pfd[POLL_NETIN].fd = -1; } + /* read something - poll stdout */ + if (netinbufpos > 0) + pfd[POLL_STDOUT].events = POLLOUT; + /* filled buffer - remove self from polling */ + if (netinbufpos == BUFSIZE) + pfd[POLL_NETIN].events = 0; + /* handle telnet */ + if (tflag) + atelnet(pfd[POLL_NETIN].fd, netinbuf, + netinbufpos); } + /* try to write to stdout */ + if (pfd[POLL_STDOUT].revents & POLLOUT && netinbufpos > 0) { + ret = drainbuf(pfd[POLL_STDOUT].fd, netinbuf, + &netinbufpos); + if (ret == -1) + pfd[POLL_STDOUT].fd = -1; + /* buffer empty - remove self from polling */ + if (netinbufpos == 0) + pfd[POLL_STDOUT].events = 0; + /* buffer no longer full - poll net in again */ + if (netinbufpos < BUFSIZE) + pfd[POLL_NETIN].events = POLLIN; + } + + /* stdin gone and queue empty? */ + if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) { + if (pfd[POLL_NETOUT].fd != -1 && Nflag) + shutdown(pfd[POLL_NETOUT].fd, SHUT_WR); + pfd[POLL_NETOUT].fd = -1; + } + /* net in gone and queue empty? */ + if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0) { + pfd[POLL_STDOUT].fd = -1; + } } + } + + ssize_t + drainbuf(int fd, unsigned char *buf, size_t *bufpos) + { + ssize_t n; + ssize_t adjust; + + n = write(fd, buf, *bufpos); + /* don't treat EAGAIN, EINTR as error */ + if (n == -1 && (errno == EAGAIN || errno == EINTR)) + n = -2; + if (n <= 0) + return n; + /* adjust buffer */ + adjust = *bufpos - n; + if (adjust > 0) + memmove(buf, buf + n, adjust); + *bufpos -= n; + return n; + } + + + ssize_t + fillbuf(int fd, unsigned char *buf, size_t *bufpos) + { + size_t num = BUFSIZE - *bufpos; + ssize_t n; + + n = read(fd, buf + *bufpos, num); + /* don't treat EAGAIN, EINTR as error */ + if (n == -1 && (errno == EAGAIN || errno == EINTR)) + n = -2; + if (n <= 0) + return n; + *bufpos += n; + return n; } /*