github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/simpleexample/SimpleClientCpp/junkyard/helpers.cc.old (about) 1 #include <stdio.h> 2 #include <stdlib.h> 3 #include <sys/types.h> 4 #include <sys/stat.h> 5 #include <fcntl.h> 6 #include <unistd.h> 7 #include <string.h> 8 #include <errno.h> 9 10 #include <netinet/in.h> 11 #include <sys/socket.h> 12 #include <arpa/inet.h> 13 14 #include <helpers.h> 15 16 #include <openssl/rsa.h> 17 #include <openssl/x509.h> 18 #include <openssl/ssl.h> 19 #include <openssl/evp.h> 20 #include <openssl/asn1.h> 21 #include <openssl/err.h> 22 #include <openssl/aes.h> 23 #include <openssl/hmac.h> 24 #include <openssl/rand.h> 25 26 #include <string> 27 #include <thread> 28 29 #include <messages.pb.h> 30 31 using std::string; 32 using std::unique_ptr; 33 using std::thread; 34 using std::vector; 35 36 // 37 // Copyright 2015 Google Corporation, All Rights Reserved. 38 // 39 // Licensed under the Apache License, Version 2.0 (the "License"); 40 // you may not use this file except in compliance with the License. 41 // You may obtain a copy of the License at 42 // http://www.apache.org/licenses/LICENSE-2.0 43 // or in the the file LICENSE-2.0.txt in the top level sourcedirectory 44 // Unless required by applicable law or agreed to in writing, software 45 // distributed under the License is distributed on an "AS IS" BASIS, 46 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 47 // See the License for the specific language governing permissions and 48 // limitations under the License 49 // 50 // Portions of this code were derived TPM2.0-TSS published 51 // by Intel under the license set forth in intel_license.txt 52 // and downloaded on or about August 6, 2015. 53 // File: helpers.cc 54 55 56 void PrintBytes(int n, byte* in) { 57 for (int i = 0; i < n; i++) printf("%02x", in[i]); 58 } 59 60 bool ReadFile(string& file_name, string* out) { 61 struct stat file_info; 62 int k = stat(file_name.c_str(), &file_info); 63 if (k != 0) 64 return false; 65 unique_ptr<byte> buf(new byte[file_info.st_size]); 66 int fd = open(file_name.c_str(), O_RDONLY); 67 if (fd < 0) 68 return false; 69 size_t n = read(fd, buf.get(), file_info.st_size); 70 out->assign((const char*)buf.get(), n); 71 close(fd); 72 return true; 73 } 74 75 bool WriteFile(string& file_name, string& in) { 76 int fd = creat(file_name.c_str(), S_IRWXU | S_IRWXG); 77 if (fd < 0) 78 return false; 79 int n = write(fd, (byte*)in.data(), in.size()); 80 close(fd); 81 return n > 0; 82 return true; 83 } 84 85 bool SerializePrivateKey(string& key_type, EVP_PKEY* key, string* out_buf) { 86 simpleexample_messages::PrivateKeyMessage msg; 87 88 if (key_type == "RSA") { 89 RSA* rsa_key = EVP_PKEY_get1_RSA(key); 90 msg.set_allocated_rsa_key(new simpleexample_messages::RsaPrivateKeyMessage()); 91 string* m_str = BN_to_bin(*rsa_key->n); 92 msg.mutable_rsa_key()->set_m(*m_str); 93 string* e_str = BN_to_bin(*rsa_key->e); 94 msg.mutable_rsa_key()->set_e(*e_str); 95 string* d_str = BN_to_bin(*rsa_key->d); 96 msg.mutable_rsa_key()->set_d(*d_str); 97 msg.set_key_type("RSA"); 98 } else if (key_type == "ECC") { 99 EC_KEY* ec_key = EVP_PKEY_get1_EC_KEY(key); 100 msg.set_allocated_ec_key(new simpleexample_messages::EcPrivateKeyMessage()); 101 byte out[4096]; 102 byte* ptr = out; 103 int n = i2d_ECPrivateKey(ec_key, &ptr); 104 if (n <= 0) { 105 printf("Can't i2d ECC private key\n"); 106 return false; 107 } 108 msg.mutable_ec_key()->set_der_blob((void*)out, (size_t)n); 109 msg.set_key_type("ECC"); 110 } else { 111 printf("SerializePrivateKey: Unknown key type\n"); 112 return false; 113 } 114 115 if (!msg.SerializeToString(out_buf)) { 116 return false; 117 } 118 return true; 119 } 120 121 bool DeserializePrivateKey(string& in_buf, string* key_type, EVP_PKEY** key) { 122 simpleexample_messages::PrivateKeyMessage msg; 123 124 if (!msg.ParseFromString(in_buf)) { 125 return false; 126 } 127 key_type->assign(msg.key_type()); 128 if (msg.key_type() == "RSA") { 129 if (!msg.has_rsa_key()) { 130 return false; 131 } 132 RSA* rsa_key = RSA_new(); 133 if (msg.rsa_key().has_m()) 134 rsa_key->n = bin_to_BN(msg.rsa_key().m().size(), (byte*)msg.rsa_key().m().data()); 135 if (msg.rsa_key().has_e()) 136 rsa_key->e = bin_to_BN(msg.rsa_key().e().size(), (byte*)msg.rsa_key().e().data()); 137 if (msg.rsa_key().has_d()) 138 rsa_key->d = bin_to_BN(msg.rsa_key().d().size(), (byte*)msg.rsa_key().d().data()); 139 EVP_PKEY* pKey = new EVP_PKEY(); 140 EVP_PKEY_assign_RSA(pKey, rsa_key); 141 *key = pKey; 142 } else if (msg.key_type() == "ECC") { 143 if (!msg.has_ec_key()) { 144 return false; 145 } 146 const byte* ptr = (byte*)msg.ec_key().der_blob().data(); 147 EC_KEY* ec_key = d2i_ECPrivateKey(nullptr, &ptr, 148 msg.ec_key().der_blob().size()); 149 if (ec_key == nullptr) { 150 printf("Can't i2d ECC private key\n"); 151 return false; 152 } 153 EVP_PKEY* pKey = new EVP_PKEY(); 154 EVP_PKEY_assign_EC_KEY(pKey, ec_key); 155 *key = pKey; 156 } else { 157 printf("DeserializePrivateKey: Unknown key type\n"); 158 return false; 159 } 160 return true; 161 } 162 163 // standard buffer size 164 #define MAX_SIZE_PARAMS 4096 165 166 void PrintPrivateRSAKey(RSA& key) { 167 if (key.n != nullptr) { 168 printf("\nModulus: \n"); 169 BN_print_fp(stdout, key.n); 170 printf("\n"); 171 } 172 if (key.e != nullptr) { 173 printf("\ne: \n"); 174 BN_print_fp(stdout, key.e); 175 printf("\n"); 176 } 177 if (key.d != nullptr) { 178 printf("\nd: \n"); 179 BN_print_fp(stdout, key.d); 180 printf("\n"); 181 } 182 if (key.p != nullptr) { 183 printf("\np: \n"); 184 BN_print_fp(stdout, key.p); 185 printf("\n"); 186 } 187 if (key.q != nullptr) { 188 printf("\nq: \n"); 189 BN_print_fp(stdout, key.q); 190 printf("\n"); 191 } 192 #if 0 193 if (key.dmp1 != nullptr) { 194 printf("\ndmp1: \n"); 195 BN_print_fp(stdout, key.dmp1); 196 printf("\n"); 197 } 198 if (key.dmq1 != nullptr) { 199 printf("\ndmq1: \n"); 200 BN_print_fp(stdout, key.dmq1); 201 printf("\n"); 202 } 203 if (key.iqmp != nullptr) { 204 printf("\niqmp: \n"); 205 BN_print_fp(stdout, key.iqmp); 206 printf("\n"); 207 } 208 #endif 209 } 210 211 BIGNUM* bin_to_BN(int len, byte* buf) { 212 BIGNUM* bn = BN_bin2bn(buf, len, nullptr); 213 return bn; 214 } 215 216 string* BN_to_bin(BIGNUM& n) { 217 byte buf[MAX_SIZE_PARAMS]; 218 219 int len = BN_bn2bin(&n, buf); 220 return new string((const char*)buf, len); 221 } 222 223 bool BN_to_string(BIGNUM& n, string* out) { 224 byte buf[MAX_SIZE_PARAMS]; 225 226 int len = BN_bn2bin(&n, buf); 227 out->assign((const char*)buf, len); 228 return true; 229 } 230 231 EVP_PKEY* GenerateKey(string& keyType, int keySize) { 232 EVP_PKEY* pKey = EVP_PKEY_new(); 233 if (pKey == nullptr) { 234 return nullptr; 235 } 236 if (keyType == "RSA") { 237 RSA* rsa_key = RSA_generate_key(keySize, 0x010001ULL, nullptr, nullptr); 238 if (rsa_key == nullptr) { 239 printf("GenerateKey: couldn't generate RSA key.\n"); 240 return nullptr; 241 } 242 EVP_PKEY_assign_RSA(pKey, rsa_key); 243 pKey->type = EVP_PKEY_RSA; 244 } else if (keyType == "ECC") { 245 EC_KEY* ec_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); 246 EC_KEY_set_asn1_flag(ec_key, OPENSSL_EC_NAMED_CURVE); 247 if (ec_key == nullptr) { 248 printf("GenerateKey: couldn't generate ECC program key (1).\n"); 249 return nullptr; 250 } 251 if (1 != EC_KEY_generate_key(ec_key)) { 252 printf("GenerateKey: couldn't generate ECC program key(2).\n"); 253 return nullptr; 254 } 255 EVP_PKEY_assign_EC_KEY(pKey, ec_key); 256 pKey->type = EVP_PKEY_EC; 257 } else { 258 printf("GenerateKey: unsupported key type.\n"); 259 return nullptr; 260 } 261 return pKey; 262 } 263 264 class extEntry { 265 public: 266 char* key_; 267 char* value_; 268 269 extEntry(const char* k, const char* v); 270 extEntry(); 271 char* getKey(); 272 char* getValue(); 273 }; 274 275 extEntry::extEntry(const char* k, const char* v) { 276 key_ = (char*)strdup(k); 277 value_ = (char*)strdup(v); 278 } 279 280 extEntry::extEntry() { 281 key_ = nullptr; 282 value_ = nullptr; 283 } 284 285 char* extEntry::getKey() { 286 return key_; 287 } 288 289 char* extEntry::getValue() { 290 return value_; 291 } 292 293 bool addExtensionsToCert(int num_entry, extEntry** entries, X509* cert) { 294 // add extensions 295 X509V3_CTX ctx; 296 X509V3_set_ctx_nodb(&ctx); 297 X509V3_set_ctx(&ctx, cert, cert, NULL, NULL, 0); 298 for (int i = 0; i < num_entry; i++) { 299 if (entries[i]->getValue() == nullptr || strlen(entries[i]->getValue()) ==0) 300 continue; 301 int nid = OBJ_txt2nid(entries[i]->getKey()); 302 X509_EXTENSION* ext = X509V3_EXT_conf_nid(NULL, &ctx, nid, entries[i]->getValue()); 303 if (ext == 0) { 304 printf("Bad ext_conf %d\n", i); 305 printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error())); 306 return false; 307 } 308 if (!X509_add_ext(cert, ext, -1)) { 309 printf("Bad add ext %d\n", i); 310 printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error())); 311 return false; 312 } 313 X509_EXTENSION_free(ext); 314 } 315 return true; 316 } 317 318 bool GenerateX509CertificateRequest(string& key_type, string& common_name, 319 EVP_PKEY* subjectKey, bool sign_request, X509_REQ* req) { 320 X509_NAME* subject = X509_NAME_new(); 321 X509_REQ_set_version(req, 2L); 322 if (subject == nullptr) { 323 printf("Can't alloc x509 name\n"); 324 return false; 325 } 326 if (common_name.size() > 0) { 327 int nid = OBJ_txt2nid("CN"); 328 X509_NAME_ENTRY* ent = X509_NAME_ENTRY_create_by_NID(nullptr, nid, 329 MBSTRING_ASC, (byte*)common_name.c_str(), -1); 330 if (ent == nullptr) { 331 printf("X509_NAME_ENTRY return is null, nid: %d\n", nid); 332 return false; 333 } 334 if (X509_NAME_add_entry(subject, ent, -1, 0) != 1) { 335 printf("Can't add name ent\n"); 336 return false; 337 } 338 } 339 // TODO: do the foregoing for the other name components 340 if (X509_REQ_set_subject_name(req, subject) != 1) { 341 printf("Can't set x509 subject\n"); 342 return false; 343 } 344 345 // fill key parameters in request 346 if (sign_request) { 347 const EVP_MD* digest = EVP_sha256(); 348 if (!X509_REQ_sign(req, subjectKey, digest)) { 349 printf("Sign request fails\n"); 350 printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error())); 351 } 352 } 353 if (X509_REQ_set_pubkey(req, subjectKey) ==0) { 354 printf("X509_REQ_set_pubkey failed\n"); 355 } 356 return true; 357 } 358 359 bool SignX509Certificate(EVP_PKEY* signingKey, bool f_isCa, 360 bool f_canSign, string& signing_issuer, 361 string& keyUsage, string& extendedKeyUsage, 362 int64 duration, EVP_PKEY* signedKey, 363 X509_REQ* req, bool verify_req_sig, X509* cert) { 364 if (signedKey == nullptr) 365 signedKey = X509_REQ_get_pubkey(req); 366 if (signedKey == nullptr) { 367 printf("Can't get pubkey\n"); 368 return false; 369 } 370 371 if (verify_req_sig) { 372 if (X509_REQ_verify(req, signedKey) != 1) { 373 printf("Req does not verify\n"); 374 return false; 375 } 376 } 377 378 uint64_t serial = 1; 379 const EVP_MD* digest = EVP_sha256(); 380 X509_NAME* name; 381 X509_set_version(cert, 2L); 382 ASN1_INTEGER_set(X509_get_serialNumber(cert), serial++); 383 384 name = X509_REQ_get_subject_name(req); 385 if (X509_set_subject_name(cert, name) != 1) { 386 printf("Can't set subject name\n"); 387 return false; 388 } 389 if (X509_set_pubkey(cert, signedKey) != 1) { 390 printf("Can't set pubkey\n"); 391 return false; 392 } 393 if (!X509_gmtime_adj(X509_get_notBefore(cert), 0)) { 394 printf("Can't adj notBefore\n"); 395 return false; 396 } 397 if (!X509_gmtime_adj(X509_get_notAfter(cert), duration)) { 398 printf("Can't adj notAfter\n"); 399 return false; 400 } 401 X509_NAME* issuer = X509_NAME_new(); 402 int nid = OBJ_txt2nid("CN"); 403 X509_NAME_ENTRY* ent = X509_NAME_ENTRY_create_by_NID(nullptr, nid, 404 MBSTRING_ASC, (byte*)signing_issuer.c_str(), -1); 405 if (X509_NAME_add_entry(issuer, ent, -1, 0) != 1) { 406 printf("Can't add issuer name ent: %s, %ld\n", 407 signing_issuer.c_str(), (long unsigned)ent); 408 printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error())); 409 return false; 410 } 411 if (X509_set_issuer_name(cert, issuer) != 1) { 412 printf("Can't set issuer name\n"); 413 return false; 414 } 415 416 // Add extensions which should be 417 // X509v3 extensions: 418 // X509v3 Key Usage: critical 419 // Key Agreement, Certificate Sign 420 // X509v3 Extended Key Usage: 421 // TLS Web Server Authentication, TLS Web Client Authentication 422 // X509v3 Basic Constraints: critical 423 // CA:TRUE 424 extEntry* entries[128]; 425 int n = 0; 426 if (f_isCa) { 427 entries[n++] = new extEntry("basicConstraints", "critical,CA:TRUE"); 428 } 429 entries[n++] = new extEntry("keyUsage", keyUsage.c_str()); 430 entries[n++] = new extEntry("extendedKeyUsage", extendedKeyUsage.c_str()); 431 if (!addExtensionsToCert(n, entries, cert)) { 432 printf("Can't add extensions\n"); 433 return false; 434 } 435 436 if (!X509_sign(cert, signingKey, digest)) { 437 printf("Bad PKEY type\n"); 438 return false; 439 } 440 return true; 441 } 442 443 void XorBlocks(int size, byte* in1, byte* in2, byte* out) { 444 int i; 445 446 for (i = 0; i < size; i++) 447 out[i] = in1[i] ^ in2[i]; 448 } 449 450 bool AesCtrCrypt(int key_size_bits, byte* key, int size, 451 byte* in, byte* out) { 452 AES_KEY ectx; 453 uint64_t ctr[2] = {0ULL, 0ULL}; 454 byte block[32]; 455 456 if (key_size_bits != 128) { 457 return false; 458 } 459 460 AES_set_encrypt_key(key, 128, &ectx); 461 while (size > 0) { 462 ctr[1]++; 463 AES_encrypt((byte*)ctr, block, &ectx); 464 XorBlocks(16, block, in, out); 465 in += 16; 466 out += 16; 467 size -= 16; 468 } 469 return true; 470 } 471 472 #define AESBLKSIZE 16 473 474 bool AesCFBEncrypt(byte* key, int in_size, byte* in, int iv_size, byte* iv, 475 int* out_size, byte* out) { 476 byte last_cipher[32]; 477 byte cipher_block[32]; 478 int size = 0; 479 int current_size; 480 481 AES_KEY ectx; 482 AES_set_encrypt_key(key, 128, &ectx); 483 484 // Don't write iv, called already knows it 485 if(iv_size != AESBLKSIZE) return false; 486 memcpy(last_cipher, iv, AESBLKSIZE); 487 488 while (in_size > 0) { 489 if ((size + AESBLKSIZE) > *out_size) return false; 490 // C[0] = IV, C[i] = P[i] ^ E(K, C[i-1]) 491 AES_encrypt(last_cipher, cipher_block, &ectx); 492 if (in_size >= AESBLKSIZE) 493 current_size = AESBLKSIZE; 494 else 495 current_size = in_size; 496 XorBlocks(AESBLKSIZE, cipher_block, in, last_cipher); 497 memcpy(out, last_cipher, current_size); 498 out += current_size; 499 size += current_size; 500 in += current_size; 501 in_size -= current_size; 502 } 503 *out_size = size; 504 return true; 505 } 506 507 bool AesCFBDecrypt(byte* key, int in_size, byte* in, int iv_size, byte* iv, 508 int* out_size, byte* out) { 509 byte last_cipher[32]; 510 byte cipher_block[32]; 511 int size = 0; 512 int current_size; 513 514 AES_KEY ectx; 515 AES_set_encrypt_key(key, 128, &ectx); 516 517 // Don't write iv, called already knows it 518 if(iv_size != AESBLKSIZE) return false; 519 memcpy(last_cipher, iv, AESBLKSIZE); 520 521 while (in_size > 0) { 522 if ((size + AESBLKSIZE) > *out_size) return false; 523 // P[i] = C[i] ^ E(K, C[i-1]) 524 AES_encrypt(last_cipher, cipher_block, &ectx); 525 if (in_size >= AESBLKSIZE) 526 current_size = AESBLKSIZE; 527 else 528 current_size = in_size; 529 XorBlocks(current_size, cipher_block, in, out); 530 memcpy(last_cipher, in, current_size); 531 out += current_size; 532 size += current_size; 533 in += current_size; 534 in_size -= current_size; 535 } 536 *out_size = size; 537 return true; 538 } 539 540 bool VerifyX509CertificateChain(X509* cacert, X509* cert) { 541 X509_STORE_CTX *store_ctx = X509_STORE_CTX_new(); 542 X509_STORE *store = X509_STORE_new(); 543 X509_STORE_add_cert(store, cacert); 544 // int X509_STORE_CTX_init(X509_STORE_CTX *ctx, X509_STORE *store, X509 *x509, STACK_OF(X509) *chain); 545 X509_STORE_CTX_init(store_ctx, store, cacert, nullptr); 546 int ret = X509_verify_cert(store_ctx); 547 if (ret <= 0) 548 printf("Error: %s\n", X509_verify_cert_error_string(store_ctx->error)); 549 return ret; 550 } 551 552 SslChannel::SslChannel() { 553 fd_ = -1; 554 ssl_ctx_ = nullptr; 555 ssl_ = nullptr; 556 peer_cert_ = nullptr; 557 store_ = nullptr; 558 private_key_ = nullptr; 559 } 560 561 SslChannel::~SslChannel() { 562 if (fd_ > 0) { 563 close(fd_); 564 } 565 fd_ = -1; 566 // clear private_key_; 567 #if 0 568 // Doesn't need to be freed, context free takes care of it. 569 if (ssl_ != nullptr) { 570 SSL_free(ssl_); 571 } 572 ssl_ = nullptr; 573 #endif 574 if (peer_cert_ != nullptr) { 575 X509_free(peer_cert_); 576 } 577 peer_cert_ = nullptr; 578 if (ssl_ctx_ != nullptr) { 579 SSL_CTX_free(ssl_ctx_); 580 } 581 ssl_ctx_ = nullptr; 582 if (store_ != nullptr) { 583 X509_STORE_free(store_); 584 } 585 store_ = nullptr; 586 } 587 588 int SslChannel::CreateServerSocket(string& address, string& port) { 589 int sockfd = socket(AF_INET, SOCK_STREAM, 0); 590 struct sockaddr_in dest_addr; 591 uint16_t s_port = atoi(port.c_str()); 592 memset((byte*)&dest_addr, 0, sizeof(dest_addr)); 593 594 dest_addr.sin_family = AF_INET; 595 dest_addr.sin_port = htons(s_port); 596 dest_addr.sin_addr.s_addr = INADDR_ANY; 597 inet_aton(address.c_str(), &dest_addr.sin_addr); 598 599 if (bind(sockfd, (struct sockaddr*)&dest_addr, sizeof(dest_addr)) < 0) { 600 printf("Unable to bind\n"); 601 return -1; 602 } 603 604 if (listen(sockfd, 1) < 0) { 605 printf("Unable to listen\n"); 606 return -1; 607 } 608 return sockfd; 609 } 610 611 612 int SslChannel::CreateClientSocket(string& addr, string& port) { 613 int sockfd = socket(AF_INET, SOCK_STREAM, 0); 614 struct sockaddr_in dest_addr; 615 uint16_t s_port = atoi(port.c_str()); 616 memset((byte*)&dest_addr, 0, sizeof(dest_addr)); 617 618 dest_addr.sin_family = AF_INET; 619 dest_addr.sin_port = htons(s_port); 620 dest_addr.sin_addr.s_addr = INADDR_ANY; 621 inet_aton(addr.c_str(), &dest_addr.sin_addr); 622 623 if (connect(sockfd, (struct sockaddr *) &dest_addr, 624 sizeof(struct sockaddr)) == -1) { 625 printf("Error: Cannot connect to host\n"); 626 return -1; 627 } 628 return sockfd; 629 } 630 631 bool SslChannel::InitServerSslChannel(string& network, string& address, 632 string& port, X509* policyCert, X509* programCert, 633 string& keyType, EVP_PKEY* privateKey, int verify) { 634 SSL_library_init(); 635 OpenSSL_add_all_algorithms(); 636 ERR_load_crypto_strings(); 637 638 // I'm a server. 639 server_role_ = true; 640 641 if (privateKey == nullptr) { 642 printf("Private key is null.\n"); 643 return false; 644 } 645 646 // Create socket and contexts. 647 fd_ = CreateServerSocket(address, port); 648 if(fd_ <= 0) { 649 printf("CreateServerSocket failed.\n"); 650 return false; 651 } 652 653 ssl_ctx_ = SSL_CTX_new(TLSv1_2_server_method()); 654 if (ssl_ctx_ == nullptr) { 655 printf("SSL_CTX_new failed(server).\n"); 656 ERR_print_errors_fp(stderr); 657 return false; 658 } 659 660 SSL_CTX_clear_extra_chain_certs(ssl_ctx_); 661 private_key_ = privateKey; 662 SSL_CTX_use_certificate(ssl_ctx_, programCert); 663 if (EVP_PKEY_id(private_key_) == EVP_PKEY_EC) { 664 if (!SSL_CTX_set_tmp_ecdh(ssl_ctx_, EVP_PKEY_get1_EC_KEY(private_key_))) { 665 printf("SSL_CTX_set_tmp_ecdh failed.\n"); 666 return false; 667 } 668 SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_ECDH_USE); 669 } 670 if(SSL_CTX_use_PrivateKey(ssl_ctx_, private_key_) <= 0) { 671 printf("SSL_CTX_use_PrivateKey failed.\n"); 672 ERR_print_errors_fp(stderr); 673 return false; 674 } 675 676 // Setup verification stuff. 677 switch(verify) { 678 case SSL_NO_SERVER_VERIFY_NO_CLIENT_AUTH: 679 case SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY: 680 case SSL_SERVER_VERIFY_NO_CLIENT_VERIFY: 681 SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr); 682 SSL_CTX_set_verify_depth(ssl_ctx_, 3); 683 break; 684 case SSL_SERVER_VERIFY_CLIENT_VERIFY: 685 SSL_CTX_add_extra_chain_cert(ssl_ctx_, programCert); 686 SSL_CTX_add_extra_chain_cert(ssl_ctx_, policyCert); 687 store_ = X509_STORE_new(); 688 if (store_ == nullptr) { 689 printf("X509_STORE_new failed.\n"); 690 return false; 691 } 692 X509_STORE_add_cert(store_, policyCert); 693 SSL_CTX_set_cert_store(ssl_ctx_, store_); 694 SSL_CTX_set_verify(ssl_ctx_, 695 SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); 696 SSL_CTX_set_verify_depth(ssl_ctx_, 3); 697 break; 698 default: 699 printf("Unknown verification mode.\n"); 700 return false; 701 } 702 return true; 703 } 704 705 bool SslChannel::InitClientSslChannel(string& network, string& address, 706 string& port, X509* policyCert, X509* programCert, 707 string& keyType, EVP_PKEY* privateKey, int verify) { 708 SSL_library_init(); 709 OpenSSL_add_all_algorithms(); 710 ERR_load_crypto_strings(); 711 712 // I'm a client. 713 server_role_ = false; 714 715 // Create socket and contexts. 716 fd_ = CreateClientSocket(address, port); 717 if(fd_ <= 0) { 718 printf("CreateClientSocket failed.\n"); 719 return false; 720 } 721 722 ssl_ctx_ = SSL_CTX_new(TLSv1_2_client_method()); 723 if (ssl_ctx_ == nullptr) { 724 printf("SSL_CTX_new failed(client).\n"); 725 return false; 726 } 727 SSL_CTX_clear_extra_chain_certs(ssl_ctx_); 728 if (privateKey == nullptr) { 729 printf("Private key is null\n"); 730 return false; 731 } 732 private_key_ = privateKey; 733 734 // Setup verification stuff. 735 switch(verify) { 736 case SSL_NO_SERVER_VERIFY_NO_CLIENT_AUTH: 737 SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr); 738 SSL_CTX_set_verify_depth(ssl_ctx_, 3); 739 break; 740 case SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY: 741 case SSL_SERVER_VERIFY_NO_CLIENT_VERIFY: 742 case SSL_SERVER_VERIFY_CLIENT_VERIFY: 743 if (privateKey == nullptr) { 744 printf("Private key is null\n"); 745 return false; 746 } 747 if (EVP_PKEY_id(private_key_) == EVP_PKEY_EC) { 748 if (!SSL_CTX_set_tmp_ecdh(ssl_ctx_, 749 EVP_PKEY_get1_EC_KEY(private_key_))) { 750 printf("SSL_CTX_set_tmp_ecdh failed.\n"); 751 return false; 752 } 753 SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_ECDH_USE); 754 } 755 if(SSL_CTX_use_PrivateKey(ssl_ctx_, private_key_) <= 0) { 756 printf("SSL_CTX_use_PrivateKey failed.\n"); 757 ERR_print_errors_fp(stderr); 758 return false; 759 } 760 SSL_CTX_use_certificate(ssl_ctx_, programCert); 761 SSL_CTX_add_extra_chain_cert(ssl_ctx_, programCert); 762 SSL_CTX_add_extra_chain_cert(ssl_ctx_, policyCert); 763 store_ = X509_STORE_new(); 764 if (store_ == nullptr) { 765 printf("X509_STORE_new failed.\n"); 766 return false; 767 } 768 X509_STORE_add_cert(store_, policyCert); 769 SSL_CTX_set_cert_store(ssl_ctx_, store_); 770 SSL_CTX_set_verify(ssl_ctx_, 771 SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); 772 SSL_CTX_set_verify_depth(ssl_ctx_, 3); 773 if (verify == SSL_NO_SERVER_VERIFY_NO_CLIENT_VERIFY) 774 SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_NONE, nullptr); 775 break; 776 default: 777 printf("Unknown verification mode.\n"); 778 return false; 779 } 780 781 ssl_ = SSL_new(ssl_ctx_); 782 if (ssl_ == nullptr) { 783 printf("SSL_new failed(client).\n"); 784 return false; 785 } 786 787 SSL_set_fd(ssl_, fd_); 788 SSL_set_connect_state(ssl_); 789 790 // Connect. 791 if (SSL_connect(ssl_) != 1) { 792 printf("SSL_connect failed.\n"); 793 ERR_print_errors_fp(stderr); 794 return false; 795 } 796 peer_cert_ = SSL_get_peer_certificate(ssl_); 797 return true; 798 } 799 800 bool SslChannel::ServerLoop(void(*server_loop)(SslChannel*, SSL*, int)) { 801 bool fContinue = true; 802 printf("ServerLoop\n"); 803 804 while(fContinue) { 805 struct sockaddr_in addr; 806 uint len = sizeof(addr); 807 memset((byte*)&addr, 0, len); 808 809 int client = accept(fd_, (struct sockaddr*)&addr, &len); 810 if (client < 0) { 811 printf("Unable to accept\n"); 812 printf("ERR: %s\n", ERR_lib_error_string(ERR_get_error())); 813 continue; 814 } 815 816 SSL* ssl = SSL_new(ssl_ctx_); 817 if (private_key_ == nullptr) { 818 printf("private_key_ is null.\n"); 819 return false; 820 } 821 SSL_set_fd(ssl, client); 822 SSL_set_accept_state(ssl); 823 if (SSL_accept(ssl) <= 0) { 824 printf("Unable to ssl_accept\n"); 825 ERR_print_errors_fp(stderr); 826 continue; 827 } 828 server_loop(this, ssl, client); 829 // thread t(server_loop, this, ssl, client); 830 } 831 return true; 832 } 833 834 void SslChannel::Close() { 835 if (fd_ > 0) { 836 close(fd_); 837 } 838 fd_ = -1; 839 if (ssl_ != nullptr) { 840 SSL_free(ssl_); 841 } 842 ssl_ = nullptr; 843 if (peer_cert_ != nullptr) { 844 X509_free(peer_cert_); 845 } 846 peer_cert_ = nullptr; 847 if (ssl_ctx_ != nullptr) { 848 SSL_CTX_free(ssl_ctx_); 849 } 850 ssl_ctx_ = nullptr; 851 if (store_ != nullptr) { 852 X509_STORE_free(store_); 853 } 854 store_ = nullptr; 855 } 856 857 X509* SslChannel::GetPeerCert() { 858 return peer_cert_; 859 } 860 861 int SslMessageRead(SSL* ssl, int size, byte* buf) { 862 byte new_buf[8192]; 863 int tmp_size = SslRead(ssl, size, new_buf); 864 if (tmp_size <= 0) 865 return tmp_size; 866 int real_size = __builtin_bswap32(*((int*)new_buf)); 867 if (tmp_size == sizeof(int)) { 868 return SslRead(ssl, real_size, buf); 869 } 870 memcpy(buf, &new_buf[4], real_size); 871 return real_size; 872 } 873 874 int SslMessageWrite(SSL* ssl, int size, byte* buf) { 875 // write 32 bit size and buffer 876 int big_endian_size = __builtin_bswap32(size); 877 byte new_buf[4096]; 878 memcpy(new_buf, (byte*)&big_endian_size, sizeof(int)); 879 memcpy(&new_buf[sizeof(int)], buf, size); 880 return SslWrite(ssl, size + sizeof(int), new_buf) - sizeof(int); 881 } 882 883 int SslRead(SSL* ssl, int size, byte* buf) { 884 return SSL_read(ssl, buf, size); 885 } 886 887 int SslWrite(SSL* ssl, int size, byte* buf) { 888 return SSL_write(ssl, buf, size); 889 } 890 891 // TODO: consider using std::to_string 892 int NumHexInBytes(int size, byte* in) { return 2 * size; } 893 894 int NumBytesInHex(char* in) { 895 if (in == nullptr) 896 return -1; 897 int len = strlen(in); 898 return ((len + 1) / 2); 899 } 900 901 char ValueToHex(byte x) { 902 if (x >= 0 && x <= 9) { 903 return x + '0'; 904 } else if (x >= 10 && x <= 15) { 905 return x - 10 + 'a'; 906 } else { 907 return ' '; 908 } 909 } 910 911 byte HexToValue(char x) { 912 if (x >= '0' && x <= '9') { 913 return x - '0'; 914 } else if (x >= 'a' && x <= 'f') { 915 return x + 10 - 'a'; 916 } else { 917 return 0; 918 } 919 } 920 921 string* ByteToHexLeftToRight(int size, byte* in) { 922 if (in == nullptr) 923 return nullptr; 924 int n = NumHexInBytes(size, in); 925 string* out = new string(n, 0); 926 char* str = (char*)out->c_str(); 927 byte a, b; 928 929 while (size > 0) { 930 a = (*in) >> 4; 931 b = (*in) & 0xf; 932 in++; 933 *(str++) = ValueToHex(a); 934 *(str++) = ValueToHex(b); 935 size--; 936 } 937 return out; 938 } 939 940 int HexToByteLeftToRight(char* in, int size, byte* out) { 941 if (in == nullptr) 942 return -1; 943 int n = NumBytesInHex(in); 944 int m = strlen(in); 945 byte a, b; 946 947 if (n > size) { 948 return -1; 949 } 950 while (m > 0) { 951 a = HexToValue(*(in++)); 952 b = HexToValue(*(in++)); 953 *(out++) = (a << 4) | b; 954 m -= 2; 955 } 956 return n; 957 } 958 959 string* ByteToHexRightToLeft(int size, byte* in) { 960 if (in == nullptr) 961 return nullptr; 962 int n = NumHexInBytes(size, in); 963 string* out = new string(n, 0); 964 char* str = (char*)out->c_str(); 965 byte a, b; 966 967 in += size - 1; 968 while (size > 0) { 969 a = (*in) >> 4; 970 b = (*in) & 0xf; 971 in--; 972 *(str++) = ValueToHex(a); 973 *(str++) = ValueToHex(b); 974 size--; 975 } 976 return out; 977 } 978 979 int HexToByteRightToLeft(char* in, int size, byte* out) { 980 if (in == nullptr) { 981 return -1; 982 } 983 int n = NumBytesInHex(in); 984 int m = strlen(in); 985 byte a, b; 986 987 out += n - 1; 988 if (m < 0) { 989 return -1; 990 } 991 while (m > 0) { 992 a = HexToValue(*(in++)); 993 b = HexToValue(*(in++)); 994 *(out--) = (a << 4) | b; 995 m -= 2; 996 } 997 return n; 998 }