#include <errno.h>
#include <stdlib.h>
#include <stdbool.h>
#include <err.h>
#include <sysexits.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/ip.h>
struct context {
unsigned short divert_port;
bool divert_back;
int fd;
struct sockaddr_in sin;
socklen_t sin_len;
char pkt[IP_MAXPACKET];
ssize_t pkt_n;
};
static void
init(struct context *c)
{
c->fd = socket(PF_DIVERT, SOCK_RAW, 0);
if (c->fd == -1)
errx(EX_OSERR, "init: Cannot create divert socket.");
memset(&c->sin, 0, sizeof(c->sin));
c->sin.sin_family = AF_INET;
c->sin.sin_port = htons(c->divert_port);
c->sin.sin_addr.s_addr = INADDR_ANY;
c->sin_len = sizeof(struct sockaddr_in);
if (bind(c->fd, (struct sockaddr *) &c->sin, c->sin_len) != 0)
errx(EX_OSERR, "init: Cannot bind divert socket.");
}
static ssize_t
recv_pkt(struct context *c)
{
fd_set readfds;
struct timeval timeout;
int s;
FD_ZERO(&readfds);
FD_SET(c->fd, &readfds);
timeout.tv_sec = 3;
timeout.tv_usec = 0;
s = select(c->fd + 1, &readfds, 0, 0, &timeout);
if (s == -1)
errx(EX_IOERR, "recv_pkt: select() errors.");
if (s != 1)
return (-1);
c->pkt_n = recvfrom(c->fd, c->pkt, sizeof(c->pkt), 0,
(struct sockaddr *) &c->sin, &c->sin_len);
if (c->pkt_n == -1)
errx(EX_IOERR, "recv_pkt: recvfrom() errors.");
return (c->pkt_n);
}
static void
send_pkt(struct context *c)
{
ssize_t n;
n = sendto(c->fd, c->pkt, c->pkt_n, 0,
(struct sockaddr *) &c->sin, c->sin_len);
if (n == -1)
err(EX_IOERR, "send_pkt: sendto() errors");
if (n != c->pkt_n)
errx(EX_IOERR, "send_pkt: sendto() sent %zd of %zd bytes.",
n, c->pkt_n);
}
int
main(int argc, char *argv[])
{
struct context c;
int npkt;
if (argc < 2)
errx(EX_USAGE,
"Usage: %s <divert-port> [divert-back]", argv[0]);
memset(&c, 0, sizeof(struct context));
c.divert_port = (unsigned short) strtol(argv[1], NULL, 10);
if (c.divert_port == 0)
errx(EX_USAGE, "divert port is not defined.");
if (argc >= 3 && strcmp(argv[2], "divert-back") == 0)
c.divert_back = true;
init(&c);
npkt = 0;
while (recv_pkt(&c) > 0) {
if (c.divert_back)
send_pkt(&c);
npkt++;
if (npkt >= 10)
break;
}
if (npkt != 1)
errx(EXIT_FAILURE, "%d: npkt=%d.", c.divert_port, npkt);
return (EXIT_SUCCESS);
}