#include <k5-int.h>
#include <k5-queue.h>
#include "internal.h"
#include <string.h>
#include <unistd.h>
#include <sys/un.h>
#define FLAGS_NONE VERTO_EV_FLAG_NONE
#define FLAGS_READ VERTO_EV_FLAG_IO_READ
#define FLAGS_WRITE VERTO_EV_FLAG_IO_WRITE
#define FLAGS_BASE VERTO_EV_FLAG_PERSIST | VERTO_EV_FLAG_IO_ERROR
K5_TAILQ_HEAD(request_head, request_st);
typedef struct request_st request;
struct request_st {
K5_TAILQ_ENTRY(request_st) list;
krad_remote *rr;
krad_packet *request;
krad_cb cb;
void *data;
verto_ev *timer;
int timeout;
size_t retries;
size_t sent;
};
struct krad_remote_st {
krb5_context kctx;
verto_ctx *vctx;
int fd;
verto_ev *io;
char *secret;
struct addrinfo *info;
struct request_head list;
char buffer_[KRAD_PACKET_SIZE_MAX];
krb5_data buffer;
};
static void
on_io(verto_ctx *ctx, verto_ev *ev);
static void
on_timeout(verto_ctx *ctx, verto_ev *ev);
static const krad_packet *
iterator(void *data, krb5_boolean cancel)
{
request **rptr = data, *req = *rptr;
if (cancel || req == NULL)
return NULL;
*rptr = K5_TAILQ_NEXT(req, list);
return req->request;
}
static krb5_error_code
request_new(krad_remote *rr, krad_packet *rqst, int timeout, size_t retries,
krad_cb cb, void *data, request **out)
{
request *tmp;
tmp = calloc(1, sizeof(request));
if (tmp == NULL)
return ENOMEM;
tmp->rr = rr;
tmp->request = rqst;
tmp->cb = cb;
tmp->data = data;
tmp->timeout = timeout;
tmp->retries = retries;
*out = tmp;
return 0;
}
static inline void
request_finish(request *req, krb5_error_code retval,
const krad_packet *response)
{
if (retval != ETIMEDOUT)
K5_TAILQ_REMOVE(&req->rr->list, req, list);
req->cb(retval, req->request, response, req->data);
if (retval != ETIMEDOUT) {
krad_packet_free(req->request);
verto_del(req->timer);
free(req);
}
}
static krb5_error_code
request_start_timer(request *r, verto_ctx *vctx)
{
verto_del(r->timer);
r->timer = verto_add_timeout(vctx, VERTO_EV_FLAG_NONE, on_timeout,
r->timeout);
if (r->timer != NULL)
verto_set_private(r->timer, r, NULL);
return (r->timer == NULL) ? ENOMEM : 0;
}
static void
remote_disconnect(krad_remote *rr)
{
if (rr->fd >= 0)
close(rr->fd);
verto_del(rr->io);
rr->fd = -1;
rr->io = NULL;
}
static krb5_error_code
remote_add_flags(krad_remote *remote, verto_ev_flag flags)
{
verto_ev_flag curflags = VERTO_EV_FLAG_NONE;
int i;
flags &= (FLAGS_READ | FLAGS_WRITE);
if (remote == NULL || flags == FLAGS_NONE)
return EINVAL;
if (remote->fd < 0) {
verto_del(remote->io);
remote->io = NULL;
remote->fd = socket(remote->info->ai_family, remote->info->ai_socktype,
remote->info->ai_protocol);
if (remote->fd < 0)
return errno;
i = connect(remote->fd, remote->info->ai_addr,
remote->info->ai_addrlen);
if (i < 0) {
i = errno;
remote_disconnect(remote);
return i;
}
}
if (remote->io == NULL) {
remote->io = verto_add_io(remote->vctx, FLAGS_BASE | flags,
on_io, remote->fd);
if (remote->io == NULL)
return ENOMEM;
verto_set_private(remote->io, remote, NULL);
}
curflags = verto_get_flags(remote->io);
if ((curflags & flags) != flags)
verto_set_flags(remote->io, FLAGS_BASE | curflags | flags);
return 0;
}
static void
remote_del_flags(krad_remote *remote, verto_ev_flag flags)
{
if (remote == NULL || remote->io == NULL)
return;
flags = verto_get_flags(remote->io) & (FLAGS_READ | FLAGS_WRITE) & ~flags;
if (flags == FLAGS_NONE) {
verto_del(remote->io);
remote->io = NULL;
return;
}
verto_set_flags(remote->io, FLAGS_BASE | flags);
}
static void
remote_shutdown(krad_remote *rr)
{
krb5_error_code retval;
request *r, *next;
remote_disconnect(rr);
K5_TAILQ_FOREACH_SAFE(r, &rr->list, list, next) {
if (r->timer == NULL) {
retval = request_start_timer(r, rr->vctx);
if (retval != 0)
request_finish(r, retval, NULL);
}
}
}
static void
on_timeout(verto_ctx *ctx, verto_ev *ev)
{
request *req = verto_get_private(ev);
krb5_error_code retval = ETIMEDOUT;
req->timer = NULL;
if (req->retries-- > 0) {
req->sent = 0;
retval = remote_add_flags(req->rr, FLAGS_WRITE);
if (retval == 0)
return;
}
request_finish(req, retval, NULL);
}
static void
on_io_write(krad_remote *rr)
{
const krb5_data *tmp;
ssize_t written;
request *r;
K5_TAILQ_FOREACH(r, &rr->list, list) {
tmp = krad_packet_encode(r->request);
if (r->sent == tmp->length)
continue;
written = sendto(verto_get_fd(rr->io), tmp->data + r->sent,
tmp->length - r->sent, 0, NULL, 0);
if (written < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN || errno == ENOBUFS ||
errno == EINTR)
return;
remote_shutdown(rr);
return;
}
r->sent += written;
if (r->sent == tmp->length) {
if (request_start_timer(r, rr->vctx) != 0) {
request_finish(r, ENOMEM, NULL);
return;
}
if (remote_add_flags(rr, FLAGS_READ) != 0) {
remote_shutdown(rr);
return;
}
}
return;
}
remote_del_flags(rr, FLAGS_WRITE);
return;
}
static void
on_io_read(krad_remote *rr)
{
const krad_packet *req = NULL;
krad_packet *rsp = NULL;
krb5_error_code retval;
ssize_t pktlen;
request *tmp, *r;
int i;
pktlen = sizeof(rr->buffer_) - rr->buffer.length;
if (rr->info->ai_socktype == SOCK_STREAM) {
pktlen = krad_packet_bytes_needed(&rr->buffer);
if (pktlen < 0) {
remote_shutdown(rr);
return;
}
}
i = recv(verto_get_fd(rr->io), rr->buffer.data + rr->buffer.length,
pktlen, 0);
if (i < 0 && (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR))
return;
if (i <= 0) {
remote_shutdown(rr);
return;
}
rr->buffer.length += i;
pktlen = krad_packet_bytes_needed(&rr->buffer);
if (rr->info->ai_socktype == SOCK_STREAM && pktlen > 0)
return;
tmp = K5_TAILQ_FIRST(&rr->list);
retval = krad_packet_decode_response(rr->kctx, rr->secret, &rr->buffer,
iterator, &tmp, &req, &rsp);
rr->buffer.length = 0;
if (retval != 0)
return;
if (req != NULL) {
K5_TAILQ_FOREACH(r, &rr->list, list) {
if (r->request == req &&
r->sent == krad_packet_encode(req)->length) {
request_finish(r, 0, rsp);
break;
}
}
}
krad_packet_free(rsp);
}
static void
on_io(verto_ctx *ctx, verto_ev *ev)
{
krad_remote *rr;
rr = verto_get_private(ev);
if (verto_get_fd_state(ev) & VERTO_EV_FLAG_IO_WRITE)
on_io_write(rr);
else
on_io_read(rr);
}
krb5_error_code
kr_remote_new(krb5_context kctx, verto_ctx *vctx, const struct addrinfo *info,
const char *secret, krad_remote **rr)
{
krb5_error_code retval = ENOMEM;
krad_remote *tmp = NULL;
tmp = calloc(1, sizeof(krad_remote));
if (tmp == NULL)
goto error;
tmp->kctx = kctx;
tmp->vctx = vctx;
tmp->buffer = make_data(tmp->buffer_, 0);
K5_TAILQ_INIT(&tmp->list);
tmp->fd = -1;
tmp->secret = strdup(secret);
if (tmp->secret == NULL)
goto error;
tmp->info = k5memdup(info, sizeof(*info), &retval);
if (tmp->info == NULL)
goto error;
tmp->info->ai_addr = k5memdup(info->ai_addr, info->ai_addrlen, &retval);
if (tmp->info == NULL)
goto error;
tmp->info->ai_next = NULL;
tmp->info->ai_canonname = NULL;
*rr = tmp;
return 0;
error:
kr_remote_free(tmp);
return retval;
}
void
kr_remote_cancel_all(krad_remote *rr)
{
while (!K5_TAILQ_EMPTY(&rr->list))
request_finish(K5_TAILQ_FIRST(&rr->list), ECANCELED, NULL);
}
void
kr_remote_free(krad_remote *rr)
{
if (rr == NULL)
return;
kr_remote_cancel_all(rr);
free(rr->secret);
if (rr->info != NULL)
free(rr->info->ai_addr);
free(rr->info);
remote_disconnect(rr);
free(rr);
}
krb5_error_code
kr_remote_send(krad_remote *rr, krad_code code, krad_attrset *attrs,
krad_cb cb, void *data, int timeout, size_t retries,
const krad_packet **pkt)
{
krad_packet *tmp = NULL;
krb5_error_code retval;
request *r, *new_request = NULL;
if (rr->info->ai_socktype == SOCK_STREAM)
retries = 0;
r = K5_TAILQ_FIRST(&rr->list);
retval = krad_packet_new_request(rr->kctx, rr->secret, code, attrs,
iterator, &r, &tmp);
if (retval != 0)
goto error;
K5_TAILQ_FOREACH(r, &rr->list, list) {
if (r->request == tmp) {
retval = EALREADY;
goto error;
}
}
timeout = timeout / (retries + 1);
retval = request_new(rr, tmp, timeout, retries, cb, data, &new_request);
if (retval != 0)
goto error;
retval = remote_add_flags(rr, FLAGS_WRITE);
if (retval != 0)
goto error;
K5_TAILQ_INSERT_TAIL(&rr->list, new_request, list);
if (pkt != NULL)
*pkt = tmp;
return 0;
error:
free(new_request);
krad_packet_free(tmp);
return retval;
}
void
kr_remote_cancel(krad_remote *rr, const krad_packet *pkt)
{
request *r;
K5_TAILQ_FOREACH(r, &rr->list, list) {
if (r->request == pkt) {
request_finish(r, ECANCELED, NULL);
return;
}
}
}
krb5_boolean
kr_remote_equals(const krad_remote *rr, const struct addrinfo *info,
const char *secret)
{
struct sockaddr_un *a, *b;
if (strcmp(rr->secret, secret) != 0)
return FALSE;
if (info->ai_addrlen != rr->info->ai_addrlen)
return FALSE;
if (info->ai_family != rr->info->ai_family)
return FALSE;
if (info->ai_socktype != rr->info->ai_socktype)
return FALSE;
if (info->ai_protocol != rr->info->ai_protocol)
return FALSE;
if (info->ai_flags != rr->info->ai_flags)
return FALSE;
if (memcmp(rr->info->ai_addr, info->ai_addr, info->ai_addrlen) != 0) {
if (info->ai_family != AF_UNIX)
return FALSE;
a = (struct sockaddr_un *)info->ai_addr;
b = (struct sockaddr_un *)rr->info->ai_addr;
if (strncmp(a->sun_path, b->sun_path, sizeof(a->sun_path)) != 0)
return FALSE;
}
return TRUE;
}