/*
 * xbmsp.c	XBMSP 1.0 library
 *
 * Version:	$Id: $
 *
 * Copyright:	(C)2007 Miquel van Smoorenburg <miquels@cistron.nl>
 *
 *		This program is free software; you can redistribute it and/or
 *		modify it under the terms of the GNU General Public License
 *		as published by the Free Software Foundation; either version
 *		2 of the License, or (at your option) any later version.
 */
#include <sys/types.h>
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <errno.h>

#include "xbmsp.h"

struct xbmsp_type xbmsp_types[] = {
   {  XBMSP_PACKET_OK,			"OK",			""	},
   {  XBMSP_PACKET_ERROR,		"Error",		"bs"	},
   {  XBMSP_PACKET_HANDLE,		"Handle",		"i"	},
   {  XBMSP_PACKET_FILE_DATA,		"File Data",		"ss"	},
   {  XBMSP_PACKET_FILE_CONTENTS,	"File Contents",	"s"	},
   {  XBMSP_PACKET_AUTHENTICATION_CONTINUE,	"Auth Continue", "d"	},
   {  XBMSP_PACKET_NULL,		"NULL",			"d"	},
   {  XBMSP_PACKET_SETCWD,		"Setcwd",		"s"	},
   {  XBMSP_PACKET_FILELIST_OPEN,	"Filelist Open",	""	},
   {  XBMSP_PACKET_FILELIST_READ,	"Filelist Read",	"i"	},
   {  XBMSP_PACKET_FILE_INFO,		"File Info",		"s"	},
   {  XBMSP_PACKET_FILE_OPEN,		"File Open",		"s"	},
   {  XBMSP_PACKET_FILE_READ,		"File Read",		"ii"	},
   {  XBMSP_PACKET_FILE_SEEK,		"File Seek",		"ibl"	},
   {  XBMSP_PACKET_CLOSE,		"Close",		"i"	},
   {  XBMSP_PACKET_CLOSE_ALL,		"Close All",		""	},
   {  XBMSP_PACKET_SET_CONFIGURATION_OPTION,	"Set Config Option", "ss" },
   {  XBMSP_PACKET_AUTHENTICATION_INIT,	"Auth Init",		"s"	},
   {  XBMSP_PACKET_AUTHENTICATE,	"Authenticate",		"id"	},
   {  XBMSP_PACKET_UPCWD,		"Up CWD",		"i"	},
   {  XBMSP_PACKET_SERVER_DISCOVERY_QUERY,	"Discover query", "s"	},
   {  XBMSP_PACKET_SERVER_DISCOVERY_REPLY,	"Discover reply", "ssss" },
   {  XBMSP_PACKET_UNKNOWN,		"Unknown type",		"i"	},
   {  -1,				NULL,			""	},
};

struct xbmsp_err {
	int		errval;
	char		*errstr;
} xbmsp_err[] = {
   {  XBMSP_E_SHORTREAD,		"Short read"		},
   {  XBMSP_E_SHORTWRITE,		"Short write"		},
   {  XBMSP_E_PKTSHORT,			"Short packet",		},
   {  XBMSP_E_PKTCORR,			"Packet corrupt",	},
   {  XBMSP_E_ATTRMISS,			"Missing attribute",	},
   {  XBMSP_E_ATTRCORR,			"Corrupt attribute",	},
   {  XBMSP_E_UNKNOWNTYPE,		"Unknown type",		},
   {  XBMSP_E_UNEXPECTEDEOF,		"Unexpected EOF",	},
   {  0,				NULL,			},
};

#ifdef __linux__
void strlcpy(char *dst, char *src, int sz)
{
        strncpy(dst, src, sz - 1);
        dst[sz - 1] = 0;
}
void strlcat(char *dst, char *src, int sz)
{
	while (sz > 0 && *dst) {
		sz--;
		dst++;
	}
	if (sz == 0) return;
	strncpy(dst, src, sz - 1);
	dst[sz - 1] = 0;
}
#endif

/*
 *	   Blocking-read "sz" bytes.
 */
static int xread(int fd, char *buf, int sz)
{
	int		n, done;

	done = 0;
	while (done < sz) {
		n = read(fd, buf + done, sz - done);
		if (n < 0 && errno == EINTR)
			continue;
		if (n <= 0)
			break;
		done += n;
	}

	return (done > 0) ? done : n;
}


/*
 *	Helper function to decode the attributes in an XBMSP packet.
 */
static int decodeattrs(struct xbmsp_packet *pkt, int justtype)
{
	struct xbmsp_type		*xt;
	char				*ptr, *m;
	int				left;
	uint32_t			length, value, v2;

	pkt->nattrs = 0;

	for (xt = xbmsp_types; xt->type != XBMSP_PACKET_UNKNOWN; xt++) {
		if (xt->type == pkt->type)
			break;
	}
	pkt->xbmsp_type = xt;

	if (justtype)
		return 0;

	if (pkt->xbmsp_type->type == XBMSP_PACKET_UNKNOWN)
		return 0;

	ptr = pkt->data;
	left = pkt->length - 5;

	for (m = xt->attrs; *m; m++) {

		if (left == 0)
			return XBMSP_E_ATTRMISS;

		switch (*m) {

		case 's':
			if (left < 4)
				return XBMSP_E_ATTRCORR;
			memcpy(&length, ptr, 4);
			length = ntohl(length);
			if (length > left - 4)
				length = left  - 4;
			pkt->attrs[pkt->nattrs].type = XBMSP_ATTR_STRING;
			pkt->attrs[pkt->nattrs].length = length;
			pkt->attrs[pkt->nattrs].string = ptr + 4;
			pkt->nattrs++;
			left -= 4 + length;
			ptr += 4 + length;
			break;

		case 'd':
			pkt->attrs[pkt->nattrs].type = XBMSP_ATTR_DATA;
			pkt->attrs[pkt->nattrs].length = left;
			pkt->attrs[pkt->nattrs].string = ptr + 4;
			pkt->nattrs++;
			ptr += left;
			left = 0;
			break;

		case 'b':
			pkt->attrs[pkt->nattrs].type = XBMSP_ATTR_BYTE;
			pkt->attrs[pkt->nattrs].length = 1;
			pkt->attrs[pkt->nattrs].val32 = (unsigned int)*ptr;
			pkt->nattrs++;
			left--;
			ptr++;
			break;

		case 'i':
			if (left < 4)
				return XBMSP_E_ATTRCORR;
			memcpy(&value, ptr, 4);
			pkt->attrs[pkt->nattrs].type = XBMSP_ATTR_INT32;
			pkt->attrs[pkt->nattrs].length = 4;
			pkt->attrs[pkt->nattrs].val32 = ntohl(value);
			pkt->nattrs++;
			left -= 4;
			ptr += 4;
			break;

		case 'l':
			if (left < 8)
				return XBMSP_E_ATTRCORR;
			memcpy(&value, ptr, 4);
			memcpy(&v2, ptr + 4, 4);
			pkt->attrs[pkt->nattrs].type = XBMSP_ATTR_INT64;
			pkt->attrs[pkt->nattrs].length = 8;
			pkt->attrs[pkt->nattrs].val64 =
				(((uint64_t)ntohl(value)) << 32) + ntohl(v2);
			pkt->nattrs++;
			left -= 8;
			ptr += 8;
			break;

		}
	}
	return 0;
}

static void esc_str(char *buf, int bufsz, char *str, int ssz)
{
	int		c;
	int		done;

	buf[0] = ' ';
	buf[1] = '"';
	buf[2] = 0;
	done = 2;

	while (ssz--) {

		if (done + 7 >= bufsz) {
			strcpy(buf + done, "...");
			done += 3;
			break;
		}

		c = *(unsigned char *)str++;
		switch (c) {
			case '\r':
				strcpy(buf + done, "\\r");	/* Safe */
				done += 2;
				break;
			case '\n':
				strcpy(buf + done, "\\n");	/* Safe */
				done += 2;
				break;
			case '\t':
				strcpy(buf + done, "\\t");	/* Safe */
				done += 2;
				break;
			case 32 ... 127:
				if (c == '\\') {
					strcpy(buf + done, "\\\\"); /* Safe */
					done += 2;
					break;
				}
				if (c == '"') {
					strcpy(buf + done, "\\\""); /* Safe */
					done += 2;
					break;
				}
				buf[done++] = c;
				buf[done] = 0;
				break;
			default:
				buf[done++] = '\\';
				buf[done++] = '0' + (c >> 6);
				buf[done++] = '0' + ((c >> 3) & 7);
				buf[done++] = '0' + (c & 7);
				buf[done] = 0;
				break;
		}
	}
	strcpy(buf + done, "\"");			/* Safe */
}

/*
 *	Debug helper: decode packet to a string.
 */
int xbmsp_dumppacket(struct xbmsp_packet *packet, char *res, int sz)
{
	char			*type;
	char			val[256];
	char			typebuf[32];
	int			len;
	int			i;

	if (packet->xbmsp_type &&
	    packet->xbmsp_type->type != XBMSP_PACKET_UNKNOWN) {
		type = packet->xbmsp_type->name;
	} else {
		snprintf(typebuf, sizeof(typebuf), "Type %d", packet->type);
		type = typebuf;
	}

	snprintf(res, sz, "%5d %08x %s",
		packet->length, packet->msgid, type);

	for (i = 0; i < packet->nattrs; i++) {
		switch (packet->attrs[i].type) {
			case XBMSP_ATTR_BYTE:
			case XBMSP_ATTR_INT32:
				snprintf(val, sizeof(val), " %u",
					packet->attrs[i].val32);
				break;
			case XBMSP_ATTR_INT64:
				snprintf(val, sizeof(val), " %llu",
					packet->attrs[i].val64);
				break;
			case XBMSP_ATTR_STRING:
				len = sizeof(val);
				if (packet->type == XBMSP_PACKET_FILE_CONTENTS)
					len = 200; /* XXX 29 is nice */
				esc_str(val, len,
					packet->attrs[i].string,
					packet->attrs[i].length);
				break;
			case XBMSP_ATTR_DATA:
				snprintf(val, sizeof(val), " [data]");
				break;
		}
		strlcat(res, val, sz);
	}
	return 0;
}

/*
 *	Return XBMSP error string.
 */
char *xbmsp_strerror(int err)
{
	struct xbmsp_err		*xe;

	for (xe = xbmsp_err; xe->errstr; xe++)
		if (xe->errval == err)
			return xe->errstr;

	return strerror(-err);
}

/*
 *	Build a packet.
 */
int xbmsp_buildpacket(struct xbmsp_packet **ppkt, int type, uint32_t msgid, ...)
{
	struct xbmsp_type	*xt;
	struct xbmsp_packet	*pkt;
	va_list			ap;
	char			*m;
	char			*string;
	char			*data;
	uint32_t		val32, val32_2;
	uint64_t		val64;
	int			byte;
	int			state;
	int			n;
	int			length;

	*ppkt = NULL;
	pkt = NULL;
	n = 0;

	for (xt = xbmsp_types; xt->name; xt++)
		if (xt->type == type)
			break;
	if (xt->name == NULL)
		return XBMSP_E_UNKNOWNTYPE;

	/* First calculate the length of the packet. */
	length = 5;

	for (state = 0; state < 2; state++) {

		if (state == 1 && pkt == NULL) {
			pkt = malloc(sizeof(*pkt) + length - 5);
			memset(pkt, 0, sizeof(*pkt));
			pkt->length = length;
			pkt->type = type;
			pkt->msgid = msgid;
			pkt->xbmsp_type = xt;
			data = pkt->data;
		}

		va_start(ap, msgid);
		for (m = xt->attrs; *m; m++) switch (*m) {

			case 'b':
				byte = va_arg(ap, int);
				if (state == 0) {
					length++;
				} else {
					pkt->attrs[n].type = XBMSP_ATTR_BYTE;
					pkt->attrs[n].length = 4;
					pkt->attrs[n++].val32 = byte;
					*data++ = byte;
				}
				break;
			case 'i':
				val32 = va_arg(ap, uint32_t);
				if (state == 0) {
					length += 4;
				} else {
					pkt->attrs[n].type = XBMSP_ATTR_INT32;
					pkt->attrs[n].length = 4;
					pkt->attrs[n++].val32 = val32;
					val32_2 = htonl(val32);
					memcpy(data, &val32_2, 4);
					data += 4;
				}
				break;
			case 'l':
				val64 = va_arg(ap, uint64_t);
				if (state == 0) {
					length += 8;
				} else {
					pkt->attrs[n].type = XBMSP_ATTR_INT64;
					pkt->attrs[n].length = 8;
					pkt->attrs[n++].val64 = val64;
					val32 = htonl(val64 >> 32);
					val32_2 = htonl(val64 & 0xFFFFFFFF);
					memcpy(data, &val32, 4);
					memcpy(data + 4, &val32_2, 4);
					data += 8;
				}
				break;
			case 's':
				val32 = va_arg(ap, uint32_t);
				string = va_arg(ap, char *);
				if (state == 0) {
					length += val32 + 4;
				} else {
					pkt->attrs[n].type = XBMSP_ATTR_STRING;
					pkt->attrs[n].length = val32;
					pkt->attrs[n++].string = data + 4;
					val32_2 = htonl(val32);
					memcpy(data, &val32_2, 4);
					memcpy(data + 4, string, val32);
					data += val32 + 4;
				}
				break;
			case 'd':
				val32 = va_arg(ap, uint32_t);
				string = va_arg(ap, char *);
				if (state == 0) {
					length += val32;
				} else {
					pkt->attrs[n].type = XBMSP_ATTR_DATA;
					pkt->attrs[n].length = val32;
					pkt->attrs[n++].string = data;
					memcpy(data, string, val32);
					data += val32;
				}
				break;
		}
		va_end(ap);
	}

	pkt->nattrs = n;
	*ppkt = pkt;
	return 0;
}

struct xbmsp_packet *xbmsp_clonepacket(struct xbmsp_packet *pkt)
{
	struct xbmsp_packet	*cpkt;
	int			l;

	l = sizeof(struct xbmsp_packet) + 4 + pkt->length;
	cpkt = malloc(l);
	memcpy(cpkt, pkt, l);

	return cpkt;
}

/*
 *	Receive one packet in blocking mode.
 */
int xbmsp_recvpacketfrom(int fd, struct xbmsp_packet **pkt,
			struct sockaddr *addr, int *addrlen)
{
	char			head[512];
	socklen_t		fromlen;
	uint32_t		length;
	uint8_t			type;
	uint32_t		msgid;
	int			e, n;
	int			isudp;

	*pkt = NULL;
	isudp = addr ? 1 : 0;

	/* Read header data */
	if (isudp) {
		fromlen = *addrlen;
		n = recvfrom(fd, head, sizeof(head), 0, addr, &fromlen);
		*addrlen = fromlen;
	} else {
		n = xread(fd, head, 9);
	}

	if (n < 9) {
		if (n == 0)
			return 0;
		if (n >= 5) {
			/* At least set the type. */
			*pkt = malloc(sizeof(struct xbmsp_packet) + length + 4);
			memset(*pkt, 0, sizeof(struct xbmsp_packet));
			(*pkt)->type = head[5];
			decodeattrs(*pkt, 1);
		}
readerror:
		e = 0;
		if (n < 0) e = -errno;
		if (n > 0) e = XBMSP_E_SHORTREAD;
		return e;
	}

	/* Fill packet struct */
	memcpy(&length, head, 4);
	memcpy(&type, head + 4, 1);
	memcpy(&msgid, head + 5, 4);
	length = ntohl(length);
	msgid = ntohl(msgid);

	*pkt = malloc(sizeof(struct xbmsp_packet) + length + 4);
	memset(*pkt, 0, sizeof(struct xbmsp_packet));
	(*pkt)->length = length;
	(*pkt)->type = type;
	(*pkt)->msgid = msgid;

	if (length < 5 || length > 2097152) {
		decodeattrs(*pkt, 1);
		return XBMSP_E_PKTCORR;
	}

	if (isudp && (length + 4 > n)) {
		decodeattrs(*pkt, 1);
		return XBMSP_E_PKTCORR;
	}

	/* Read the data */
	if (isudp) {
		memcpy((*pkt)->data, head + 9, length - 4);
	} else {
		n = xread(fd, (*pkt)->data, length - 5);
		if (n < length - 5) {
			decodeattrs(*pkt, 1);
			goto readerror;
		}
	}

	/* Decode the data */
	e = decodeattrs(*pkt, 0);

	return e;
}

int xbmsp_recvpacket(int fd, struct xbmsp_packet **pkt)
{
	return xbmsp_recvpacketfrom(fd, pkt, NULL, NULL);
}

/*
 *	Send one packet. Blocking.
 */
int xbmsp_sendpacketto(int fd, struct xbmsp_packet *packet,
			struct sockaddr *addr, int addrlen)
{
	struct iovec	vec[2];
	struct msghdr	msgh;
	char		head[9];
	uint32_t	length;
	uint32_t	msgid;
	int		n;

	length = htonl(packet->length);
	msgid = htonl(packet->msgid);
	memcpy(head, &length, 4);
	head[4] = packet->type;
	memcpy(head + 5, &msgid, 4);

	vec[0].iov_base = head;
	vec[0].iov_len = 9;
	vec[1].iov_base = packet->data;
	vec[1].iov_len = packet->length - 5;

	msgh.msg_name = addr;
	msgh.msg_namelen = addrlen;
	msgh.msg_iov = vec;
	msgh.msg_iovlen = 2;
	msgh.msg_control = NULL;
	msgh.msg_controllen = 0;
	msgh.msg_flags = 0;

	n = sendmsg(fd, &msgh, 0);

	if (n < 0)
		return -errno;
	if (n != packet->length + 4)
		return XBMSP_E_SHORTWRITE;

	return n;
}

int xbmsp_sendpacket(int fd, struct xbmsp_packet *packet)
{
	return xbmsp_sendpacketto(fd, packet, NULL, 0);
}


/*
 *	Read ident banner from remote host up to \n
 */
int xbmsp_readident(int fd, char *ident, int sz)
{
	int			n;
	int			done = 0;
	int			e = 0;

	while (done < sz) {
		n = read(fd, ident + done, 1);
		if (n == 1) {
			if (ident[done++] == '\n')
				break;
			continue;
		}
		if (n == 0)
			e = XBMSP_E_UNEXPECTEDEOF;
		if (n < 0)
			e = -errno;
		break;
	}

	return e ? e : done;
}

/*
 *	Write ident banner to remote host.
 */
int xbmsp_sendident(int fd, char *ident, int sz)
{
	int			n;

	if (sz == 0) sz = strlen(ident);

	n = write(fd, ident, sz);
	if (n < 0)
		return -errno;
	if (n < sz)
		return XBMSP_E_SHORTWRITE;

	return n;
}

/*
 *	Free a list of packets.
 */
void xbmsp_freepackets(struct xbmsp_packet **pktlist)
{
        struct xbmsp_packet             *pkt, *next;

        for (pkt = *pktlist; pkt; pkt = next) {
                next = pkt->next;
                free(pkt);
        }
	if (*pktlist) *pktlist = NULL;
}

