proxygen
AsyncSSLSocketTest.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2011-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
17 
18 #include <folly/SocketAddress.h>
19 #include <folly/String.h>
20 #include <folly/io/Cursor.h>
30 #include <folly/ssl/Init.h>
31 
33 
34 #include <fcntl.h>
35 #include <signal.h>
36 #include <sys/types.h>
37 
38 #include <fstream>
39 #include <iostream>
40 #include <list>
41 #include <set>
42 #include <thread>
43 
44 #ifdef __linux__
45 #include <dlfcn.h>
46 #endif
47 
48 #if FOLLY_OPENSSL_IS_110
49 #include <openssl/async.h>
50 #endif
51 
52 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
53 #include <sys/utsname.h>
54 #endif
55 
56 using std::cerr;
57 using std::endl;
58 using std::list;
59 using std::min;
60 using std::string;
61 using std::vector;
62 
63 using namespace testing;
64 
65 #if defined __linux__
66 namespace {
67 
68 // to store libc's original setsockopt()
69 typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t);
70 setsockopt_ptr real_setsockopt_ = nullptr;
71 
72 // global struct to initialize before main runs. we can init within a test,
73 // or in main, but this method seems to be least intrsive and universal
74 struct GlobalStatic {
75  GlobalStatic() {
76  real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt");
77  }
78  void reset() noexcept {
79  ttlsDisabledSet.clear();
80  }
81  // for each fd, tracks whether TTLS is disabled or not
82  std::set<int /* fd */> ttlsDisabledSet;
83 };
84 
85 // the constructor will be called before main() which is all we care about
86 GlobalStatic globalStatic;
87 
88 } // namespace
89 
90 // we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt
91 // this name has to be global
92 int setsockopt(
93  int sockfd,
94  int level,
95  int optname,
96  const void* optval,
97  socklen_t optlen) {
98  if (optname == SO_NO_TRANSPARENT_TLS) {
99  globalStatic.ttlsDisabledSet.insert(sockfd);
100  return 0;
101  }
102  return real_setsockopt_(sockfd, level, optname, optval, optlen);
103 }
104 #endif
105 
106 namespace folly {
107 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
108 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
109 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
110 
111 constexpr size_t SSLClient::kMaxReadBufferSz;
112 constexpr size_t SSLClient::kMaxReadsPerEvent;
113 
114 void getfds(int fds[2]) {
115  if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
116  FAIL() << "failed to create socketpair: " << errnoStr(errno);
117  }
118  for (int idx = 0; idx < 2; ++idx) {
119  int flags = fcntl(fds[idx], F_GETFL, 0);
120  if (flags == -1) {
121  FAIL() << "failed to get flags for socket " << idx << ": "
122  << errnoStr(errno);
123  }
124  if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
125  FAIL() << "failed to put socket " << idx
126  << " in non-blocking mode: " << errnoStr(errno);
127  }
128  }
129 }
130 
131 void getctx(
132  std::shared_ptr<folly::SSLContext> clientCtx,
133  std::shared_ptr<folly::SSLContext> serverCtx) {
134  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
135 
136  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
137  serverCtx->loadCertificate(kTestCert);
138  serverCtx->loadPrivateKey(kTestKey);
139 }
140 
142  EventBase* eventBase,
143  AsyncSSLSocket::UniquePtr* clientSock,
144  AsyncSSLSocket::UniquePtr* serverSock) {
145  auto clientCtx = std::make_shared<folly::SSLContext>();
146  auto serverCtx = std::make_shared<folly::SSLContext>();
147  int fds[2];
148  getfds(fds);
149  getctx(clientCtx, serverCtx);
150  clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
151  serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
152 
153  // (*clientSock)->setSendTimeout(100);
154  // (*serverSock)->setSendTimeout(100);
155 }
156 
157 // client protocol filters
159  unsigned char** client,
160  unsigned int* client_len,
161  const unsigned char*,
162  unsigned int) {
163  // the protocol string in length prefixed byte string. the
164  // length byte is not included in the length
165  static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
166  *client = p;
167  *client_len = 7;
168  return true;
169 }
170 
172  unsigned char**,
173  unsigned int*,
174  const unsigned char*,
175  unsigned int) {
176  return false;
177 }
178 
179 std::string getFileAsBuf(const char* fileName) {
181  folly::readFile(fileName, buffer);
182  return buffer;
183 }
184 
189 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
190  // Start listening on a local port
191  WriteCallbackBase writeCallback;
192  ReadCallback readCallback(&writeCallback);
193  HandshakeCallback handshakeCallback(&readCallback);
194  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
195  TestSSLServer server(&acceptCallback);
196 
197  // Set up SSL context.
198  std::shared_ptr<SSLContext> sslContext(new SSLContext());
199  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
200  // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
201  // sslContext->authenticate(true, false);
202 
203  // connect
204  auto socket =
205  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
206  socket->open(std::chrono::milliseconds(10000));
207 
208  // write()
209  uint8_t buf[128];
210  memset(buf, 'a', sizeof(buf));
211  socket->write(buf, sizeof(buf));
212 
213  // read()
214  uint8_t readbuf[128];
215  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
216  EXPECT_EQ(bytesRead, 128);
217  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
218 
219  // close()
220  socket->close();
221 
222  cerr << "ConnectWriteReadClose test completed" << endl;
223  EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
224 }
225 
229 TEST(AsyncSSLSocketTest, ReadAfterClose) {
230  // Start listening on a local port
231  WriteCallbackBase writeCallback;
232  ReadEOFCallback readCallback(&writeCallback);
233  HandshakeCallback handshakeCallback(&readCallback);
234  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
235  auto server = std::make_unique<TestSSLServer>(&acceptCallback);
236 
237  // Set up SSL context.
238  auto sslContext = std::make_shared<SSLContext>();
239  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
240 
241  auto socket =
242  std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
243  socket->open();
244 
245  // This should trigger an EOF on the client.
246  auto evb = handshakeCallback.getSocket()->getEventBase();
247  evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
248  std::array<uint8_t, 128> readbuf;
249  auto bytesRead = socket->read(readbuf.data(), readbuf.size());
250  EXPECT_EQ(0, bytesRead);
251 }
252 
256 #if !defined(OPENSSL_IS_BORINGSSL)
257 TEST(AsyncSSLSocketTest, Renegotiate) {
258  EventBase eventBase;
259  auto clientCtx = std::make_shared<SSLContext>();
260  auto dfServerCtx = std::make_shared<SSLContext>();
261  std::array<int, 2> fds;
262  getfds(fds.data());
263  getctx(clientCtx, dfServerCtx);
264 
265  AsyncSSLSocket::UniquePtr clientSock(
266  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
267  AsyncSSLSocket::UniquePtr serverSock(
268  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
269  SSLHandshakeClient client(std::move(clientSock), true, true);
270  RenegotiatingServer server(std::move(serverSock));
271 
272  while (!client.handshakeSuccess_ && !client.handshakeError_) {
273  eventBase.loopOnce();
274  }
275 
277 
278  auto sslSock = std::move(client).moveSocket();
279  sslSock->detachEventBase();
280  // This is nasty, however we don't want to add support for
281  // renegotiation in AsyncSSLSocket.
282  SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
283 
284  auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
285 
286  std::thread t([&]() { eventBase.loopForever(); });
287 
288  // Trigger the renegotiation.
289  std::array<uint8_t, 128> buf;
290  memset(buf.data(), 'a', buf.size());
291  try {
292  socket->write(buf.data(), buf.size());
293  } catch (AsyncSocketException& e) {
294  LOG(INFO) << "client got error " << e.what();
295  }
296  eventBase.terminateLoopSoon();
297  t.join();
298 
299  eventBase.loop();
301 }
302 #endif
303 
307 TEST(AsyncSSLSocketTest, HandshakeError) {
308  // Start listening on a local port
309  WriteCallbackBase writeCallback;
310  WriteErrorCallback readCallback(&writeCallback);
311  HandshakeCallback handshakeCallback(&readCallback);
312  HandshakeErrorCallback acceptCallback(&handshakeCallback);
313  TestSSLServer server(&acceptCallback);
314 
315  // Set up SSL context.
316  std::shared_ptr<SSLContext> sslContext(new SSLContext());
317  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
318 
319  // connect
320  auto socket =
321  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
322  // read()
323  bool ex = false;
324  try {
325  socket->open();
326 
327  uint8_t readbuf[128];
328  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
329  LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
330  } catch (AsyncSocketException&) {
331  ex = true;
332  }
333  EXPECT_TRUE(ex);
334 
335  // close()
336  socket->close();
337  cerr << "HandshakeError test completed" << endl;
338 }
339 
343 TEST(AsyncSSLSocketTest, ReadError) {
344  // Start listening on a local port
345  WriteCallbackBase writeCallback;
346  ReadErrorCallback readCallback(&writeCallback);
347  HandshakeCallback handshakeCallback(&readCallback);
348  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
349  TestSSLServer server(&acceptCallback);
350 
351  // Set up SSL context.
352  std::shared_ptr<SSLContext> sslContext(new SSLContext());
353  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
354 
355  // connect
356  auto socket =
357  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
358  socket->open();
359 
360  // write something to trigger ssl handshake
361  uint8_t buf[128];
362  memset(buf, 'a', sizeof(buf));
363  socket->write(buf, sizeof(buf));
364 
365  socket->close();
366  cerr << "ReadError test completed" << endl;
367 }
368 
372 TEST(AsyncSSLSocketTest, WriteError) {
373  // Start listening on a local port
374  WriteCallbackBase writeCallback;
375  WriteErrorCallback readCallback(&writeCallback);
376  HandshakeCallback handshakeCallback(&readCallback);
377  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
378  TestSSLServer server(&acceptCallback);
379 
380  // Set up SSL context.
381  std::shared_ptr<SSLContext> sslContext(new SSLContext());
382  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
383 
384  // connect
385  auto socket =
386  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
387  socket->open();
388 
389  // write something to trigger ssl handshake
390  uint8_t buf[128];
391  memset(buf, 'a', sizeof(buf));
392  socket->write(buf, sizeof(buf));
393 
394  socket->close();
395  cerr << "WriteError test completed" << endl;
396 }
397 
401 TEST(AsyncSSLSocketTest, SocketWithDelay) {
402  // Start listening on a local port
403  WriteCallbackBase writeCallback;
404  ReadCallback readCallback(&writeCallback);
405  HandshakeCallback handshakeCallback(&readCallback);
406  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
407  TestSSLServer server(&acceptCallback);
408 
409  // Set up SSL context.
410  std::shared_ptr<SSLContext> sslContext(new SSLContext());
411  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
412 
413  // connect
414  auto socket =
415  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
416  socket->open();
417 
418  // write()
419  uint8_t buf[128];
420  memset(buf, 'a', sizeof(buf));
421  socket->write(buf, sizeof(buf));
422 
423  // read()
424  uint8_t readbuf[128];
425  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
426  EXPECT_EQ(bytesRead, 128);
427  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
428 
429  // close()
430  socket->close();
431 
432  cerr << "SocketWithDelay test completed" << endl;
433 }
434 
435 #if FOLLY_OPENSSL_HAS_ALPN
436 class NextProtocolTest : public Test {
437  // For matching protos
438  public:
439  void SetUp() override {
440  getctx(clientCtx, serverCtx);
441  }
442 
443  void connect(bool unset = false) {
444  getfds(fds);
445 
446  if (unset) {
447  // unsetting NPN for any of [client, server] is enough to make NPN not
448  // work
449  clientCtx->unsetNextProtocols();
450  }
451 
452  AsyncSSLSocket::UniquePtr clientSock(
453  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
454  AsyncSSLSocket::UniquePtr serverSock(
455  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
456  client = std::make_unique<AlpnClient>(std::move(clientSock));
457  server = std::make_unique<AlpnServer>(std::move(serverSock));
458 
459  eventBase.loop();
460  }
461 
462  void expectProtocol(const std::string& proto) {
463  expectHandshakeSuccess();
464  EXPECT_NE(client->nextProtoLength, 0);
465  EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
466  EXPECT_EQ(
467  memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
468  0);
469  string selected((const char*)client->nextProto, client->nextProtoLength);
470  EXPECT_EQ(proto, selected);
471  }
472 
473  void expectNoProtocol() {
474  expectHandshakeSuccess();
475  EXPECT_EQ(client->nextProtoLength, 0);
476  EXPECT_EQ(server->nextProtoLength, 0);
477  EXPECT_EQ(client->nextProto, nullptr);
478  EXPECT_EQ(server->nextProto, nullptr);
479  }
480 
481  void expectHandshakeSuccess() {
482  EXPECT_FALSE(client->except.hasValue())
483  << "client handshake error: " << client->except->what();
484  EXPECT_FALSE(server->except.hasValue())
485  << "server handshake error: " << server->except->what();
486  }
487 
488  void expectHandshakeError() {
489  EXPECT_TRUE(client->except.hasValue())
490  << "Expected client handshake error!";
491  EXPECT_TRUE(server->except.hasValue())
492  << "Expected server handshake error!";
493  }
494 
495  EventBase eventBase;
496  std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
497  std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
498  int fds[2];
499  std::unique_ptr<AlpnClient> client;
500  std::unique_ptr<AlpnServer> server;
501 };
502 
503 TEST_F(NextProtocolTest, AlpnTestOverlap) {
504  clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
505  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
506 
507  connect();
508 
509  expectProtocol("baz");
510 }
511 
512 TEST_F(NextProtocolTest, AlpnTestUnset) {
513  // Identical to above test, except that we want unset NPN before
514  // looping.
515  clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
516  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
517 
518  connect(true /* unset */);
519 
520  expectNoProtocol();
521 }
522 
523 TEST_F(NextProtocolTest, AlpnTestNoOverlap) {
524  clientCtx->setAdvertisedNextProtocols({"blub"});
525  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
526  connect();
527 
528  expectNoProtocol();
529 }
530 
531 TEST_F(NextProtocolTest, RandomizedAlpnTest) {
532  // Probability that this test will fail is 2^-64, which could be considered
533  // as negligible.
534  const int kTries = 64;
535 
536  clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
537  serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}});
538 
539  std::set<string> selectedProtocols;
540  for (int i = 0; i < kTries; ++i) {
541  connect();
542 
543  EXPECT_NE(client->nextProtoLength, 0);
544  EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
545  EXPECT_EQ(
546  memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
547  0);
548  string selected((const char*)client->nextProto, client->nextProtoLength);
549  selectedProtocols.insert(selected);
550  expectHandshakeSuccess();
551  }
552  EXPECT_EQ(selectedProtocols.size(), 2);
553 }
554 #endif
555 
556 #ifndef OPENSSL_NO_TLSEXT
557 
563 TEST(AsyncSSLSocketTest, SNITestMatch) {
564  EventBase eventBase;
565  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
566  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
567  // Use the same SSLContext to continue the handshake after
568  // tlsext_hostname match.
569  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
570  const std::string serverName("xyz.newdev.facebook.com");
571  int fds[2];
572  getfds(fds);
573  getctx(clientCtx, dfServerCtx);
574 
575  AsyncSSLSocket::UniquePtr clientSock(
576  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
577  AsyncSSLSocket::UniquePtr serverSock(
578  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
579  SNIClient client(std::move(clientSock));
580  SNIServer server(
581  std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
582 
583  eventBase.loop();
584 
587 }
588 
595 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
596  EventBase eventBase;
597  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
598  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
599  // Use the same SSLContext to continue the handshake after
600  // tlsext_hostname match.
601  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
602  const std::string clientRequestingServerName("foo.com");
603  const std::string serverExpectedServerName("xyz.newdev.facebook.com");
604 
605  int fds[2];
606  getfds(fds);
607  getctx(clientCtx, dfServerCtx);
608 
610  clientCtx, &eventBase, fds[0], clientRequestingServerName));
611  AsyncSSLSocket::UniquePtr serverSock(
612  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
613  SNIClient client(std::move(clientSock));
614  SNIServer server(
615  std::move(serverSock),
616  dfServerCtx,
617  hskServerCtx,
618  serverExpectedServerName);
619 
620  eventBase.loop();
621 
622  EXPECT_TRUE(!client.serverNameMatch);
623  EXPECT_TRUE(!server.serverNameMatch);
624 }
631 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
632  EventBase eventBase;
633  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
634  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
635  // Use the same SSLContext to continue the handshake after
636  // tlsext_hostname match.
637  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
638  const std::string serverName("xyz.newdev.facebook.com");
639  int fds[2];
640  getfds(fds);
641  getctx(clientCtx, dfServerCtx);
642 
643  AsyncSSLSocket::UniquePtr clientSock(
644  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
645  // Change the server name
646  std::string newName("new.com");
647  clientSock->setServerName(newName);
648  AsyncSSLSocket::UniquePtr serverSock(
649  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
650  SNIClient client(std::move(clientSock));
651  SNIServer server(
652  std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
653 
654  eventBase.loop();
655 
656  EXPECT_TRUE(!client.serverNameMatch);
657 }
658 
663 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
664  EventBase eventBase;
665  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
666  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
667  // Use the same SSLContext to continue the handshake after
668  // tlsext_hostname match.
669  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
670  const std::string serverExpectedServerName("xyz.newdev.facebook.com");
671 
672  int fds[2];
673  getfds(fds);
674  getctx(clientCtx, dfServerCtx);
675 
676  AsyncSSLSocket::UniquePtr clientSock(
677  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
678  AsyncSSLSocket::UniquePtr serverSock(
679  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
680  SNIClient client(std::move(clientSock));
681  SNIServer server(
682  std::move(serverSock),
683  dfServerCtx,
684  hskServerCtx,
685  serverExpectedServerName);
686 
687  eventBase.loop();
688 
689  EXPECT_TRUE(!client.serverNameMatch);
690  EXPECT_TRUE(!server.serverNameMatch);
691 }
692 
693 #endif
694 
697 TEST(AsyncSSLSocketTest, SSLClientTest) {
698  // Start listening on a local port
699  WriteCallbackBase writeCallback;
700  ReadCallback readCallback(&writeCallback);
701  HandshakeCallback handshakeCallback(&readCallback);
702  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
703  TestSSLServer server(&acceptCallback);
704 
705  // Set up SSL client
706  EventBase eventBase;
707  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
708 
709  client->connect();
710  EventBaseAborter eba(&eventBase, 3000);
711  eventBase.loop();
712 
713  EXPECT_EQ(client->getMiss(), 1);
714  EXPECT_EQ(client->getHit(), 0);
715 
716  cerr << "SSLClientTest test completed" << endl;
717 }
718 
722 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
723  // Start listening on a local port
724  WriteCallbackBase writeCallback;
725  ReadCallback readCallback(&writeCallback);
726  HandshakeCallback handshakeCallback(&readCallback);
727  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
728  TestSSLServer server(&acceptCallback);
729 
730  // Set up SSL client
731  EventBase eventBase;
732  auto client =
733  std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
734 
735  client->connect();
736  EventBaseAborter eba(&eventBase, 3000);
737  eventBase.loop();
738 
739  EXPECT_EQ(client->getMiss(), 1);
740  EXPECT_EQ(client->getHit(), 9);
741 
742  cerr << "SSLClientTestReuse test completed" << endl;
743 }
744 
748 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
749  // Start listening on a local port
750  EmptyReadCallback readCallback;
751  HandshakeCallback handshakeCallback(
752  &readCallback, HandshakeCallback::EXPECT_ERROR);
753  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
754  TestSSLServer server(&acceptCallback);
755 
756  // Set up SSL client
757  EventBase eventBase;
758  auto client =
759  std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
760  client->connect(true /* write before connect completes */);
761  EventBaseAborter eba(&eventBase, 3000);
762  eventBase.loop();
763 
764  usleep(100000);
765  // This is checking that the connectError callback precedes any queued
766  // writeError callbacks. This matches AsyncSocket's behavior
767  EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
768  EXPECT_EQ(client->getErrors(), 1);
769  EXPECT_EQ(client->getMiss(), 0);
770  EXPECT_EQ(client->getHit(), 0);
771 
772  cerr << "SSLClientTimeoutTest test completed" << endl;
773 }
774 
775 // The next 3 tests need an FB-only extension, and will fail without it
776 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
777 
780 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
781  // Start listening on a local port
782  WriteCallbackBase writeCallback;
783  ReadCallback readCallback(&writeCallback);
784  HandshakeCallback handshakeCallback(&readCallback);
785  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
786  TestSSLAsyncCacheServer server(&acceptCallback);
787 
788  // Set up SSL client
789  EventBase eventBase;
790  auto client =
791  std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
792 
793  client->connect();
794  EventBaseAborter eba(&eventBase, 3000);
795  eventBase.loop();
796 
797  EXPECT_EQ(server.getAsyncCallbacks(), 18);
798  EXPECT_EQ(server.getAsyncLookups(), 9);
799  EXPECT_EQ(client->getMiss(), 10);
800  EXPECT_EQ(client->getHit(), 0);
801 
802  cerr << "SSLServerAsyncCacheTest test completed" << endl;
803 }
804 
808 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
809  // Start listening on a local port
810  WriteCallbackBase writeCallback;
811  ReadCallback readCallback(&writeCallback);
812  HandshakeCallback handshakeCallback(&readCallback);
813  SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
814  TestSSLAsyncCacheServer server(&acceptCallback);
815 
816  // Set up SSL client
817  EventBase eventBase;
818  // only do a TCP connect
819  std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
820  sock->connect(nullptr, server.getAddress());
821 
822  EmptyReadCallback clientReadCallback;
823  clientReadCallback.tcpSocket_ = sock;
824  sock->setReadCB(&clientReadCallback);
825 
826  EventBaseAborter eba(&eventBase, 3000);
827  eventBase.loop();
828 
829  EXPECT_EQ(readCallback.state, STATE_WAITING);
830 
831  cerr << "SSLServerTimeoutTest test completed" << endl;
832 }
833 
837 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
838  // Start listening on a local port
839  WriteCallbackBase writeCallback;
840  ReadCallback readCallback(&writeCallback);
841  HandshakeCallback handshakeCallback(&readCallback);
842  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
843  TestSSLAsyncCacheServer server(&acceptCallback);
844 
845  // Set up SSL client
846  EventBase eventBase;
847  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
848 
849  client->connect();
850  EventBaseAborter eba(&eventBase, 3000);
851  eventBase.loop();
852 
853  EXPECT_EQ(server.getAsyncCallbacks(), 1);
854  EXPECT_EQ(server.getAsyncLookups(), 1);
855  EXPECT_EQ(client->getErrors(), 1);
856  EXPECT_EQ(client->getMiss(), 1);
857  EXPECT_EQ(client->getHit(), 0);
858 
859  cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
860 }
861 
865 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
866  // Start listening on a local port
867  WriteCallbackBase writeCallback;
868  ReadCallback readCallback(&writeCallback);
869  HandshakeCallback handshakeCallback(
870  &readCallback, HandshakeCallback::EXPECT_ERROR);
871  SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
872  TestSSLAsyncCacheServer server(&acceptCallback, 500);
873 
874  // Set up SSL client
875  EventBase eventBase;
876  auto client =
877  std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
878 
879  client->connect();
880  EventBaseAborter eba(&eventBase, 3000);
881  eventBase.loop();
882 
884  [&handshakeCallback] { handshakeCallback.closeSocket(); });
885  // give time for the cache lookup to come back and find it closed
886  handshakeCallback.waitForHandshake();
887 
888  EXPECT_EQ(server.getAsyncCallbacks(), 1);
889  EXPECT_EQ(server.getAsyncLookups(), 1);
890  EXPECT_EQ(client->getErrors(), 1);
891  EXPECT_EQ(client->getMiss(), 1);
892  EXPECT_EQ(client->getHit(), 0);
893 
894  cerr << "SSLServerCacheCloseTest test completed" << endl;
895 }
896 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
897 
901 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
902  EventBase eventBase;
903  auto clientCtx = std::make_shared<SSLContext>();
904  auto serverCtx = std::make_shared<SSLContext>();
905  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
906  serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
907  serverCtx->loadPrivateKey(kTestKey);
908  serverCtx->loadCertificate(kTestCert);
909  serverCtx->loadTrustedCertificates(kTestCA);
910  serverCtx->loadClientCAList(kTestCA);
911 
912  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
913  clientCtx->ciphers("AES256-SHA:AES128-SHA");
914  clientCtx->loadPrivateKey(kTestKey);
915  clientCtx->loadCertificate(kTestCert);
916  clientCtx->loadTrustedCertificates(kTestCA);
917 
918  int fds[2];
919  getfds(fds);
920 
921  AsyncSSLSocket::UniquePtr clientSock(
922  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
923  AsyncSSLSocket::UniquePtr serverSock(
924  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
925 
926  SSLHandshakeClient client(std::move(clientSock), true, true);
927  SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
928 
929  eventBase.loop();
930 
931 #if defined(OPENSSL_IS_BORINGSSL)
932  EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
933 #else
934  EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
935 #endif
936  EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
939  EXPECT_TRUE(!client.handshakeError_);
942  EXPECT_TRUE(!server.handshakeError_);
943 }
944 
948 TEST(AsyncSSLSocketTest, GetClientCertificate) {
949  EventBase eventBase;
950  auto clientCtx = std::make_shared<SSLContext>();
951  auto serverCtx = std::make_shared<SSLContext>();
952  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
953  serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
954  serverCtx->loadPrivateKey(kTestKey);
955  serverCtx->loadCertificate(kTestCert);
956  serverCtx->loadTrustedCertificates(kClientTestCA);
957  serverCtx->loadClientCAList(kClientTestCA);
958 
959  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
960  clientCtx->ciphers("AES256-SHA:AES128-SHA");
961  clientCtx->loadPrivateKey(kClientTestKey);
962  clientCtx->loadCertificate(kClientTestCert);
963  clientCtx->loadTrustedCertificates(kTestCA);
964 
965  std::array<int, 2> fds;
966  getfds(fds.data());
967 
968  AsyncSSLSocket::UniquePtr clientSock(
969  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
970  AsyncSSLSocket::UniquePtr serverSock(
971  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
972 
973  SSLHandshakeClient client(std::move(clientSock), true, true);
974  SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
975 
976  eventBase.loop();
977 
978  // Handshake should succeed.
981 
982  // Reclaim the sockets from SSLHandshakeBase.
983  auto cliSocket = std::move(client).moveSocket();
984  auto srvSocket = std::move(server).moveSocket();
985 
986  // Client cert retrieved from server side.
987  folly::ssl::X509UniquePtr serverPeerCert = srvSocket->getPeerCert();
988  CHECK(serverPeerCert);
989 
990  // Client cert retrieved from client side.
991  const X509* clientSelfCert = cliSocket->getSelfCert();
992  CHECK(clientSelfCert);
993 
994  // The two certs should be the same.
995  EXPECT_EQ(0, X509_cmp(clientSelfCert, serverPeerCert.get()));
996 }
997 
998 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
999  EventBase eventBase;
1000  auto ctx = std::make_shared<SSLContext>();
1001 
1002  int fds[2];
1003  getfds(fds);
1004 
1005  int bufLen = 42;
1006  uint8_t majorVersion = 18;
1007  uint8_t minorVersion = 25;
1008 
1009  // Create callback buf
1010  auto buf = IOBuf::create(bufLen);
1011  buf->append(bufLen);
1012  folly::io::RWPrivateCursor cursor(buf.get());
1013  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1014  cursor.write<uint16_t>(0);
1015  cursor.write<uint8_t>(38);
1016  cursor.write<uint8_t>(majorVersion);
1017  cursor.write<uint8_t>(minorVersion);
1018  cursor.skip(32);
1019  cursor.write<uint32_t>(0);
1020 
1021  SSL* ssl = ctx->createSSL();
1022  SCOPE_EXIT {
1023  SSL_free(ssl);
1024  };
1026  new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1027  sock->enableClientHelloParsing();
1028 
1029  // Test client hello parsing in one packet
1030  AsyncSSLSocket::clientHelloParsingCallback(
1031  0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1032  buf.reset();
1033 
1034  auto parsedClientHello = sock->getClientHelloInfo();
1035  EXPECT_TRUE(parsedClientHello != nullptr);
1036  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1037  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1038 }
1039 
1040 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1041  EventBase eventBase;
1042  auto ctx = std::make_shared<SSLContext>();
1043 
1044  int fds[2];
1045  getfds(fds);
1046 
1047  int bufLen = 42;
1048  uint8_t majorVersion = 18;
1049  uint8_t minorVersion = 25;
1050 
1051  // Create callback buf
1052  auto buf = IOBuf::create(bufLen);
1053  buf->append(bufLen);
1054  folly::io::RWPrivateCursor cursor(buf.get());
1055  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1056  cursor.write<uint16_t>(0);
1057  cursor.write<uint8_t>(38);
1058  cursor.write<uint8_t>(majorVersion);
1059  cursor.write<uint8_t>(minorVersion);
1060  cursor.skip(32);
1061  cursor.write<uint32_t>(0);
1062 
1063  SSL* ssl = ctx->createSSL();
1064  SCOPE_EXIT {
1065  SSL_free(ssl);
1066  };
1068  new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1069  sock->enableClientHelloParsing();
1070 
1071  // Test parsing with two packets with first packet size < 3
1072  auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1073  AsyncSSLSocket::clientHelloParsingCallback(
1074  0,
1075  0,
1076  SSL3_RT_HANDSHAKE,
1077  bufCopy->data(),
1078  bufCopy->length(),
1079  ssl,
1080  sock.get());
1081  bufCopy.reset();
1082  bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1083  AsyncSSLSocket::clientHelloParsingCallback(
1084  0,
1085  0,
1086  SSL3_RT_HANDSHAKE,
1087  bufCopy->data(),
1088  bufCopy->length(),
1089  ssl,
1090  sock.get());
1091  bufCopy.reset();
1092 
1093  auto parsedClientHello = sock->getClientHelloInfo();
1094  EXPECT_TRUE(parsedClientHello != nullptr);
1095  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1096  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1097 }
1098 
1099 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1100  EventBase eventBase;
1101  auto ctx = std::make_shared<SSLContext>();
1102 
1103  int fds[2];
1104  getfds(fds);
1105 
1106  int bufLen = 42;
1107  uint8_t majorVersion = 18;
1108  uint8_t minorVersion = 25;
1109 
1110  // Create callback buf
1111  auto buf = IOBuf::create(bufLen);
1112  buf->append(bufLen);
1113  folly::io::RWPrivateCursor cursor(buf.get());
1114  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1115  cursor.write<uint16_t>(0);
1116  cursor.write<uint8_t>(38);
1117  cursor.write<uint8_t>(majorVersion);
1118  cursor.write<uint8_t>(minorVersion);
1119  cursor.skip(32);
1120  cursor.write<uint32_t>(0);
1121 
1122  SSL* ssl = ctx->createSSL();
1123  SCOPE_EXIT {
1124  SSL_free(ssl);
1125  };
1127  new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1128  sock->enableClientHelloParsing();
1129 
1130  // Test parsing with multiple small packets
1131  for (std::size_t i = 0; i < buf->length(); i += 3) {
1132  auto bufCopy = folly::IOBuf::copyBuffer(
1133  buf->data() + i, std::min((std::size_t)3, buf->length() - i));
1134  AsyncSSLSocket::clientHelloParsingCallback(
1135  0,
1136  0,
1137  SSL3_RT_HANDSHAKE,
1138  bufCopy->data(),
1139  bufCopy->length(),
1140  ssl,
1141  sock.get());
1142  bufCopy.reset();
1143  }
1144 
1145  auto parsedClientHello = sock->getClientHelloInfo();
1146  EXPECT_TRUE(parsedClientHello != nullptr);
1147  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1148  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1149 }
1150 
1154 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1155  EventBase eventBase;
1156  auto clientCtx = std::make_shared<SSLContext>();
1157  auto dfServerCtx = std::make_shared<SSLContext>();
1158 
1159  int fds[2];
1160  getfds(fds);
1161  getctx(clientCtx, dfServerCtx);
1162 
1163  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1164  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1165 
1166  AsyncSSLSocket::UniquePtr clientSock(
1167  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1168  AsyncSSLSocket::UniquePtr serverSock(
1169  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1170 
1171  SSLHandshakeClient client(std::move(clientSock), true, true);
1172  clientCtx->loadTrustedCertificates(kTestCA);
1173 
1174  SSLHandshakeServer server(std::move(serverSock), true, true);
1175 
1176  eventBase.loop();
1177 
1178  EXPECT_TRUE(client.handshakeVerify_);
1180  EXPECT_TRUE(!client.handshakeError_);
1181  EXPECT_LE(0, client.handshakeTime.count());
1182  EXPECT_TRUE(!server.handshakeVerify_);
1184  EXPECT_TRUE(!server.handshakeError_);
1185  EXPECT_LE(0, server.handshakeTime.count());
1186 }
1187 
1192 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1193  EventBase eventBase;
1194  auto clientCtx = std::make_shared<SSLContext>();
1195  auto dfServerCtx = std::make_shared<SSLContext>();
1196 
1197  int fds[2];
1198  getfds(fds);
1199  getctx(clientCtx, dfServerCtx);
1200 
1201  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1202  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1203 
1204  AsyncSSLSocket::UniquePtr clientSock(
1205  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1206  AsyncSSLSocket::UniquePtr serverSock(
1207  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1208 
1209  SSLHandshakeClient client(std::move(clientSock), true, false);
1210  clientCtx->loadTrustedCertificates(kTestCA);
1211 
1212  SSLHandshakeServer server(std::move(serverSock), true, true);
1213 
1214  eventBase.loop();
1215 
1216  EXPECT_TRUE(client.handshakeVerify_);
1217  EXPECT_TRUE(!client.handshakeSuccess_);
1218  EXPECT_TRUE(client.handshakeError_);
1219  EXPECT_LE(0, client.handshakeTime.count());
1220  EXPECT_TRUE(!server.handshakeVerify_);
1221  EXPECT_TRUE(!server.handshakeSuccess_);
1222  EXPECT_TRUE(server.handshakeError_);
1223  EXPECT_LE(0, server.handshakeTime.count());
1224 }
1225 
1232 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1233  EventBase eventBase;
1234  auto clientCtx = std::make_shared<SSLContext>();
1235  auto dfServerCtx = std::make_shared<SSLContext>();
1236 
1237  int fds[2];
1238  getfds(fds);
1239  getctx(clientCtx, dfServerCtx);
1240 
1241  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1242  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1243 
1244  AsyncSSLSocket::UniquePtr clientSock(
1245  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1246  AsyncSSLSocket::UniquePtr serverSock(
1247  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1248 
1249  SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1250  clientCtx->loadTrustedCertificates(kTestCA);
1251 
1252  SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1253 
1254  eventBase.loop();
1255 
1256  EXPECT_TRUE(!client.handshakeVerify_);
1258  EXPECT_TRUE(!client.handshakeError_);
1259  EXPECT_LE(0, client.handshakeTime.count());
1260  EXPECT_TRUE(!server.handshakeVerify_);
1262  EXPECT_TRUE(!server.handshakeError_);
1263  EXPECT_LE(0, server.handshakeTime.count());
1264 }
1265 
1271 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1272  EventBase eventBase;
1273  auto clientCtx = std::make_shared<SSLContext>();
1274  auto serverCtx = std::make_shared<SSLContext>();
1275  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1276  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1277  serverCtx->loadPrivateKey(kTestKey);
1278  serverCtx->loadCertificate(kTestCert);
1279  serverCtx->loadTrustedCertificates(kTestCA);
1280  serverCtx->loadClientCAList(kTestCA);
1281 
1282  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1283  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1284  clientCtx->loadPrivateKey(kTestKey);
1285  clientCtx->loadCertificate(kTestCert);
1286  clientCtx->loadTrustedCertificates(kTestCA);
1287 
1288  int fds[2];
1289  getfds(fds);
1290 
1291  AsyncSSLSocket::UniquePtr clientSock(
1292  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1293  AsyncSSLSocket::UniquePtr serverSock(
1294  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1295 
1296  SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1297  SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1298 
1299  eventBase.loop();
1300 
1301  EXPECT_TRUE(client.handshakeVerify_);
1303  EXPECT_FALSE(client.handshakeError_);
1304  EXPECT_LE(0, client.handshakeTime.count());
1305  EXPECT_TRUE(server.handshakeVerify_);
1307  EXPECT_FALSE(server.handshakeError_);
1308  EXPECT_LE(0, server.handshakeTime.count());
1309 }
1310 
1315 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1316  EventBase eventBase;
1317  auto clientCtx = std::make_shared<SSLContext>();
1318  auto dfServerCtx = std::make_shared<SSLContext>();
1319 
1320  int fds[2];
1321  getfds(fds);
1322  getctx(clientCtx, dfServerCtx);
1323 
1324  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1325  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1326 
1327  AsyncSSLSocket::UniquePtr clientSock(
1328  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1329  AsyncSSLSocket::UniquePtr serverSock(
1330  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1331 
1332  SSLHandshakeClient client(std::move(clientSock), false, true);
1333  SSLHandshakeServer server(std::move(serverSock), true, true);
1334 
1335  eventBase.loop();
1336 
1337  EXPECT_TRUE(client.handshakeVerify_);
1339  EXPECT_TRUE(!client.handshakeError_);
1340  EXPECT_LE(0, client.handshakeTime.count());
1341  EXPECT_TRUE(!server.handshakeVerify_);
1343  EXPECT_TRUE(!server.handshakeError_);
1344  EXPECT_LE(0, server.handshakeTime.count());
1345 }
1346 
1352 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1353  EventBase eventBase;
1354  auto clientCtx = std::make_shared<SSLContext>();
1355  auto dfServerCtx = std::make_shared<SSLContext>();
1356 
1357  int fds[2];
1358  getfds(fds);
1359  getctx(clientCtx, dfServerCtx);
1360 
1361  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1362  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1363 
1364  AsyncSSLSocket::UniquePtr clientSock(
1365  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1366  AsyncSSLSocket::UniquePtr serverSock(
1367  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1368 
1369  SSLHandshakeClient client(std::move(clientSock), false, false);
1370  SSLHandshakeServer server(std::move(serverSock), false, false);
1371 
1372  eventBase.loop();
1373 
1374  EXPECT_TRUE(!client.handshakeVerify_);
1376  EXPECT_TRUE(!client.handshakeError_);
1377  EXPECT_LE(0, client.handshakeTime.count());
1378  EXPECT_TRUE(!server.handshakeVerify_);
1380  EXPECT_TRUE(!server.handshakeError_);
1381  EXPECT_LE(0, server.handshakeTime.count());
1382 }
1383 
1387 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1388  EventBase eventBase;
1389  auto clientCtx = std::make_shared<SSLContext>();
1390  auto serverCtx = std::make_shared<SSLContext>();
1391  serverCtx->setVerificationOption(
1392  SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1393  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1394  serverCtx->loadPrivateKey(kTestKey);
1395  serverCtx->loadCertificate(kTestCert);
1396  serverCtx->loadTrustedCertificates(kTestCA);
1397  serverCtx->loadClientCAList(kTestCA);
1398 
1399  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1400  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1401  clientCtx->loadPrivateKey(kTestKey);
1402  clientCtx->loadCertificate(kTestCert);
1403  clientCtx->loadTrustedCertificates(kTestCA);
1404 
1405  int fds[2];
1406  getfds(fds);
1407 
1408  AsyncSSLSocket::UniquePtr clientSock(
1409  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1410  AsyncSSLSocket::UniquePtr serverSock(
1411  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1412 
1413  SSLHandshakeClient client(std::move(clientSock), true, true);
1414  SSLHandshakeServer server(std::move(serverSock), true, true);
1415 
1416  eventBase.loop();
1417 
1418  EXPECT_TRUE(client.handshakeVerify_);
1420  EXPECT_FALSE(client.handshakeError_);
1421  EXPECT_LE(0, client.handshakeTime.count());
1422  EXPECT_TRUE(server.handshakeVerify_);
1424  EXPECT_FALSE(server.handshakeError_);
1425  EXPECT_LE(0, server.handshakeTime.count());
1426 
1427  // check certificates
1428  auto clientSsl = std::move(client).moveSocket();
1429  auto serverSsl = std::move(server).moveSocket();
1430 
1431  auto clientPeer = clientSsl->getPeerCertificate();
1432  auto clientSelf = clientSsl->getSelfCertificate();
1433  auto serverPeer = serverSsl->getPeerCertificate();
1434  auto serverSelf = serverSsl->getSelfCertificate();
1435 
1436  EXPECT_NE(clientPeer, nullptr);
1437  EXPECT_NE(clientSelf, nullptr);
1438  EXPECT_NE(serverPeer, nullptr);
1439  EXPECT_NE(serverSelf, nullptr);
1440 
1441  EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
1442  EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
1443 }
1444 
1448 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1449  EventBase eventBase;
1450  auto clientCtx = std::make_shared<SSLContext>();
1451  auto serverCtx = std::make_shared<SSLContext>();
1452  serverCtx->setVerificationOption(
1453  SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1454  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1455  serverCtx->loadPrivateKey(kTestKey);
1456  serverCtx->loadCertificate(kTestCert);
1457  serverCtx->loadTrustedCertificates(kTestCA);
1458  serverCtx->loadClientCAList(kTestCA);
1459  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1460  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1461 
1462  int fds[2];
1463  getfds(fds);
1464 
1465  AsyncSSLSocket::UniquePtr clientSock(
1466  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1467  AsyncSSLSocket::UniquePtr serverSock(
1468  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1469 
1470  SSLHandshakeClient client(std::move(clientSock), false, false);
1471  SSLHandshakeServer server(std::move(serverSock), false, false);
1472 
1473  eventBase.loop();
1474 
1477  EXPECT_TRUE(server.handshakeError_);
1478  EXPECT_LE(0, client.handshakeTime.count());
1479  EXPECT_LE(0, server.handshakeTime.count());
1480 }
1481 
1485 #if FOLLY_OPENSSL_IS_110
1486 
1487 static void makeNonBlockingPipe(int pipefds[2]) {
1488  if (pipe(pipefds) != 0) {
1489  throw std::runtime_error("Cannot create pipe");
1490  }
1491  if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) {
1492  throw std::runtime_error("Cannot set pipe to nonblocking");
1493  }
1494  if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) {
1495  throw std::runtime_error("Cannot set pipe to nonblocking");
1496  }
1497 }
1498 
1499 // Custom RSA private key encryption method
1500 static int kRSAExIndex = -1;
1501 static int kRSAEvbExIndex = -1;
1502 static int kRSASocketExIndex = -1;
1503 static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";
1504 
1505 static int customRsaPrivEnc(
1506  int flen,
1507  const unsigned char* from,
1508  unsigned char* to,
1509  RSA* rsa,
1510  int padding) {
1511  LOG(INFO) << "rsa_priv_enc";
1512  EventBase* asyncJobEvb =
1513  reinterpret_cast<EventBase*>(RSA_get_ex_data(rsa, kRSAEvbExIndex));
1514  CHECK(asyncJobEvb);
1515 
1516  RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
1517  CHECK(actualRSA);
1518 
1519  AsyncSSLSocket* socket = reinterpret_cast<AsyncSSLSocket*>(
1520  RSA_get_ex_data(rsa, kRSASocketExIndex));
1521 
1522  ASYNC_JOB* job = ASYNC_get_current_job();
1523  if (job == nullptr) {
1524  throw std::runtime_error("Expected call in job context");
1525  }
1526  ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job);
1527  OSSL_ASYNC_FD pipefds[2] = {0, 0};
1528  makeNonBlockingPipe(pipefds);
1529  if (!ASYNC_WAIT_CTX_set_wait_fd(
1530  waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) {
1531  throw std::runtime_error("Cannot set wait fd");
1532  }
1533  int ret = 0;
1534  int* retptr = &ret;
1535 
1536  auto asyncPipeWriter =
1537  folly::AsyncPipeWriter::newWriter(asyncJobEvb, pipefds[1]);
1538 
1539  asyncJobEvb->runInEventBaseThread([retptr = retptr,
1540  flen = flen,
1541  from = from,
1542  to = to,
1543  padding = padding,
1544  actualRSA = actualRSA,
1545  writer = std::move(asyncPipeWriter),
1546  socket = socket]() {
1547  LOG(INFO) << "Running job";
1548  if (socket) {
1549  LOG(INFO) << "Got a socket passed in, closing it...";
1550  socket->closeNow();
1551  }
1552  *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
1553  flen, from, to, actualRSA, padding);
1554  LOG(INFO) << "Finished job, writing to pipe";
1555  uint8_t byte = *retptr > 0 ? 1 : 0;
1556  writer->write(nullptr, &byte, 1);
1557  });
1558 
1559  LOG(INFO) << "About to pause job";
1560 
1561  ASYNC_pause_job();
1562  LOG(INFO) << "Resumed job with ret: " << ret;
1563  return ret;
1564 }
1565 
1566 void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) {
1567  LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr;
1568  if (ptr == nullptr) {
1569  LOG(INFO) << "Returning early from rsaFree because ptr is null";
1570  return;
1571  }
1572  RSA* rsa = (RSA*)ptr;
1573  auto meth = RSA_get_method(rsa);
1574  if (meth != RSA_get_default_method()) {
1575  auto nonconst = const_cast<RSA_METHOD*>(meth);
1576  RSA_meth_free(nonconst);
1577  RSA_set_method(rsa, RSA_get_default_method());
1578  }
1579  RSA_free(rsa);
1580 }
1581 
1582 struct RSAPointers {
1583  RSA* actualrsa{nullptr};
1584  RSA* dummyrsa{nullptr};
1585  RSA_METHOD* meth{nullptr};
1586 };
1587 
1588 inline void RSAPointersFree(RSAPointers* p) {
1589  if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) {
1590  RSA_set_method(p->dummyrsa, RSA_get_default_method());
1591  }
1592 
1593  if (p->meth) {
1594  LOG(INFO) << "Freeing meth";
1595  RSA_meth_free(p->meth);
1596  }
1597 
1598  if (p->actualrsa) {
1599  LOG(INFO) << "Freeing actualrsa";
1600  RSA_free(p->actualrsa);
1601  }
1602 
1603  if (p->dummyrsa) {
1604  LOG(INFO) << "Freeing dummyrsa";
1605  RSA_free(p->dummyrsa);
1606  }
1607 
1608  delete p;
1609 }
1610 
1611 using RSAPointersDeleter =
1613 
1614 std::unique_ptr<RSAPointers, RSAPointersDeleter>
1615 setupCustomRSA(const char* certPath, const char* keyPath, EventBase* jobEvb) {
1616  auto certPEM = getFileAsBuf(certPath);
1617  auto keyPEM = getFileAsBuf(keyPath);
1618 
1619  ssl::BioUniquePtr certBio(
1620  BIO_new_mem_buf((void*)certPEM.data(), certPEM.size()));
1621  ssl::BioUniquePtr keyBio(
1622  BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size()));
1623 
1624  ssl::X509UniquePtr cert(
1625  PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1626  ssl::EvpPkeyUniquePtr evpPkey(
1627  PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1628  ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get()));
1629 
1630  std::unique_ptr<RSAPointers, RSAPointersDeleter> ret(new RSAPointers());
1631 
1632  RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get());
1633  LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa;
1634  RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get());
1635  if (dummyrsa == nullptr) {
1636  throw std::runtime_error("Couldn't get RSA cert public factors");
1637  }
1638  RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method());
1639  if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 ||
1640  RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 ||
1641  RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) {
1642  throw std::runtime_error("Cannot create async RSA_METHOD");
1643  }
1644  RSA_set_method(dummyrsa, meth);
1645  RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY);
1646 
1647  kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
1648  kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
1649  kRSASocketExIndex =
1650  RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
1651  CHECK_NE(kRSAExIndex, -1);
1652  CHECK_NE(kRSAEvbExIndex, -1);
1653  CHECK_NE(kRSASocketExIndex, -1);
1654  RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
1655  RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);
1656 
1657  ret->actualrsa = actualrsa;
1658  ret->dummyrsa = dummyrsa;
1659  ret->meth = meth;
1660 
1661  return ret;
1662 }
1663 
1664 // TODO: disabled with ASAN doesn't play nice with ASYNC for some reason
1665 #ifndef FOLLY_SANITIZE_ADDRESS
1666 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) {
1667  ASYNC_init_thread(1, 1);
1668  EventBase eventBase;
1669  ScopedEventBaseThread jobEvbThread;
1670  auto clientCtx = std::make_shared<SSLContext>();
1671  auto serverCtx = std::make_shared<SSLContext>();
1672  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1673  serverCtx->loadCertificate(kTestCert);
1674  serverCtx->loadTrustedCertificates(kTestCA);
1675  serverCtx->loadClientCAList(kTestCA);
1676 
1677  auto rsaPointers =
1678  setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
1679  CHECK(rsaPointers->dummyrsa);
1680  // up-refs dummyrsa
1681  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
1682  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
1683 
1684  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1685  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1686 
1687  int fds[2];
1688  getfds(fds);
1689 
1690  AsyncSSLSocket::UniquePtr clientSock(
1691  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1692  AsyncSSLSocket::UniquePtr serverSock(
1693  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1694 
1695  SSLHandshakeClient client(std::move(clientSock), false, false);
1696  SSLHandshakeServer server(std::move(serverSock), false, false);
1697 
1698  eventBase.loop();
1699 
1702  ASYNC_cleanup_thread();
1703 }
1704 
1705 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
1706  ASYNC_init_thread(1, 1);
1707  EventBase eventBase;
1708  ScopedEventBaseThread jobEvbThread;
1709  auto clientCtx = std::make_shared<SSLContext>();
1710  auto serverCtx = std::make_shared<SSLContext>();
1711  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1712  serverCtx->loadCertificate(kTestCert);
1713  serverCtx->loadTrustedCertificates(kTestCA);
1714  serverCtx->loadClientCAList(kTestCA);
1715  // Set the wrong key for the cert
1716  auto rsaPointers =
1717  setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase());
1718  CHECK(rsaPointers->dummyrsa);
1719  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
1720  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
1721 
1722  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1723  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1724 
1725  int fds[2];
1726  getfds(fds);
1727 
1728  AsyncSSLSocket::UniquePtr clientSock(
1729  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1730  AsyncSSLSocket::UniquePtr serverSock(
1731  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1732 
1733  SSLHandshakeClient client(std::move(clientSock), false, false);
1734  SSLHandshakeServer server(std::move(serverSock), false, false);
1735 
1736  eventBase.loop();
1737 
1738  EXPECT_TRUE(server.handshakeError_);
1739  EXPECT_TRUE(client.handshakeError_);
1740  ASYNC_cleanup_thread();
1741 }
1742 
1743 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) {
1744  ASYNC_init_thread(1, 1);
1745  EventBase eventBase;
1746  ScopedEventBaseThread jobEvbThread;
1747  auto clientCtx = std::make_shared<SSLContext>();
1748  auto serverCtx = std::make_shared<SSLContext>();
1749  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1750  serverCtx->loadCertificate(kTestCert);
1751  serverCtx->loadTrustedCertificates(kTestCA);
1752  serverCtx->loadClientCAList(kTestCA);
1753 
1754  auto rsaPointers =
1755  setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
1756  CHECK(rsaPointers->dummyrsa);
1757  // up-refs dummyrsa
1758  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
1759  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
1760 
1761  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1762  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1763 
1764  int fds[2];
1765  getfds(fds);
1766 
1767  AsyncSSLSocket::UniquePtr clientSock(
1768  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1769  AsyncSSLSocket::UniquePtr serverSock(
1770  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1771 
1772  RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get());
1773 
1774  SSLHandshakeClient client(std::move(clientSock), false, false);
1775  SSLHandshakeServer server(std::move(serverSock), false, false);
1776 
1777  eventBase.loop();
1778 
1779  EXPECT_TRUE(server.handshakeError_);
1780  EXPECT_TRUE(client.handshakeError_);
1781  ASYNC_cleanup_thread();
1782 }
1783 #endif // FOLLY_SANITIZE_ADDRESS
1784 
1785 #endif // FOLLY_OPENSSL_IS_110
1786 
1787 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1789  auto cert = getFileAsBuf(kTestCert);
1790  auto key = getFileAsBuf(kTestKey);
1791 
1792  ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1793  BIO_write(certBio.get(), cert.data(), cert.size());
1794  ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1795  BIO_write(keyBio.get(), key.data(), key.size());
1796 
1797  // Create SSL structs from buffers to get properties
1798  ssl::X509UniquePtr certStruct(
1799  PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1800  ssl::EvpPkeyUniquePtr keyStruct(
1801  PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1802  certBio = nullptr;
1803  keyBio = nullptr;
1804 
1805  auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
1806  auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1807  certStruct = nullptr;
1808  keyStruct = nullptr;
1809 
1810  auto ctx = std::make_shared<SSLContext>();
1811  ctx->loadPrivateKeyFromBufferPEM(key);
1812  ctx->loadCertificateFromBufferPEM(cert);
1813  ctx->loadTrustedCertificates(kTestCA);
1814 
1815  ssl::SSLUniquePtr ssl(ctx->createSSL());
1816 
1817  auto newCert = SSL_get_certificate(ssl.get());
1818  auto newKey = SSL_get_privatekey(ssl.get());
1819 
1820  // Get properties from SSL struct
1821  auto newCommonName = OpenSSLUtils::getCommonName(newCert);
1822  auto newKeySize = EVP_PKEY_bits(newKey);
1823 
1824  // Check that the key and cert have the expected properties
1825  EXPECT_EQ(origCommonName, newCommonName);
1826  EXPECT_EQ(origKeySize, newKeySize);
1827 }
1828 
1829 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1830  EventBase eb;
1831 
1832  // Set up SSL context.
1833  auto sslContext = std::make_shared<SSLContext>();
1834  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1835 
1836  // create SSL socket
1837  AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1838 
1839  EXPECT_EQ(1500, socket->getMinWriteSize());
1840 
1841  socket->setMinWriteSize(0);
1842  EXPECT_EQ(0, socket->getMinWriteSize());
1843  socket->setMinWriteSize(50000);
1844  EXPECT_EQ(50000, socket->getMinWriteSize());
1845 }
1846 
1848  public:
1850  : ReadCallback(wcb), base_(base) {}
1851 
1852  // Do not write data back, terminate the loop.
1853  void readDataAvailable(size_t len) noexcept override {
1854  std::cerr << "readDataAvailable, len " << len << std::endl;
1855 
1856  currentBuffer.length = len;
1857 
1858  buffers.push_back(currentBuffer);
1859  currentBuffer.reset();
1861 
1862  socket_->setReadCB(nullptr);
1863  base_->terminateLoopSoon();
1864  }
1865 
1866  private:
1868 };
1869 
1873 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1874  EventBase base;
1875 
1876  auto clientCtx = std::make_shared<folly::SSLContext>();
1877  auto serverCtx = std::make_shared<folly::SSLContext>();
1878  int fds[2];
1879  getfds(fds);
1880  getctx(clientCtx, serverCtx);
1881  auto client =
1882  AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
1883  auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
1884 
1885  ReadCallbackTerminator readCallback(&base, nullptr);
1886  server->setReadCB(&readCallback);
1887  readCallback.setSocket(server);
1888 
1889  uint8_t buf[128];
1890  memset(buf, 'a', sizeof(buf));
1891  client->write(nullptr, buf, sizeof(buf));
1892 
1893  // Check that bytes are unencrypted
1894  char c;
1895  EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1896  EXPECT_EQ('a', c);
1897 
1898  EventBaseAborter eba(&base, 3000);
1899  base.loop();
1900 
1901  EXPECT_EQ(1, readCallback.buffers.size());
1902  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1903 
1904  server->setReadCB(&readCallback);
1905 
1906  // Unencrypted
1907  server->sslAccept(nullptr);
1908  client->sslConn(nullptr);
1909 
1910  // Do NOT wait for handshake, writing should be queued and happen after
1911 
1912  client->write(nullptr, buf, sizeof(buf));
1913 
1914  // Check that bytes are *not* unencrypted
1915  char c2;
1916  EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1917  EXPECT_NE('a', c2);
1918 
1919  base.loop();
1920 
1921  EXPECT_EQ(2, readCallback.buffers.size());
1922  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1923 }
1924 
1925 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1926  auto clientCtx = std::make_shared<folly::SSLContext>();
1927  auto serverCtx = std::make_shared<folly::SSLContext>();
1928  getctx(clientCtx, serverCtx);
1929 
1930  WriteCallbackBase writeCallback;
1931  ReadCallback readCallback(&writeCallback);
1932  HandshakeCallback handshakeCallback(&readCallback);
1933  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1934  TestSSLServer server(&acceptCallback);
1935 
1936  EventBase evb;
1937  std::shared_ptr<AsyncSSLSocket> socket =
1938  AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1939  socket->connect(nullptr, server.getAddress(), 0);
1940 
1941  evb.loop();
1942 
1943  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1944  socket->sslConn(nullptr);
1945  evb.loop();
1946  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1947 
1948  // write()
1949  std::array<uint8_t, 128> buf;
1950  memset(buf.data(), 'a', buf.size());
1951  socket->write(nullptr, buf.data(), buf.size());
1952 
1953  socket->close();
1954 }
1955 
1959 TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
1960  EventBase eventBase;
1961  auto clientCtx = std::make_shared<SSLContext>();
1962  auto serverCtx = std::make_shared<SSLContext>();
1963  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1964  serverCtx->loadPrivateKey(kTestKey);
1965  serverCtx->loadCertificate(kTestCert);
1966 
1967  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1968  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1969  clientCtx->loadTrustedCertificates(kTestCA);
1970 
1971  int fds[2];
1972  getfds(fds);
1973 
1974  AsyncSSLSocket::UniquePtr clientSock(
1975  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1976  AsyncSSLSocket::UniquePtr serverSock(
1977  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1978 
1979  serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));
1980 
1981  SSLHandshakeClient client(std::move(clientSock), true, true);
1982  SSLHandshakeServer server(std::move(serverSock), true, true);
1983 
1984  eventBase.loop();
1985 
1987  EXPECT_FALSE(client.handshakeError_);
1988  EXPECT_LE(0, client.handshakeTime.count());
1990  EXPECT_FALSE(server.handshakeError_);
1991  EXPECT_LE(0, server.handshakeTime.count());
1992 }
1993 
1994 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
1995  EventBase eventBase;
1996  auto clientCtx = std::make_shared<SSLContext>();
1997  auto serverCtx = std::make_shared<SSLContext>();
1998  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1999  serverCtx->loadPrivateKey(kTestKey);
2000  serverCtx->loadCertificate(kTestCert);
2001 
2002  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2003  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2004  clientCtx->loadTrustedCertificates(kTestCA);
2005 
2006  int fds[2];
2007  getfds(fds);
2008 
2009  AsyncSSLSocket::UniquePtr clientSock(
2010  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2011  AsyncSSLSocket::UniquePtr serverSock(
2012  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2013 
2014  serverCtx->sslAcceptRunner(
2015  std::make_unique<SSLAcceptErrorRunner>(&eventBase));
2016 
2017  SSLHandshakeClient client(std::move(clientSock), true, true);
2018  SSLHandshakeServer server(std::move(serverSock), true, true);
2019 
2020  eventBase.loop();
2021 
2023  EXPECT_TRUE(client.handshakeError_);
2025  EXPECT_TRUE(server.handshakeError_);
2026 }
2027 
2028 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
2029  EventBase eventBase;
2030  auto clientCtx = std::make_shared<SSLContext>();
2031  auto serverCtx = std::make_shared<SSLContext>();
2032  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2033  serverCtx->loadPrivateKey(kTestKey);
2034  serverCtx->loadCertificate(kTestCert);
2035 
2036  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2037  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2038  clientCtx->loadTrustedCertificates(kTestCA);
2039 
2040  int fds[2];
2041  getfds(fds);
2042 
2043  AsyncSSLSocket::UniquePtr clientSock(
2044  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2045  AsyncSSLSocket::UniquePtr serverSock(
2046  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2047 
2048  serverCtx->sslAcceptRunner(
2049  std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));
2050 
2051  SSLHandshakeClient client(std::move(clientSock), true, true);
2052  SSLHandshakeServer server(std::move(serverSock), true, true);
2053 
2054  eventBase.loop();
2055 
2057  EXPECT_TRUE(client.handshakeError_);
2059  EXPECT_TRUE(server.handshakeError_);
2060 }
2061 
2062 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
2063  EventBase eventBase;
2064  auto clientCtx = std::make_shared<SSLContext>();
2065  auto serverCtx = std::make_shared<SSLContext>();
2066  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2067  serverCtx->loadPrivateKey(kTestKey);
2068  serverCtx->loadCertificate(kTestCert);
2069 
2070  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2071  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2072  clientCtx->loadTrustedCertificates(kTestCA);
2073 
2074  int fds[2];
2075  getfds(fds);
2076 
2077  AsyncSSLSocket::UniquePtr clientSock(
2078  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2079  AsyncSSLSocket::UniquePtr serverSock(
2080  new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2081 
2082  SSLHandshakeClient client(std::move(clientSock), true, true);
2083  SSLHandshakeServer server(std::move(serverSock), true, true);
2084 
2085  serverCtx->sslAcceptRunner(
2086  std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));
2087 
2088  eventBase.loop();
2089 
2091  EXPECT_TRUE(client.handshakeError_);
2093  EXPECT_TRUE(server.handshakeError_);
2094 }
2095 
2096 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
2097  // Start listening on a local port
2098  WriteCallbackBase writeCallback;
2099  WriteErrorCallback readCallback(&writeCallback);
2100  HandshakeCallback handshakeCallback(
2101  &readCallback, HandshakeCallback::EXPECT_ERROR);
2102  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2103  TestSSLServer server(&acceptCallback);
2104 
2105  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2106  socket->open();
2107  uint8_t buf[3] = {0x16, 0x03, 0x01};
2108  socket->write(buf, sizeof(buf));
2109  socket->closeWithReset();
2110 
2111  handshakeCallback.waitForHandshake();
2112  EXPECT_NE(
2113  handshakeCallback.errorString_.find("Network error"), std::string::npos);
2114  EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
2115 }
2116 
2117 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
2118  // Start listening on a local port
2119  WriteCallbackBase writeCallback;
2120  WriteErrorCallback readCallback(&writeCallback);
2121  HandshakeCallback handshakeCallback(
2122  &readCallback, HandshakeCallback::EXPECT_ERROR);
2123  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2124  TestSSLServer server(&acceptCallback);
2125 
2126  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2127  socket->open();
2128  uint8_t buf[3] = {0x16, 0x03, 0x01};
2129  socket->write(buf, sizeof(buf));
2130  socket->close();
2131 
2132  handshakeCallback.waitForHandshake();
2133 #if FOLLY_OPENSSL_IS_110
2134  EXPECT_NE(
2135  handshakeCallback.errorString_.find("Network error"), std::string::npos);
2136 #else
2137  EXPECT_NE(
2138  handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
2139 #endif
2140 }
2141 
2142 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
2143  // Start listening on a local port
2144  WriteCallbackBase writeCallback;
2145  WriteErrorCallback readCallback(&writeCallback);
2146  HandshakeCallback handshakeCallback(
2147  &readCallback, HandshakeCallback::EXPECT_ERROR);
2148  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2149  TestSSLServer server(&acceptCallback);
2150 
2151  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2152  socket->open();
2153  uint8_t buf[256] = {0x16, 0x03};
2154  memset(buf + 2, 'a', sizeof(buf) - 2);
2155  socket->write(buf, sizeof(buf));
2156  socket->close();
2157 
2158  handshakeCallback.waitForHandshake();
2159  EXPECT_NE(
2160  handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
2161 #if defined(OPENSSL_IS_BORINGSSL)
2162  EXPECT_NE(
2163  handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
2164  std::string::npos);
2165 #elif FOLLY_OPENSSL_IS_110
2166  EXPECT_NE(
2167  handshakeCallback.errorString_.find("packet length too long"),
2168  std::string::npos);
2169 #else
2170  EXPECT_NE(
2171  handshakeCallback.errorString_.find("unknown protocol"),
2172  std::string::npos);
2173 #endif
2174 }
2175 
2176 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
2178  EXPECT_EQ(
2179  OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
2180  // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
2181  EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
2182  // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
2183  EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
2184 }
2185 
2186 #if defined __linux__
2187 
2190 TEST(AsyncSSLSocketTest, TTLSDisabled) {
2191  // clear all setsockopt tracking history
2192  globalStatic.reset();
2193 
2194  // Start listening on a local port
2195  WriteCallbackBase writeCallback;
2196  ReadCallback readCallback(&writeCallback);
2197  HandshakeCallback handshakeCallback(&readCallback);
2198  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2199  TestSSLServer server(&acceptCallback, false);
2200 
2201  // Set up SSL context.
2202  auto sslContext = std::make_shared<SSLContext>();
2203 
2204  // connect
2205  auto socket =
2206  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2207  socket->open();
2208 
2209  EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD()));
2210 
2211  // write()
2212  std::array<uint8_t, 128> buf;
2213  memset(buf.data(), 'a', buf.size());
2214  socket->write(buf.data(), buf.size());
2215 
2216  // close()
2217  socket->close();
2218 }
2219 #endif
2220 
2221 #if FOLLY_ALLOW_TFO
2222 
2223 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
2224  public:
2225  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
2226 
2227  explicit MockAsyncTFOSSLSocket(
2228  std::shared_ptr<folly::SSLContext> sslCtx,
2229  EventBase* evb)
2230  : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
2231 
2232  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
2233 };
2234 
2235 #if defined __linux__
2236 
2239 TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) {
2240  // clear all setsockopt tracking history
2241  globalStatic.reset();
2242 
2243  // Start listening on a local port
2244  WriteCallbackBase writeCallback;
2245  ReadCallback readCallback(&writeCallback);
2246  HandshakeCallback handshakeCallback(&readCallback);
2247  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2248  TestSSLServer server(&acceptCallback, true);
2249 
2250  // Set up SSL context.
2251  auto sslContext = std::make_shared<SSLContext>();
2252 
2253  // connect
2254  auto socket =
2255  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2256  socket->enableTFO();
2257  socket->open();
2258 
2259  EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD()));
2260 
2261  // write()
2262  std::array<uint8_t, 128> buf;
2263  memset(buf.data(), 'a', buf.size());
2264  socket->write(buf.data(), buf.size());
2265 
2266  // close()
2267  socket->close();
2268 }
2269 #endif
2270 
2275 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
2276  // Start listening on a local port
2277  WriteCallbackBase writeCallback;
2278  ReadCallback readCallback(&writeCallback);
2279  HandshakeCallback handshakeCallback(&readCallback);
2280  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2281  TestSSLServer server(&acceptCallback, true);
2282 
2283  // Set up SSL context.
2284  auto sslContext = std::make_shared<SSLContext>();
2285 
2286  // connect
2287  auto socket =
2288  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2289  socket->enableTFO();
2290  socket->open();
2291 
2292  // write()
2293  std::array<uint8_t, 128> buf;
2294  memset(buf.data(), 'a', buf.size());
2295  socket->write(buf.data(), buf.size());
2296 
2297  // read()
2298  std::array<uint8_t, 128> readbuf;
2299  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2300  EXPECT_EQ(bytesRead, 128);
2301  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2302 
2303  // close()
2304  socket->close();
2305 }
2306 
2311 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
2312  // Start listening on a local port
2313  WriteCallbackBase writeCallback;
2314  ReadCallback readCallback(&writeCallback);
2315  HandshakeCallback handshakeCallback(&readCallback);
2316  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2317  TestSSLServer server(&acceptCallback, false);
2318 
2319  // Set up SSL context.
2320  auto sslContext = std::make_shared<SSLContext>();
2321 
2322  // connect
2323  auto socket =
2324  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2325  socket->enableTFO();
2326  socket->open();
2327 
2328  // write()
2329  std::array<uint8_t, 128> buf;
2330  memset(buf.data(), 'a', buf.size());
2331  socket->write(buf.data(), buf.size());
2332 
2333  // read()
2334  std::array<uint8_t, 128> readbuf;
2335  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2336  EXPECT_EQ(bytesRead, 128);
2337  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2338 
2339  // close()
2340  socket->close();
2341 }
2342 
2344  public:
2345  void connectSuccess() noexcept override {
2346  state = State::SUCCESS;
2347  }
2348 
2349  void connectErr(const AsyncSocketException& ex) noexcept override {
2350  state = State::ERROR;
2351  error = ex.what();
2352  }
2353 
2354  enum class State { WAITING, SUCCESS, ERROR };
2355 
2356  State state{State::WAITING};
2358 };
2359 
2360 template <class Cardinality>
2361 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
2362  EventBase* evb,
2363  const SocketAddress& address,
2364  Cardinality cardinality) {
2365  // Set up SSL context.
2366  auto sslContext = std::make_shared<SSLContext>();
2367 
2368  // connect
2369  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
2370  new MockAsyncTFOSSLSocket(sslContext, evb));
2371  socket->enableTFO();
2372 
2373  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2374  .Times(cardinality)
2375  .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2376  sockaddr_storage addr;
2377  auto len = address.getAddress(&addr);
2378  return connect(fd, (const struct sockaddr*)&addr, len);
2379  }));
2380  return socket;
2381 }
2382 
2383 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
2384  // Start listening on a local port
2385  WriteCallbackBase writeCallback;
2386  ReadCallback readCallback(&writeCallback);
2387  HandshakeCallback handshakeCallback(&readCallback);
2388  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2389  TestSSLServer server(&acceptCallback, true);
2390 
2391  EventBase evb;
2392 
2393  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
2394  ConnCallback ccb;
2395  socket->connect(&ccb, server.getAddress(), 30);
2396 
2397  evb.loop();
2398  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
2399 
2400  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
2401  evb.loop();
2402 
2404  // write()
2405  std::array<uint8_t, 128> buf;
2406  memset(buf.data(), 'a', buf.size());
2407  sock.write(buf.data(), buf.size());
2408 
2409  // read()
2410  std::array<uint8_t, 128> readbuf;
2411  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
2412  EXPECT_EQ(bytesRead, 128);
2413  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2414 
2415  // close()
2416  sock.close();
2417 }
2418 
2419 #if !defined(OPENSSL_IS_BORINGSSL)
2420 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
2421  // Start listening on a local port
2422  ConnectTimeoutCallback acceptCallback;
2423  TestSSLServer server(&acceptCallback, true);
2424 
2425  // Set up SSL context.
2426  auto sslContext = std::make_shared<SSLContext>();
2427 
2428  // connect
2429  auto socket =
2430  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2431  socket->enableTFO();
2432  EXPECT_THROW(
2433  socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
2434 }
2435 #endif
2436 
2437 #if !defined(OPENSSL_IS_BORINGSSL)
2438 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
2439  // Start listening on a local port
2440  ConnectTimeoutCallback acceptCallback;
2441  TestSSLServer server(&acceptCallback, true);
2442 
2443  EventBase evb;
2444 
2445  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2446  ConnCallback ccb;
2447  // Set a short timeout
2448  socket->connect(&ccb, server.getAddress(), 1);
2449 
2450  evb.loop();
2451  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2452 }
2453 #endif
2454 
2455 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
2456  // Start listening on a local port
2457  EmptyReadCallback readCallback;
2458  HandshakeCallback handshakeCallback(
2459  &readCallback, HandshakeCallback::EXPECT_ERROR);
2460  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
2461  TestSSLServer server(&acceptCallback, true);
2462 
2463  EventBase evb;
2464 
2465  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2466  ConnCallback ccb;
2467  socket->connect(&ccb, server.getAddress(), 100);
2468 
2469  evb.loop();
2470  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2471  EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
2472 }
2473 
2474 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
2475  // Start listening on a local port
2476  EventBase evb;
2477 
2478  // Hopefully nothing is listening on this address
2479  SocketAddress addr("127.0.0.1", 65535);
2480  auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
2481  ConnCallback ccb;
2482  socket->connect(&ccb, addr, 100);
2483 
2484  evb.loop();
2485  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2486  EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
2487 }
2488 
2489 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
2490  EventBase eventBase;
2491  auto clientCtx = std::make_shared<SSLContext>();
2492  auto dfServerCtx = std::make_shared<SSLContext>();
2493  std::array<int, 2> fds;
2494  getfds(fds.data());
2495  getctx(clientCtx, dfServerCtx);
2496 
2497  AsyncSSLSocket::UniquePtr clientSockPtr(
2498  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2499  AsyncSSLSocket::UniquePtr serverSockPtr(
2500  new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
2501  auto clientSock = clientSockPtr.get();
2502  auto serverSock = serverSockPtr.get();
2503  SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2504 
2505  // Steal some data from the server.
2506  std::array<uint8_t, 10> buf;
2507  auto bytesReceived = recv(fds[1], buf.data(), buf.size(), 0);
2508  checkUnixError(bytesReceived, "recv failed");
2509 
2510  serverSock->setPreReceivedData(
2511  IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
2512  SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2513  while (!client.handshakeSuccess_ && !client.handshakeError_) {
2514  eventBase.loopOnce();
2515  }
2516 
2519  EXPECT_EQ(
2520  serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2521 }
2522 
2523 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2524  EventBase eventBase;
2525  auto clientCtx = std::make_shared<SSLContext>();
2526  auto dfServerCtx = std::make_shared<SSLContext>();
2527  std::array<int, 2> fds;
2528  getfds(fds.data());
2529  getctx(clientCtx, dfServerCtx);
2530 
2531  AsyncSSLSocket::UniquePtr clientSockPtr(
2532  new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2533  AsyncSocket::UniquePtr serverSockPtr(new AsyncSocket(&eventBase, fds[1]));
2534  auto clientSock = clientSockPtr.get();
2535  auto serverSock = serverSockPtr.get();
2536  SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2537 
2538  // Steal some data from the server.
2539  std::array<uint8_t, 10> buf;
2540  auto bytesReceived = recv(fds[1], buf.data(), buf.size(), 0);
2541  checkUnixError(bytesReceived, "recv failed");
2542 
2543  serverSock->setPreReceivedData(
2544  IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
2545  AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2546  new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2547  auto serverSSLSock = serverSSLSockPtr.get();
2548  SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2549  while (!client.handshakeSuccess_ && !client.handshakeError_) {
2550  eventBase.loopOnce();
2551  }
2552 
2555  EXPECT_EQ(
2556  serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2557 }
2558 
2563 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2564  // Start listening on a local port
2565  SendMsgFlagsCallback msgCallback;
2566  ExpectWriteErrorCallback writeCallback(&msgCallback);
2567  ReadCallback readCallback(&writeCallback);
2568  HandshakeCallback handshakeCallback(&readCallback);
2569  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2570  TestSSLServer server(&acceptCallback);
2571 
2572  // Set up SSL context.
2573  auto sslContext = std::make_shared<SSLContext>();
2574  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2575 
2576  // connect
2577  auto socket =
2578  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2579  socket->open();
2580 
2581  // Setting flags to "-1" to trigger "Invalid argument" error
2582  // on attempt to use this flags in sendmsg() system call.
2583  msgCallback.resetFlags(-1);
2584 
2585  // write()
2586  std::vector<uint8_t> buf(128, 'a');
2587  ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2588 
2589  // close()
2590  socket->close();
2591 
2592  cerr << "SendMsgParamsCallback test completed" << endl;
2593 }
2594 
2595 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2596 
2600 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2601  // This test requires Linux kernel v4.6 or later
2602  struct utsname s_uname;
2603  memset(&s_uname, 0, sizeof(s_uname));
2604  ASSERT_EQ(uname(&s_uname), 0);
2605  int major, minor;
2606  folly::StringPiece extra;
2607  if (folly::split<false>(
2608  '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2609  if (major < 4 || (major == 4 && minor < 6)) {
2610  LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2611  << "kernel ver. " << s_uname.release << " detected).";
2612  return;
2613  }
2614  }
2615 
2616  // Start listening on a local port
2617  SendMsgDataCallback msgCallback;
2618  WriteCheckTimestampCallback writeCallback(&msgCallback);
2619  ReadCallback readCallback(&writeCallback);
2620  HandshakeCallback handshakeCallback(&readCallback);
2621  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2622  TestSSLServer server(&acceptCallback);
2623 
2624  // Set up SSL context.
2625  auto sslContext = std::make_shared<SSLContext>();
2626  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2627 
2628  // connect
2629  auto socket =
2630  std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2631  socket->open();
2632 
2633  // Adding MSG_EOR flag to the message flags - it'll trigger
2634  // timestamp generation for the last byte of the message.
2635  msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
2636 
2637  // Init ancillary data buffer to trigger timestamp notification
2638  union {
2639  uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2640  struct cmsghdr cmsg;
2641  } u;
2642  u.cmsg.cmsg_level = SOL_SOCKET;
2643  u.cmsg.cmsg_type = SO_TIMESTAMPING;
2644  u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2645  uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
2646  SOF_TIMESTAMPING_TX_ACK;
2647  memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2648  std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2649  memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2650  msgCallback.resetData(std::move(ctrl));
2651 
2652  // write()
2653  std::vector<uint8_t> buf(128, 'a');
2654  socket->write(buf.data(), buf.size());
2655 
2656  // read()
2657  std::vector<uint8_t> readbuf(buf.size());
2658  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2659  EXPECT_EQ(bytesRead, buf.size());
2660  EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2661 
2662  writeCallback.checkForTimestampNotifications();
2663 
2664  // close()
2665  socket->close();
2666 
2667  cerr << "SendMsgDataCallback test completed" << endl;
2668 }
2669 #endif // FOLLY_HAVE_MSG_ERRQUEUE
2670 
2671 #endif
2672 
2673 } // namespace folly
2674 
2675 #ifdef SIGPIPE
2676 // init_unit_test_suite
2679 namespace {
2680 struct Initializer {
2681  Initializer() {
2682  signal(SIGPIPE, SIG_IGN);
2683  }
2684 };
2685 Initializer initializer;
2686 } // namespace
2687 #endif
#define EXPECT_LE(val1, val2)
Definition: gtest.h:1928
void * ptr
std::vector< uint8_t > buffer(kBufferSize+16)
void resetData(std::vector< char > &&data)
std::string getFileAsBuf(const char *fileName)
const char * kTestCert
bool readFile(int fd, Container &out, size_t num_bytes=std::numeric_limits< size_t >::max())
Definition: FileUtil.h:125
flags
Definition: http_parser.h:127
bool clientProtoFilterPickNone(unsigned char **, unsigned int *, const unsigned char *, unsigned int)
#define FAIL()
Definition: gtest.h:1822
int connect(NetworkSocket s, const sockaddr *name, socklen_t namelen)
Definition: NetOps.cpp:94
#define EXPECT_THROW(statement, expected_exception)
Definition: gtest.h:1843
#define ASSERT_EQ(val1, val2)
Definition: gtest.h:1956
void readDataAvailable(size_t len) noexceptoverride
std::unique_ptr< X509, X509Deleter > X509UniquePtr
GTEST_API_ Cardinality AtMost(int n)
std::unique_ptr< BIO, BioDeleter > BioUniquePtr
int setsockopt(NetworkSocket s, int level, int optname, const void *optval, socklen_t optlen)
Definition: NetOps.cpp:384
TEST_F(TestInfoTest, Names)
const std::string kTestKey
std::enable_if< std::is_arithmetic< T >::value >::type write(T value)
Definition: Cursor.h:737
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:1922
const char * kClientTestKey
constexpr detail::Map< Move > move
Definition: Base-inl.h:2567
EventBase & getEventBase()
std::vector< Buffer > buffers
#define SCOPE_EXIT
Definition: ScopeGuard.h:274
void sslsocketpair(EventBase *eventBase, AsyncSSLSocket::UniquePtr *clientSock, AsyncSSLSocket::UniquePtr *serverSock)
std::unique_ptr< EVP_PKEY, EvpPkeyDeleter > EvpPkeyUniquePtr
—— Concurrent Priority Queue Implementation ——
Definition: AtomicBitSet.h:29
requires E e noexcept(noexcept(s.error(std::move(e))))
const char * kClientTestCert
requires And< SemiMovable< VN >... > &&SemiMovable< E > auto error(E e)
Definition: error.h:48
#define MOCK_METHOD3(m,...)
std::unique_ptr< AsyncSSLSocket, Destructor > UniquePtr
State
See Core for details.
Definition: Core.h:43
LogLevel min
Definition: LogLevel.cpp:30
StateEnum state
folly::Synchronized< EventBase * > base_
const char * kTestCA
bool loopOnce(int flags=0)
Definition: EventBase.cpp:271
PolymorphicMatcher< internal::HasSubstrMatcher< internal::string > > HasSubstr(const internal::string &substring)
PolymorphicAction< internal::InvokeAction< FunctionImpl > > Invoke(FunctionImpl function_impl)
std::enable_if< detail::is_chrono_conversion< Tgt, Src >::value, Tgt >::type to(const Src &value)
Definition: Conv.h:677
TEST(GTestEnvVarTest, Dummy)
void terminateLoopSoon()
Definition: EventBase.cpp:493
AsyncServerSocket::UniquePtr socket_
Encoder::MutableCompressedList list
int32_t readAll(uint8_t *buf, size_t len)
bool runInEventBaseThread(void(*fn)(T *), T *arg)
Definition: EventBase.h:794
bool clientProtoFilterPickPony(unsigned char **client, unsigned int *client_len, const unsigned char *, unsigned int)
void getfds(int fds[2])
NetworkSocket socket(int af, int type, int protocol)
Definition: NetOps.cpp:412
socklen_t getAddress(sockaddr_storage *addr) const
PUSHMI_INLINE_VAR constexpr struct folly::pushmi::operators::from_fn from
fbstring errnoStr(int err)
Definition: String.cpp:463
void checkUnixError(ssize_t ret, Args &&...args)
Definition: Exception.h:101
#define EXPECT_TRUE(condition)
Definition: gtest.h:1859
std::chrono::nanoseconds handshakeTime
#define EXPECT_THAT(value, matcher)
std::shared_ptr< AsyncSSLSocket > getSocket()
const char * string
Definition: Conv.cpp:212
Range< const unsigned char * > ByteRange
Definition: Range.h:1163
void getctx(std::shared_ptr< folly::SSLContext > clientCtx, std::shared_ptr< folly::SSLContext > serverCtx)
#define EXPECT_NE(val1, val2)
Definition: gtest.h:1926
static UniquePtr newWriter(Args &&...args)
Definition: AsyncPipe.h:104
#define EXPECT_CALL(obj, call)
const internal::AnythingMatcher _
#define EXPECT_FALSE(condition)
Definition: gtest.h:1862
ssize_t recv(NetworkSocket s, void *buf, size_t len, int flags)
Definition: NetOps.cpp:180
const char * kClientTestCA
static std::unique_ptr< IOBuf > copyBuffer(const void *buf, std::size_t size, std::size_t headroom=0, std::size_t minTailroom=0)
Definition: IOBuf.h:1587
std::shared_ptr< AsyncSocket > tcpSocket_
ReadCallbackTerminator(EventBase *base, WriteCallbackBase *wcb)
char c
#define ASSERT_TRUE(condition)
Definition: gtest.h:1865
ThreadPoolListHook * addr
std::unique_ptr< AsyncSocket, Destructor > UniquePtr
Definition: AsyncSocket.h:83
const SocketAddress & getAddress() const
void pipe(CPUExecutor cpu, IOExecutor io)
state
Definition: http_parser.c:272
int32_t write(uint8_t const *buf, size_t len)
void closeNow() override
std::unique_ptr< SSL, SSLDeleter > SSLUniquePtr
void setSocket(const std::shared_ptr< AsyncSSLSocket > &socket)
int socketpair(int domain, int type, int protocol, NetworkSocket sv[2])
Definition: NetOps.cpp:416