/*
 * Copyright (c) 2018 Emmanuel Dreyfus
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. All advertising materials mentioning features or use of this software
 *    must display the following acknowledgement:
 *        This product includes software developed by Emmanuel Dreyfus
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <fcntl.h>
#include <string.h>
#include <ctype.h>
#include <getopt.h>
#include <err.h>
#include <errno.h>
#include <sysexits.h>
#include <pthread.h>
#include <netdb.h>
#include <netinet/in.h>
#if defined(AF_INET6) && defined(__NetBSD__)
#include <netinet6/in6.h>
#endif
#include <arpa/nameser.h>
#include <arpa/inet.h>
#include <resolv.h>

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/queue.h>
#include <sys/utsname.h>

#ifndef STATE_OK
#define STATE_OK        0
#define STATE_WARNING   1
#define STATE_CRITICAL  2
#define STATE_UNKNOWN   3
#endif

#define DEFAULT_TIMEOUT 120
#define ERRMSG_LEN	256

struct raddrlist {
	char ral_addr[MAXHOSTNAMELEN + 1];
	TAILQ_ENTRY(raddrlist) ral_next;
};
TAILQ_HEAD(raddrlist_head, raddrlist);

struct addrlist {
	char al_addr[MAXHOSTNAMELEN + 1];
	struct raddrlist_head al_local_rep;
	struct raddrlist_head al_remote_rep;
	int al_match;
	SLIST_ENTRY(addrlist) al_next;
};
SLIST_HEAD(addrlist_head, addrlist);

struct check_dnsmaster_args {
	char			*cda_addr;
	struct options		*cda_opts;
	union res_sockaddr_union*cda_nameserver_socks;
	int			 cda_nameserver_count;
	struct raddrlist_head	*cda_ralp;
	pthread_mutex_t		*cda_mtx;
	pthread_cond_t		*cda_cond;
	char			*cda_errmsg;
	int			*cda_count;
};


struct options {
	int o_verbose;
	struct timespec o_timeout;
	char *o_servaddr_local;
	char *o_servaddr_remote;
	struct addrlist_head *o_alp;
};

void panic(char *);
void free_addrlist(struct addrlist_head *);
void usage(char *);
void parse_options(int, char **, struct options *);
void load_domain(struct addrlist_head *, char *);
int check_dnsmaster(char *, struct options *, union res_sockaddr_union *, int,
		 struct raddrlist_head *, pthread_mutex_t *, 
		 pthread_cond_t *, char *, int *);
void *check_dnsmaster_thread(void *);
int main(int, char **);

void
panic(msg)
	char *msg;
{
	printf("%s\n", msg); 
	exit(STATE_UNKNOWN);
}

void
free_addrlist(alp)
	struct addrlist_head *alp;
{
	struct addrlist *al;
     
	while (!SLIST_EMPTY(alp)) {
		al = SLIST_FIRST(alp);
		SLIST_REMOVE_HEAD(alp, al_next);
		free(al);
	}

	free(alp);
}


void
usage(name)
	char *name;
{
	printf("%s [-L local_dns] -R remote_dns -H domain[,domain...] [-v]\n",
	       name);

	exit(EX_USAGE);
}

void
parse_options(argc, argv, opts)
	int argc;
	char **argv;
	struct options *opts;
{
	time_t now;
	char *progname = argv[0];
	int ch;

	now = time(NULL);
	opts->o_timeout.tv_sec = now + DEFAULT_TIMEOUT;
	opts->o_timeout.tv_nsec = 0;

	while ((ch = getopt(argc, argv, "H:L:R:hvt:")) != -1) {
		switch (ch) {
		case 'L':
			opts->o_servaddr_local = optarg;
			break;
		case 'R':
			opts->o_servaddr_remote = optarg;
			break;
		case 'H':
			load_domain(opts->o_alp, optarg);
			break;
		case 'v':
			opts->o_verbose = 1;
			break;
		case 't':
			opts->o_timeout.tv_sec = now + atoi(optarg);
			opts->o_timeout.tv_nsec = 0;
			break;
		case 'h':
		default:
			usage(progname);
			/* NOTREACHED */
			break;
		}
	}

	argc -= optind;
	argv += optind;

	return;
}

void
load_domain(alp, addrs)
	struct addrlist_head *alp;
	char *addrs;
{	
	const char *sep = " ,";
	char *last;
	char *a;
	struct addrlist *al;

	for (a = strtok_r(addrs, sep, &last);
	     a != NULL;
	     a = strtok_r(NULL, sep, &last)) {

		if ((al = malloc(sizeof(*al))) == NULL)
			panic("malloc failed");
		
		strlcpy(al->al_addr, a, sizeof(al->al_addr));
		TAILQ_INIT(&al->al_local_rep);
		TAILQ_INIT(&al->al_remote_rep);
		al->al_match = 0;

		SLIST_INSERT_HEAD(alp, al, al_next);
		
	}

	return;
}

int
addrlistcmp(struct raddrlist_head *ralp1, struct raddrlist_head *ralp2)
{
	int res = 0;
	struct raddrlist *ral1 = TAILQ_FIRST(ralp1);
	struct raddrlist *ral2 = TAILQ_FIRST(ralp2);

	while (ral1 != NULL && ral2 != NULL) {
		if (strcmp(ral1->ral_addr, ral2->ral_addr) != 0)
			break;

		ral1 = TAILQ_NEXT(ral1, ral_next);
		ral2 = TAILQ_NEXT(ral2, ral_next);
	}

	if (ral1 != NULL || ral2 != NULL)
		res = -1;

	return res;
}

void
insert_sorted(struct raddrlist_head *ralp, struct raddrlist *ral)
{
	struct raddrlist *rali;

	if (TAILQ_EMPTY(ralp)) {
		TAILQ_INSERT_HEAD(ralp, ral, ral_next);
		goto out;
	}

	if (strcmp(TAILQ_FIRST(ralp)->ral_addr, ral->ral_addr) > 0) {
		TAILQ_INSERT_HEAD(ralp, ral, ral_next);
		goto out;
	}

	TAILQ_FOREACH(rali, ralp, ral_next) {
		/*
		 * strcmp(a,b) > 0:  b is before a
		 */
		if (strcmp(rali->ral_addr, ral->ral_addr) > 0)
			break;
	}

	if (rali != NULL)
		TAILQ_INSERT_BEFORE(rali, ral, ral_next);
	else 
		TAILQ_INSERT_TAIL(ralp, ral, ral_next);

out:
	return;
}

int
check_dnsmaster(addr, opts, nameserver_socks, nameserver_count,
		ralp, mtx, wait_cond, errmsg, count) 
	char *addr;
	struct options *opts;
	union res_sockaddr_union *nameserver_socks;
	int nameserver_count;
	struct raddrlist_head *ralp;
	pthread_mutex_t *mtx;
	pthread_cond_t *wait_cond;
	char *errmsg;
	int *count;
{
	struct check_dnsmaster_args *cda;
	pthread_t thread;

	/* thread is to free it */
	if ((cda = malloc(sizeof(*cda))) == NULL)
		panic("malloc failed");

	cda->cda_addr = addr;
	cda->cda_opts = opts;
	cda->cda_nameserver_socks = nameserver_socks;
	cda->cda_nameserver_count = nameserver_count;
	cda->cda_ralp = ralp;
	cda->cda_mtx = mtx;
	cda->cda_cond = wait_cond;
	cda->cda_errmsg = errmsg;
	cda->cda_count = count;

	if (pthread_create(&thread, NULL, *check_dnsmaster_thread, cda) != 0)
		panic("pthread_create failed");

	return 0;
}


void *
check_dnsmaster_thread(void *args)
{
	struct check_dnsmaster_args *cda = args;
	char *queryname = cda->cda_addr;
	struct options *opts = cda->cda_opts;
	union res_sockaddr_union *nameserver_socks = cda->cda_nameserver_socks;
	int nameserver_count = cda->cda_nameserver_count;
	struct raddrlist_head *ralp = cda->cda_ralp;
	pthread_mutex_t *mtx = cda->cda_mtx;
	pthread_cond_t *wait_cond = cda->cda_cond;
	char *errmsg = cda->cda_errmsg;
	int *count = cda->cda_count;

	struct __res_state res;
        unsigned char ans[NS_MAXMSG + 1];
        int anslen;
	int res_error;
	ns_msg handle;
	ns_rr rr;
	int retry = 0;
	int i;

	/*
	 * initialize resolver
	 */
	if (res_ninit(&res) != 0)
		panic("Cannot init resolver");

	if (nameserver_count > 0)
		res_setservers(&res, nameserver_socks, nameserver_count);

	/* 
	 * Retry up to five requests
	 */
	do {
		retry++;

		anslen = res_nquery(&res, queryname, C_IN, T_NS,
				     ans, NS_MAXMSG + 1);
		res_error = res.res_h_errno;
	} while (anslen == -1 && retry < 5 && res_error == TRY_AGAIN);

	if (anslen == -1 && opts->o_verbose)
		printf("%s try %d error %d %s\n", queryname, retry,
		       res_error,
	    	       (res_error == -1) ? strerror(errno) : "");

	/*
	 * Collect IPv4 ansswers
	 */
	if (anslen != -1) {
		int msg_count;

		if (ns_initparse(ans, anslen, &handle) != 0)
			panic("ns_initparse IPv4 failed");

		msg_count = ns_msg_count(handle, ns_s_an);
		for (i = 0; i < msg_count; i++) {
			struct raddrlist *ral;

			if ((ns_parserr(&handle, ns_s_an, i, &rr)) != 0) 
				continue;

			if (rr.type != T_NS)
				continue;

			/*
			 * Parse result
			 */			 
			if ((ral = malloc(sizeof(*ral))) == NULL)
				panic("malloc failed");

			if (ns_name_uncompress(ns_msg_base(handle),
					       ns_msg_end(handle),
				 	       ns_rr_rdata(rr),
				  	       ral->ral_addr,
					       sizeof(ral->ral_addr)) < 0)
				panic("ns_name_uncompress failed");

			/*
			 * Store result
			 */
			if (pthread_mutex_lock(mtx) != 0)
				panic("pthread_mutex_lock failed");

			insert_sorted(ralp, ral);

			if (pthread_mutex_unlock(mtx) != 0)
				panic("pthread_mutex_unlock failed");
		}
	}

	if (pthread_mutex_lock(mtx) != 0)
		panic("pthread_mutex_lock failed");

	(*count)++;

	if (anslen == -1) {
		switch (res_error) {
		case 0:		/* NETDB_SUCCESS */
		case HOST_NOT_FOUND:
		case NO_DATA:
			break;
		case -1:	/* NETDB_INTERNAL */
		case TRY_AGAIN:
			snprintf(errmsg, ERRMSG_LEN,
				 "%s%s%s: resolver timeout",
				 errmsg, strlen(errmsg) ? ", " : "",
				 queryname);
			errmsg[ERRMSG_LEN] = '\0';
			break;
		case NO_RECOVERY:
			snprintf(errmsg, ERRMSG_LEN,
				 "%s%s%s:fatal resolver error",
				 errmsg, strlen(errmsg) ? ", " : "",
				 queryname);
			errmsg[ERRMSG_LEN] = '\0';
			break;
		default:
			snprintf(errmsg, ERRMSG_LEN,
				 "%s%s%s:unexepected resolver error",
				 errmsg, strlen(errmsg) ? ", " : "",
				 queryname);
			errmsg[ERRMSG_LEN] = '\0';
			break;
		}
	}

	if (pthread_mutex_unlock(mtx) != 0)
		panic("pthread_mutex_unlock failed");

	if (pthread_cond_signal(wait_cond) != 0)
		panic("pthread_cond_signal failed");

	free(cda);
	return NULL;
}

int
get_server_addr(char *name, union res_sockaddr_union **addrsp, int *countp)
{
	struct addrinfo hints, *res, *res0;
	int error;
	int count;
	union res_sockaddr_union *addrs;

	memset(&hints, 0, sizeof(hints));
	hints.ai_family = AF_UNSPEC;
	hints.ai_socktype = SOCK_DGRAM;
	error = getaddrinfo(name, "domain", &hints, &res0);
	if (error) { 
		warnx("getaddrinfo %s failed: %s", name, gai_strerror(error));
		return -1;
	}

	count = 0;
	for (res = res0; res; res = res->ai_next) {
		switch (res->ai_family) {
		case AF_INET:
		case AF_INET6:
			count++;
			break;
		default:
			break;
		}
	}

	if ((addrs = malloc(sizeof(*addrs) * count)) == NULL)
		err(1, "malloc failed");

	count = 0;
	for (res = res0; res; res = res->ai_next) {
		switch (res->ai_family) {
		case AF_INET:
			memcpy(&addrs[count].sin, res->ai_addr,
			       sizeof(addrs[count].sin));
			count++;
			break;
		case AF_INET6:
			memcpy(&addrs[count].sin6, res->ai_addr,
			       sizeof(addrs[count].sin6));
			count++;
			break;
		default:
			break;
		}
	}
	freeaddrinfo(res0);

	*addrsp = addrs;
	*countp = count;

	return 0;
}

int
main(argc, argv)
	int argc;
	char **argv;
{
	struct options opts;
	union res_sockaddr_union *nameserver_socks_local = NULL;
	union res_sockaddr_union *nameserver_socks_remote = NULL;
	int nameserver_count_local = 0;
	int nameserver_count_remote = 0;
	int status_code = STATE_UNKNOWN;
	struct addrlist *al;
	struct addrlist_head *alp;
	pthread_mutex_t mtx = PTHREAD_MUTEX_INITIALIZER;
	pthread_cond_t wait_cond = PTHREAD_COND_INITIALIZER;
	struct raddrlist *ral;
	char errmsg[ERRMSG_LEN + 1] = "";
	int count = 0;
	int total_count = 0;
	int done = 0;


	/*
	 * Initialize address list
	 */
	if ((alp = malloc(sizeof(*alp))) == NULL)
		err(EX_OSERR, "failed to malloc addrlist head");
	SLIST_INIT(alp);

	opts.o_verbose = 0;
	opts.o_servaddr_local = NULL;
	opts.o_servaddr_remote = NULL;
	opts.o_alp = alp;

	parse_options(argc, argv, &opts);

	if (SLIST_EMPTY(alp))
		panic("No domain specified");

	if (opts.o_servaddr_local == NULL) {
		if (opts.o_verbose)
			printf("No local DNS provided, using system\n");
		/* Use system setup */
	} else {
		if (get_server_addr(opts.o_servaddr_local,
				    &nameserver_socks_local,
				    &nameserver_count_local) != 0)
			panic("local DNS server resolution failed");
	}
	
	if (opts.o_servaddr_remote == NULL)
		panic("No remote DNS specified");

	if (get_server_addr(opts.o_servaddr_remote,
			    &nameserver_socks_remote,
			    &nameserver_count_remote) != 0)
		panic("remote DNS server resolution failed");

	/*
	 * Run the resolvers
	 */
	SLIST_FOREACH(al, alp, al_next) {
		if (opts.o_verbose)
			printf("trying %s\n", al->al_addr);

		total_count++;
		check_dnsmaster(al->al_addr, &opts, 
				nameserver_socks_local,
				nameserver_count_local,
				&al->al_local_rep, &mtx, 
				&wait_cond, errmsg, &count);

		total_count++;
		check_dnsmaster(al->al_addr, &opts,
				nameserver_socks_remote,
				nameserver_count_remote,
				&al->al_remote_rep, &mtx, 
				&wait_cond, errmsg, &count);
	}

	/*
	 * Wait replies 
	 */
	if (pthread_mutex_lock(&mtx) != 0)
		panic("pthread_mutex_lock failed");

	while (!done) {
		int res;

		res = pthread_cond_timedwait(&wait_cond,
					     &mtx, &opts.o_timeout);
		if (res != 0 && errno == ETIMEDOUT) {
			snprintf(errmsg, ERRMSG_LEN,
				 "%s%sglobal timeout with %d/%d replies",
				 errmsg, strlen(errmsg) ? ", " : "",
				 count, total_count);
			errmsg[ERRMSG_LEN] = '\0';
			done = 1;
			break;
		}

		done = (count == total_count);
	}

	/*
	 * Display results
	 */
	status_code = (strlen(errmsg) == 0) ? STATE_OK : STATE_UNKNOWN;
	SLIST_FOREACH(al, alp, al_next) {
		if (opts.o_verbose) {
			char *servaddr_local = opts.o_servaddr_local;

			if (servaddr_local == NULL)
				servaddr_local = "system resolver";

			printf("--- %s from %s\n", al->al_addr,
			       servaddr_local);
			printf("+++ %s from %s\n", al->al_addr,
			       opts.o_servaddr_remote);

			TAILQ_FOREACH(ral, &al->al_local_rep, ral_next)
				printf("- %s\n", ral->ral_addr);
			TAILQ_FOREACH(ral, &al->al_remote_rep, ral_next)
				printf("+ %s\n", ral->ral_addr);

			printf("\n");
		}

		if (addrlistcmp(&al->al_local_rep, &al->al_remote_rep) == 0) {
			al->al_match = 1;
		} else {
			if (status_code == STATE_OK)
				status_code = STATE_CRITICAL;
		}
	}


	switch (status_code) {
	case STATE_OK:
		printf("OK: ");
		SLIST_FOREACH(al, alp, al_next)
			printf("%s ", al->al_addr);
		printf("\n");
		break;
	case STATE_CRITICAL:
		printf("CRITICAL: DNS master mismatches for ");
		SLIST_FOREACH(al, alp, al_next)
			if (al->al_match == 0)
			printf("%s ", al->al_addr);
		printf("\n");
		break;
	default:
		printf("UNKNOWN: %s\n", errmsg);
		break;
	}

	/* 
	 * Do not free anything as some other thread may still use it.
	 * exit() will cleanup anyway
	 */

	return status_code;
}
