/* $Id: sock.c,v 1.58 2006/04/08 22:53:28 holger Exp $ */

/*
 * Copyright (c) 2004, 2005 Holger Weiss
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#if HAVE_CONFIG_H
#include <config.h>
#endif				/* HAVE_CONFIG_H */

#if HAVE_SYS_TIME_H
#include <sys/time.h>
#endif				/* HAVE_SYS_TIME_H */
#if HAVE_SYS_TYPES_H
#include <sys/types.h>
#endif				/* HAVE_SYS_TYPES_H */
#if HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif				/* HAVE_SYS_SOCKET_H */
#if HAVE_SYS_RESOURCE_H
#include <sys/resource.h>
#endif				/* HAVE_SYS_RESOURCE_H */
#if HAVE_NETINET_IN_H
#include <netinet/in.h>
#endif				/* HAVE_NETINET_IN_H */
#include <errno.h>
#include <fcntl.h>
#if HAVE_INTTYPES_H
#include <inttypes.h>
#else
#if HAVE_STDINT_H
#include <stdint.h>
#endif				/* HAVE_STDINT_H */
#endif				/* HAVE_INTTYPES_H */
#if HAVE_NETDB_H
#include <netdb.h>
#endif				/* HAVE_NETDB_H */
#include <stdio.h>
#include <string.h>
#if HAVE_UNISTD_H
#include <unistd.h>
#endif				/* HAVE_UNISTD_H */

#if WITH_SSL
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#endif				/* WITH_SSL */

#include "report.h"
#include "sock.h"
#include "system.h"

#define SOCK_TIMEOUT	60		/* socket timeout in seconds */
#define SHUT_TIMEOUT	 1		/* SSL shutdown timeout in seconds */

extern int      errno;
static int      sd = -1;		/* static socket descriptor */

#if WITH_SSL
static bool     ssl_conn;		/* are we using SSL? */
static SSL_CTX *ctx;			/* static SSL context */
static SSL     *ssl;			/* static SSL connetion */
#endif				/* WITH_SSL */

static int      unblock_descriptor(void);

/*
 * Initiate SSL/TLS connection.  We return 0 on success and -1 otherwise.
 */
#if WITH_SSL
int
sock_ssl_open(const char *server, short proto)
{
	int             i;
	BIO            *bio;
	X509           *peer_cert;
	char            peer_info[256];

	SSL_library_init();
	SSL_load_error_strings();

	if (proto == TLSv1)
		ctx = SSL_CTX_new(TLSv1_client_method());
	else
		ctx = SSL_CTX_new(SSLv23_client_method());

	if (ctx == NULL) {
		report(LOG_SERROR, "SSL_CTX_new(3) failed");
		ssl_conn = false;
		return -1;
	}
	if ((ssl = SSL_new(ctx)) == NULL) {
		report(LOG_SERROR, "SSL_new(3) failed");
		ssl_conn = false;
		return -1;
	}
	if ((bio = BIO_new_socket(sd, BIO_NOCLOSE)) == NULL) {
		report(LOG_ERROR, "SSL error: BIO_new_socket(3) failed");
		ssl_conn = false;
		return -1;
	}
	SSL_set_bio(ssl, bio, bio);

	if (RAND_status())
		report(LOG_DEBUG, "your system has a useable random device");
	else {
#if HAVE_GETRUSAGE
		struct rusage   ru;

#endif				/* HAVE_GETRUSAGE */
#if HAVE_GETPID
		pid_t           pid;
		pid_t           ppid;

#endif				/* HAVE_GETPID */
#if HAVE_GETUID
		uid_t           uid;
		gid_t           gid;

#endif				/* HAVE_GETUID */

		report(LOG_DEBUG, "your system lacks a useable random device");
		report(LOG_DEBUG, "trying to collect some random seed data");
#if HAVE_GETRUSAGE
		if (getrusage(RUSAGE_SELF, &ru) == 0) {
			report(LOG_DEBUG, "adding resource usage data to random seed");
			RAND_add(&ru, sizeof(ru), 0.1);
		}
#endif				/* HAVE_GETRUSAGE */
#if HAVE_GETPID
		pid = getpid();
		report(LOG_DEBUG, "adding PID %d to random seed", (int) pid);
		RAND_add(&pid, sizeof(pid), 0.1);
		ppid = getppid();
		report(LOG_DEBUG, "adding parent PID %d to random seed", (int) ppid);
		RAND_add(&ppid, sizeof(ppid), 0.1);
#endif				/* HAVE_GETPID */
#if HAVE_GETUID
		uid = getuid();
		report(LOG_DEBUG, "adding UID %d to random seed", (int) uid);
		RAND_add(&uid, sizeof(uid), 0.1);
		gid = getgid();
		report(LOG_DEBUG, "adding GID %d to random seed", (int) gid);
		RAND_add(&gid, sizeof(gid), 0.1);
#endif				/* HAVE_GETUID */
#if HAVE_GETTIMEOFDAY
		i = 0;
		report(LOG_DEBUG, "adding system clock time to random seed");
		do {
			struct timeval  tv;

			gettimeofday(&tv, NULL);
			RAND_add(&tv, sizeof(tv), 0.1);
		} while ((++i < 9999) && !RAND_status());
#endif				/* HAVE_GETTIMEOFDAY */
		if (RAND_status())
			report(LOG_DEBUG, "got some random seed data");
		else {
			report(LOG_ERROR, "can't get enough random seed data on your system");
			report(LOG_ERROR, "consider using an entropy gathering daemon like EGD or PRNGD:");
			report(LOG_ERROR, "http://egd.sourceforge.net/");
			report(LOG_ERROR, "http://www.aet.tu-cottbus.de/personen/jaenicke/postfix_tls/prngd.html");
			ssl_conn = false;
			return -1;
		}
	}

	do {
		if ((i = SSL_connect(ssl)) <= 0) {
			struct timeval  tv = {SOCK_TIMEOUT, 0};
			fd_set          wd;
			fd_set          rd;

			switch (SSL_get_error(ssl, i)) {
			    case SSL_ERROR_WANT_WRITE:
				report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
				FD_ZERO(&wd);
				FD_SET(sd, &wd);
				if (select(sd + 1, NULL, &wd, 0, &tv) <= 0) {
					report(LOG_ERROR, "timeout connecting to server");
					ssl_conn = false;
					return -1;
				}
				break;
			    case SSL_ERROR_WANT_READ:
				report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
				FD_ZERO(&rd);
				FD_SET(sd, &rd);
				if (select(sd + 1, &rd, NULL, 0, &tv) <= 0) {
					report(LOG_ERROR, "timeout connecting to server");
					ssl_conn = false;
					return -1;
				}
				break;
			    default:
				report(LOG_SERROR, "SSL_connect(3) failed");
				ssl_conn = false;
				return -1;
			}
		}
	} while (i <= 0);

	/* read peer certificate fields (just out of curiosity) */
	if ((peer_cert = SSL_get_peer_certificate(ssl)) == NULL)
		report(LOG_DEFAULT, "peer does not present a certificate");
	else {
		if (X509_NAME_get_text_by_NID(X509_get_issuer_name(peer_cert),
					      NID_organizationName,
					      peer_info,
					      sizeof(peer_info)) == -1)
			report(LOG_DEBUG, "issuer organization name not specified in peer certificate");
		else
			report(LOG_INFO, "SSL certificate issuer organization: %s", peer_info);

		if (X509_NAME_get_text_by_NID(X509_get_issuer_name(peer_cert),
					      NID_commonName,
					      peer_info,
					      sizeof(peer_info)) == -1)
			report(LOG_DEBUG, "issuer common name not specified in peer certificate");
		else
			report(LOG_INFO, "SSL certificate issuer common name: %s", peer_info);

		if (X509_NAME_get_text_by_NID(X509_get_subject_name(peer_cert),
					      NID_commonName,
					      peer_info,
					      sizeof(peer_info)) == -1)
			report(LOG_DEFAULT, "hostname not specified in peer certificate");
		else { 
			report(LOG_INFO, "SSL certificate server common name: %s", peer_info);
			if (strncasecmp(peer_info, server, sizeof(peer_info)))
				report(LOG_DEFAULT,
				       "peer certificate common name mismatch: %s != %s",
				       peer_info,
				       server);
			else
				report(LOG_DEBUG,
				       "peer certificate common name matches %s",
				       peer_info);
		}
		/* currently, we do no further cert checks */
		X509_free(peer_cert);
	}

	/* cipher information */
	report(LOG_INFO,
	       "SSL cipher: %s (%d bit)",
	       SSL_get_cipher_name(ssl),
	       SSL_get_cipher_bits(ssl, NULL));
	report(LOG_INFO,
	       "SSL protocol version: %s",
	       SSL_get_cipher_version(ssl));

	ssl_conn = true;
	return 0;
}
#endif				/* WITH_SSL */

/*
 * Connect to server.  We return 0 on success and -1 otherwise.
 */
int
sock_open(const char *server, int port)
{
	struct hostent *h;
	struct sockaddr_in srv_addr;

	if ((h = gethostbyname(server)) == NULL) {
		report(LOG_ERROR, "unknown host %s", server);
		return -1;
	}
	srv_addr.sin_family = h->h_addrtype;
	memcpy((char *) &srv_addr.sin_addr.s_addr, h->h_addr_list[0], h->h_length);
	srv_addr.sin_port = htons((uint16_t) port);

	/* non-blocking socket for use with select(2) */
	if ((sd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		report(LOG_PERROR, "can't create socket");
		return -1;
	}
	if (unblock_descriptor() == -1) {
		report(LOG_PERROR, "can't set non-blocking flag on socket");
		return -1;
	}
	if (connect(sd, (struct sockaddr *) &srv_addr, sizeof(srv_addr)) == -1) {
		if (errno == EINPROGRESS) {
			struct timeval  tv = {SOCK_TIMEOUT, 0};
			fd_set          rd,
			                wd;

			/* call select(2) if needed */
			report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
			FD_ZERO(&rd);
			FD_ZERO(&wd);
			FD_SET(sd, &rd);
			FD_SET(sd, &wd);
			if (select(sd + 1, &rd, &wd, 0, &tv) <= 0) {
				report(LOG_ERROR, "timeout connecting to server");
				return -1;
			}
			if ((connect(sd, (struct sockaddr *) &srv_addr, sizeof(srv_addr)) == -1)
			    && (errno != EISCONN)) {
				report(LOG_PERROR, "connection failed");
				return -1;
			}
		} else {
			report(LOG_PERROR, "connection failed");
			return -1;
		}
	}

#if WITH_SSL
	ssl_conn = false;
#endif				/* WITH_SSL */
	return 0;
}

/*
 * send(3) wrapper which tries to push all data out until either an error
 * occurs or select(2) runs into our timeout.  If select(2) timed out or an
 * error occured, we return -1.  Otherwise, the number of sent bytes is
 * returned.
 */
int
sock_write(const char *buf, size_t len)
{
	int             n,
	                sent = 0;

#if WITH_SSL
	if (ssl_conn) {
		do {
			struct timeval  tv = {SOCK_TIMEOUT, 0};
			fd_set          wd;
			fd_set          rd;

			do {
				n = SSL_write(ssl, buf + sent, len - sent);
				switch (SSL_get_error(ssl, n)) {
				    case SSL_ERROR_NONE:
					sent += n;
					report(LOG_DEBUG, "sent %d of %lu byte", sent, (unsigned long) len);
					break;
				    case SSL_ERROR_WANT_WRITE:
					report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
					FD_ZERO(&wd);
					FD_SET(sd, &wd);
					if (select(sd + 1, NULL, &wd, 0, &tv) <= 0) {
						report(LOG_ERROR, "timeout sending data");
						return -1;
					}
					break;
				    case SSL_ERROR_WANT_READ:
					report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
					FD_ZERO(&rd);
					FD_SET(sd, &rd);
					if (select(sd + 1, &rd, NULL, 0, &tv) <= 0) {
						report(LOG_ERROR, "re-handshake timeout while sending data");
						return -1;
					}
					break;
				    case SSL_ERROR_ZERO_RETURN:
					/* peer sent close_notify */
					report(LOG_ERROR, "connection reset by peer (using SSL close_notify)");
					return -1;
				    default:
					if (n == 0)
						/* peer sent TCP FIN */
						report(LOG_ERROR, "connection reset by peer (using TCP FIN)");
					else
						report(LOG_SERROR, "SSL_write(3) failed");
					return -1;
				}
			} while (n <= 0);
		} while (sent < len);
	} else
#endif				/* WITH_SSL */
		do {
			if ((n = send(sd, buf + sent, len - sent, 0)) <= 0) {
				if (errno == EWOULDBLOCK) {
					struct timeval  tv = {SOCK_TIMEOUT, 0};
					fd_set          wd;

					report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
					FD_ZERO(&wd);
					FD_SET(sd, &wd);
					if (select(sd + 1, NULL, &wd, 0, &tv) <= 0) {
						report(LOG_ERROR, "timeout sending data");
						return -1;
					}
					if ((n = send(sd, buf + sent, len - sent, 0)) <= 0) {
						report(LOG_PERROR, "send(3) failed");
						return -1;
					}
				} else {
					report(LOG_PERROR, "send(3) failed");
					return -1;
				}
			}
			sent += n;
			report(LOG_DEBUG, "sent %d of %lu byte", sent, (unsigned long) len);
		} while (sent < len);

	return sent;
}

/*
 * recv(3) wrapper which uses select(2) if no data is on the socket.  If
 * select(2) timed out, we return -1, otherwise, the return value of recv(3)
 * or SSL_read(3) is returned.
 */
int
sock_read(char *buf, size_t len)
{
	int             n;

#if WITH_SSL
	if (ssl_conn) {
		struct timeval  tv = {SOCK_TIMEOUT, 0};
		fd_set          rd;
		fd_set          wd;

		do {
			n = SSL_read(ssl, buf, len);
			switch (SSL_get_error(ssl, n)) {
			    case SSL_ERROR_NONE:
				report(LOG_DEBUG, "received %d byte", n);
				break;
			    case SSL_ERROR_WANT_READ:
				report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
				FD_ZERO(&rd);
				FD_SET(sd, &rd);
				if (select(sd + 1, &rd, NULL, 0, &tv) <= 0) {
					report(LOG_ERROR, "timeout receiving data");
					return -1;
				}
				break;
			    case SSL_ERROR_WANT_WRITE:
				report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
				FD_ZERO(&wd);
				FD_SET(sd, &wd);
				if (select(sd + 1, NULL, &wd, 0, &tv) <= 0) {
					report(LOG_ERROR, "re-handshake timeout while receiving data");
					return -1;
				}
				break;
			    case SSL_ERROR_ZERO_RETURN:
				/*
				 * Peer sent close_notify.  The caller must
				 * decide whether or not this is a problem, so
				 * we won't print an error.
				 */
				report(LOG_DEBUG, "peer sent close_notify");
				return n;
			    default:
				if (n == 0)
					report(LOG_DEBUG, "peer sent TCP FIN packet");
				else
					report(LOG_SERROR, "SSL_read(3) failed");
				return n;
			}
		} while (n <= 0);
	} else {
#endif				/* WITH_SSL */
		if (((n = recv(sd, buf, len, 0)) == -1) && (errno == EWOULDBLOCK)) {
			struct timeval  tv = {SOCK_TIMEOUT, 0};
			fd_set          rd;

			report(LOG_DEBUG, "select(2)ing, timeout is %d seconds", SOCK_TIMEOUT);
			FD_ZERO(&rd);
			FD_SET(sd, &rd);
			if (select(sd + 1, &rd, NULL, 0, &tv) <= 0) {
				report(LOG_ERROR, "timeout receiving data");
				return -1;
			}
			n = recv(sd, buf, len, 0);
		}
		switch (n) {
		    case -1:
			report(LOG_PERROR, "recv(3) failed");
			break;
		    case 0:
			report(LOG_DEBUG, "peer sent TCP FIN packet");
			break;
		    default:
			report(LOG_DEBUG, "received %d byte", n);
		}
#if WITH_SSL
	}
#endif				/* WITH_SSL */
	return n;
}

/*
 * Look whether there is data on the socket or in the SSL record buffer.
 * Return 1 if there is at least one byte to read, 0 if not, -1 on error.
 * Similar to select(2), we return 1 if we received a TCP FIN packet (EOF).
 */
int
sock_peek(void)
{
	int             i;
#if WITH_SSL
	int             j;
#endif				/* WITH_SSL */
	char            byte;	/* our one-byte buffer :-) */

	if ((i = recv(sd, &byte, 1, MSG_PEEK)) == -1) {
		if (errno == EWOULDBLOCK)
			i = 0;
		else
			report(LOG_PERROR, "recv(3) failed");
	} else if (i == 0)
		i = 1;

#if WITH_SSL
	/*
	 * For SSL connections, we return 1 if there is data available on the
	 * socket (see above) and/or in the SSL record buffer (in which case
	 * SSL_pending(3) returns the number of bytes available).
	 */
	if (ssl_conn) {
		/*
		 * The BUGS section in SSL_pending(3) says: "Up to OpenSSL
		 * 0.9.6, SSL_pending() does not check if the record type of
		 * pending data is application data."  However, we don't
		 * support OpenSSL < 0.9.7, users should upgrade anyway.
		 */
		if ((j = SSL_pending(ssl)))
			i = 1;
		report(LOG_DEBUG, "SSL_pending(3) returned %d", j);
	}
#endif				/* WITH_SSL */

	return i;
}

void
sock_close(void)
{
#if WITH_SSL
	bool            done = false;

	if (ssl_conn) {
		do {
			struct timeval  tv = {SHUT_TIMEOUT, 0};
			fd_set          wd;
			fd_set          rd;
			int             i;

			/*
			 * RFC 2246, 7.2.1 states that we're required to send
			 * a close_notify (via SSL_shutdown(3)), but since we
			 * won't re-use the SSL connection, we're not required
			 * to check the responding close_notify (by using a
			 * second call to SSL_shutdown(3)).  If we run into
			 * trouble, we'll forget about the shutdown silently.
			 */
			if ((i = SSL_shutdown(ssl)) == -1)
				switch (SSL_get_error(ssl, i)) {
				    case SSL_ERROR_WANT_WRITE:
					report(LOG_DEBUG, "select(2)ing, timeout is %d second(s)", SHUT_TIMEOUT);
					FD_ZERO(&wd);
					FD_SET(sd, &wd);
					if (select(sd + 1, NULL, &wd, 0, &tv) <= 0) {
						report(LOG_DEBUG, "timeout shutting down SSL connection");
						done = true;
					}
					break;
				    case SSL_ERROR_WANT_READ:
					report(LOG_DEBUG, "select(2)ing, timeout is %d second(s)", SHUT_TIMEOUT);
					FD_ZERO(&rd);
					FD_SET(sd, &rd);
					if (select(sd + 1, &rd, NULL, 0, &tv) <= 0) {
						report(LOG_DEBUG, "timeout shutting down SSL connection");
						done = true;
					}
					break;
				    default:
					printf("%d\n", i);
					report(LOG_DEBUG, "SSL_shutdown(3) failed");
					done = true;
				}
			else {
				report(LOG_DEBUG, "sent close_notify to shutdown SSL connection");
				done = true;
			}
		} while (!done);

		SSL_free(ssl);
		SSL_CTX_free(ctx);
	}
#endif				/* WITH_SSL */

	report(LOG_DEBUG, "closing socket");
	close(sd);
}

static int
unblock_descriptor(void)
{
	int             flags;

	if ((flags = fcntl(sd, F_GETFL, 0)) == -1)
		return -1;
	if (fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
		return -1;

	return 0;
}
