/***************************************************************************
 *
 * This file is covered by a dual licence. You can choose whether you
 * want to use it according to the terms of the GNU GPL version 2, or
 * under the terms of Zorp Professional Firewall System EULA located
 * on the Zorp installation CD.
 *
 * $Id: streamssl.c,v 1.45 2003/09/10 11:46:58 bazsi Exp $
 *
 * Author  : SaSa
 * Auditor :
 * Last audited version:
 * Notes:
 *
 ***************************************************************************/

#include <zorp/stream.h>
#include <zorp/log.h>
#include <zorp/ssl.h>
#include <zorp/zorplib.h>
#include <zorp/error.h>

#include <openssl/err.h>

#include <string.h>
#include <sys/types.h>
#include <assert.h>

#ifdef G_OS_WIN32
#  include <winsock2.h>
#else
#  include <sys/socket.h>
#  include <sys/poll.h>
#endif

#define ERR_buflen 4096

#define DO_AS_USUAL          0
#define CALL_READ_WHEN_WRITE 1
#define CALL_WRITE_WHEN_READ 2

typedef struct _ZStreamSsl
{
  ZStream super;

  ZStreamSetCb cbv_read;
  ZStreamSetCb cbv_write;
  gboolean old_want_read;
  gboolean old_want_write;
  guint what_if_called;
  gboolean shutdown;

  ZSSLSession *ssl;
  gchar error[ERR_buflen];
} ZStreamSsl;

extern ZClass ZStreamSsl__class;

static gboolean
z_stream_ssl_read_callback(ZStream *stream G_GNUC_UNUSED, GIOCondition poll_cond, gpointer s)
{
  ZStreamSsl *self = (ZStreamSsl *) s;
  gboolean rc;

  z_enter();

  if (self->what_if_called == CALL_WRITE_WHEN_READ)
    rc = (*self->super.write_cb)(s, poll_cond, self->super.user_data_write);
  else
    rc = (*self->super.read_cb)(s, poll_cond, self->super.user_data_read);
  
  z_leave();
  return rc;
}

static gboolean
z_stream_ssl_write_callback(ZStream *stream G_GNUC_UNUSED, GIOCondition poll_cond, gpointer s)
{
  ZStreamSsl *self = (ZStreamSsl *) s;
  gboolean rc;
  
  z_enter();
  
  if (self->what_if_called == CALL_READ_WHEN_WRITE)
    rc = (*self->super.read_cb)(s, poll_cond, self->super.user_data_read);
  else
    rc = (*self->super.write_cb)(s, poll_cond, self->super.user_data_write);
  
  z_leave();
  return rc;
}

static gboolean
z_stream_ssl_pri_callback(ZStream *stream G_GNUC_UNUSED, GIOCondition poll_cond, gpointer s)
{
  ZStreamSsl *self = (ZStreamSsl *) s;
  gboolean rc;

  z_enter();
  
  rc = (*self->super.pri_cb)(s, poll_cond, self->super.user_data_pri);
  
  z_leave();
  return rc;
}

/* virtual functions */

static GIOStatus
z_stream_ssl_read_method(ZStream  *stream,
                           gchar  *buf,
                           gsize   count,
                           gsize  *bytes_read,
                          GError **error)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;
  gint result;
  gint ssl_err;

  z_enter();
  g_return_val_if_fail ((error == NULL) || (*error == NULL), G_IO_STATUS_ERROR);

  if (self->what_if_called == CALL_WRITE_WHEN_READ)
    {
      /*LOG
        This message indicates an internal error. Please report this event to the Balabit
	QA Team (devel@balabit.com).
       */
      z_log(NULL, CORE_ERROR, 2, "Internal error; error='Read called, when only write might be called'");
    }

  if (self->shutdown)
    {
      z_leave();
      return G_IO_STATUS_EOF;
    }
  self->super.child->timeout = self->super.timeout;
  result = SSL_read(self->ssl->ssl, buf, count);

  if (result < 0)
    {
      *bytes_read = 0;
      ssl_err = SSL_get_error(self->ssl->ssl, result);
      switch (ssl_err)
        {
        case SSL_ERROR_ZERO_RETURN:
          return G_IO_STATUS_EOF;
        case SSL_ERROR_WANT_READ:
          return G_IO_STATUS_AGAIN;
        case SSL_ERROR_WANT_WRITE:
          if (self->what_if_called == DO_AS_USUAL)
            {
              z_stream_set_cond(self->super.child, Z_STREAM_FLAG_WRITE, self->old_want_write);
              z_stream_set_callback(self->super.child, Z_STREAM_FLAG_WRITE, self->cbv_write.cb, self->cbv_write.cb_data, self->cbv_write.cb_notify);
              if (!self->old_want_write)
                {
                  z_stream_set_cond(self->super.child, Z_STREAM_FLAG_WRITE, TRUE);
                }
              if (self->cbv_write.cb != z_stream_ssl_write_callback)
                {
                  z_stream_set_callback(self->super.child, Z_STREAM_FLAG_WRITE, z_stream_ssl_write_callback, stream, NULL);
                }
            }
          self->what_if_called = CALL_READ_WHEN_WRITE;
          return G_IO_STATUS_AGAIN;
        case SSL_ERROR_SYSCALL:
          if (z_errno_is(EAGAIN) || z_errno_is(EINTR))
            {
              z_leave();
              return G_IO_STATUS_AGAIN;
            }
          if (z_errno_is(0))
            {
              z_leave();
              return G_IO_STATUS_EOF;
            }
	  /*LOG
	    This message indicates that an OS level error occurred during the SSL read. Check your OpenSSL
	    installation and please report this event to the Balabit QA Team (devel@balabit.com).
	   */
          z_log(self->super.name, CORE_ERROR, 3, "An OS error occurred during SSL read; error='%s'", g_strerror(errno));
          g_set_error(error, G_IO_CHANNEL_ERROR, g_io_channel_error_from_errno(errno), g_strerror(errno));
          z_leave();
          return G_IO_STATUS_ERROR;
        case SSL_ERROR_SSL:
        default:
          z_ssl_get_error_str(self->error, ERR_buflen);
          ERR_clear_error();

	  /*LOG
	    This message indicates that an internal SSL error occurred during the SSL read. Check your OpenSSL
	    installation and please report this event to the Balabit QA Team (devel@balabit.com).
	   */
          z_log(self->super.name, CORE_ERROR, 3, "An SSL error occurred during SSL read; error='%s'", self->error);
          g_set_error(error, G_IO_CHANNEL_ERROR, G_IO_CHANNEL_ERROR_IO, self->error);
          return G_IO_STATUS_ERROR;
        }
    }
  else if (result == 0)
    {
      *bytes_read = result;
      ERR_clear_error();
      z_leave();
      return G_IO_STATUS_EOF;
    }
  else
    {
      if (self->what_if_called != DO_AS_USUAL)
        {
          z_stream_set_cond(self->super.child, Z_STREAM_FLAG_WRITE, self->old_want_write);
          z_stream_set_callback(self->super.child, Z_STREAM_FLAG_WRITE, self->cbv_write.cb, self->cbv_write.cb_data, self->cbv_write.cb_notify);
          self->what_if_called = DO_AS_USUAL;
        }
      *bytes_read = result;
      ERR_clear_error();
      z_leave();
      return G_IO_STATUS_NORMAL;
    }

  z_leave();
  return G_IO_STATUS_ERROR;
}

static GIOStatus
z_stream_ssl_write_method(ZStream  *stream,
                       const gchar  *buf,
                             gsize   count,
                             gsize  *bytes_written,
                            GError **error)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;
  gint result;
  gint ssl_err;
  ZStreamSetCb cbv = {NULL, NULL, NULL};

  z_enter();
  g_return_val_if_fail ((error == NULL) || (*error == NULL), G_IO_STATUS_ERROR);
  if (self->shutdown)
    {
      g_set_error(error, G_IO_CHANNEL_ERROR,
                  g_io_channel_error_from_errno (ENOTCONN),
                  g_strerror (ENOTCONN));
      z_leave();
      return G_IO_STATUS_ERROR;
    }
  
  self->super.child->timeout = self->super.timeout;
  result = SSL_write(self->ssl->ssl, buf, count);
  
  if (result < 0)
    {
      *bytes_written = 0;
      ssl_err = SSL_get_error(self->ssl->ssl, result);
      switch (ssl_err)
        {
        case SSL_ERROR_ZERO_RETURN:
          return G_IO_STATUS_EOF;
        case SSL_ERROR_WANT_READ:
          if (self->what_if_called == DO_AS_USUAL)
            {
              z_stream_set_cond(self->super.child, Z_STREAM_FLAG_READ, self->old_want_read);
              z_stream_set_callback(self->super.child, Z_STREAM_FLAG_READ, self->cbv_read.cb, self->cbv_read.cb_data, self->cbv_read.cb_notify);
              if (!self->old_want_read)
                {
                  z_stream_set_cond(self->super.child, Z_STREAM_FLAG_READ, TRUE);
                  
                }
              if (self->cbv_read.cb != z_stream_ssl_read_callback)
                {
                  cbv.cb = z_stream_ssl_read_callback;
                  cbv.cb_data = stream;
                  z_stream_set_callback(self->super.child, Z_STREAM_FLAG_READ, z_stream_ssl_read_callback, stream, NULL);
                }
            }
          self->what_if_called = CALL_WRITE_WHEN_READ;
          return G_IO_STATUS_AGAIN;
        case SSL_ERROR_WANT_WRITE:
          return G_IO_STATUS_AGAIN;
        case SSL_ERROR_SYSCALL:
          if (z_errno_is(EAGAIN) || z_errno_is(EINTR))
            {
              z_leave();
              return G_IO_STATUS_AGAIN;
            }
	  /*LOG
	    This message indicates that an OS level error occurred during the SSL write. Check your OpenSSL
	    installation and please report this event to the Balabit QA Team (devel@balabit.com).
	   */
          z_log(self->super.name, CORE_ERROR, 3, "An OS error occurred during SSL write; error='%s'", g_strerror(errno));
          g_set_error(error, G_IO_CHANNEL_ERROR, g_io_channel_error_from_errno(errno), g_strerror(errno));
          z_leave();
          return G_IO_STATUS_ERROR;

        case SSL_ERROR_SSL:
        default:
          z_ssl_get_error_str(self->error, ERR_buflen);
          ERR_clear_error();

	  /*LOG
	    This message indicates that an internal SSL error occurred during the SSL write. Check your OpenSSL
	    installation and please report this event to the Balabit QA Team (devel@balabit.com).
	   */
          z_log(self->super.name, CORE_ERROR, 3, "An SSL error occurred during SSL write; error='%s'", self->error);
          g_set_error(error, G_IO_CHANNEL_ERROR, G_IO_CHANNEL_ERROR_IO, self->error);
          return G_IO_STATUS_ERROR;
        }
    }
  else
    {
      if (self->what_if_called != DO_AS_USUAL)
        {
          z_stream_set_cond(self->super.child, Z_STREAM_FLAG_READ, self->old_want_read);
          z_stream_set_callback(self->super.child, Z_STREAM_FLAG_READ, self->cbv_read.cb, self->cbv_read.cb_data, self->cbv_read.cb_notify);
          self->what_if_called = DO_AS_USUAL;
        }
      *bytes_written = result;
      ERR_clear_error();
      return G_IO_STATUS_NORMAL;
    }
  
  z_leave();
  return G_IO_STATUS_ERROR;
}

/* FIXME
 * This function doesn't use shutdown type. Is it OK?
 */
static GIOStatus
z_stream_ssl_shutdown_method(ZStream *stream, int i G_GNUC_UNUSED, GError **error)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;
  
  z_enter();
  g_return_val_if_fail ((error == NULL) || (*error == NULL), G_IO_STATUS_ERROR);

  if (!self->shutdown)
    {
      SSL_shutdown(self->ssl->ssl);
      ERR_clear_error();
      self->shutdown = TRUE;
    }

  z_leave();
  return G_IO_STATUS_NORMAL;
}

static gboolean
z_stream_ssl_ctrl_method(ZStream *s, guint function, gpointer value, guint vlen)
{
  gboolean ret;
  
  Z_CAST(s, ZStreamSsl);
  
  z_enter();
  switch (ZST_CTRL_MSG(function))
    {
    case ZST_CTRL_SET_CALLBACK_READ:
    case ZST_CTRL_SET_CALLBACK_WRITE:
    case ZST_CTRL_SET_CALLBACK_PRI:
      ret = z_stream_ctrl_method(s, function, value, vlen);
      break;
    default:
      ret = z_stream_ctrl_method(s, ZST_CTRL_MSG_FORWARD | function, value, vlen);
      break;
    }
                  
  z_leave();
  return ret;
}


static GIOStatus
z_stream_ssl_close_method(ZStream *stream, GError **error)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;
  GIOStatus res;

  z_enter();
  g_return_val_if_fail ((error == NULL) || (*error == NULL), G_IO_STATUS_ERROR);

  res = z_stream_close(self->super.child, error);
  
  z_leave();
  return res;
}


static void
z_stream_ssl_attach_source_method(ZStream *stream, GMainContext *context)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;

  z_enter();
  z_stream_ref(stream);
  z_stream_attach_source(self->super.child, context);
  if (!stream->source)
    {
      stream->source = z_stream_source_new(stream);
      g_source_attach(stream->source, context);
    }

  z_stream_unref(stream);
  z_leave();
  return;
}

static void
z_stream_ssl_detach_source_method(ZStream *stream)
{
  ZStreamSsl *self = (ZStreamSsl *) stream;
  GSource *source;
  
  if (stream->source)
    {
      source = stream->source;
      stream->source = NULL;
      /*
         NOTE Must be in this order because
         g_source_unref may drop the last
         reference to source.
       */
      g_source_destroy(source);
      g_source_unref(source);
    }
  
  z_stream_detach_source(self->super.child);
}

static gboolean 
z_stream_ssl_watch_prepare(ZStream *s, GSource *src G_GNUC_UNUSED, gint *timeout)
{
  ZStreamSsl *self = (ZStreamSsl *) s;

  z_enter();

  *timeout = -1;

  if (s->want_read)
    {
      if (self->shutdown)
        {
          *timeout = 0;
          z_leave();
          return TRUE;
        }
      if (SSL_pending(self->ssl->ssl))
        {
          *timeout = 0;
          z_leave();
          return TRUE;
        }
    }
  z_leave();
  return FALSE;

}

static gboolean 
z_stream_ssl_watch_check(ZStream *s, GSource *src G_GNUC_UNUSED)
{
  ZStreamSsl *self = (ZStreamSsl *) s;

  z_enter();

  if (s->want_read)
    {
      if (SSL_pending(self->ssl->ssl))
        {
          z_leave();
          return TRUE;
        }
    }
  z_leave();
  return FALSE;
}

static gboolean 
z_stream_ssl_watch_dispatch(ZStream *s, GSource *src G_GNUC_UNUSED)
{
  ZStreamSsl *self = (ZStreamSsl *) s;
  gboolean rc;

  z_enter();

  rc = self->super.read_cb(s, G_IO_IN, self->super.user_data_read);

  z_leave();
  return rc;
}


/* destructor */
static void
z_stream_ssl_free_method(ZObject *s)
{
  ZStreamSsl *self = Z_CAST(s, ZStreamSsl);

  z_enter();
  z_ssl_session_unref(self->ssl);
  ERR_clear_error();
  z_stream_free_method(s);
  z_leave();
}


ZStreamFuncs z_stream_ssl_funcs =
{
  {
    Z_FUNCS_COUNT(ZStream),
    z_stream_ssl_free_method,
  },
  z_stream_ssl_read_method,
  z_stream_ssl_write_method,
  NULL,
  NULL,
  z_stream_ssl_shutdown_method,
  z_stream_ssl_close_method,
  z_stream_ssl_ctrl_method,
  z_stream_ssl_attach_source_method,
  z_stream_ssl_detach_source_method,
  z_stream_ssl_watch_prepare,
  z_stream_ssl_watch_check,
  z_stream_ssl_watch_dispatch,
  NULL,
  NULL,
  NULL,
  NULL
};

ZClass ZStreamSsl__class =
{
  Z_CLASS_HEADER,
  &ZStream__class,
  "ZStreamSsl",
  sizeof(ZStreamSsl),
  &z_stream_ssl_funcs.super,
};

ZStream *
z_stream_ssl_new(ZStream *child, ZSSLSession *ssl)
{
  ZStreamSsl *self;
  BIO *bio;

  z_enter();
  self = Z_CAST(z_stream_new(Z_CLASS(ZStreamSsl), child->name, child, Z_STREAM_FLAG_READ|Z_STREAM_FLAG_WRITE), ZStreamSsl);
  self->ssl = z_ssl_session_ref(ssl);
  self->super.timeout = self->super.child->timeout;

  bio = z_ssl_bio_new(self->super.child);
  SSL_set_bio(self->ssl->ssl, bio, bio);

  z_stream_set_callback(self->super.child, Z_STREAM_FLAG_READ, z_stream_ssl_read_callback, self, NULL);
  z_stream_set_callback(self->super.child, Z_STREAM_FLAG_WRITE, z_stream_ssl_write_callback, self, NULL);
  z_stream_set_callback(self->super.child, Z_STREAM_FLAG_PRI, z_stream_ssl_pri_callback, self, NULL);

  z_leave();
  return (ZStream *) self;
}

