#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/param.h>
#include <netdb.h>
#include "uint16.h"
#include "str.h"
#include "byte.h"
#include "fmt.h"
#include "scan.h"
#include "ip4.h"
#include "fd.h"
#include "exit.h"
#include "env.h"
#include "prot.h"
#include "open.h"
#include "wait.h"
#include "readwrite.h"
#include "stralloc.h"
#include "alloc.h"
#include "buffer.h"
#include "error.h"
#include "strerr.h"
#include "sgetopt.h"
#include "pathexec.h"
#include "socket.h"
#include "ndelay.h"
#include "remoteinfo.h"
#include "rules.h"
#include "sig.h"
#include "dns.h"
#include "/usr/local/include/maxminddb.h"

char CountryCode[4];

struct MMDB_s mmdb;

MMDB_s open_or_die(const char *fname);

MMDB_lookup_result_s lookup_or_die(MMDB_s *mmdb, const char *ipstr);
MMDB_entry_data_list_s *entry_data_list;
MMDB_entry_data_list_s *dump_data( MMDB_entry_data_list_s *entry_data_list, int *status );

int FoundResult = 0;

int verbosity = 1;
int flagkillopts = 1;
int flagdelay = 1;
char *banner = "";
int flagremoteinfo = 1;
int flagremotehost = 1;
int flagparanoid = 0;
unsigned long timeout = 26;
struct stralloc pidfile = {0};

static stralloc tcpremoteinfo;

uint16 localport;
char localportstr[FMT_ULONG];
char localip[4];
char localipstr[IP4_FMT];
static stralloc localhostsa;
char *localhost = 0;

uint16 remoteport;
char remoteportstr[FMT_ULONG];
char remoteip[4];
char remoteipstr[IP4_FMT];
static stralloc remotehostsa;
char *remotehost = 0;

char *deny_country;
char *allow_country;

char strnum[FMT_ULONG];
char strnum2[FMT_ULONG];

static stralloc tmp;
static stralloc fqdn;
static stralloc addresses;

char bspace[16];
buffer b;



/* ---------------------------- child */

#define DROP "tcpserver: warning: dropping connection, "

int flagdeny = 0;
int flagallownorules = 0;
char *fnrules = 0;

void drop_nomem(void)
{
  strerr_die2sys(111,DROP,"out of memory");
}
void cats(char *s)
{
  if (!stralloc_cats(&tmp,s)) drop_nomem();
}
void append(char *ch)
{
  if (!stralloc_append(&tmp,ch)) drop_nomem();
}
void safecats(char *s)
{
  char ch;
  int i;

  for (i = 0;i < 100;++i) {
    ch = s[i];
    if (!ch) return;
    if (ch < 33) ch = '?';
    if (ch > 126) ch = '?';
    if (ch == '%') ch = '?'; /* logger stupidity */
    if (ch == ':') ch = '?';
    append(&ch);
  }
  cats("...");
}
void env(char *s,char *t)
{
  if (!pathexec_env(s,t)) drop_nomem();
}
void drop_rules(void)
{
  strerr_die4sys(111,DROP,"unable to read ",fnrules,": ");
}

void found(char *data,unsigned int datalen)
{
  unsigned int next0;
  unsigned int split;

  while ((next0 = byte_chr(data,datalen,0)) < datalen) {
    switch(data[0]) {
      case 'D':
	flagdeny = 1;
	break;
      case '+':
	split = str_chr(data + 1,'=');
	if (data[1 + split] == '=') {
	  data[1 + split] = 0;
	  env(data + 1,data + 1 + split + 1);
	}
	break;
    }
    ++next0;
    data += next0; datalen -= next0;
  }
}

void doit(int t)
{
  int j;
  int i,ipnum;
  int country_id;
  uint32_t size;
  int status;
  MMDB_lookup_result_s result;

  remoteipstr[ip4_fmt(remoteipstr,remoteip)] = 0;

  if (verbosity >= 2) {
    strnum[fmt_ulong(strnum,getpid())] = 0;
    strerr_warn4("tcpserver: pid ",strnum," from ",remoteipstr,0);
  }

  if (flagkillopts)
    socket_ipoptionskill(t);
  if (!flagdelay)
    socket_tcpnodelay(t);

  if (*banner) {
    buffer_init(&b,write,t,bspace,sizeof bspace);
    if (buffer_putsflush(&b,banner) == -1)
      strerr_die2sys(111,DROP,"unable to print banner: ");
  }

  if (socket_local4(t,localip,&localport) == -1)
    strerr_die2sys(111,DROP,"unable to get local address: ");

  localipstr[ip4_fmt(localipstr,localip)] = 0;
  remoteportstr[fmt_ulong(remoteportstr,remoteport)] = 0;

  if (!localhost)
    if (dns_name4(&localhostsa,localip) == 0)
      if (localhostsa.len) {
	if (!stralloc_0(&localhostsa)) drop_nomem();
	localhost = localhostsa.s;
      }
  env("PROTO","TCP");
  env("TCPLOCALIP",localipstr);
  env("TCPLOCALPORT",localportstr);
  env("TCPLOCALHOST",localhost);

  if (flagremotehost)
    if (dns_name4(&remotehostsa,remoteip) == 0)
      if (remotehostsa.len) {
	if (flagparanoid)
	  if (dns_ip4(&tmp,&remotehostsa) == 0)
	    for (j = 0;j + 4 <= tmp.len;j += 4)
	      if (byte_equal(remoteip,4,tmp.s + j)) {
		flagparanoid = 0;
		break;
	      }
	if (!flagparanoid) {
	  if (!stralloc_0(&remotehostsa)) drop_nomem();
	  remotehost = remotehostsa.s;
	}
      }
  env("TCPREMOTEIP",remoteipstr);
  env("TCPREMOTEPORT",remoteportstr);
  env("TCPREMOTEHOST",remotehost);

  deny_country = getenv("DENYCOUNTRY");
  allow_country = getenv("ALLOWCOUNTRY");

  if ( deny_country!=NULL || allow_country!=NULL ) {
    memset(CountryCode,0,sizeof(CountryCode));

    mmdb = open_or_die("/usr/share/GeoIP/GeoLite2-Country.mmdb");
    result = lookup_or_die(&mmdb, remoteipstr);
    entry_data_list = NULL;

    if (result.found_entry) {        

      status = MMDB_get_entry_data_list(&result.entry, &entry_data_list);
      if (MMDB_SUCCESS != status) {
          fprintf(stderr, "Got an error looking up the entry data - %s\n",
                  MMDB_strerror(status));
      }

      if (entry_data_list!=NULL) {
        dump_data( entry_data_list, &status );
        MMDB_free_entry_data_list(entry_data_list);
        MMDB_close(&mmdb);
      }

      if ( CountryCode[0]!=0 && CountryCode[0]!='-' && deny_country!=NULL && 
           strstr(deny_country,CountryCode) != NULL ) flagdeny = 1;
      if ( CountryCode[0]!=0 && CountryCode[0]!='-' && allow_country!=NULL && 
           strstr(allow_country,CountryCode) == NULL ) flagdeny = 1;

      if (verbosity >= 2 && flagdeny==1) {
        strnum[fmt_ulong(strnum,getpid())] = 0;
        if (!stralloc_copys(&tmp,"tcpserver: ")) drop_nomem();
        safecats(flagdeny ? "deny" : "ok");
        cats(" "); safecats(strnum);
        cats(" "); if (localhost) safecats(localhost);
        cats(":"); safecats(localipstr);
        cats(":"); safecats(localportstr);
        cats(" "); if (remotehost) safecats(remotehost);
        cats(":"); safecats(remoteipstr);
        cats(":"); if (flagremoteinfo) safecats(tcpremoteinfo.s);
        cats(":"); safecats(remoteportstr);
        if ( CountryCode[0]!=0 ) cats(":"); safecats(CountryCode);
        cats("\n");
        buffer_putflush(buffer_2,tmp.s,tmp.len);
      }
      if (flagdeny) _exit(100);
    }
  }


  if (flagremoteinfo) {
    if (remoteinfo(&tcpremoteinfo,remoteip,remoteport,localip,localport,timeout) == -1)
      flagremoteinfo = 0;
    if (!stralloc_0(&tcpremoteinfo)) drop_nomem();
  }
  env("TCPREMOTEINFO",flagremoteinfo ? tcpremoteinfo.s : 0);

  if (fnrules) {
    int fdrules;
    fdrules = open_read(fnrules);
    if (fdrules == -1) {
      if (errno != error_noent) drop_rules();
      if (!flagallownorules) drop_rules();
    }
    else {
      if (rules(found,fdrules,remoteipstr,remotehost,flagremoteinfo ? tcpremoteinfo.s : 0) == -1) drop_rules();
      close(fdrules);
    }
  }

  if (verbosity >= 2) {
    strnum[fmt_ulong(strnum,getpid())] = 0;
    if (!stralloc_copys(&tmp,"tcpserver: ")) drop_nomem();
    safecats(flagdeny ? "deny" : "ok");
    cats(" "); safecats(strnum);
    cats(" "); if (localhost) safecats(localhost);
    cats(":"); safecats(localipstr);
    cats(":"); safecats(localportstr);
    cats(" "); if (remotehost) safecats(remotehost);
    cats(":"); safecats(remoteipstr);
    cats(":"); if (flagremoteinfo) safecats(tcpremoteinfo.s);
    cats(":"); safecats(remoteportstr);
    if ( CountryCode[0]!=0 ) cats(":"); safecats(CountryCode);
    cats("\n");
    buffer_putflush(buffer_2,tmp.s,tmp.len);
  }

  if (flagdeny) _exit(100);
}


#define FATAL "tcpserver: fatal: "

void usage(void)
{
  strerr_warn1("\
tcpserver: usage: tcpserver \
[ -1UXpPhHrRoOdDqQv ] \
[ -c limit ] \
[ -x rules.cdb ] \
[ -B banner ] \
[ -g gid ] \
[ -u uid ] \
[ -b backlog ] \
[ -l localname ] \
[ -t timeout ] \
[ -w pidfile ] \
host port program",0);
  _exit(100);
}

unsigned long limit = 40;
unsigned long numchildren = 0;

int flag1 = 0;
unsigned long backlog = 20;
unsigned long uid = 0;
unsigned long gid = 0;

void printstatus(void)
{
  if (verbosity < 2) return;
  strnum[fmt_ulong(strnum,numchildren)] = 0;
  strnum2[fmt_ulong(strnum2,limit)] = 0;
  strerr_warn4("tcpserver: status: ",strnum,"/",strnum2,0);
}

void sigterm()
{
  _exit(0);
}

void sigchld()
{
  int wstat;
  int pid;
 
  while ((pid = wait_nohang(&wstat)) > 0) {
    if (verbosity >= 2) {
      strnum[fmt_ulong(strnum,pid)] = 0;
      strnum2[fmt_ulong(strnum2,wstat)] = 0;
      strerr_warn4("tcpserver: end ",strnum," status ",strnum2,0);
    }
    if (numchildren) --numchildren; printstatus();
  }
}

main(int argc,char **argv)
{
  char *hostname;
  char *portname;
  int opt;
  struct servent *se;
  char *x;
  unsigned long u;
  int s;
  int t;

  while ((opt = getopt(argc,argv,"dDvqQhHrR1UXx:t:u:g:l:b:B:c:pPoOw:")) != opteof)
    switch(opt) {
      case 'b': scan_ulong(optarg,&backlog); break;
      case 'c': scan_ulong(optarg,&limit); break;
      case 'X': flagallownorules = 1; break;
      case 'x': fnrules = optarg; break;
      case 'B': banner = optarg; break;
      case 'd': flagdelay = 1; break;
      case 'D': flagdelay = 0; break;
      case 'v': verbosity = 2; break;
      case 'q': verbosity = 0; break;
      case 'Q': verbosity = 1; break;
      case 'P': flagparanoid = 0; break;
      case 'p': flagparanoid = 1; break;
      case 'O': flagkillopts = 1; break;
      case 'o': flagkillopts = 0; break;
      case 'H': flagremotehost = 0; break;
      case 'h': flagremotehost = 1; break;
      case 'R': flagremoteinfo = 0; break;
      case 'r': flagremoteinfo = 1; break;
      case 't': scan_ulong(optarg,&timeout); break;
      case 'U': x = env_get("UID"); if (x) scan_ulong(x,&uid);
		x = env_get("GID"); if (x) scan_ulong(x,&gid); break;
      case 'u': scan_ulong(optarg,&uid); break;
      case 'g': scan_ulong(optarg,&gid); break;
      case '1': flag1 = 1; break;
      case 'l': localhost = optarg; break;
      case 'w': if (!stralloc_copys(&pidfile, optarg) ||
		    !stralloc_0(&pidfile) )
		  strerr_die2x(111,FATAL,"out of memory");
		break;
      default: usage();
    }
  argc -= optind;
  argv += optind;

  if (!verbosity)
    buffer_2->fd = -1;
 
  hostname = *argv++;
  if (!hostname) usage();
  if (str_equal(hostname,"")) hostname = "0.0.0.0";
  if (str_equal(hostname,"0")) hostname = "0.0.0.0";

  x = *argv++;
  if (!x) usage();
  if (!x[scan_ulong(x,&u)])
    localport = u;
  else {
    se = getservbyname(x,"tcp");
    if (!se)
      strerr_die3x(111,FATAL,"unable to figure out port number for ",x);
    localport = ntohs(se->s_port);
  }

  if (!*argv) usage();
 
  sig_block(sig_child);
  sig_catch(sig_child,sigchld);
  sig_catch(sig_term,sigterm);
  sig_ignore(sig_pipe);
 
  if (!stralloc_copys(&tmp,hostname))
    strerr_die2x(111,FATAL,"out of memory");
  if (dns_ip4_qualify(&addresses,&fqdn,&tmp) == -1)
    strerr_die4sys(111,FATAL,"temporarily unable to figure out IP address for ",hostname,": ");
  if (addresses.len < 4)
    strerr_die3x(111,FATAL,"no IP address for ",hostname);
  byte_copy(localip,4,addresses.s);

  s = socket_tcp();
  if (s == -1)
    strerr_die2sys(111,FATAL,"unable to create socket: ");
  if (socket_bind4_reuse(s,localip,localport) == -1)
    strerr_die2sys(111,FATAL,"unable to bind: ");
  if (socket_local4(s,localip,&localport) == -1)
    strerr_die2sys(111,FATAL,"unable to get local address: ");
  if (socket_listen(s,backlog) == -1)
    strerr_die2sys(111,FATAL,"unable to listen: ");
  ndelay_off(s);

  if (gid) if (prot_gid(gid) == -1)
    strerr_die2sys(111,FATAL,"unable to set gid: ");
  if (uid) if (prot_uid(uid) == -1)
    strerr_die2sys(111,FATAL,"unable to set uid: ");

 
  localportstr[fmt_ulong(localportstr,localport)] = 0;
  if (flag1) {
    buffer_init(&b,write,1,bspace,sizeof bspace);
    buffer_puts(&b,localportstr);
    buffer_puts(&b,"\n");
    buffer_flush(&b);
  }
 
  close(0);
  close(1);
  printstatus();

  if ( pidfile.len > 0 ) {
   int pidfd;
    pidfd = open_trunc(pidfile.s);
    strnum[fmt_ulong(strnum,getpid())] = 0;
    stralloc_copys(&tmp,strnum);
    write(pidfd, tmp.s, tmp.len);
    close(pidfd);
  }
 
  for (;;) {
    while (numchildren >= limit) sig_pause();

    sig_unblock(sig_child);
    t = socket_accept4(s,remoteip,&remoteport);
    sig_block(sig_child);

    if (t == -1) continue;
    ++numchildren; printstatus();
 
    switch(fork()) {
      case 0:
        close(s);
        doit(t);
        if ((fd_move(0,t) == -1) || (fd_copy(1,0) == -1))
	  strerr_die2sys(111,DROP,"unable to set up descriptors: ");
        sig_uncatch(sig_child);
        sig_unblock(sig_child);
        sig_uncatch(sig_term);
        sig_uncatch(sig_pipe);
        pathexec(argv);
	strerr_die4sys(111,DROP,"unable to run ",*argv,": ");
      case -1:
        strerr_warn2(DROP,"unable to fork: ",&strerr_sys);
        --numchildren; printstatus();
    }
    close(t);
  }
}

MMDB_lookup_result_s lookup_or_die(MMDB_s *mmdb, const char *ipstr)
{
    int gai_error, mmdb_error;


    MMDB_lookup_result_s result =
        MMDB_lookup_string(mmdb, ipstr, &gai_error, &mmdb_error);

    if (0 != gai_error) {
        fprintf(stderr,
                "\n  Error from call to getaddrinfo for %s - %s\n\n",
                ipstr,
#ifdef _WIN32
                gai_strerrorA(gai_error)
#else
                gai_strerror(gai_error)
#endif
                );
        exit(3);
    }

    if (MMDB_SUCCESS != mmdb_error) {
        fprintf(stderr, "\n  Got an error from the maxminddb library: %s\n\n",
                MMDB_strerror(mmdb_error));
        exit(4);
    }

    return result;
}


MMDB_s open_or_die(const char *fname)
{
    MMDB_s mmdb;
    int status = MMDB_open(fname, MMDB_MODE_MMAP, &mmdb);

    if (MMDB_SUCCESS != status) {
        fprintf(stderr, "\n  Can't open %s - %s\n", fname,
                MMDB_strerror(status));

        if (MMDB_IO_ERROR == status) {
            fprintf(stderr, "    IO error: %s\n", strerror(errno));
        }

        fprintf(stderr, "\n");

        exit(2);
    }

    return mmdb;
}

MMDB_entry_data_list_s *dump_data( MMDB_entry_data_list_s *entry_data_list, int *status )
{
 int i;
    switch (entry_data_list->entry_data.type) {
    case MMDB_DATA_TYPE_MAP:
        {
            uint32_t size = entry_data_list->entry_data.data_size;

            for (entry_data_list = entry_data_list->next;
                 size && entry_data_list; size--) {

                if (MMDB_DATA_TYPE_UTF8_STRING !=
                    entry_data_list->entry_data.type) {
                    *status = MMDB_INVALID_DATA_ERROR;
                    return NULL;
                }

                if ( FoundResult==0 && strncmp(entry_data_list->entry_data.utf8_string,"country",entry_data_list->entry_data.data_size )==0) {
                  FoundResult = 1;
                } else if ( FoundResult==1 && strncmp(entry_data_list->entry_data.utf8_string,"iso_code",entry_data_list->entry_data.data_size )==0) {
                  FoundResult = 2;
                }

                entry_data_list = entry_data_list->next;
                entry_data_list = dump_data(entry_data_list,status);

                if (MMDB_SUCCESS != *status) {
                    return NULL;
                }
            }

        }
        break;
    case MMDB_DATA_TYPE_ARRAY:
        {
            uint32_t size = entry_data_list->entry_data.data_size;

            for (entry_data_list = entry_data_list->next;
                 size && entry_data_list; size--) {
                entry_data_list =
                    dump_data(entry_data_list, status);
                if (MMDB_SUCCESS != *status) {
                    return NULL;
                }
            }

        }
        break;
    case MMDB_DATA_TYPE_UTF8_STRING:
        {
            if ( FoundResult==2 ) {
              CountryCode[0] = entry_data_list->entry_data.utf8_string[0];
              CountryCode[1] = entry_data_list->entry_data.utf8_string[1];
              FoundResult = 3;
              return(NULL);
            }
            entry_data_list = entry_data_list->next;
        }
        break;
    case MMDB_DATA_TYPE_BYTES:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_DOUBLE:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_FLOAT:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_UINT16:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_UINT32:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_BOOLEAN:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_UINT64:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_UINT128:
        entry_data_list = entry_data_list->next;
        break;
    case MMDB_DATA_TYPE_INT32:
        entry_data_list = entry_data_list->next;
        break;
    default:
        *status = MMDB_INVALID_DATA_ERROR;
        return NULL;
    }

    *status = MMDB_SUCCESS;
    return entry_data_list;
}
