github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/simpleexample/SimpleClientCpp/junkyard/helpers.cc.old (about)

     1  #include <stdio.h>
     2  #include <stdlib.h>
     3  #include <sys/types.h>
     4  #include <sys/stat.h>
     5  #include <fcntl.h>
     6  #include <unistd.h>
     7  #include <string.h>
     8  #include <errno.h>
     9  
    10  #include <netinet/in.h>
    11  #include <sys/socket.h>
    12  #include <arpa/inet.h>
    13  
    14  #include <helpers.h>
    15  
    16  #include <openssl/rsa.h>
    17  #include <openssl/x509.h>
    18  #include <openssl/ssl.h>
    19  #include <openssl/evp.h>
    20  #include <openssl/asn1.h>
    21  #include <openssl/err.h>
    22  #include <openssl/aes.h>
    23  #include <openssl/hmac.h>
    24  #include <openssl/rand.h>
    25  
    26  #include <string>
    27  #include <thread>
    28  
    29  #include <messages.pb.h>
    30  
    31  using std::string;
    32  using std::unique_ptr;
    33  using std::thread;
    34  using std::vector;
    35  
    36  //
    37  // Copyright 2015 Google Corporation, All Rights Reserved.
    38  //
    39  // Licensed under the Apache License, Version 2.0 (the "License");
    40  // you may not use this file except in compliance with the License.
    41  // You may obtain a copy of the License at
    42  //     http://www.apache.org/licenses/LICENSE-2.0
    43  // or in the the file LICENSE-2.0.txt in the top level sourcedirectory
    44  // Unless required by applicable law or agreed to in writing, software
    45  // distributed under the License is distributed on an "AS IS" BASIS,
    46  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    47  // See the License for the specific language governing permissions and
    48  // limitations under the License
    49  //
    50  // Portions of this code were derived TPM2.0-TSS published
    51  // by Intel under the license set forth in intel_license.txt
    52  // and downloaded on or about August 6, 2015.
    53  // File: helpers.cc
    54  
    55  
    56  void PrintBytes(int n, byte* in) {
    57    for (int i = 0; i < n; i++) printf("%02x", in[i]);
    58  }
    59  
    60  bool ReadFile(string& file_name, string* out) {
    61    struct stat file_info;
    62    int k = stat(file_name.c_str(), &file_info);
    63    if (k != 0)
    64      return false;
    65    unique_ptr<byte> buf(new byte[file_info.st_size]);
    66    int fd = open(file_name.c_str(), O_RDONLY);
    67    if (fd < 0)
    68      return false;
    69    size_t n = read(fd, buf.get(), file_info.st_size);
    70    out->assign((const char*)buf.get(), n);
    71    close(fd);
    72    return true;
    73  }
    74  
    75  bool WriteFile(string& file_name, string& in) {
    76    int fd = creat(file_name.c_str(), S_IRWXU | S_IRWXG);
    77    if (fd < 0)
    78      return false;
    79    int n = write(fd, (byte*)in.data(), in.size());
    80    close(fd);
    81    return n > 0;
    82    return true;
    83  }
    84  
    85  bool SerializePrivateKey(string& key_type, EVP_PKEY* key, string* out_buf) {
    86    simpleexample_messages::PrivateKeyMessage msg;
    87  
    88    if (key_type == "RSA") {
    89      RSA* rsa_key = EVP_PKEY_get1_RSA(key);
    90      msg.set_allocated_rsa_key(new simpleexample_messages::RsaPrivateKeyMessage());
    91      string* m_str = BN_to_bin(*rsa_key->n);
    92      msg.mutable_rsa_key()->set_m(*m_str);
    93      string* e_str = BN_to_bin(*rsa_key->e);
    94      msg.mutable_rsa_key()->set_e(*e_str);
    95      string* d_str = BN_to_bin(*rsa_key->d);
    96      msg.mutable_rsa_key()->set_d(*d_str);
    97      msg.set_key_type("RSA");
    98    } else if (key_type == "ECC") {
    99      EC_KEY* ec_key = EVP_PKEY_get1_EC_KEY(key);
   100      msg.set_allocated_ec_key(new simpleexample_messages::EcPrivateKeyMessage());
   101      byte out[4096];
   102      byte* ptr = out;
   103      int n = i2d_ECPrivateKey(ec_key, &ptr);
   104      if (n <= 0) {
   105        printf("Can't i2d ECC private key\n");
   106        return false;
   107      }
   108      msg.mutable_ec_key()->set_der_blob((void*)out, (size_t)n);
   109      msg.set_key_type("ECC");
   110    } else {
   111      printf("SerializePrivateKey: Unknown key type\n");
   112      return false;
   113    }
   114  
   115    if (!msg.SerializeToString(out_buf)) {
   116      return false;
   117    }
   118    return true;
   119  }
   120  
   121  bool DeserializePrivateKey(string& in_buf, string* key_type, EVP_PKEY** key) {
   122    simpleexample_messages::PrivateKeyMessage msg;
   123  
   124    if (!msg.ParseFromString(in_buf)) {
   125      return false;
   126    }
   127    key_type->assign(msg.key_type());
   128    if (msg.key_type() == "RSA") {
   129      if (!msg.has_rsa_key()) {
   130        return false;
   131      }
   132      RSA* rsa_key = RSA_new();
   133      if (msg.rsa_key().has_m())
   134        rsa_key->n = bin_to_BN(msg.rsa_key().m().size(), (byte*)msg.rsa_key().m().data());
   135      if (msg.rsa_key().has_e())
   136        rsa_key->e = bin_to_BN(msg.rsa_key().e().size(), (byte*)msg.rsa_key().e().data());
   137      if (msg.rsa_key().has_d())
   138        rsa_key->d = bin_to_BN(msg.rsa_key().d().size(), (byte*)msg.rsa_key().d().data());
   139        EVP_PKEY* pKey = new EVP_PKEY();
   140        EVP_PKEY_assign_RSA(pKey, rsa_key);
   141        *key = pKey;
   142    } else if (msg.key_type() == "ECC") {
   143      if (!msg.has_ec_key()) {
   144        return false;
   145      }
   146      const byte* ptr = (byte*)msg.ec_key().der_blob().data();
   147      EC_KEY* ec_key = d2i_ECPrivateKey(nullptr, &ptr,
   148                                        msg.ec_key().der_blob().size());
   149      if (ec_key == nullptr) {
   150        printf("Can't i2d ECC private key\n");
   151        return false;
   152      }
   153      EVP_PKEY* pKey = new EVP_PKEY();
   154      EVP_PKEY_assign_EC_KEY(pKey, ec_key);
   155      *key = pKey;
   156    } else {
   157      printf("DeserializePrivateKey: Unknown key type\n");
   158      return false;
   159    }
   160    return true;
   161  }
   162  
   163  // standard buffer size
   164  #define MAX_SIZE_PARAMS 4096
   165  
   166  void PrintPrivateRSAKey(RSA& key) {
   167    if (key.n != nullptr) {
   168      printf("\nModulus: \n");
   169      BN_print_fp(stdout, key.n);
   170      printf("\n");
   171    }
   172    if (key.e != nullptr) {
   173      printf("\ne: \n");
   174      BN_print_fp(stdout, key.e);
   175      printf("\n");
   176    }
   177    if (key.d != nullptr) {
   178      printf("\nd: \n");
   179      BN_print_fp(stdout, key.d);
   180      printf("\n");
   181    }
   182    if (key.p != nullptr) {
   183      printf("\np: \n");
   184      BN_print_fp(stdout, key.p);
   185      printf("\n");
   186    }
   187    if (key.q != nullptr) {
   188      printf("\nq: \n");
   189      BN_print_fp(stdout, key.q);
   190      printf("\n");
   191    }
   192  #if 0
   193    if (key.dmp1 != nullptr) {
   194      printf("\ndmp1: \n");
   195      BN_print_fp(stdout, key.dmp1);
   196      printf("\n");
   197    }
   198    if (key.dmq1 != nullptr) {
   199      printf("\ndmq1: \n");
   200      BN_print_fp(stdout, key.dmq1);
   201      printf("\n");
   202    }
   203    if (key.iqmp != nullptr) {
   204      printf("\niqmp: \n");
   205      BN_print_fp(stdout, key.iqmp);
   206      printf("\n");
   207    }
   208  #endif
   209  }
   210  
   211  BIGNUM* bin_to_BN(int len, byte* buf) {
   212    BIGNUM* bn = BN_bin2bn(buf, len, nullptr);
   213    return bn;
   214  }
   215  
   216  string* BN_to_bin(BIGNUM& n) {
   217    byte buf[MAX_SIZE_PARAMS];
   218  
   219    int len = BN_bn2bin(&n, buf);
   220    return new string((const char*)buf, len);
   221  }
   222  
   223  bool BN_to_string(BIGNUM& n, string* out) {
   224    byte buf[MAX_SIZE_PARAMS];
   225  
   226    int len = BN_bn2bin(&n, buf);
   227    out->assign((const char*)buf, len);
   228    return true;
   229  }
   230  
   231  EVP_PKEY* GenerateKey(string& keyType, int keySize) {
   232    EVP_PKEY* pKey = EVP_PKEY_new();
   233    if (pKey == nullptr) {
   234      return nullptr;
   235    }
   236    if (keyType == "RSA") {
   237      RSA* rsa_key = RSA_generate_key(keySize, 0x010001ULL, nullptr, nullptr);
   238      if (rsa_key == nullptr) {
   239        printf("GenerateKey: couldn't generate RSA key.\n");
   240        return nullptr;
   241      }
   242      EVP_PKEY_assign_RSA(pKey, rsa_key);
   243      pKey->type = EVP_PKEY_RSA;
   244    } else if (keyType == "ECC") {
   245      EC_KEY* ec_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
   246      EC_KEY_set_asn1_flag(ec_key, OPENSSL_EC_NAMED_CURVE);
   247      if (ec_key == nullptr) {
   248        printf("GenerateKey: couldn't generate ECC program key (1).\n");
   249        return nullptr;
   250      }
   251      if (1 != EC_KEY_generate_key(ec_key)) {
   252        printf("GenerateKey: couldn't generate ECC program key(2).\n");
   253        return nullptr;
   254      }
   255      EVP_PKEY_assign_EC_KEY(pKey, ec_key);
   256      pKey->type = EVP_PKEY_EC;
   257    } else {
   258      printf("GenerateKey: unsupported key type.\n");
   259      return nullptr;
   260    }
   261    return pKey;
   262  }
   263  
   264  class extEntry {
   265  public:
   266    char* key_;
   267    char* value_;
   268  
   269    extEntry(const char* k, const char* v);
   270    extEntry();
   271    char* getKey();
   272    char* getValue();
   273  };
   274  
   275  extEntry::extEntry(const char* k, const char* v) {
   276    key_ = (char*)strdup(k);
   277    value_ = (char*)strdup(v);
   278  }
   279  
   280  extEntry::extEntry() {
   281    key_ = nullptr;
   282    value_ = nullptr;
   283  }
   284  
   285  char* extEntry::getKey() {
   286    return key_;
   287  }
   288  
   289  char* extEntry::getValue() {
   290    return value_;
   291  }
   292  
   293  bool addExtensionsToCert(int num_entry, extEntry** entries, X509* cert) {
   294    // add extensions
   295    X509V3_CTX ctx;
   296    X509V3_set_ctx_nodb(&ctx);
   297    X509V3_set_ctx(&ctx, cert, cert, NULL, NULL, 0);
   298    for (int i = 0; i < num_entry; i++) {
   299      if (entries[i]->getValue() == nullptr || strlen(entries[i]->getValue()) ==0)
   300        continue;
   301      int nid = OBJ_txt2nid(entries[i]->getKey());
   302      X509_EXTENSION* ext = X509V3_EXT_conf_nid(NULL, &ctx, nid, entries[i]->getValue());
   303      if (ext == 0) {
   304        printf("Bad ext_conf %d\n", i);
   305        printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error()));
   306        return false;
   307      }
   308      if (!X509_add_ext(cert, ext, -1)) {
   309        printf("Bad add ext %d\n", i);
   310        printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error()));
   311        return false;
   312      }
   313      X509_EXTENSION_free(ext);
   314    }
   315    return true;
   316  }
   317  
   318  bool GenerateX509CertificateRequest(string& key_type, string& common_name,
   319              EVP_PKEY* subjectKey, bool sign_request, X509_REQ* req) {
   320    X509_NAME* subject = X509_NAME_new();
   321    X509_REQ_set_version(req, 2L);
   322    if (subject == nullptr) {
   323      printf("Can't alloc x509 name\n");
   324      return false;
   325    }
   326    if (common_name.size() > 0) {
   327      int nid = OBJ_txt2nid("CN");
   328      X509_NAME_ENTRY* ent = X509_NAME_ENTRY_create_by_NID(nullptr, nid,
   329          MBSTRING_ASC, (byte*)common_name.c_str(), -1);
   330      if (ent == nullptr) {
   331        printf("X509_NAME_ENTRY return is null, nid: %d\n", nid);
   332        return false;
   333      }
   334      if (X509_NAME_add_entry(subject, ent, -1, 0) != 1) {
   335        printf("Can't add name ent\n");
   336        return false;
   337      }
   338    }
   339    // TODO: do the foregoing for the other name components
   340    if (X509_REQ_set_subject_name(req, subject) != 1)  {
   341      printf("Can't set x509 subject\n");
   342      return false;
   343    }
   344  
   345    // fill key parameters in request
   346    if (sign_request) {
   347      const EVP_MD* digest = EVP_sha256();
   348      if (!X509_REQ_sign(req, subjectKey, digest)) {
   349        printf("Sign request fails\n");
   350        printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error()));
   351      }
   352    }
   353    if (X509_REQ_set_pubkey(req, subjectKey) ==0) {
   354        printf("X509_REQ_set_pubkey failed\n");
   355    }
   356    return true;
   357  }
   358  
   359  bool SignX509Certificate(EVP_PKEY* signingKey, bool f_isCa,
   360                           bool f_canSign, string& signing_issuer,
   361                           string& keyUsage, string& extendedKeyUsage,
   362                           int64 duration, EVP_PKEY* signedKey,
   363                           X509_REQ* req, bool verify_req_sig, X509* cert) {
   364    if (signedKey == nullptr)
   365      signedKey = X509_REQ_get_pubkey(req);
   366    if (signedKey == nullptr) {
   367      printf("Can't get pubkey\n");
   368      return false;
   369    }
   370  
   371    if (verify_req_sig) {
   372      if (X509_REQ_verify(req, signedKey) != 1) {
   373        printf("Req does not verify\n");
   374        return false;
   375      }
   376    }
   377    
   378    uint64_t serial = 1;
   379    const EVP_MD* digest = EVP_sha256();
   380    X509_NAME* name;
   381    X509_set_version(cert, 2L);
   382    ASN1_INTEGER_set(X509_get_serialNumber(cert), serial++);
   383  
   384    name = X509_REQ_get_subject_name(req);
   385    if (X509_set_subject_name(cert, name) != 1) {
   386      printf("Can't set subject name\n");
   387      return false;
   388    }
   389    if (X509_set_pubkey(cert, signedKey) != 1) {
   390      printf("Can't set pubkey\n");
   391      return false;
   392    }
   393    if (!X509_gmtime_adj(X509_get_notBefore(cert), 0)) {
   394      printf("Can't adj notBefore\n");
   395      return false;
   396    }
   397    if (!X509_gmtime_adj(X509_get_notAfter(cert), duration)) {
   398      printf("Can't adj notAfter\n");
   399      return false;
   400    }
   401    X509_NAME* issuer = X509_NAME_new();
   402    int nid = OBJ_txt2nid("CN");
   403    X509_NAME_ENTRY* ent = X509_NAME_ENTRY_create_by_NID(nullptr, nid,
   404        MBSTRING_ASC, (byte*)signing_issuer.c_str(), -1);
   405    if (X509_NAME_add_entry(issuer, ent, -1, 0) != 1) {
   406      printf("Can't add issuer name ent: %s, %ld\n",
   407             signing_issuer.c_str(), (long unsigned)ent);
   408      printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error()));
   409      return false;
   410    }
   411    if (X509_set_issuer_name(cert, issuer) != 1) {
   412      printf("Can't set issuer name\n");
   413      return false;
   414    }
   415  
   416    // Add extensions which should be
   417    //    X509v3 extensions:
   418    //        X509v3 Key Usage: critical
   419    //            Key Agreement, Certificate Sign
   420    //        X509v3 Extended Key Usage: 
   421    //            TLS Web Server Authentication, TLS Web Client Authentication
   422    //        X509v3 Basic Constraints: critical
   423    //            CA:TRUE
   424    extEntry* entries[128];
   425    int n = 0;
   426    if (f_isCa) {
   427      entries[n++] = new extEntry("basicConstraints", "critical,CA:TRUE");
   428    }
   429    entries[n++] = new extEntry("keyUsage", keyUsage.c_str());
   430    entries[n++] = new extEntry("extendedKeyUsage", extendedKeyUsage.c_str());
   431    if (!addExtensionsToCert(n, entries, cert)) {
   432      printf("Can't add extensions\n");
   433      return false;
   434    }
   435  
   436    if (!X509_sign(cert, signingKey, digest)) {
   437      printf("Bad PKEY type\n");
   438      return false;
   439    }
   440    return true;
   441  }
   442  
   443  void XorBlocks(int size, byte* in1, byte* in2, byte* out) {
   444    int i;
   445  
   446    for (i = 0; i < size; i++)
   447      out[i] = in1[i] ^ in2[i];
   448  }
   449  
   450  bool AesCtrCrypt(int key_size_bits, byte* key, int size,
   451                   byte* in, byte* out) {
   452    AES_KEY ectx;
   453    uint64_t ctr[2] = {0ULL, 0ULL};
   454    byte block[32];
   455  
   456    if (key_size_bits != 128) {
   457      return false;
   458    }
   459    
   460    AES_set_encrypt_key(key, 128, &ectx);
   461    while (size > 0) {
   462      ctr[1]++;
   463      AES_encrypt((byte*)ctr, block, &ectx);
   464      XorBlocks(16, block, in, out);
   465      in += 16;
   466      out += 16;
   467      size -= 16;
   468    }
   469    return true;
   470  }
   471  
   472  #define AESBLKSIZE 16
   473  
   474  bool AesCFBEncrypt(byte* key, int in_size, byte* in, int iv_size, byte* iv,
   475                     int* out_size, byte* out) {
   476    byte last_cipher[32];
   477    byte cipher_block[32];
   478    int size = 0;
   479    int current_size;
   480  
   481    AES_KEY ectx;
   482    AES_set_encrypt_key(key, 128, &ectx);
   483  
   484    // Don't write iv, called already knows it
   485    if(iv_size != AESBLKSIZE) return false;
   486    memcpy(last_cipher, iv, AESBLKSIZE);
   487  
   488    while (in_size > 0) {
   489      if ((size + AESBLKSIZE) > *out_size) return false; 
   490      // C[0] = IV, C[i] = P[i] ^ E(K, C[i-1])
   491      AES_encrypt(last_cipher, cipher_block, &ectx);
   492      if (in_size >= AESBLKSIZE)
   493        current_size = AESBLKSIZE;
   494      else
   495        current_size = in_size;
   496      XorBlocks(AESBLKSIZE, cipher_block, in, last_cipher);
   497      memcpy(out, last_cipher, current_size);
   498      out += current_size;
   499      size += current_size;
   500      in += current_size;
   501      in_size -= current_size;
   502    }
   503    *out_size = size;
   504    return true;
   505  }
   506  
   507  bool AesCFBDecrypt(byte* key, int in_size, byte* in, int iv_size, byte* iv,
   508                     int* out_size, byte* out) {
   509    byte last_cipher[32];
   510    byte cipher_block[32];
   511    int size = 0;
   512    int current_size;
   513  
   514    AES_KEY ectx;
   515    AES_set_encrypt_key(key, 128, &ectx);
   516  
   517    // Don't write iv, called already knows it
   518    if(iv_size != AESBLKSIZE) return false;
   519    memcpy(last_cipher, iv, AESBLKSIZE);
   520  
   521    while (in_size > 0) {
   522      if ((size + AESBLKSIZE) > *out_size) return false; 
   523      // P[i] = C[i] ^ E(K, C[i-1])
   524      AES_encrypt(last_cipher, cipher_block, &ectx);
   525      if (in_size >= AESBLKSIZE)
   526        current_size = AESBLKSIZE;
   527      else
   528        current_size = in_size;
   529      XorBlocks(current_size, cipher_block, in, out);
   530      memcpy(last_cipher, in, current_size);
   531      out += current_size;
   532      size += current_size;
   533      in += current_size;
   534      in_size -= current_size;
   535    }
   536    *out_size = size;
   537    return true;
   538  }
   539  
   540  bool VerifyX509CertificateChain(X509* cacert, X509* cert) {
   541    X509_STORE_CTX *store_ctx = X509_STORE_CTX_new();
   542    X509_STORE *store = X509_STORE_new();
   543    X509_STORE_add_cert(store, cacert);
   544    // int X509_STORE_CTX_init(X509_STORE_CTX *ctx, X509_STORE *store, X509 *x509, STACK_OF(X509) *chain);
   545    X509_STORE_CTX_init(store_ctx, store, cacert, nullptr);
   546    int ret = X509_verify_cert(store_ctx);
   547    if (ret <= 0)
   548      printf("Error: %s\n", X509_verify_cert_error_string(store_ctx->error));
   549    return ret;
   550  }
   551  
   552  SslChannel::SslChannel() {
   553    fd_ = -1;
   554    ssl_ctx_ = nullptr;
   555    ssl_ = nullptr;
   556    peer_cert_ = nullptr;
   557    store_ = nullptr;
   558    private_key_ = nullptr;
   559  }
   560  
   561  SslChannel::~SslChannel() {
   562    if (fd_ > 0) {
   563      close(fd_);
   564    }
   565    fd_ = -1;
   566    // clear private_key_;
   567  #if 0
   568    // Doesn't need to be freed, context free takes care of it.
   569    if (ssl_ != nullptr) {
   570      SSL_free(ssl_);
   571    }
   572    ssl_ = nullptr;
   573  #endif
   574    if (peer_cert_ != nullptr) {
   575      X509_free(peer_cert_);
   576    }
   577    peer_cert_ = nullptr;
   578    if (ssl_ctx_ != nullptr) {
   579      SSL_CTX_free(ssl_ctx_);
   580    }
   581    ssl_ctx_ = nullptr;
   582    if (store_ != nullptr) {
   583      X509_STORE_free(store_);
   584    }
   585    store_ = nullptr;
   586  }
   587  
   588  int SslChannel::CreateServerSocket(string& address, string& port) {
   589    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
   590    struct sockaddr_in dest_addr;
   591    uint16_t s_port = atoi(port.c_str());
   592    memset((byte*)&dest_addr, 0, sizeof(dest_addr));
   593  
   594    dest_addr.sin_family = AF_INET;
   595    dest_addr.sin_port = htons(s_port);
   596    dest_addr.sin_addr.s_addr = INADDR_ANY;
   597    inet_aton(address.c_str(), &dest_addr.sin_addr);
   598  
   599    if (bind(sockfd, (struct sockaddr*)&dest_addr, sizeof(dest_addr)) < 0) {
   600      printf("Unable to bind\n");
   601      return -1;
   602    }
   603  
   604    if (listen(sockfd, 1) < 0) {
   605      printf("Unable to listen\n");
   606      return -1;
   607    }
   608    return sockfd;
   609  }
   610  
   611  
   612  int SslChannel::CreateClientSocket(string& addr, string& port) {
   613    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
   614    struct sockaddr_in dest_addr;
   615    uint16_t s_port = atoi(port.c_str());
   616    memset((byte*)&dest_addr, 0, sizeof(dest_addr));
   617  
   618    dest_addr.sin_family = AF_INET;
   619    dest_addr.sin_port = htons(s_port);
   620    dest_addr.sin_addr.s_addr = INADDR_ANY;
   621    inet_aton(addr.c_str(), &dest_addr.sin_addr);
   622  
   623    if (connect(sockfd, (struct sockaddr *) &dest_addr,
   624                sizeof(struct sockaddr)) == -1) {
   625      printf("Error: Cannot connect to host\n");
   626      return -1;
   627    }
   628    return sockfd;
   629  }
   630  
   631  bool SslChannel::InitServerSslChannel(string& network, string& address,
   632                  string& port, X509* policyCert, X509* programCert,
   633                  string& keyType, EVP_PKEY* privateKey, int verify) {
   634     SSL_library_init();
   635     OpenSSL_add_all_algorithms();
   636     ERR_load_crypto_strings();
   637  
   638    // I'm a server.
   639    server_role_ = true;
   640  
   641    if (privateKey == nullptr) {
   642      printf("Private key is null.\n");
   643      return false;
   644    }
   645  
   646    // Create socket and contexts.
   647    fd_ = CreateServerSocket(address, port);
   648    if(fd_ <= 0) {
   649      printf("CreateServerSocket failed.\n");
   650      return false;
   651    }
   652  
   653    ssl_ctx_ = SSL_CTX_new(TLSv1_2_server_method());
   654    if (ssl_ctx_ == nullptr) {
   655      printf("SSL_CTX_new failed(server).\n");
   656      ERR_print_errors_fp(stderr);
   657      return false;
   658    }
   659  
   660    SSL_CTX_clear_extra_chain_certs(ssl_ctx_);
   661    private_key_ = privateKey;
   662    SSL_CTX_use_certificate(ssl_ctx_, programCert);
   663    if (EVP_PKEY_id(private_key_) == EVP_PKEY_EC) {
   664      if (!SSL_CTX_set_tmp_ecdh(ssl_ctx_, EVP_PKEY_get1_EC_KEY(private_key_))) {
   665         printf("SSL_CTX_set_tmp_ecdh failed.\n");
   666         return false;
   667      }
   668      SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_ECDH_USE);
   669    }
   670    if(SSL_CTX_use_PrivateKey(ssl_ctx_, private_key_) <= 0) {
   671      printf("SSL_CTX_use_PrivateKey failed.\n");
   672      ERR_print_errors_fp(stderr);
   673      return false;
   674    }
   675  
   676    // Setup verification stuff.
   677    switch(verify) {
   678      case SSL_NO_SERVER_VERIFY_NO_CLIENT_AUTH:
   679      case SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY:
   680      case SSL_SERVER_VERIFY_NO_CLIENT_VERIFY:
   681        SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr);
   682        SSL_CTX_set_verify_depth(ssl_ctx_, 3);
   683        break;
   684      case SSL_SERVER_VERIFY_CLIENT_VERIFY:
   685        SSL_CTX_add_extra_chain_cert(ssl_ctx_, programCert);
   686        SSL_CTX_add_extra_chain_cert(ssl_ctx_, policyCert);
   687        store_ = X509_STORE_new();
   688        if (store_ == nullptr) {
   689          printf("X509_STORE_new failed.\n");
   690          return false;
   691        }
   692        X509_STORE_add_cert(store_, policyCert);
   693        SSL_CTX_set_cert_store(ssl_ctx_, store_);
   694        SSL_CTX_set_verify(ssl_ctx_,
   695            SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
   696        SSL_CTX_set_verify_depth(ssl_ctx_, 3);
   697        break;
   698      default:
   699        printf("Unknown verification mode.\n");
   700        return false;
   701    }
   702    return true;
   703  }
   704  
   705  bool SslChannel::InitClientSslChannel(string& network, string& address,
   706                  string& port, X509* policyCert, X509* programCert,
   707                  string& keyType, EVP_PKEY* privateKey, int verify) {
   708     SSL_library_init();
   709     OpenSSL_add_all_algorithms();
   710     ERR_load_crypto_strings();
   711  
   712    // I'm a client.
   713    server_role_ = false;
   714  
   715    // Create socket and contexts.
   716    fd_ = CreateClientSocket(address, port);
   717    if(fd_ <= 0) {
   718      printf("CreateClientSocket failed.\n");
   719      return false;
   720    }
   721  
   722    ssl_ctx_ = SSL_CTX_new(TLSv1_2_client_method());
   723    if (ssl_ctx_ == nullptr) {
   724      printf("SSL_CTX_new failed(client).\n");
   725      return false;
   726    }
   727    SSL_CTX_clear_extra_chain_certs(ssl_ctx_);
   728    if (privateKey == nullptr) {
   729      printf("Private key is null\n");
   730      return false;
   731    }
   732    private_key_ = privateKey;
   733  
   734    // Setup verification stuff.
   735    switch(verify) {
   736      case SSL_NO_SERVER_VERIFY_NO_CLIENT_AUTH:
   737        SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr);
   738        SSL_CTX_set_verify_depth(ssl_ctx_, 3);
   739        break;
   740      case SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY:
   741      case SSL_SERVER_VERIFY_NO_CLIENT_VERIFY:
   742      case SSL_SERVER_VERIFY_CLIENT_VERIFY:
   743        if (privateKey == nullptr) {
   744          printf("Private key is null\n");
   745          return false;
   746        }
   747        if (EVP_PKEY_id(private_key_) == EVP_PKEY_EC) {
   748          if (!SSL_CTX_set_tmp_ecdh(ssl_ctx_,
   749                  EVP_PKEY_get1_EC_KEY(private_key_))) {
   750            printf("SSL_CTX_set_tmp_ecdh failed.\n");
   751            return false;
   752          }
   753          SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_ECDH_USE);
   754        }
   755        if(SSL_CTX_use_PrivateKey(ssl_ctx_, private_key_) <= 0) {
   756          printf("SSL_CTX_use_PrivateKey failed.\n");
   757          ERR_print_errors_fp(stderr);
   758          return false;
   759        }
   760        SSL_CTX_use_certificate(ssl_ctx_, programCert);
   761        SSL_CTX_add_extra_chain_cert(ssl_ctx_, programCert);
   762        SSL_CTX_add_extra_chain_cert(ssl_ctx_, policyCert);
   763        store_ = X509_STORE_new();
   764        if (store_ == nullptr) {
   765          printf("X509_STORE_new failed.\n");
   766          return false;
   767        }
   768        X509_STORE_add_cert(store_, policyCert);
   769        SSL_CTX_set_cert_store(ssl_ctx_, store_);
   770        SSL_CTX_set_verify(ssl_ctx_,
   771          SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
   772        SSL_CTX_set_verify_depth(ssl_ctx_, 3);
   773        if (verify == SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY)
   774          SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr); 
   775        break;
   776      default:
   777        printf("Unknown verification mode.\n");
   778        return false;
   779    }
   780  
   781    ssl_ = SSL_new(ssl_ctx_);
   782    if (ssl_ == nullptr) {
   783      printf("SSL_new failed(client).\n");
   784      return false;
   785    }
   786  
   787    SSL_set_fd(ssl_, fd_);
   788    SSL_set_connect_state(ssl_);
   789  
   790    // Connect.
   791    if (SSL_connect(ssl_) != 1) {
   792      printf("SSL_connect failed.\n");
   793      ERR_print_errors_fp(stderr);
   794      return false;
   795    }
   796    peer_cert_ = SSL_get_peer_certificate(ssl_);
   797    return true;
   798  }
   799  
   800  bool SslChannel::ServerLoop(void(*server_loop)(SslChannel*,  SSL*, int)) {
   801    bool fContinue = true;
   802    printf("ServerLoop\n");
   803  
   804    while(fContinue) {
   805      struct sockaddr_in addr;
   806      uint len = sizeof(addr);
   807      memset((byte*)&addr, 0, len);
   808  
   809      int client = accept(fd_, (struct sockaddr*)&addr, &len);
   810      if (client < 0) {
   811        printf("Unable to accept\n");
   812        printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error()));
   813        continue;
   814      }
   815  
   816      SSL* ssl = SSL_new(ssl_ctx_);
   817      if (private_key_ == nullptr) {
   818        printf("private_key_ is null.\n");
   819        return false;
   820      }
   821      SSL_set_fd(ssl, client);
   822      SSL_set_accept_state(ssl);
   823      if (SSL_accept(ssl) <= 0) {
   824        printf("Unable to ssl_accept\n");
   825        ERR_print_errors_fp(stderr);
   826        continue;
   827      } 
   828      server_loop(this, ssl, client);
   829      // thread t(server_loop, this, ssl, client);
   830    }
   831    return true;
   832  }
   833  
   834  void SslChannel::Close() {
   835    if (fd_ > 0) {
   836      close(fd_);
   837    }
   838    fd_ = -1;
   839    if (ssl_ != nullptr) {
   840      SSL_free(ssl_);
   841    }
   842    ssl_ = nullptr;
   843    if (peer_cert_ != nullptr) {
   844      X509_free(peer_cert_);
   845    }
   846    peer_cert_ = nullptr;
   847    if (ssl_ctx_ != nullptr) {
   848      SSL_CTX_free(ssl_ctx_);
   849    }
   850    ssl_ctx_ = nullptr;
   851    if (store_ != nullptr) {
   852      X509_STORE_free(store_);
   853    }
   854    store_ = nullptr;
   855  }
   856  
   857  X509* SslChannel::GetPeerCert() {
   858    return peer_cert_;
   859  }
   860  
   861  int SslMessageRead(SSL* ssl, int size, byte* buf) {
   862    byte new_buf[8192];
   863    int tmp_size = SslRead(ssl, size, new_buf);
   864    if (tmp_size <= 0)
   865      return tmp_size;
   866    int real_size = __builtin_bswap32(*((int*)new_buf));
   867    if (tmp_size == sizeof(int)) {
   868      return SslRead(ssl, real_size, buf);
   869    }
   870    memcpy(buf, &new_buf[4], real_size);
   871    return real_size;
   872  }
   873  
   874  int SslMessageWrite(SSL* ssl, int size, byte* buf) {
   875    // write 32 bit size and buffer
   876    int big_endian_size = __builtin_bswap32(size);
   877    byte new_buf[4096];
   878    memcpy(new_buf, (byte*)&big_endian_size, sizeof(int));
   879    memcpy(&new_buf[sizeof(int)], buf, size);
   880    return SslWrite(ssl, size + sizeof(int), new_buf) - sizeof(int);
   881  }
   882  
   883  int SslRead(SSL* ssl, int size, byte* buf) {
   884    return SSL_read(ssl, buf, size);
   885  }
   886  
   887  int SslWrite(SSL* ssl, int size, byte* buf) {
   888    return SSL_write(ssl, buf, size);
   889  }
   890  
   891  // TODO: consider using std::to_string
   892  int NumHexInBytes(int size, byte* in) { return 2 * size; }
   893  
   894  int NumBytesInHex(char* in) {
   895    if (in == nullptr)
   896      return -1;
   897    int len = strlen(in);
   898    return ((len + 1) / 2);
   899  }
   900  
   901  char ValueToHex(byte x) {
   902    if (x >= 0 && x <= 9) {
   903      return x + '0';
   904    } else if (x >= 10 && x <= 15) {
   905      return x - 10 + 'a';
   906    } else {
   907      return ' ';
   908    }
   909  }
   910  
   911  byte HexToValue(char x) {
   912    if (x >= '0' && x <= '9') {
   913      return x - '0';
   914    } else if (x >= 'a' && x <= 'f') {
   915      return x + 10 - 'a';
   916    } else {
   917      return 0;
   918    }
   919  }
   920  
   921  string* ByteToHexLeftToRight(int size, byte* in) {
   922    if (in == nullptr)
   923      return nullptr;
   924    int n = NumHexInBytes(size, in);
   925    string* out = new string(n, 0);
   926    char* str = (char*)out->c_str();
   927    byte a, b;
   928  
   929    while (size > 0) {
   930      a = (*in) >> 4;
   931      b = (*in) & 0xf;
   932      in++;
   933      *(str++) = ValueToHex(a);
   934      *(str++) = ValueToHex(b);
   935      size--;
   936    }
   937    return out;
   938  }
   939  
   940  int HexToByteLeftToRight(char* in, int size, byte* out) {
   941    if (in == nullptr)
   942      return -1;
   943    int n = NumBytesInHex(in);
   944    int m = strlen(in);
   945    byte a, b;
   946  
   947    if (n > size) {
   948      return -1;
   949    }
   950    while (m > 0) {
   951      a = HexToValue(*(in++));
   952      b = HexToValue(*(in++));
   953      *(out++) = (a << 4) | b;
   954      m -= 2;
   955    }
   956    return n;
   957  }
   958  
   959  string* ByteToHexRightToLeft(int size, byte* in) {
   960    if (in == nullptr)
   961      return nullptr;
   962    int n = NumHexInBytes(size, in);
   963    string* out = new string(n, 0);
   964    char* str = (char*)out->c_str();
   965    byte a, b;
   966  
   967    in += size - 1;
   968    while (size > 0) {
   969      a = (*in) >> 4;
   970      b = (*in) & 0xf;
   971      in--;
   972      *(str++) = ValueToHex(a);
   973      *(str++) = ValueToHex(b);
   974      size--;
   975    }
   976    return out;
   977  }
   978  
   979  int HexToByteRightToLeft(char* in, int size, byte* out) {
   980    if (in == nullptr) {
   981      return -1;
   982    }
   983    int n = NumBytesInHex(in);
   984    int m = strlen(in);
   985    byte a, b;
   986  
   987    out += n - 1;
   988    if (m < 0) {
   989      return -1;
   990    }
   991    while (m > 0) {
   992      a = HexToValue(*(in++));
   993      b = HexToValue(*(in++));
   994      *(out--) = (a << 4) | b;
   995      m -= 2;
   996    }
   997    return n;
   998  }