#include <sys/types.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <stdlib.h>
#include <errno.h>
#include <stdio.h>
#include "errlib.h"
#include "xmalloc.h"
#include "tcplib.h"
#include "libc.h"

/*
 *   eutl - A collection of useful libraries
 *   tcplib - A set of routines making tcp stuff easier.
 *
 *   (c) Copyright 1993 Eric Anderson 
 *
 * My thanks to Geoffrey Collyer and Henry Spencer for providing the basis
 * for this copyright.
 *
 * Permission is granted to anyone to use this software for any purpose on
 * any computer system, and to alter it and redistribute it freely, subject
 * to the following restrictions:
 *
 * 1. The authors are not responsible for the consequences of use of this
 *    software, no matter how awful, even if they arise from flaws in it.
 *
 * 2. The origin of this software must not be misrepresented, either by
 *    explicit claim or by omission.  Since few users ever read sources,
 *    credits must appear in the documentation.
 *
 * 3. Altered versions must be plainly marked as such, and must not be
 *    misrepresented as being the original software.  Since few users
 *    ever read sources, credits must appear in the documentation.
 *
 * 4. This notice may not be removed or altered.
 */

/* 
   System Configuration Options.
   Descriptions of each configuration option contained in the default
   section at the end.
*/

#ifdef sun
#define STATIC_XADDR 1
#endif

#ifdef sgi
#define BUFFER_READ 0
#define USE_IOVEC 0
#endif

/* 
   Default Section
*/

#ifndef STATIC_XADDR
/*
   On the sun, you seem to need to have struct sockaddr_in, and the
   xaddrlen static when you do an accept, or you get an EFAULT.
   See the comment in AcceptConnection for more information.
*/
#define STATIC_XADDR 0
#endif

#ifndef BUFFER_READ
/*
   This decides whether or not to buffer the information on read.
   On the Sgi's the system seems to not return with the amount of data which
   is on the socket immediately, and instead waits for a little while.
   Hence it is a performance hit to attempt to pre-read data 
*/
#define BUFFER_READ 1
#endif

#ifndef USE_IOVEC
/* 
   This decides whether or not we will actually use iovec.  If the system
   doesn't implement iovec more cleverly than just doing a set of reads or
   writes, we are better off doing the stuff ourselves because the behavior
   is more likely to be as expected.
*/
#define USE_IOVEC HAS_IOVEC
#endif

/*
   END of Default Section
*/

#define DEFAULT_AUTOFLUSHLEN 100000

struct __TcpSocket {
  int fd;
  char *rcvbuf,*sendbuf,*sizedbuf;
  long rcvbufsize,sendbufsize,rcvbufvalid,sizedbufsize;
  char *rcvbufpos,*sendbufpos;
  long autoflushlen; /* Length at which to automatically flush output.
			-- Prevents massively filling the output buffer,
			causing the system to allocate lots of space */
  struct {
    unsigned int hasinput:1;
    unsigned int haserror:1;
    unsigned int isforeign:1;
    unsigned int isserver:1;
    unsigned int bufferout:1;
    unsigned int incrsize:1;
    unsigned int checkwrite:1;
    unsigned int readywrite:1;
  } flags;
};

const int RCVINIT = 8192;
const int RCVINCR = 8192;

char *tcplib_packagever = "TcpLib V2.0";
static struct protoent *protox;
static ErrorFunction TcpLibErf = LongJmpErrorFunction;

char *tcplib_EProtoByName = "Error getting protocol by name";
char *tcplib_ESocket = "Error on socket system call";
char *tcplib_ESockOpt = "Error setting socket option";
char *tcplib_EBindFailed = "Bind system call failed";
char *tcplib_EListenFailed = "Listen system call failed";
char *tcplib_ELookupFailed = "Lookup of hostname failed";
char *tcplib_ECloseErr = "Error on close call";
char *tcplib_EConnectFailed = "Connect call failed";
char *tcplib_EInvalidArg = "Invalid Argument passed to function";
char *tcplib_EAcceptBad = "Error on accept system call";
char *tcplib_EWriteFailed = "Error on write system call";
char *tcplib_EWriteTooMuch = "Internal Error: write() wrote too much";
char *tcplib_EReadError = "Error on read system call";
char *tcplib_EBadWait = "Invalid arguments to Wait";
char *tcplib_ESelectError = "Error on select system call";
char *tcplib_ENothingRead = "Couldn't read data, socket probably closed";

void tcplib_SetDefaultErrorFunction(ErrorFunction erf)
{
  TcpLibErf = erf;
}

static TcpSocket initTcpSocket()
{
  TcpSocket ret;

  ret = xbzmalloc(sizeof(struct __TcpSocket));
  ret->autoflushlen = DEFAULT_AUTOFLUSHLEN;
  return ret;
}

int GetTcpSocketFD(TcpSocket of)
{
  return of->fd;
}

#ifdef SO_LINGER
static struct linger lingeroff={0,0};
#endif

static int GetSocket()
{
  int fd;

  protox = getprotobyname("tcp");

  if (protox==NULL) {
    TcpLibErf(tcplib_packagever,tcplib_EProtoByName,
	      "Error getting protocol by name\n");
  }
  fd = socket(PF_INET,SOCK_STREAM,protox->p_proto);
  if (fd<=0) {
    TcpLibErf(tcplib_packagever,tcplib_ESocket,"Error on socket command: %s\n",
	      strerror(errno));
  }

#ifdef SO_REUSEADDR
  { 
    int one = 1;
    if(setsockopt(fd,SOL_SOCKET,SO_REUSEADDR,(char *)&one,sizeof(one))) {
      TcpLibErf(tcplib_packagever,tcplib_ESockOpt,
		"Error setting ReuseAddr Socket option: %s\n",strerror(errno));
    }
  }
#endif

#ifdef SO_LINGER
  if (setsockopt(fd,SOL_SOCKET,SO_LINGER,
		 (char *)&lingeroff,sizeof(lingeroff))) {
    TcpLibErf(tcplib_packagever,tcplib_ESockOpt,
	      "Error setting Linger Socket option: %s\n",strerror(errno));
  }
#endif
  return fd;
}

static TcpSocket newSocket()
{
  TcpSocket ret;

  ret = initTcpSocket();
  WITH_HANDLING {
    ret->fd = GetSocket();
  } HANDLE {
    free(ret);
    RERAISE();
  }
  END_HANDLING
  return ret;
}

TcpSocket BecomeServer(int port)
{
  TcpSocket ret;
  struct sockaddr_in insock;

  ret = newSocket();

  ret->flags.isserver = 1;
  insock.sin_family = AF_INET;
  insock.sin_port = htons((unsigned short)port);
  insock.sin_addr.s_addr = htonl(INADDR_ANY);
  
  if (bind(ret->fd,&insock,sizeof(insock))<0) {
    close(ret->fd);
    free(ret);
    TcpLibErf(tcplib_packagever,tcplib_EBindFailed,
	      "Error binding socket to port %d: %s\n",port,strerror(errno));
  }

  if (listen(ret->fd,5)<0) { 
    close(ret->fd);
    free(ret);
    TcpLibErf(tcplib_packagever,tcplib_EListenFailed,
	      "Error Listening for clients: %s\n",strerror(errno));
  }
  return ret;
}

TcpSocket GetConnection(char *hostname,int port)
{
  TcpSocket ret;
  struct hostent *remotehost;
  struct sockaddr_in insock;

  if (hostname==NULL)
    hostname = "localhost";

  remotehost = gethostbyname(hostname);
  if (remotehost == NULL) {
    TcpLibErf(tcplib_packagever,tcplib_ELookupFailed,
	      "Unable to find host %s: %s\n",hostname,strerror(errno));
  }
    
  ret = newSocket();
  insock.sin_family = AF_INET;
  insock.sin_port = htons((unsigned short)port);
/*  insock.sin_addr.s_addr = (remotehost->h_addr_list[0][0]<<0) +
        (remotehost->h_addr_list[0][1]<<8)+
        (remotehost->h_addr_list[0][2]<<16)+
	(remotehost->h_addr_list[0][3]<<24);*/
  bcopy((void *)remotehost->h_addr, 
	(void *)&insock.sin_addr, remotehost->h_length);

/*
  insock.sin_addr.s_addr = htonl((unsigned long)remotehost->h_addr_list[0]);
*/
  
  if (connect(ret->fd,&insock,sizeof(struct sockaddr_in))) {
    TcpLibErf(tcplib_packagever,tcplib_EConnectFailed,
	      "Unable to connect to host %s port %d: %s\n",
	      hostname,port,strerror(errno));
  }
  return ret;
}

#if STATIC_XADDR
#define SSTATIC_XADDR static
#else
#define SSTATIC_XADDR 
#endif

TcpSocket AcceptConnection(TcpSocket server,unsigned long *from)
{
/* I do not know why this needs to be static.  However, I have observed that
   without making it static, I get a EFAULT on the accept, so therefore I
   am making it static to make it work.  I'd appreciate an explanation if
   anyone has one. */
  SSTATIC_XADDR struct sockaddr_in xaddr;
  SSTATIC_XADDR int xaddrlen;
  int inputfd;
  TcpSocket ret;

  if (server->flags.isserver == 0) {
    TcpLibErf(tcplib_packagever,tcplib_EInvalidArg,
	      "Attempted to accept connection on non-server port\n");
  }
  if ((inputfd = accept(server->fd,&xaddr,&xaddrlen))<0) {
    TcpLibErf(tcplib_packagever,tcplib_EAcceptBad,
	      "Accept Failed: %s\n",strerror(errno));
  }
  WITH_HANDLING {
    ret = initTcpSocket();
  } HANDLE {
    close(inputfd);
    RERAISE();
  }
  END_HANDLING;
  if (from) {
    *from = ntohl(xaddr.sin_addr.s_addr);
  }
  ret->fd = inputfd;
  return ret;
}

TcpSocket MakeForeignFDConnection(int fd)
{
  TcpSocket ret;

  ret = initTcpSocket();
  ret->fd = fd;
  ret->flags.isforeign = 1;
  return ret;
}

void CloseConnection(TcpSocket gone)
{
  if (gone->rcvbuf) free(gone->rcvbuf);
  if (gone->sendbuf) free(gone->sendbuf);
  if (gone->sizedbuf) free(gone->sizedbuf);
  if (!gone->flags.isforeign) 
    if(close(gone->fd)) 
      TcpLibErf(tcplib_packagever,tcplib_ECloseErr,
		"Error Closing Connection: %s\n",strerror(errno));
  free(gone);
}

static void __DoSendMsg(int fd,char *msg,unsigned long msglen)
{
  long amt;

  while(msglen>0) {
    amt = write(fd,msg,msglen);
    if (amt<0) {
      TcpLibErf(tcplib_packagever,tcplib_EWriteFailed,
		"Error on write in SendMessage: %s\n",strerror(errno));
    }
    msg += amt;
    msglen -= amt;
  }
}

void TcpLibFlush(TcpSocket to)
{
  if (to->sendbufpos>to->sendbuf)
    __DoSendMsg(to->fd,to->sendbuf,to->sendbufpos-to->sendbuf);
  to->sendbufpos = to->sendbuf;
}

void TcpLibBufferOutput(TcpSocket sock,int on)
{
  if (on)
    sock->flags.bufferout = 1;
  else {
    TcpLibFlush(sock);
    sock->flags.bufferout = 0;
  }
}
    
void IOVSendMsg(TcpSocket to,struct iovec *vecs,int nvecs)
{
  if (to->flags.bufferout) {
    int l;

    for(l=0;l<nvecs;l++)
      SendMsg(to,vecs[l].iov_base,vecs[l].iov_len);
  } else {
#if USE_IOVEC
    int f;
    long amt;

    f = 0;
    while (1) {
      amt = writev(to->fd,vecs,nvecs);
      if (amt<0) {
	TcpLibErf(tcplib_packagever,tcplib_EWriteFailed,
		  "Error on writev in IOVSendMsg: %s\n",strerror(errno));
      }
      while (amt>vecs[f].iov_len) {
	amt -= vecs[f].iov_len;
	vecs[f].iov_len = 0;
	f++;
	if (f==nvecs) {
	  TcpLibErf(tcplib_packagever,tcplib_EWriteTooMuch,
		    "Internal Error: writev wrote more data than we have??\n");
	}
      }
      vecs[f].iov_len -= amt;
      if (vecs[f].iov_len == 0)
	break;			/* We're done writing */
    }
#else
    int l;
    for(l=0;l<nvecs;l++)
      __DoSendMsg(to->fd,vecs[l].iov_base,vecs[l].iov_len);
#endif
  }
}
  
void SendMsg(TcpSocket to,char *msg,unsigned long msglen)
{
  long used;
  if (to->flags.bufferout) {
    used = to->sendbufpos - to->sendbuf;
    if (used+msglen > to->autoflushlen) {
      TcpLibFlush(to);
      __DoSendMsg(to->fd,msg,msglen);
      return;
    }
    if ((to->sendbuf == NULL) ||
	(to->sendbufsize - (used))<msglen) {
      to->sendbuf = xrealloc(to->sendbuf,used + msglen);
      to->sendbufpos = to->sendbuf + used;
      to->sendbufsize = used+msglen;
    }
    bcopy(msg,to->sendbufpos,msglen);
    to->sendbufpos += msglen;
  } else {
    __DoSendMsg(to->fd,msg,msglen);
  }
}

void SendSizedMsg(TcpSocket to,char *buf,long size)
{
  char sbuf[4];
  struct iovec iov[2];

  sbuf[0] = (unsigned long)size >> 24;
  sbuf[1] = ((unsigned long)size >> 16) & 0xFF;
  sbuf[2] = ((unsigned long)size >> 8) & 0xFF;
  sbuf[3] = ((unsigned long)size) & 0xFF;
  
  iov[0].iov_base = sbuf;
  iov[0].iov_len = 4;
  iov[1].iov_base = buf;
  iov[1].iov_len = size;
  IOVSendMsg(to,iov,2);
}

void SendString(TcpSocket to,char *buf)
{
  SendSizedMsg(to,buf,strlen(buf)+1);
}

#if ! BUFFER_READ
static long eread(TcpSocket from, char *buf,long amt)
{
  long ramt;
  
  ramt = read(from->fd,buf,amt);
  if (ramt < 0) {
    TcpLibErf(tcplib_packagever,EReadError,
	      "Error reading from socket: %s\n",strerror(errno));
  }
  return ramt;
}
#else
static long eread(TcpSocket from,char *buf,long amt)
{
  long ramt;
  
 readfrombuf:
  if (from->rcvbufvalid) {
    long min = from->rcvbufvalid;
    if (amt<min)
      min = amt;
    bcopy(from->rcvbufpos,buf,min);
    from->rcvbufpos += min;
    from->rcvbufvalid -= min;
    return min;
    goto readfrombuf;
  }
  if (from->rcvbuf == NULL) {
    from->rcvbuf = xmalloc(RCVINIT);
    from->rcvbufsize = RCVINIT;
  }
  if (from->flags.incrsize) {
    free(from->rcvbuf);
    from->rcvbufsize += RCVINCR;
    from->rcvbuf = xmalloc(from->rcvbufsize);
    from->flags.incrsize = 0;
  }
  from->rcvbufpos = from->rcvbuf;
  from->rcvbufvalid = 0;
/* Using readv should be more efficient since it may lower by one the number
   of copies we have to make of input data. */
#if USE_IOVEC
  {
    struct iovec iov[2];

    iov[0].iov_base = buf;
    iov[0].iov_len = amt;
    iov[1].iov_base = from->rcvbuf;
    iov[1].iov_len = from->rcvbufsize;
    ramt = readv(from->fd,iov,2);
    if (ramt <0) {
      from->flags.haserror = 1;
      TcpLibErf(tcplib_packagever,tcplib_EReadError,
		"Error reading from socket: %s\n",strerror(errno));
    }
    if (ramt<amt)
      return ramt;		/* Didn't fill our buffer at all */
    ramt -= amt;
    from->rcvbufvalid = ramt;
    if (from->rcvbufvalid == from->rcvbufsize)
      from->flags.incrsize = 1;
    return amt;
  }
#else
  ramt = read(from->fd,from->rcvbuf,from->rcvbufsize);
  if (ramt <0) {
    TcpLibErf(tcplib_packagever,EReadError,
	      "Error reading from socket: %s\n",strerror(errno));
  }
  from->rcvbufvalid = ramt;
  if (from->rcvbufvalid == from->rcvbufsize)
    from->flags.incrsize = 1;
  if (ramt==0)
    return 0;/* Prevent loops */
  goto readfrombuf;
#endif
}
#endif

void GetMsg(TcpSocket from,char *buf,long size)
{
  long t;
  
  if (from->sendbufpos > from->sendbuf)
    TcpLibFlush(from);
  while(size>0) {
    t = eread(from,buf,size);
    size -= t;
    buf += t;
    /* It seems as though we should do something if the amount we read in
       is 0.  Because otherwise this code will go into a tight busy-wait
       loop */
    if (t==0)
      TcpLibErf(tcplib_packagever,tcplib_ENothingRead,
		"Tried to read data, but nothing was read, socket probably closed.\n");
  }
}

char *GetSizedMsg(TcpSocket from,long *size)
{
  unsigned char csize[4];
  long x;
  
  GetMsg(from,(char *)csize,4);
  x = (csize[0] << 24) + (csize[1] << 16) + (csize[2] << 8) + csize[3];
  if (from->sizedbufsize<x) {
    from->sizedbuf = xrealloc(from->sizedbuf,x);
    from->sizedbufsize = x;
  }
  GetMsg(from,from->sizedbuf,x);
  if (size)
    *size = x;
  return from->sizedbuf;
}

char *GetString(TcpSocket from)
{
  return GetSizedMsg(from,NULL);
}

int WaitForInput(TcpSocket *socks,int nsocks,long msTimeout)
{
  fd_set rmaskfd,wmaskfd,emaskfd;
  int nfound,l,maxfd;
  struct timeval max_wait;

  if (nsocks<=0 && msTimeout == TCPLIB_FOREVER)
    TcpLibErf(tcplib_packagever,tcplib_EBadWait,
	      "Tried to wait with no sockets and no timeout\n");

  FD_ZERO(&rmaskfd);FD_ZERO(&wmaskfd);FD_ZERO(&emaskfd);
  for(l=0;l<nsocks;l++) {
    TcpLibFlush(socks[l]);
    socks[l]->flags.haserror = 0;
    socks[l]->flags.hasinput = socks[l]->rcvbufvalid > 0;
  }
  for(l=0;l<nsocks;l++) {
    if (socks[l]->flags.hasinput)
      return 1;
  }
  if (msTimeout != TCPLIB_FOREVER) {
    max_wait.tv_sec = msTimeout / 1000;
    max_wait.tv_usec = (msTimeout % 1000) * 1000;
  }
    
  maxfd = 0;
  for(l=0;l<nsocks;l++) {
    FD_SET(socks[l]->fd,&rmaskfd);
    FD_SET(socks[l]->fd,&emaskfd);
    if (socks[l]->flags.checkwrite) 
      FD_SET(socks[l]->fd,&wmaskfd);
    if (socks[l]->fd >= maxfd)
      maxfd = socks[l]->fd + 1;
  }
  
  nfound = select(maxfd,&rmaskfd,&wmaskfd,&emaskfd,
		  msTimeout != TCPLIB_FOREVER ? &max_wait : NULL);
  if (nfound == 0)
    return 0;
  if (nfound < 0)
    TcpLibErf(tcplib_packagever,tcplib_ESelectError,
	      "Error on Select: %s\n",strerror(errno));
  for(l=0;l<nsocks;l++) {
    if (FD_ISSET(socks[l]->fd,&rmaskfd))
      socks[l]->flags.hasinput = 1;
    if (FD_ISSET(socks[l]->fd,&wmaskfd)) 
      socks[l]->flags.readywrite = 1;
    if (FD_ISSET(socks[l]->fd,&emaskfd))
      socks[l]->flags.haserror = 1;
  }
  return nfound;
}

int HasInput(TcpSocket sock)
{
  return sock->flags.hasinput;
}

