/* Copyright (C) 2002, 2003 Thorsten Kukuk
   Author: Thorsten Kukuk <kukuk@suse.de>

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License version 2 as
   published by the Free Software Foundation.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software Foundation,
   Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.  */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#define _GNU_SOURCE

#include <assert.h>
#include <getopt.h>
#include <errno.h>
#include <ctype.h>
#include <grp.h>
#include <pwd.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <termios.h>
#include <locale.h>
#include <libintl.h>
#include <arpa/inet.h>
#include <sys/param.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <syslog.h>
#include <signal.h>
#include <string.h>
#include <netdb.h>

#include <security/pam_appl.h>
#include <security/pam_misc.h>

#include <openssl/crypto.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <openssl/err.h>

#include "dbg_log.h"
#include "rpasswd-client.h"
#include "error_codes.h"

#ifndef _
#define _(String) gettext (String)
#endif

/* Print the version information.  */
static void
print_version (const char *program)
{
  fprintf (stdout, "%s (%s) %s\n", program, PACKAGE, VERSION);
  fprintf (stdout, gettext ("\
Copyright (C) %s Thorsten Kukuk.\n\
This is free software; see the source for copying conditions.  There is NO\n\
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n\
"), "2002, 2003");
  /* fprintf (stdout, _("Written by %s.\n"), "Thorsten Kukuk"); */
}

static void
print_usage (FILE * stream, const char *program)
{
  fprintf (stream, _("Usage: %s [-4|-6][-a][-f config-file][-h hostname][-p port][-v][name]\n"),
	   program);
}

static void
print_help (const char *program)
{
  print_usage (stdout, program);
  fprintf (stdout, _("%s - change password information\n\n"), program);

  fputs (_("  -4             Use IPv4 only\n"), stdout);
  fputs (_("  -6             Use IPv6 only\n"), stdout);
  fputs (_(\
"  -a             Admin mode, special admin password is required\n"),
	 stdout);
  fputs (_("  -f config-file Specify a different config file\n"), stdout);
  fputs (_("  -h hostname    Specify the remote server\n"), stdout);
  fputs (_("  -p port        Specify port remote server is listening on\n"),
	 stdout);
  fputs (_("  -v, --verbose  Be verbose, print SSL connection data\n"),
	 stdout);
  fputs (_("  --help         Give this help list\n"), stdout);
  fputs (_("  --usage        Give a short usage message\n"), stdout);
  fputs (_("  --version      Print program version\n"), stdout);
}

static void
print_error (const char *program)
{
  fprintf (stderr,
	   _("Try `%s --help' or `%s --usage' for more information.\n"),
	   program, program);
}

static int
start_request (SSL * ssl, char *username, int admin_mode)
{
  request_header req;
  char *locale = getenv ("LANG");

  if (admin_mode)
    req.request = START_ADMIN;
  else
    req.request = START;
  req.version = RPASSWD_VERSION;
  req.data_len = strlen (username) + 1;
  if (locale)
    req.locale_len = strlen (locale) + 1;
  else
    req.locale_len = 0;

  if (SSL_write (ssl, &req, sizeof (request_header)) !=
      sizeof (request_header))
    return -1;

  if (locale)
    if (SSL_write (ssl, locale, req.locale_len) != req.locale_len)
      return -1;

  if (SSL_write (ssl, username, req.data_len) != req.data_len)
    return -1;

  return 0;
}

/* Read a line of input string, giving prompt when appropriate.  */
static int
read_string (int echo, const char *prompt, char **retstr)
{
  struct termios term_before, term_tmp;
  char line[PAM_MAX_MSG_SIZE];
  int nc = -1, have_term = 0;

  D (("called with echo='%s', prompt='%s'.", echo ? "ON" : "OFF", prompt));

  if (isatty (STDIN_FILENO))
    {				/* terminal state */
      /* is a terminal so record settings and flush it */
      if (tcgetattr (STDIN_FILENO, &term_before) != 0)
	{
	  fprintf (stderr, ("Error: failed to get terminal settings\n"));
	  *retstr = NULL;
	  return -1;
	}
      memcpy (&term_tmp, &term_before, sizeof (term_tmp));
      if (!echo)
	term_tmp.c_lflag &= ~(ECHO);

      have_term = 1;
    }
  else if (!echo)
      fprintf (stderr, _("Warning: cannot turn echo off\n"));

  /* reading the line */
  fprintf (stderr, "%s", prompt);
  /* this may, or may not set echo off -- drop pending input */
  if (have_term)
    (void) tcsetattr (STDIN_FILENO, TCSAFLUSH, &term_tmp);

  nc = read (STDIN_FILENO, line, PAM_MAX_MSG_SIZE - 1);
  if (have_term)
    {
      (void) tcsetattr (STDIN_FILENO, TCSADRAIN, &term_before);
      if (!echo)	/* do we need a newline? */
	fprintf (stderr, "\n");
    }

  if (nc > 0) /* We got some user input.  */
    {
      if (line[nc - 1] == '\n') /* <NUL> terminate */
	line[--nc] = '\0';
      else
	line[nc] = '\0';

      *retstr = x_strdup (line);	/* return malloc()ed string */
      _pam_overwrite (line);

      return nc;
    }
  else if (nc == 0) /* Ctrl-D */
    {
      D (("user did not want to type anything"));
      fprintf (stderr, "\n");
    }

  /* getting here implies that there was an error or Ctrl-D pressed.  */
  if (have_term)
    (void) tcsetattr (STDIN_FILENO, TCSADRAIN, &term_before);

  memset (line, 0, PAM_MAX_MSG_SIZE);	/* clean up */
  *retstr = NULL;
  return nc;
}

static int
send_string (SSL * ssl, u_int32_t ret, const char *str)
{
  conv_header resp;

  resp.retval = ret;
  if (str == NULL)
    resp.data_len = 0;
  else
    resp.data_len = strlen (str) + 1;
  if (TEMP_FAILURE_RETRY (SSL_write (ssl, &resp, sizeof (resp)))
      != sizeof (resp))
    return E_FAILURE;

  if (str)
    if (TEMP_FAILURE_RETRY (SSL_write (ssl, str, resp.data_len))
	!= resp.data_len)
      return E_FAILURE;

  return E_SUCCESS;
}

#define CONV_ECHO_ON  1		/* types of echo state */
#define CONV_ECHO_OFF 0

static int
handle_responses (SSL *ssl)
{
  response_header resp;
  char retval = E_SUCCESS;
  char *buf;

  do
    {
      errno = 0;
      if (TEMP_FAILURE_RETRY (SSL_read (ssl, &resp, sizeof (resp)))
	  != sizeof (resp))
	{
	  char err_buf[256];

	  if (errno == 0)
	    fprintf (stderr, _("error while reading request: %s"),
		     _("wrong data received"));
	  else
	    fprintf (stderr, _("error while reading request: %s"),
		     strerror_r (errno, err_buf, sizeof (err_buf)));
	  fputs ("\n", stderr);
	  return E_FAILURE;
	}

      buf = alloca (resp.data_len);
      if (TEMP_FAILURE_RETRY (SSL_read (ssl, buf, resp.data_len))
	  != resp.data_len)
	{
	  char err_buf[256];

	  fprintf (stderr, _("error while reading request data: %s"),
		   strerror_r (errno, err_buf, sizeof (err_buf)));
	  fputs ("\n", stderr);
	  return E_FAILURE;
	}

      switch (resp.type)
	{
	case TEXT_INFO:
	  printf ("%s\n", buf);
	  break;
	case ERROR_MSG:
	  fprintf (stderr, "%s\n", buf);
	  break;
	case PROMPT_ECHO_OFF:
	  {
	    char *string = NULL;
	    int nc = read_string (CONV_ECHO_OFF, buf, &string);
	    if (nc < 0)
	      retval = send_string (ssl, PAM_CONV_ERR, string);
	    else
	      retval = send_string (ssl, PAM_SUCCESS, string);
	  }
	  break;
	case PROMPT_ECHO_ON:
	  {
	    char *string = NULL;
	    int nc = read_string (CONV_ECHO_ON, buf, &string);
	    if (nc < 0)
	      retval = send_string (ssl, PAM_CONV_ERR, string);
	    else
	      retval = send_string (ssl, PAM_SUCCESS, string);
	  }
	  break;
	case FINISH:
	  retval = buf[0];
	  break;
	default:
	  break;
	}

      if ((resp.type == PROMPT_ECHO_ON || resp.type == PROMPT_ECHO_OFF) &&
	  retval != 0)
	{
	  char err_buf[256];

	  fprintf (stderr, _("Cannot send input back to server: %s\n"),
		   strerror_r (errno, err_buf, sizeof (err_buf)));
	  return E_FAILURE;
	}
    }
  while (resp.type != FINISH);

  return retval;
}

/* Load the config file (/etc/rpasswd.conf)  */
static int
load_config (const char *configfile, int verbose, int check_syntax,
	     char **hostp, char **portp, int *reqcertp)
{
  FILE *fp;
  char *buf = NULL;
  size_t buflen = 0;
  int have_entries = 0; /* # of entries we found in config file */
  int bad_entries = 0;

  fp = fopen (configfile, "r");
  if (NULL == fp)
    return 1;

  if (verbose > 1)
    printf (_("parsing config file"));

  while (!feof (fp))
    {
      char *tmp, *cp;
#if defined(HAVE_GETLINE)
      ssize_t n = getline (&buf, &buflen, fp);
#elif defined (HAVE_GETDELIM)
      ssize_t n = getdelim (&buf, &buflen, '\n', fp);
#else
      ssize_t n;

      if (buf == NULL)
        {
          buflen = 8096;
          buf = malloc (buflen);
        }
      buf[0] = '\0';
      fgets (buf, buflen - 1, fp);
      if (buf != NULL)
        n = strlen (buf);
      else
        n = 0;
#endif /* HAVE_GETLINE / HAVE_GETDELIM */
      cp = buf;

      if (n < 1)
        break;

      tmp = strchr (cp, '#');  /* remove comments */
      if (tmp)
        *tmp = '\0';
      while (isspace ((int)*cp))    /* remove spaces and tabs */
        ++cp;
      if (*cp == '\0')        /* ignore empty lines */
        continue;

      if (cp[strlen (cp) - 1] == '\n')
        cp[strlen (cp) - 1] = '\0';

      if (verbose > 1)
        printf ("%s %s", _("Trying entry:"), cp);

      if (check_syntax)
        printf ("%s %s\n", _("Trying entry:"), cp);

      if (strncmp (cp, "server", 6) == 0 && isspace ((int)cp[6]))
	{
	  if (hostp != NULL)
	    {
	      char tmpserver[MAXHOSTNAMELEN+1];

	      if (sscanf (cp, "server %s", tmpserver) == 1)
		*hostp = strdup (tmpserver);
	    }
	  continue;
	}
      else if (strncmp (cp, "port", 4) == 0 && isspace ((int)cp[4]))
	{
	  if (portp != NULL)
	    {
	      char tmpport [30];

	      if (sscanf (cp, "port %s", tmpport) == 1)
		*portp = strdup (tmpport);
	    }
	  continue;
	}
      else if (strncmp (cp, "reqcert", 7) == 0 && isspace ((int)cp[7]))
	{
	  char *p = &cp[7];

	  while (isspace (*p))
	    ++p;

	  if (strcmp (p, "never") == 0)
	    *reqcertp = 0;
	  else if (strcmp (p, "allow") == 0)
	    *reqcertp = 1;
	  else if (strcmp (p, "try") == 0)
	    *reqcertp = 2;
	  else if (strcmp (p, "demand") == 0 ||
		   strcmp (p, "hard") == 0)
	    *reqcertp = 3;
	  continue;
	}
      if (check_syntax)
        {
          printf (_("Entry \"%s\" is not valid!\n"), cp);
          ++bad_entries;
        }
      else
        fprintf (stderr, _("Entry \"%s\" is not valid, ignored!\n"), cp);
    }
  fclose (fp);

  if (buf)
    free (buf);

  if (check_syntax)
    {
      if (bad_entries)
        {
          printf (_("Bad entries found.\n"));
          return 1;
        }
      if (!have_entries)
        {
          printf (_("No entry found.\n"));
          return 1;
        }
    }

  if (!have_entries)
    {
      if (verbose > 1)
	printf (_("No entry found."));
      return 1;
    }

  return 0;
}

int
main (int argc, char **argv)
{
  const char *config_file = _PATH_RPASSWDCONF;
  const char *program = "rpasswd";
  char *hostp = NULL, *portp = NULL;
  int sock = -1, ipv4 = 1, ipv6 = 1;
  int verbose = 0;
  int admin_mode = 0;
  int reqcert = 3;
  int retval;
  char *username;
  SSL_CTX *ctx;
  SSL *ssl;


  /* Set locale via LC_ALL.  */
  setlocale (LC_ALL, "");
  /* Set the text message domain.  */
  textdomain (PACKAGE);

  /* Ignore all signals which can make trouble later.  */
  signal (SIGXFSZ, SIG_IGN);
  signal (SIGPIPE, SIG_IGN);

  /* Parse program arguments */
  while (1)
    {
      int c;
      int option_index = 0;
      static struct option long_options[] = {
	{"admin", no_argument, NULL, 'a'},
	{"config-file", required_argument, NULL, 'f'},
	{"host", required_argument, NULL, 'h'},
	{"ipv4", no_argument, NULL, '4'},
	{"ipv6", no_argument, NULL, '6'},
	{"port", no_argument, NULL, 'p'},
	{"verbose", no_argument, NULL, 'v'},
	{"version", no_argument, NULL, '\255'},
	{"usage", no_argument, NULL, '\254'},
	{"help", no_argument, NULL, '\253'},
	{NULL, 0, NULL, '\0'}
      };

      c = getopt_long (argc, argv, "af:h:p:v46", long_options, &option_index);
      if (c == EOF)
	break;
      switch (c)
	{
	case 'a':
	  admin_mode = 1;
	  break;
	case 'f':
	  config_file = optarg;
	  break;
	case 'h':
	  hostp = optarg;
	  break;
	case '4':
	  if (ipv4 == 0 || ipv6 == 0)
	    {
	      print_usage (stderr, program);
	      return E_USAGE;
	    }
	  ipv6 = 0;
	  break;
	case '6':
	  if (ipv4 == 0 || ipv6 == 0)
	    {
	      print_usage (stderr, program);
	      return E_USAGE;
	    }
	  ipv4 = 0;
	  break;
	case 'p':
	  portp = optarg;
	  break;
	case 'v':
	  verbose++;
	  break;
	case '\253':
	  print_help (program);
	  return 0;
	case '\255':
	  print_version (program);
	  return 0;
	case '\254':
	  print_usage (stdout, program);
	  return E_USAGE;
	default:
	  print_error (program);
	  return E_BAD_ARG;
	}
    }

  argc -= optind;
  argv += optind;

  if (argc > 1)
    {
      fprintf (stderr, _("%s: Too many arguments\n"), program);
      print_error (program);
      return E_USAGE;
    }

  if (hostp)
    if (portp)
      load_config (config_file, verbose, 0, NULL, NULL, &reqcert);
    else
      load_config (config_file, verbose, 0, NULL, &portp, &reqcert);
  else
    if (portp)
      load_config (config_file, verbose, 0, &hostp, NULL, &reqcert);
    else
      load_config (config_file, verbose, 0, &hostp, &portp, &reqcert);

  if (portp == NULL)
    portp = "rpasswd";

  /* Get the login name of the calling user. This could be the one
     argument we still have or we use getpwuid/getuid to determine
     the login name.  */
  if (argc == 1)
    username = strdup (argv[0]);
  else
    {
      int pw_buflen = 256;
      char *pw_buffer = alloca (pw_buflen);
      struct passwd pw_resultbuf;
      struct passwd *pw = NULL;

      while (getpwuid_r (getuid (), &pw_resultbuf, pw_buffer, pw_buflen,
			 &pw) != 0 && errno == ERANGE)
	{
	  errno = 0;
	  pw_buflen += 256;
	  pw_buffer = alloca (pw_buflen);
	}
      if (pw == NULL)
	{
	  fprintf (stderr, _("Go away, you do not exist!"));
	  return E_UNKNOWN_USER;
	}
      username = strdup (pw->pw_name);
    }

  if (hostp != NULL)
    {
#ifdef NI_WITHSCOPEID
      const int niflags = NI_NUMERICHOST | NI_WITHSCOPEID;
#else
      const int niflags = NI_NUMERICHOST;
#endif
      struct addrinfo hints, *res, *res0;
      int error;

      memset (&hints, 0, sizeof (hints));
      if (ipv4 && ipv6)
	hints.ai_family = PF_UNSPEC;
      else if (ipv6)
	hints.ai_family = PF_INET6;
      else if (ipv4)
	hints.ai_family = PF_INET;
      else
	hints.ai_family = PF_UNSPEC;

      hints.ai_socktype = SOCK_STREAM;
      hints.ai_flags = AI_CANONNAME;

      error = getaddrinfo (hostp, portp, &hints, &res0);
      if (error)
	{
	  if (error == EAI_NONAME)
	    {
	      fprintf (stderr,
		       _("\
Hostname or service not known for specified protocol\n"));
	      return E_FAILURE;
	    }
	  else if (error == EAI_SERVICE)
	    {
	      /* if port cannot be resolved, try compiled in
		 port number. If this works, don't abort here.  */
	      char *cp;
	      asprintf (&cp, "%d", RPASSWDD_PORT);
	      error = getaddrinfo (hostp, cp, &hints, &res0);
	      if (error)
		{
		  fprintf (stderr, _("bad port: %s\n"), portp);
		  return E_FAILURE;
		}
	    }
	  else
	    {
	      fprintf (stderr, "%s: %s\n", hostp, gai_strerror (error));
	      return E_FAILURE;
	    }
	}

      for (res = res0; res; res = res->ai_next)
	{
	  char hbuf[NI_MAXHOST];

	  if (getnameinfo (res->ai_addr, res->ai_addrlen,
			   hbuf, sizeof (hbuf), NULL, 0, niflags) != 0)
	    strcpy (hbuf, "(invalid)");
	  printf (_("Trying %s...\r\n"), hbuf);

	  /* Create the socket.  */
	  sock = socket (res->ai_family, res->ai_socktype, res->ai_protocol);
	  if (sock < 0)
	    continue;

	  if (connect (sock, res->ai_addr, res->ai_addrlen) < 0)
	    {
	      if (getnameinfo (res->ai_addr, res->ai_addrlen,
			       hbuf, sizeof (hbuf), NULL, 0, niflags) != 0)
		strcpy (hbuf, "(invalid)");
	      fprintf (stderr, _("connect to address %s: %s\n"), hbuf,
		       strerror (errno));
	      close (sock);
	      sock = -1;
	      continue;
	    }
	  fputs ("\n", stdout);
	  break;
	}
      freeaddrinfo (res0);
      if (sock < 0)
	return E_FAILURE;
    }
  else
    {
      fprintf (stderr, _("No server specified\n"));
      return E_USAGE;
    }

  /* Do SSL */
  {
    X509 *server_cert;
    char *str;
    SSL_METHOD *meth;
    long verify_result;
    int err;

    SSLeay_add_ssl_algorithms ();
    meth = SSLv23_client_method ();
    SSL_load_error_strings ();
    ctx = SSL_CTX_new (meth);
    if (ctx == NULL)
      {
	fprintf (stderr, ERR_error_string (ERR_get_error (), NULL));
	return E_SSL_FAILURE;
      }

#if 0
    /* This is only necessary if we configure a unusual path.
       XXX Make this a program option.  */
    if (!SSL_CTX_load_verify_locations (ctx, NULL, "/etc/ssl/certs"))
      {
	fprintf (stderr, _("error loading default verify locations: %s\n"),
		 ERR_error_string (ERR_get_error (), NULL));
	if (reqcert > 1)
	  return E_SSL_FAILURE;
      }
#endif
    if (!SSL_CTX_set_default_verify_paths(ctx))
      {
	fprintf (stderr, _("error setting default verify path: %s\n"),
		 ERR_error_string (ERR_get_error (), NULL));
	if (reqcert > 1)
	  return E_SSL_FAILURE;
      }

    /* Now we have TCP conncetion. Start SSL negotiation. */
    ssl = SSL_new (ctx);
    if (ssl == NULL)
      {
	fprintf (stderr, ERR_error_string (ERR_get_error (), NULL));
	return E_SSL_FAILURE;
      }
    SSL_set_fd (ssl, sock);
    err = SSL_connect (ssl);
    if (err < 1)
      {
	fprintf (stderr, "SSL_connect: %s", ERR_error_string (err, NULL));
	close (sock);
	return E_SSL_FAILURE;
      }

    if (reqcert > 0 || verbose)
      {
	/* Get server's certificate (note: beware of dynamic allocation).  */
	server_cert = SSL_get_peer_certificate (ssl);

	/* Verify severs certificate.  */
	verify_result = SSL_get_verify_result (ssl);

	/* Following two steps are optional and not required for
	   data exchange to be successful except the client couldn't verfiy
	   the server certificate.  */
	if (verify_result || verbose)
	  {
	    /* Get the cipher.  */
	    printf (_("SSL connection using %s\n\n"), SSL_get_cipher (ssl));

	    if (server_cert == NULL)
	      {
		fprintf (stderr, _("Server does not have a certificate?\n"));
		if (reqcert >= 3)
		  return E_SSL_FAILURE;
	      }
	    else
	      {
		printf (_("Server certificate:\n"));

		str = X509_NAME_oneline (X509_get_subject_name (server_cert),
					 0, 0);
		if (str)
		  {
		    printf (_("  subject: %s\n"), str);
		    free (str);
		  }
		str = X509_NAME_oneline (X509_get_issuer_name (server_cert),
					 0, 0);
		if (str)
		  {
		    printf (_("  issuer: %s\n"), str);
		    free (str);
		  }
		/* We could do all sorts of certificate verification stuff
		   here before deallocating the certificate.  */

		fputs ("\n", stdout);
	      }
	  }

	if ((verify_result = SSL_get_verify_result (ssl)) != X509_V_OK)
	  {
	    fprintf (stderr, "Server certificate is not ok: %s!\n",
		     X509_verify_cert_error_string (verify_result));
	    if (reqcert >= 2)
	      return E_SSL_FAILURE;
	  }

	X509_free (server_cert);
      }
  }

  if ((retval = start_request (ssl, username, admin_mode)) == 0)
    retval = handle_responses (ssl);
  else
    retval = E_FAILURE;

  free (username);
  close (sock);
  SSL_free (ssl);
  SSL_CTX_free (ctx);

  return retval;
}
