// -*- Mode: C++; -*-
//                            Package   : omniORB2
// sslContext.cc              Created on: 12/11/98
//                            Author    : Tatsuo Nakajima (tatsuo)
//
//    Copyright (C) 1996, 1997 Olivetti & Oracle Research Laboratory
//
//    This file is part of the omniORB library
//
//    The omniORB library is free software; you can redistribute it and/or
//    modify it under the terms of the GNU Library General Public
//    License as published by the Free Software Foundation; either
//    version 2 of the License, or (at your option) any later version.
//
//    This library is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//    Library General Public License for more details.
//
//    You should have received a copy of the GNU Library General Public
//    License along with this library; if not, write to the Free
//    Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  
//    02111-1307, USA
//
//
// Description:
//	

#include <omniORB2/CORBA.h>
#include <omniORB2/QOS.h>
#include <ropeFactory.h>
#include <inetSocketFactory.h>
#include <inetSocket.h>
#include <objectManager.h>

#include <stdlib.h>
#include <stdio.h>
#include <ssl/rsa.h>
#include <ssl/crypto.h>
#include <ssl/x509.h>
#include <ssl/pem.h>
#include <ssl/ssl.h>
#include <ssl/err.h>

sslContext* sslContext::singleton = 0;
SSL_CTX* sslContext::sslCtx = NULL;
CORBA::Boolean sslContext::ssl_only = 0;
CORBA::Boolean sslContext::ssl_vbcompat = 0;
int sslContext::ssl_verify = SSL_VERIFY_NONE;

static void
move_args(int& argc,char **argv,int idx,int nargs)
{
  if ((idx+nargs) <= argc)
    {
      for (int i=idx+nargs; i < argc; i++) {
        argv[i-nargs] = argv[i];
      }
      argc -= nargs;
    }
}

static int ssl_port = 0;

static CORBA::Boolean
parse_ssl_args(int &argc, char **argv)
{
  int idx = 1;
  int ret = 0;

  while (argc > idx) {

    // -SSLport <port number>
    if (strcmp(argv[idx],"-SSLport") == 0) {
      if ((idx+1) >= argc) {
        if (omniORB::traceLevel > 0) {
          cerr << "sslContext failed: missing -BOAssl_port parameter." << endl;
	}
      }

      unsigned int port;
      if (sscanf(argv[idx+1],"%u",&port) != 1 || (port == 0 || port >= 65536)) {
        if (omniORB::traceLevel > 0) {
          cerr << "sslContext failed: invalid -BOAssl_port parameter." << endl;
        }
	return 0;
      }

      ssl_port = 1;

      CORBA::Char* hostname;
      hostname = (CORBA::Char*)"";
      inetSocketEndpoint e;
      e.inetHost(hostname);
      e.sslPort(port);
      e.status(TRANS_SSL);
      omniObjectManager* rootObjectManager = omniObjectManager::root(0);
      ropeFactory_iterator iter(*rootObjectManager->incomingRopeFactories());
      incomingRopeFactory* factory;

      while ((factory = (incomingRopeFactory*)iter())) {
        if (factory->getType()->is_protocol(inetSocketEndpoint::protocol_name)) {
	  if (!factory->isIncoming(&e)) {
	    // This port has not been instantiated
	    factory->instantiateIncoming(&e,1); // create ssl session
	    if (omniORB::traceLevel >= 2) {
	      cerr << "Accept SSL calls on port " << e.sslPort() << endl;
	    }
	    break;
	  }
	}
      }
      move_args(argc,argv,idx,2);
      continue;
    }
  }
}

sslContext::sslContext()
{
  pd_status = 0;
  pd_connection_interceptor_factory = NULL;
}

static int init_flag = 0;

void
sslContext::initContext(int &argc, char **argv)
{
  //
  // If the method is already called, an exception is raised.
  //
  if(init_flag) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }
  if (!singleton) {
    singleton = new sslContext();
  }
  init_flag = 1;

  //
  // If the method is called before BOA is initialized, 
  // an exception is raised.
  //
  CORBA::BOA_ptr boa;
  try {
    boa = CORBA::BOA::getBOA();
  } catch(CORBA::OBJ_ADAPTER& ex) {
    CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  CORBA::release(boa);

  if(!parse_ssl_args(argc, argv)) {
    return;
  }

  //
  // create an incoming rope for ssl session
  //
  if(ssl_port == 0) {
    omniObjectManager* rootObjectManager = omniObjectManager::root(0);
    incomingRopeFactory* factory;

    ropeFactory_iterator iter(*rootObjectManager->incomingRopeFactories());

    while ((factory = (incomingRopeFactory*)iter())) {
      if (factory->getType()->is_protocol(inetSocketEndpoint::protocol_name)) {
        CORBA::Char* hostname = (CORBA::Char*)"";
        inetSocketEndpoint e;
        e.inetHost(hostname);
        e.status(TRANS_SSL);
        // instantiate a rope. Let the OS pick a port number.
        factory->instantiateIncoming(&e,1); // create ssl session
      }
    }
  }

  setContext();

  //
  // Set a factory object for the Default connection interceptor.
  //
  singleton->setConnectionInterceptorFactory((ConnectionInterceptorFactory *)
			new DefaultConnectionInterceptorFactory());
}

void
sslContext::initContext()
{
  initContextInternal(1);
}

void
sslContext::initContextInternal(CORBA::Boolean flag)
{
  if(singleton) {
     return;
  }

  singleton = new sslContext();
  setContext();

  singleton->setConnectionInterceptorFactory((ConnectionInterceptorFactory *)
                        new DefaultConnectionInterceptorFactory());
  if(!singleton->pd_status) {
    singleton->pd_status = flag;
  }
}

CORBA::Boolean
sslContext::checkInitializedByUser()
{
  return singleton->pd_status;
}

void
sslContext::setConnectionInterceptorFactory(ConnectionInterceptorFactory *factory)
{
  if(pd_connection_interceptor_factory) {
    delete pd_connection_interceptor_factory;
  }

  pd_connection_interceptor_factory = factory;
}

ConnectionInterceptorFactory *
sslContext::getConnectionInterceptorFactory()
{
  return pd_connection_interceptor_factory;
}

void
sslContext::setUseSSL()
{
  ssl_only = 1;
}

CORBA::Boolean
sslContext::checkUseSSL()
{
  return ssl_only;
}

void
sslContext::setVBCompat()
{
  ssl_vbcompat = 1;
}

CORBA::Boolean
sslContext::checkVBCompat()
{
  return ssl_vbcompat;
}

void
sslContext::setContext()
{
  if(sslCtx) {
    return;
  }

  if (omniORB::traceLevel >= 2) {
    cerr << "omniORB2 Initialize SSL Context" << endl;
  }

  //
  // Initiarize SSL stuff
  //
  SSLeay_add_ssl_algorithms();
  SSL_load_error_strings();
  if(sslContext::checkVBCompat()) {
    sslCtx = SSL_CTX_new(SSLv23_method());
  } else {
    sslCtx = SSL_CTX_new(SSLv3_method());
  }
  if(sslCtx == NULL) {
    cerr  << "PANIC !! SSL Context cannot be created\n";
    cerr << "This error should not occur." << endl;
  }
}

void
sslContext::setCertificateFile(char *cert)
{
  int rc;

  if(sslCtx == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  if(omniORB::traceLevel >= 2) {
    cerr << "We use " << cert << " as certificate file" << endl;
  }

  SSL_CTX_set_client_CA_list(sslCtx, SSL_load_client_CA_file((char *)cert));
  rc = SSL_CTX_use_certificate_file(sslCtx, cert, SSL_FILETYPE_PEM);
  if(rc == -1) {
     if(omniORB::traceLevel >= 2) {
       cerr << "some error in certification file" << endl;
     }
  }
}

int verify_callback(int ok, X509_STORE_CTX *ctx)
{
  int error = X509_STORE_CTX_get_error(ctx);

  if(!ok) {
    if(omniORB::traceLevel >= 2) {
      cerr << "SSL verify error: "
		<< X509_verify_cert_error_string(error)  << endl;
    }
  }
  return ok;
}

void
sslContext::setVerify()
{
  ssl_verify = SSL_VERIFY_PEER;
  SSL_CTX_set_verify(sslCtx, ssl_verify, (int (*)())verify_callback);
}

void
sslContext::setCertificateAuthorityFile(char *ca)
{
  int rc;

  if(sslCtx == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  if(omniORB::traceLevel >= 2) {
    if(ca == NULL) {
      ca = "undefined";
    }
    cerr << "We use " << ca << " as CA file" << endl;
  }

  rc = SSL_CTX_load_verify_locations(sslCtx, ca, NULL);
  if(rc != NULL) {
    rc = SSL_CTX_set_default_verify_paths(sslCtx);
    if(rc != NULL) {
      if(omniORB::traceLevel >= 2) {
        cerr << "We cannot set CA location." << endl; 
      }
    }
  }
}

void
sslContext::setKeyFile(char *key)
{
  int rc;

  if(sslCtx == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  if(omniORB::traceLevel >= 2) {
    cerr << "We use " << key << " as KEY file" << endl;
  }

  rc = SSL_CTX_use_RSAPrivateKey_file(sslCtx, key, SSL_FILETYPE_PEM);
  if(rc == -1) {
     if(omniORB::traceLevel >= 2) {
       cerr << "some error in key file" << endl;
     }
  }
  rc = SSL_CTX_check_private_key(sslCtx);
  if(rc == -1) {
     if(omniORB::traceLevel >= 2) {
       cerr << "bad private key" << endl;
      }
  }
}

void
sslContext::terminateSSLContext()
{
  if(sslCtx) {
    SSL_CTX_free(sslCtx);
  }
  sslCtx = NULL;
}

SSL_CTX*
sslContext::getSSLContext()
{
  return sslCtx;
}

sslContext*
sslContext::getContext()
{
  if(singleton) {
    return singleton;
  }

  initContextInternal(0);

  return singleton;
}

//
// Connection State for SSL
//
SSLConnectionState::SSLConnectionState(Strand *s)
	: ConnectionState(s)
{
  s_connection_control = (SSLConnectionControl *)new SSLConnectionControl(s);
  s_type = TRANS_SSL;
}

char *
SSLConnectionState::getChipher()
{
  char buf[500];

  if(s_strand == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  s_strand->controlStrand(GetChipher, buf , sizeof(buf));

  return CORBA::string_dup(buf);
}

char *
SSLConnectionState::getIssuer()
{
  char buf[500];

  if(s_strand == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  s_strand->controlStrand(GetIssuer, buf , sizeof(buf));

  return CORBA::string_dup(buf);
}

char *
SSLConnectionState::getSubject()
{
  char buf[500];

  if(s_strand == NULL) {
    throw CORBA::BAD_OPERATION(0,CORBA::COMPLETED_NO);
  }

  s_strand->controlStrand(GetSubject, buf , sizeof(buf));

  return CORBA::string_dup(buf);
}

//
// Connection Control Object for SSL
//
SSLConnectionControl::SSLConnectionControl(Strand *s)
			: inetConnectionControl(s)
{
  s_type = TRANS_SSL;
}
