proxygen
AsyncSSLSocketTest2.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2012-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
17 
18 #include <folly/futures/Promise.h>
19 #include <folly/init/Init.h>
26 #include <folly/ssl/Init.h>
27 
28 using std::cerr;
29 using std::endl;
30 using std::list;
31 using std::min;
32 using std::string;
33 using std::vector;
34 
35 namespace folly {
36 
37 struct EvbAndContext {
39  ctx_.reset(new SSLContext());
40  ctx_->setOptions(SSL_OP_NO_TICKET);
41  ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
42  }
43 
44  std::shared_ptr<AsyncSSLSocket> createSocket() {
46  }
47 
49  return evb_.getEventBase();
50  }
51 
53  socket.attachEventBase(getEventBase());
54  socket.attachSSLContext(ctx_);
55  }
56 
58  std::shared_ptr<SSLContext> ctx_;
59 };
60 
64  private:
65  // two threads here - we'll create the socket in one, connect
66  // in the other, and then read/write in the initial one
69  std::shared_ptr<AsyncSSLSocket> sslSocket_;
71  char buf_[128];
72  char readbuf_[128];
74  // promise to fulfill when done
76 
77  void detach() {
78  sslSocket_->detachEventBase();
79  sslSocket_->detachSSLContext();
80  }
81 
82  public:
83  explicit AttachDetachClient(const folly::SocketAddress& address)
84  : address_(address), bytesRead_(0) {}
85 
87  return promise_.getFuture();
88  }
89 
90  void connect() {
91  // create in one and then move to another
92  auto t1Evb = t1_.getEventBase();
93  t1Evb->runInEventBaseThread([this] {
94  sslSocket_ = t1_.createSocket();
95  // ensure we can detach and reattach the context before connecting
96  for (int i = 0; i < 1000; ++i) {
97  sslSocket_->detachSSLContext();
98  sslSocket_->attachSSLContext(t1_.ctx_);
99  }
100  // detach from t1 and connect in t2
101  detach();
102  auto t2Evb = t2_.getEventBase();
103  t2Evb->runInEventBaseThread([this] {
104  t2_.attach(*sslSocket_);
105  sslSocket_->connect(this, address_);
106  });
107  });
108  }
109 
110  void connectSuccess() noexcept override {
111  auto t2Evb = t2_.getEventBase();
112  EXPECT_TRUE(t2Evb->isInEventBaseThread());
113  cerr << "client SSL socket connected" << endl;
114  for (int i = 0; i < 1000; ++i) {
115  sslSocket_->detachSSLContext();
116  sslSocket_->attachSSLContext(t2_.ctx_);
117  }
118 
119  // detach from t2 and then read/write in t1
120  t2Evb->runInEventBaseThread([this] {
121  detach();
122  auto t1Evb = t1_.getEventBase();
123  t1Evb->runInEventBaseThread([this] {
124  t1_.attach(*sslSocket_);
125  sslSocket_->write(this, buf_, sizeof(buf_));
126  sslSocket_->setReadCB(this);
127  memset(readbuf_, 'b', sizeof(readbuf_));
128  bytesRead_ = 0;
129  });
130  });
131  }
132 
133  void connectErr(const AsyncSocketException& ex) noexcept override {
134  cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
135  sslSocket_.reset();
136  }
137 
138  void writeSuccess() noexcept override {
139  cerr << "client write success" << endl;
140  }
141 
142  void writeErr(
143  size_t /* bytesWritten */,
144  const AsyncSocketException& ex) noexcept override {
145  cerr << "client writeError: " << ex.what() << endl;
146  }
147 
148  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
149  *bufReturn = readbuf_ + bytesRead_;
150  *lenReturn = sizeof(readbuf_) - bytesRead_;
151  }
152  void readEOF() noexcept override {
153  cerr << "client readEOF" << endl;
154  }
155 
156  void readErr(const AsyncSocketException& ex) noexcept override {
157  cerr << "client readError: " << ex.what() << endl;
158  promise_.setException(ex);
159  }
160 
161  void readDataAvailable(size_t len) noexcept override {
163  EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
164  cerr << "client read data: " << len << endl;
165  bytesRead_ += len;
166  if (len == sizeof(buf_)) {
167  EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
168  sslSocket_->closeNow();
169  sslSocket_.reset();
170  promise_.setValue(true);
171  }
172  }
173 };
174 
178 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
179  // Start listening on a local port
180  WriteCallbackBase writeCallback;
181  ReadCallback readCallback(&writeCallback);
182  HandshakeCallback handshakeCallback(&readCallback);
183  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
184  TestSSLServer server(&acceptCallback);
185 
186  std::shared_ptr<AttachDetachClient> client(
187  new AttachDetachClient(server.getAddress()));
188 
189  auto f = client->getFuture();
190  client->connect();
191  EXPECT_TRUE(std::move(f).within(std::chrono::seconds(3)).get());
192 }
193 
195  public:
196  ConnectClient() = default;
197 
199  return promise_.getFuture();
200  }
201 
203  t1_.getEventBase()->runInEventBaseThread([&] {
204  socket_ = t1_.createSocket();
205  socket_->connect(this, addr);
206  });
207  }
208 
209  void connectSuccess() noexcept override {
210  socket_.reset();
211  promise_.setValue(true);
212  }
213 
214  void connectErr(const AsyncSocketException& /* ex */) noexcept override {
215  socket_.reset();
216  promise_.setValue(false);
217  }
218 
219  void setCtx(std::shared_ptr<SSLContext> ctx) {
220  t1_.ctx_ = ctx;
221  }
222 
223  private:
225  // promise to fulfill when done with a value of true if connect succeeded
227  std::shared_ptr<AsyncSSLSocket> socket_;
228 };
229 
231  public:
234  }
235 
236  void getReadBuffer(void** buf, size_t* lenReturn) override {
237  *buf = &buffer_;
238  *lenReturn = 1;
239  }
240  void readDataAvailable(size_t) noexcept override {}
241 
243 };
244 
245 TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
246  // Start listening on a local port
247  NoopReadCallback readCallback;
248  HandshakeCallback handshakeCallback(&readCallback);
249  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
250  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
251  TestSSLServer server(&acceptCallback, ctx);
252  server.loadTestCerts();
253 
254  // create a default client
255  auto c1 = std::make_unique<ConnectClient>();
256  auto f1 = c1->getFuture();
257  c1->connect(server.getAddress());
258  EXPECT_TRUE(std::move(f1).within(std::chrono::seconds(3)).get());
259 }
260 
261 TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
262  // Start listening on a local port
263  NoopReadCallback readCallback;
264  HandshakeCallback handshakeCallback(
265  &readCallback, HandshakeCallback::EXPECT_ERROR);
266  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
267  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
268  TestSSLServer server(&acceptCallback, ctx);
269  server.loadTestCerts();
270 
271  // create a client that doesn't speak TLS 1.2
272  auto c2 = std::make_unique<ConnectClient>();
273  auto clientCtx = std::make_shared<SSLContext>();
274  clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
275  c2->setCtx(clientCtx);
276  auto f2 = c2->getFuture();
277  c2->connect(server.getAddress());
278  EXPECT_FALSE(std::move(f2).within(std::chrono::seconds(3)).get());
279 }
280 
281 } // namespace folly
282 
283 int main(int argc, char* argv[]) {
285 #ifdef SIGPIPE
286  signal(SIGPIPE, SIG_IGN);
287 #endif
288  testing::InitGoogleTest(&argc, argv);
289  folly::init(&argc, &argv);
290  return RUN_ALL_TESTS();
291  OPENSSL_cleanup();
292 }
void setCtx(std::shared_ptr< SSLContext > ctx)
auto f
void getReadBuffer(void **bufReturn, size_t *lenReturn) override
int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_
Definition: gtest.h:2232
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:1922
constexpr detail::Map< Move > move
Definition: Base-inl.h:2567
std::shared_ptr< SSLContext > ctx_
std::shared_ptr< AsyncSSLSocket > createSocket()
void init()
Definition: Init.cpp:54
void setException(exception_wrapper ew)
Definition: Promise-inl.h:111
std::shared_ptr< AsyncSSLSocket > socket_
—— Concurrent Priority Queue Implementation ——
Definition: AtomicBitSet.h:29
requires E e noexcept(noexcept(s.error(std::move(e))))
#define nullptr
Definition: http_parser.c:41
folly::Promise< bool > promise_
void connectSuccess() noexceptoverride
static std::shared_ptr< AsyncSSLSocket > newSocket(const std::shared_ptr< folly::SSLContext > &ctx, EventBase *evb, int fd, bool server=true, bool deferSecurityNegotiation=false)
void init(int *argc, char ***argv, bool removeFlags)
Definition: Init.cpp:34
AttachDetachClient(const folly::SocketAddress &address)
char ** argv
bool isInEventBaseThread() const
Definition: EventBase.h:504
LogLevel min
Definition: LogLevel.cpp:30
folly::Promise< bool > promise_
void readDataAvailable(size_t) noexceptoverride
int main(int argc, char *argv[])
AsyncServerSocket::UniquePtr socket_
Encoder::MutableCompressedList list
bool runInEventBaseThread(void(*fn)(T *), T *arg)
Definition: EventBase.h:794
Future< T > getFuture()
Definition: Promise-inl.h:97
Promise< Unit > promise_
void writeSuccess() noexceptoverride
NetworkSocket socket(int af, int type, int protocol)
Definition: NetOps.cpp:412
void readEOF() noexceptoverride
void attachEventBase(EventBase *eventBase) override
folly::ScopedEventBaseThread evb_
#define EXPECT_TRUE(condition)
Definition: gtest.h:1859
void connectErr(const AsyncSocketException &) noexceptoverride
std::enable_if< std::is_same< Unit, B >::value, void >::type setValue()
Definition: Promise.h:326
void connectSuccess() noexceptoverride
const char * string
Definition: Conv.cpp:212
void readErr(const AsyncSocketException &ex) noexceptoverride
void connectErr(const AsyncSocketException &ex) noexceptoverride
void connect(const folly::SocketAddress &addr)
GTEST_API_ void InitGoogleTest(int *argc, char **argv)
Definition: gtest.cc:5370
#define EXPECT_FALSE(condition)
Definition: gtest.h:1862
folly::SocketAddress address_
void getReadBuffer(void **buf, size_t *lenReturn) override
void readDataAvailable(size_t len) noexceptoverride
void writeErr(size_t, const AsyncSocketException &ex) noexceptoverride
std::shared_ptr< AsyncSSLSocket > sslSocket_
std::unique_ptr< unsigned char[]> buffer_
Definition: Random.cpp:105
ThreadPoolListHook * addr
TEST(SequencedExecutor, CPUThreadPoolExecutor)
const SocketAddress & getAddress() const
state
Definition: http_parser.c:272
void attach(AsyncSSLSocket &socket)