/*
** Copyright 2000-2001 Double Precision, Inc.
** See COPYING for distribution information.
*/
#include	"config.h"
#include	"argparse.h"
#include	"spipe.h"
#include	"rfc1035/rfc1035.h"
#include	"soxwrap/soxwrap.h"
#ifdef  getc
#undef  getc
#endif
#include	<stdio.h>
#include	<string.h>
#include	<stdlib.h>
#include	<ctype.h>
#include	<netdb.h>
#if HAVE_DIRENT_H
#include <dirent.h>
#define NAMLEN(dirent) strlen((dirent)->d_name)
#else
#define dirent direct
#define NAMLEN(dirent) (dirent)->d_namlen
#if HAVE_SYS_NDIR_H
#include <sys/ndir.h>
#endif
#if HAVE_SYS_DIR_H
#include <sys/dir.h>
#endif
#if HAVE_NDIR_H
#include <ndir.h>
#endif
#endif
#if	HAVE_UNISTD_H
#include	<unistd.h>
#endif
#if	HAVE_FCNTL_H
#include	<fcntl.h>
#endif
#include	<errno.h>
#if	HAVE_SYS_TYPES_H
#include	<sys/types.h>
#endif
#if	HAVE_SYS_STAT_H
#include	<sys/stat.h>
#endif
#include	<sys/socket.h>
#include	<arpa/inet.h>
#define	DEBUG_SAFESTACK	1	/* For openssl 0.9.6 */

#include	<openssl/ssl.h>
#include	<openssl/err.h>
#include	<sys/time.h>

static const char rcsid[]="$Id: starttls.c,v 1.25 2002/03/15 19:17:18 mrsam Exp $";

#ifndef NO_RSA
static RSA *rsa_callback(SSL *, int, int);
#endif

const char *ssl_cipher_list=0;
int session_timeout=0;
const char *dhcertfile=0;
const char *certfile=0;
const char *protocol=0;

int peer_verify_level=SSL_VERIFY_PEER;
		/* SSL_VERIFY_NONE */
		/* SSL_VERIFY_PEER */
		/* SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT */
const char *peer_cert_dir=0;
const char *peer_cert_file=0;

/* Command-line options: */
const char *clienthost=0;
const char *clientport=0;

const char *server=0;
const char *localfd=0;
const char *remotefd=0;
const char *statusfd=0;
const char *tcpd=0;
const char *peer_verify_domain=0;
const char *verify_fail_msg=0;
const char *fdprotocol=0;
int peer_domain_verified;

const char *printx509=0;
FILE *printx509_fp;

FILE *errfp;

/* -------------------------------------------------------------------- */

static int ssl_verify_callback(int, X509_STORE_CTX *);

static void sslerror(const char *pfix)
{
char errmsg[256];

	ERR_error_string(ERR_get_error(), errmsg);
	fprintf(errfp, "starttls: %s: %s\n", pfix, errmsg);
}

static void nonsslerror(const char *pfix)
{
	fprintf(errfp, "%s: %s\n", pfix, strerror(errno));
}

SSL_CTX *create_tls(int isserver)
{
SSL_CTX *ctx;

	SSL_load_error_strings();
	SSLeay_add_ssl_algorithms();

	ctx=SSL_CTX_new(protocol && strcmp(protocol, "SSL2") == 0
							? SSLv2_method():
		protocol && strcmp(protocol, "SSL3") == 0 ? SSLv23_method():
		TLSv1_method());

	if (!ctx)
	{
		nonsslerror("SSL_CTX_NEW");
		return (0);
	}
	SSL_CTX_set_options(ctx, SSL_OP_ALL);

	if (ssl_cipher_list)
		SSL_CTX_set_cipher_list(ctx, ssl_cipher_list);
	SSL_CTX_set_timeout(ctx, session_timeout);

	if (isserver)
	{
#ifndef NO_RSA
		SSL_CTX_set_tmp_rsa_callback(ctx, rsa_callback);
#endif

#ifndef	NO_DH
		if (dhcertfile)
		{
		BIO	*bio;
		DH	*dh;
		int	cert_done=0;

			if ((bio=BIO_new_file(dhcertfile, "r")) != 0)
			{
				if ((dh=PEM_read_bio_DHparams(bio, NULL, NULL,
					NULL)) != 0)
				{
					SSL_CTX_set_tmp_dh(ctx, dh);
					cert_done=1;
					DH_free(dh);
				}
				else
					sslerror(dhcertfile);
				BIO_free(bio);
			}
			else
				sslerror(dhcertfile);
			if (!cert_done)
			{
				fprintf(errfp, "starttls: DH init failed!\n");
				SSL_CTX_free(ctx);
				return (0);
			}
		}
#endif
	}
	SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_BOTH);
	if (certfile)
	{
		if(!SSL_CTX_use_certificate_file(ctx, certfile,
			SSL_FILETYPE_PEM))
		{
			sslerror(certfile);
			SSL_CTX_free(ctx);
			return (0);
		}
#ifndef	NO_RSA
		if(!SSL_CTX_use_RSAPrivateKey_file(ctx, certfile,
			SSL_FILETYPE_PEM))
#else
		if(!SSL_CTX_use_PrivateKey_file(ctx, certfile,
			SSL_FILETYPE_PEM))
#endif
		{
			sslerror(certfile);
			SSL_CTX_free(ctx);
			return (0);
		}
	}

	if (peer_cert_dir || peer_cert_file)
	{
		if ((!SSL_CTX_set_default_verify_paths(ctx))
			|| (!SSL_CTX_load_verify_locations(ctx, peer_cert_file,
				peer_cert_dir)))
		{
			sslerror(peer_cert_dir);
			SSL_CTX_free(ctx);
			return (0);
		}

		if (isserver && peer_cert_file)
		{
			SSL_CTX_set_client_CA_list(ctx,
						   SSL_load_client_CA_file
						   (peer_cert_file));
		}

		if (isserver && peer_cert_dir)
		{
			DIR *dirp;
			struct dirent *de;
			X509 *x;

			dirp=opendir(peer_cert_dir);
			while (dirp && (de=readdir(dirp)) != NULL)
			{
				const char *p;
				char *q;
				FILE *fp;

				p=strrchr(de->d_name, '.');
				if (!p[0] || !p[1])
					continue;
				while (*++p)
				{
					if (strchr("0123456789", *p) == NULL)
						break;
				}
				if (*p)
					continue;

				q=malloc(strlen(peer_cert_dir)
					 +strlen(de->d_name) + 4);
				if (!q)
				{
					nonsslerror("malloc");
					exit(1);
				}

				strcat(strcat(strcpy(q, peer_cert_dir),
					      "/"), de->d_name);

				fp=fopen(q, "r");
				if (!fp)
				{
					nonsslerror(q);
					exit(1);
				}
				free(q);

				while ((x=PEM_read_X509(fp, NULL, NULL, NULL)))
				{
					SSL_CTX_add_client_CA(ctx,x);
					X509_free(x);
				}
				fclose(fp);
			}
			if (dirp)
				closedir(dirp);
                }
	}
	SSL_CTX_set_verify(ctx, peer_verify_level, ssl_verify_callback);
	return (ctx);
}

static int ssl_verify_callback(int goodcert, X509_STORE_CTX *x509)
{
	if (peer_verify_domain || peer_verify_level)
	{
		if (!goodcert)
			return (0);
	}
	return (1);
}

static RSA *rsa_callback(SSL *s, int export, int keylength)
{
	return (RSA_generate_key(keylength,RSA_F4,NULL,NULL));
}

static char domain[256];

static void dump_x509(X509 *x509, FILE *printx509_fp)
{
	X509_NAME *subj=X509_get_subject_name(x509);
	int nentries, j;

	if (!subj)
		return;

	if (printx509_fp)
		fprintf(printx509_fp, "Subject:\n");

	nentries=X509_NAME_entry_count(subj);
	for (j=0; j<nentries; j++)
	{
		const char *obj_name;
		X509_NAME_ENTRY *e;
		ASN1_OBJECT *o;
		ASN1_STRING *d;

		int dlen;
		unsigned char *ddata;

		e=X509_NAME_get_entry(subj, j);
		if (!e)
			continue;

		o=X509_NAME_ENTRY_get_object(e);
		d=X509_NAME_ENTRY_get_data(e);

		if (!o || !d)
			continue;

		obj_name=OBJ_nid2sn(OBJ_obj2nid(o));

		dlen=ASN1_STRING_length(d);
		ddata=ASN1_STRING_data(d);

		if (strcasecmp(obj_name, "CN") == 0)
		{
			if (dlen >= sizeof(domain)-1)
				dlen=sizeof(domain)-1;

			memcpy(domain, ddata, dlen);
			domain[dlen]=0;
		}

		if (printx509_fp)
		{
			fprintf(printx509_fp, "   %s=", obj_name);
			fwrite(ddata, dlen, 1, printx509_fp);
			fprintf(printx509_fp, "\n");
		}
	}
	if (printx509_fp)
		fprintf(printx509_fp, "\n");
}

SSL *connect_tls(SSL_CTX *ctx, int isserver, int fd)
{
SSL *ssl;

	if (!(ssl=SSL_new(ctx)))
	{
		sslerror("SSL_new");
		return (0);
	}

	SSL_set_fd(ssl, fd);
	peer_domain_verified=0;

	if (!peer_verify_domain)
		peer_domain_verified=1;

	if (printx509)
	{
		printx509_fp=fdopen(atoi(printx509), "w");
		if (!printx509_fp)
		{
			nonsslerror("fdopen");
			printx509=0;
		}
	}

	if (isserver)
	{
		SSL_set_accept_state(ssl);
		if (SSL_accept(ssl) <= 0)
		{
			sslerror("accept");
			SSL_set_shutdown(ssl,
				SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
			SSL_free(ssl);
			ERR_remove_state(0);
			if (printx509)
				fclose(printx509_fp);
			printx509=0;
			return (0);
		}
	}
	else
	{
		SSL_set_connect_state(ssl);
		if (SSL_connect(ssl) <= 0)
		{
			sslerror("accept");
			SSL_set_shutdown(ssl,
				SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
			SSL_free(ssl);
			ERR_remove_state(0);
			if (printx509)
				fclose(printx509_fp);

			printx509=0;
			return (0);
		}
	}

	{
		STACK_OF(X509) *peer_cert_chain=SSL_get_peer_cert_chain(ssl);
		int i;

		domain[0]=0;
		for (i=0; peer_cert_chain && i<peer_cert_chain->stack.num; i++)
			dump_x509((X509 *)peer_cert_chain->stack.data[i],
				  printx509 ? printx509_fp:NULL);
		if (isserver)
		{
			X509 *x=SSL_get_peer_certificate(ssl);

			if (x)
				dump_x509(x, printx509 ? printx509_fp:NULL);
		}
	}

	if (peer_verify_domain)
	{
		char	*p=domain;

		if (*p == '*')
		{
			int	pl, l;

			pl=strlen(++p);
			l=strlen(peer_verify_domain);

			if (*p == '.' && pl <= l &&
			    strcasecmp(peer_verify_domain+l-pl, p) == 0)
				peer_domain_verified=1;
		}
		else if (strcasecmp(peer_verify_domain, p) == 0)
			peer_domain_verified=1;
	}

	if (printx509)
	{
		SSL_CIPHER *cipher;

		cipher=SSL_get_current_cipher(ssl);

		if (cipher)
		{
			const char *c;

			c=SSL_CIPHER_get_name(cipher);

			if (c)
				fprintf(printx509_fp, "Cipher: %s\n", c);

			c=SSL_CIPHER_get_version(cipher);
			if (c)
				fprintf(printx509_fp, "Version: %s\n", c);


			fprintf(printx509_fp, "Bits: %d\n",
				SSL_CIPHER_get_bits(cipher, NULL));
		}
		fclose(printx509_fp);
	}
	printx509=0;
	return (ssl);
}

void disconnect_tls(SSL_CTX *ctx, SSL *ssl)
{
	if (ssl)
	{
		SSL_set_shutdown(ssl,
			SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
		SSL_free(ssl);
		ERR_remove_state(0);
	}
	SSL_CTX_free(ctx);
}

void transfer_tls(SSL *ssl, int sslfd, int stdinfd, int stdoutfd)
{
char	from_ssl_buf[BUFSIZ], to_ssl_buf[BUFSIZ];
char	*from_ssl_ptr=0, *to_ssl_ptr=0;
int	from_ssl_cnt, to_ssl_cnt;
fd_set	fdr, fdw;
int	maxfd=sslfd;
int	suppress_read;
int	suppress_write;

	if (fcntl(sslfd, F_SETFL, O_NONBLOCK)
	    || fcntl(stdinfd, F_SETFL, O_NONBLOCK)
	    || fcntl(stdoutfd, F_SETFL, O_NONBLOCK)

	    )
	{
		nonsslerror("fcntl");
		return;
	}

	if (maxfd < stdinfd)	maxfd=stdinfd;
	if (maxfd < stdoutfd)	maxfd=stdoutfd;

	from_ssl_cnt=0;
	to_ssl_cnt=0;

	suppress_read=0;
	suppress_write=0;

	for (;;)
	{
		FD_ZERO(&fdr);
		FD_ZERO(&fdw);
		if (from_ssl_cnt)
			FD_SET(stdoutfd, &fdw);
		else if (!suppress_read)
		{
		int	n=SSL_pending(ssl);

			if (n > 0)
			{
				if (n >= sizeof(from_ssl_buf))
					n=sizeof(from_ssl_buf);

				n=SSL_read(ssl, from_ssl_buf, n);
				switch (SSL_get_error(ssl, n))	{
				case SSL_ERROR_NONE:
					if (n <= 0)
						return;
					break;
				case SSL_ERROR_WANT_WRITE:
					suppress_read=1;
					suppress_write=0;
					continue;
				case SSL_ERROR_WANT_READ:
					suppress_read=0;
					suppress_write=1;
					continue;
				case SSL_ERROR_WANT_X509_LOOKUP:
					continue;
				default:
					return;
				}
				from_ssl_cnt=n;
				from_ssl_ptr=from_ssl_buf;
				suppress_read=0;
				suppress_write=0;
				continue;
			}
			else
				FD_SET(sslfd, &fdr);
		}
		if (to_ssl_cnt)
		{
			if (!suppress_write)
				FD_SET(sslfd, &fdw);
		}
		else
			FD_SET(stdinfd, &fdr);

		if (select(maxfd+1, &fdr, &fdw, 0, 0) <= 0)
		{
			nonsslerror("select");
			return;
		}

		if (from_ssl_cnt && FD_ISSET(stdoutfd, &fdw))
		{
		int n=write(stdoutfd, from_ssl_ptr, from_ssl_cnt);

			if (n <= 0)	return;
			from_ssl_ptr += n;
			from_ssl_cnt -= n;
		}
		else if (!from_ssl_cnt && FD_ISSET(sslfd, &fdr))
		{
		int n=SSL_read(ssl, from_ssl_buf, sizeof(from_ssl_buf));

			switch (SSL_get_error(ssl, n))	{
			case SSL_ERROR_NONE:
				if (n <= 0)	return;
				break;
			case SSL_ERROR_WANT_WRITE:
				suppress_read=1;
				suppress_write=0;
				continue;
			case SSL_ERROR_WANT_READ:
				suppress_read=0;
				suppress_write=1;
				continue;
			case SSL_ERROR_WANT_X509_LOOKUP:
				continue;
			default:
				return;
			}

			from_ssl_ptr=from_ssl_buf;
			from_ssl_cnt=n;
			suppress_read=0;
			suppress_write=0;
		}

		if (to_ssl_cnt && FD_ISSET(sslfd, &fdw))
		{
		int n=SSL_write(ssl, to_ssl_ptr, to_ssl_cnt);

			switch (SSL_get_error(ssl, n))	{
			case SSL_ERROR_NONE:
				if (n <= 0)	return;
				break;
			case SSL_ERROR_WANT_WRITE:
				suppress_write=0;
				suppress_read=1;
				continue;
			case SSL_ERROR_WANT_READ:
				suppress_read=0;
				suppress_write=1;
				continue;
			case SSL_ERROR_WANT_X509_LOOKUP:
				continue;
			default:
				return;
			}

			to_ssl_ptr += n;
			to_ssl_cnt -= n;
			suppress_read=0;
			suppress_write=0;
		}
		else if (!to_ssl_cnt && FD_ISSET(stdinfd, &fdr))
		{
		int n=read(stdinfd, to_ssl_buf, sizeof(to_ssl_buf));

			if (n <= 0)	return;
			to_ssl_ptr=to_ssl_buf;
			to_ssl_cnt=n;
		}
	}
}

/* ----------------------------------------------------------------------- */

static int prepsocket(int sockfd)
{
	if (fcntl(sockfd, F_SETFL, 0))	/* Turn off O_NONBLOCK, for now */
	{
		nonsslerror("fcntl");
		return (1);
	}

#ifdef  SO_KEEPALIVE
	{
	int	dummy;

		dummy=1;

		if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
			(const char *)&dummy, sizeof(dummy)) < 0)
                {
                        nonsslerror("setsockopt");
			return (1);
                }
	}
#endif

#ifdef  SO_LINGER
	{
	struct linger l;

		l.l_onoff=0;
		l.l_linger=0;

		if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER,
			(const char *)&l, sizeof(l)) < 0)
		{
			nonsslerror("setsockopt");
			return (1);
		}
	}
#endif
	return (0);
}

static void startclient(int argn, int argc, char **argv, int fd,
	int *stdin_fd, int *stdout_fd)
{
pid_t	p;
int	streampipe[2];

	if (localfd)
	{
		*stdin_fd= *stdout_fd= atoi(localfd);
		return;
	}

	if (argn >= argc)	return;		/* Interactive */

	if (s_pipe(streampipe))
	{
		nonsslerror("s_pipe");
		exit(1);
	}
	if ((p=fork()) == -1)
	{
		nonsslerror("fork");
		close(streampipe[0]);
		close(streampipe[1]);
		exit(1);
	}
	if (p == 0)
	{
	char **argvec;
	int n;

		close(fd);	/* Child process doesn't need it */
		close(0);
		dup(streampipe[1]);
		close(1);
		dup(streampipe[1]);
		close(streampipe[0]);
		close(streampipe[1]);

		argvec=malloc(sizeof(char *)*(argc-argn+1));
		if (!argvec)
		{
			nonsslerror("malloc");
			exit(1);
		}
		for (n=0; n<argc-argn; n++)
			argvec[n]=argv[argn+n];
		argvec[n]=0;
		execvp(argvec[0], argvec);
		nonsslerror(argvec[0]);
		exit(1);
	}
	close(streampipe[1]);

	*stdin_fd= *stdout_fd= streampipe[0];
}

static int connectremote(const char *host, const char *port)
{
int	fd;

RFC1035_ADDR addr;
int	af;
RFC1035_ADDR *addrs;
unsigned	naddrs, n;

RFC1035_NETADDR addrbuf;
const struct sockaddr *saddr;
int     saddrlen;
int	port_num;

	port_num=atoi(port);
	if (port_num <= 0)
	{
	struct servent *servent;

		servent=getservbyname(port, "tcp");

		if (!servent)
		{
			fprintf(errfp, "%s: invalid port.\n", port);
			return (-1);
		}
		port_num=servent->s_port;
	}
	else
		port_num=htons(port_num);

	if (rfc1035_aton(host, &addr) == 0) /* An explicit IP addr */
	{
		if ((addrs=malloc(sizeof(addr))) == 0)
		{
			nonsslerror("malloc");
			return (-1);
		}
		memcpy(addrs, &addr, sizeof(addr));
		naddrs=1;
	}
	else
	{
		if (rfc1035_a(&rfc1035_default_resolver, host, &addrs, &naddrs))
		{
			fprintf(errfp, "%s: not found.\n", host);
			return (-1);
		}
	}

        if ((fd=rfc1035_mksocket(SOCK_STREAM, 0, &af)) < 0)
        {
                nonsslerror("socket");
                return (-1);
        }

	for (n=0; n<naddrs; n++)
	{
		if (rfc1035_mkaddress(af, &addrbuf, addrs+n, port_num,
			&saddr, &saddrlen))	continue;

		if (sox_connect(fd, saddr, saddrlen) == 0)
			break;
	}
	free(addrs);

	if (n >= naddrs)
	{
		close(fd);
		nonsslerror("connect");
		return (-1);
	}

	return (fd);
}

/* Connect to a remote server */

char *get_ip_certfile(const char *file, const char *ip) {
	char *test_file = malloc(strlen(file)+strlen(ip)+2);

	strcpy(test_file, file);
	strcat(test_file, ".");
	strcat(test_file, ip);

	if (access(test_file, R_OK) == 0) {
		return test_file;
	} else {
		free(test_file);
		/* Check for ipv4 version if ip is ipv6 */
		if (strncmp(ip, "::ffff:", 7) == 0 && strchr(ip, '.')) {
			return get_ip_certfile(file, ip+7);
		}
	}

	return NULL;
}

static int dossl(int fd, int argn, int argc, char **argv, FILE *statusfp)
{
SSL_CTX *ctx;
SSL	*ssl;

int	stdin_fd, stdout_fd;
const char *ip = NULL;

	if ((ip = getenv("TCPLOCALIP")) != NULL) {
		if (certfile) {
			const char *newcertfile =
				get_ip_certfile(certfile, ip);
			if (newcertfile)
				certfile = newcertfile;
		}
		if (dhcertfile) {
			const char *newdhcertfile =
				get_ip_certfile(dhcertfile, ip);
			if (newdhcertfile)
				dhcertfile = newdhcertfile;
		}
	}
	
	if (prepsocket(fd))
		return (1);

	ctx=create_tls(server ? 1:0);
	if (ctx == 0)	return (1);

	ssl=connect_tls(ctx, server ? 1:0, fd);
	if (!ssl)
	{
		close(fd);
		return (1);
	}

	stdin_fd=0;
	stdout_fd=1;

	startclient(argn, argc, argv, fd, &stdin_fd, &stdout_fd);

	if (!peer_domain_verified)
	{
		SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
		SSL_free(ssl);
		ERR_remove_state(0);
		if (verify_fail_msg)
			fprintf(statusfp, "%s", verify_fail_msg);
		else
			fprintf(errfp,
				"starttls: unable to verify peer domain.\n");
		return (0);
	}

	if (statusfp)
	{
		fclose(statusfp);
		errfp=stderr;
	}

	transfer_tls(ssl, fd, stdin_fd, stdout_fd);
	disconnect_tls(ctx, ssl);
	return (0);
}

static const char *safe_getenv(const char *n)
{
const char *v=getenv(n);

	if (!v)	v="";
	return (v);
}

struct protoreadbuf {
	char buffer[512];
	char *bufptr;
	int bufleft;

	char line[256];
} ;

#define PRB_INIT(p) ( (p)->bufptr=0, (p)->bufleft=0)

static char protoread(int fd, struct protoreadbuf *prb)
{
	fd_set fds;
	struct timeval tv;

	FD_ZERO(&fds);
	FD_SET(fd, &fds);

	tv.tv_sec=60;
	tv.tv_usec=0;

	if (select(fd+1, &fds, NULL, NULL, &tv) <= 0)
	{
		nonsslerror("select");
		exit(1);
	}

	if ( (prb->bufleft=read(fd, prb->buffer, sizeof(prb->buffer))) <= 0)
	{
		errno=ECONNRESET;
		nonsslerror("read");
		exit(1);
	}

	prb->bufptr= prb->buffer;

	--prb->bufleft;
	return (*prb->bufptr++);
}

#define PRB_GETCH(fd,prb) ( (prb)->bufleft-- > 0 ? *(prb)->bufptr++:\
				protoread( (fd), (prb)))

static const char *prb_getline(int fd, struct protoreadbuf *prb)
{
	int i=0;
	char c;

	while ((c=PRB_GETCH(fd, prb)) != '\n')
	{
		if ( i < sizeof (prb->line)-1)
			prb->line[i++]=c;
	}
	prb->line[i]=0;
	return (prb->line);
}

static void prb_write(int fd, struct protoreadbuf *prb, const char *p)
{
	printf("%s", p);
	while (*p)
	{
		int l=write(fd, p, strlen(p));

		if (l <= 0)
		{
			nonsslerror("write");
			exit(1);
		}
		p += l;
	}
}

static int goodimap(const char *p)
{
	if (*p == 'x' && p[1] && isspace((int)(unsigned char)p[1]))
		++p;
	else
	{
		if (*p != '*')
			return (0);
		++p;
	}
	while (*p && isspace((int)(unsigned char)*p))
		++p;
	if (strncasecmp(p, "BAD", 3) == 0)
	{
		exit(1);
	}

	if (strncasecmp(p, "BYE", 3) == 0)
	{
		exit(1);
	}

	if (strncasecmp(p, "NO", 2) == 0)
	{
		exit(1);
	}

	return (strncasecmp(p, "OK", 2) == 0);
}

static void imap_proto(int fd)
{
	struct protoreadbuf prb;
	const char *p;

	PRB_INIT(&prb);

	do
	{
		p=prb_getline(fd, &prb);
		printf("%s\n", p);

	} while (!goodimap(p));

	prb_write(fd, &prb, "x STARTTLS\r\n");

	do
	{
		p=prb_getline(fd, &prb);
		printf("%s\n", p);
	} while (!goodimap(p));
}

static void smtp_proto(int fd)
{
	struct protoreadbuf prb;
	const char *p;

	PRB_INIT(&prb);

	do
	{
		p=prb_getline(fd, &prb);
		printf("%s\n", p);
	} while ( ! ( isdigit((int)(unsigned char)p[0]) && 
		      isdigit((int)(unsigned char)p[1]) &&
		      isdigit((int)(unsigned char)p[2]) &&
		      (p[3] == 0 || isspace((int)(unsigned char)p[3]))));
	if (strchr("123", *p) == 0)
		exit(1);

	prb_write(fd, &prb, "STARTTLS\r\n");

	do
	{
		p=prb_getline(fd, &prb);
		printf("%s\n", p);
	} while ( ! ( isdigit((int)(unsigned char)p[0]) && 
		      isdigit((int)(unsigned char)p[1]) &&
		      isdigit((int)(unsigned char)p[2]) &&
		      (p[3] == 0 || isspace((int)(unsigned char)p[3]))));
	if (strchr("123", *p) == 0)
		exit(1);

}

int main(int argc, char **argv)
{
const char *s;
int	argn;
int	fd;
struct stat stat_buf;
static struct args arginfo[] = {
	{ "host", &clienthost },
	{ "localfd", &localfd},
	{ "port", &clientport },
	{ "printx509", &printx509},
	{ "remotefd", &remotefd},
	{ "server", &server},
	{ "tcpd", &tcpd},
	{ "verify", &peer_verify_domain},
	{ "verifyfailmsg", &verify_fail_msg},
	{ "statusfd", &statusfd},
	{ "protocol", &fdprotocol},
	{0}};
FILE	*statusfp=0;
 void (*protocol_func)(int)=0;

	errfp=stderr;
 
	s=safe_getenv("TLS_PROTOCOL");
	if (*s) protocol=s;

	s=safe_getenv("TLS_CIPHER_LIST");
	if (*s)	ssl_cipher_list=s;

	s=safe_getenv("TLS_TIMEOUT");
	session_timeout=atoi(s);

	s=safe_getenv("TLS_DHCERTFILE");
	if (*s)	dhcertfile=s;

	s=safe_getenv("TLS_CERTFILE");
	if (*s)	certfile=s;

	s=safe_getenv("TLS_VERIFYPEER");
	switch (*s)	{
	case 'n':
	case 'N':		/* NONE */
		peer_verify_level=SSL_VERIFY_NONE;
		break;
	case 'p':
	case 'P':		/* PEER */
		peer_verify_level=SSL_VERIFY_PEER;
		break;
	case 'r':
	case 'R':		/* REQUIREPEER */
		peer_verify_level=
			SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
		break;
	}

	s=safe_getenv("TLS_TRUSTCERTS");
	if (s && stat(s, &stat_buf) == 0)
	{
		if (S_ISDIR(stat_buf.st_mode))
			peer_cert_dir=s;
		else
			peer_cert_file=s;
	}

	argn=argparse(argc, argv, arginfo);

	if (statusfd)
		statusfp=fdopen(atoi(statusfd), "w");

	if (statusfp)
		errfp=statusfp;

	if (fdprotocol)
	{
		if (strcmp(fdprotocol, "smtp") == 0)
			protocol_func= &smtp_proto;
		else if (strcmp(fdprotocol, "imap") == 0)
			protocol_func= &imap_proto;
		else
		{
			fprintf(stderr, "--protocol=%s - unknown protocol.\n",
				fdprotocol);
			exit(1);
		}
	}

	if (tcpd)
	{
		close(1);
		dup(2);
		fd=0;
	}
	else if (remotefd)
		fd=atoi(remotefd);
	else if (clienthost && clientport)
		fd=connectremote(clienthost, clientport);
	else
	{
		fprintf(errfp, "%s: specify remote location.\n", argv[0]);
		return (1);
	}

	if (fd < 0)	return (1);
	if (protocol_func)
		(*protocol_func)(fd);
	return (dossl(fd, argn, argc, argv, statusfp));
}
