Built motion from commit 6a09e18b.|2.6.11
[motion2.git] / legacy-libs / grpc / deps / grpc / src / core / lib / security / transport / security_handshaker.cc
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/lib/security/transport/security_handshaker.h"
22
23 #include <stdbool.h>
24 #include <string.h>
25
26 #include <grpc/slice_buffer.h>
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/log.h>
29
30 #include "src/core/lib/channel/channel_args.h"
31 #include "src/core/lib/channel/handshaker.h"
32 #include "src/core/lib/channel/handshaker_registry.h"
33 #include "src/core/lib/gprpp/ref_counted_ptr.h"
34 #include "src/core/lib/security/context/security_context.h"
35 #include "src/core/lib/security/transport/secure_endpoint.h"
36 #include "src/core/lib/security/transport/tsi_error.h"
37 #include "src/core/lib/slice/slice_internal.h"
38 #include "src/core/tsi/transport_security_grpc.h"
39
40 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
41
42 namespace grpc_core {
43
44 namespace {
45
46 class SecurityHandshaker : public Handshaker {
47  public:
48   SecurityHandshaker(tsi_handshaker* handshaker,
49                      grpc_security_connector* connector);
50   ~SecurityHandshaker() override;
51   void Shutdown(grpc_error* why) override;
52   void DoHandshake(grpc_tcp_server_acceptor* acceptor,
53                    grpc_closure* on_handshake_done,
54                    HandshakerArgs* args) override;
55   const char* name() const override { return "security"; }
56
57  private:
58   grpc_error* DoHandshakerNextLocked(const unsigned char* bytes_received,
59                                      size_t bytes_received_size);
60
61   grpc_error* OnHandshakeNextDoneLocked(
62       tsi_result result, const unsigned char* bytes_to_send,
63       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
64   void HandshakeFailedLocked(grpc_error* error);
65   void CleanupArgsForFailureLocked();
66
67   static void OnHandshakeDataReceivedFromPeerFn(void* arg, grpc_error* error);
68   static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error* error);
69   static void OnHandshakeNextDoneGrpcWrapper(
70       tsi_result result, void* user_data, const unsigned char* bytes_to_send,
71       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
72   static void OnPeerCheckedFn(void* arg, grpc_error* error);
73   void OnPeerCheckedInner(grpc_error* error);
74   size_t MoveReadBufferIntoHandshakeBuffer();
75   grpc_error* CheckPeerLocked();
76
77   // State set at creation time.
78   tsi_handshaker* handshaker_;
79   RefCountedPtr<grpc_security_connector> connector_;
80
81   gpr_mu mu_;
82
83   bool is_shutdown_ = false;
84   // Endpoint and read buffer to destroy after a shutdown.
85   grpc_endpoint* endpoint_to_destroy_ = nullptr;
86   grpc_slice_buffer* read_buffer_to_destroy_ = nullptr;
87
88   // State saved while performing the handshake.
89   HandshakerArgs* args_ = nullptr;
90   grpc_closure* on_handshake_done_ = nullptr;
91
92   size_t handshake_buffer_size_;
93   unsigned char* handshake_buffer_;
94   grpc_slice_buffer outgoing_;
95   grpc_closure on_handshake_data_sent_to_peer_;
96   grpc_closure on_handshake_data_received_from_peer_;
97   grpc_closure on_peer_checked_;
98   RefCountedPtr<grpc_auth_context> auth_context_;
99   tsi_handshaker_result* handshaker_result_ = nullptr;
100 };
101
102 SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
103                                        grpc_security_connector* connector)
104     : handshaker_(handshaker),
105       connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
106       handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
107       handshake_buffer_(
108           static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) {
109   gpr_mu_init(&mu_);
110   grpc_slice_buffer_init(&outgoing_);
111   GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer_,
112                     &SecurityHandshaker::OnHandshakeDataSentToPeerFn, this,
113                     grpc_schedule_on_exec_ctx);
114   GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer_,
115                     &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn,
116                     this, grpc_schedule_on_exec_ctx);
117   GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
118                     this, grpc_schedule_on_exec_ctx);
119 }
120
121 SecurityHandshaker::~SecurityHandshaker() {
122   gpr_mu_destroy(&mu_);
123   tsi_handshaker_destroy(handshaker_);
124   tsi_handshaker_result_destroy(handshaker_result_);
125   if (endpoint_to_destroy_ != nullptr) {
126     grpc_endpoint_destroy(endpoint_to_destroy_);
127   }
128   if (read_buffer_to_destroy_ != nullptr) {
129     grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_);
130     gpr_free(read_buffer_to_destroy_);
131   }
132   gpr_free(handshake_buffer_);
133   grpc_slice_buffer_destroy_internal(&outgoing_);
134   auth_context_.reset(DEBUG_LOCATION, "handshake");
135   connector_.reset(DEBUG_LOCATION, "handshake");
136 }
137
138 size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
139   size_t bytes_in_read_buffer = args_->read_buffer->length;
140   if (handshake_buffer_size_ < bytes_in_read_buffer) {
141     handshake_buffer_ = static_cast<uint8_t*>(
142         gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
143     handshake_buffer_size_ = bytes_in_read_buffer;
144   }
145   size_t offset = 0;
146   while (args_->read_buffer->count > 0) {
147     grpc_slice* next_slice = grpc_slice_buffer_peek_first(args_->read_buffer);
148     memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(*next_slice),
149            GRPC_SLICE_LENGTH(*next_slice));
150     offset += GRPC_SLICE_LENGTH(*next_slice);
151     grpc_slice_buffer_remove_first(args_->read_buffer);
152   }
153   return bytes_in_read_buffer;
154 }
155
156 // Set args_ fields to NULL, saving the endpoint and read buffer for
157 // later destruction.
158 void SecurityHandshaker::CleanupArgsForFailureLocked() {
159   endpoint_to_destroy_ = args_->endpoint;
160   args_->endpoint = nullptr;
161   read_buffer_to_destroy_ = args_->read_buffer;
162   args_->read_buffer = nullptr;
163   grpc_channel_args_destroy(args_->args);
164   args_->args = nullptr;
165 }
166
167 // If the handshake failed or we're shutting down, clean up and invoke the
168 // callback with the error.
169 void SecurityHandshaker::HandshakeFailedLocked(grpc_error* error) {
170   if (error == GRPC_ERROR_NONE) {
171     // If we were shut down after the handshake succeeded but before an
172     // endpoint callback was invoked, we need to generate our own error.
173     error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
174   }
175   const char* msg = grpc_error_string(error);
176   gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg);
177
178   if (!is_shutdown_) {
179     // TODO(ctiller): It is currently necessary to shutdown endpoints
180     // before destroying them, even if we know that there are no
181     // pending read/write callbacks.  This should be fixed, at which
182     // point this can be removed.
183     grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error));
184     // Not shutting down, so the write failed.  Clean up before
185     // invoking the callback.
186     CleanupArgsForFailureLocked();
187     // Set shutdown to true so that subsequent calls to
188     // security_handshaker_shutdown() do nothing.
189     is_shutdown_ = true;
190   }
191   // Invoke callback.
192   GRPC_CLOSURE_SCHED(on_handshake_done_, error);
193 }
194
195 void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) {
196   MutexLock lock(&mu_);
197   if (error != GRPC_ERROR_NONE || is_shutdown_) {
198     HandshakeFailedLocked(error);
199     return;
200   }
201   // Create zero-copy frame protector, if implemented.
202   tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
203   tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector(
204       handshaker_result_, nullptr, &zero_copy_protector);
205   if (result != TSI_OK && result != TSI_UNIMPLEMENTED) {
206     error = grpc_set_tsi_error_result(
207         GRPC_ERROR_CREATE_FROM_STATIC_STRING(
208             "Zero-copy frame protector creation failed"),
209         result);
210     HandshakeFailedLocked(error);
211     return;
212   }
213   // Create frame protector if zero-copy frame protector is NULL.
214   tsi_frame_protector* protector = nullptr;
215   if (zero_copy_protector == nullptr) {
216     result = tsi_handshaker_result_create_frame_protector(handshaker_result_,
217                                                           nullptr, &protector);
218     if (result != TSI_OK) {
219       error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
220                                             "Frame protector creation failed"),
221                                         result);
222       HandshakeFailedLocked(error);
223       return;
224     }
225   }
226   // Get unused bytes.
227   const unsigned char* unused_bytes = nullptr;
228   size_t unused_bytes_size = 0;
229   result = tsi_handshaker_result_get_unused_bytes(
230       handshaker_result_, &unused_bytes, &unused_bytes_size);
231   // Create secure endpoint.
232   if (unused_bytes_size > 0) {
233     grpc_slice slice =
234         grpc_slice_from_copied_buffer((char*)unused_bytes, unused_bytes_size);
235     args_->endpoint = grpc_secure_endpoint_create(
236         protector, zero_copy_protector, args_->endpoint, &slice, 1);
237     grpc_slice_unref_internal(slice);
238   } else {
239     args_->endpoint = grpc_secure_endpoint_create(
240         protector, zero_copy_protector, args_->endpoint, nullptr, 0);
241   }
242   tsi_handshaker_result_destroy(handshaker_result_);
243   handshaker_result_ = nullptr;
244   // Add auth context to channel args.
245   grpc_arg auth_context_arg = grpc_auth_context_to_arg(auth_context_.get());
246   grpc_channel_args* tmp_args = args_->args;
247   args_->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
248   grpc_channel_args_destroy(tmp_args);
249   // Invoke callback.
250   GRPC_CLOSURE_SCHED(on_handshake_done_, GRPC_ERROR_NONE);
251   // Set shutdown to true so that subsequent calls to
252   // security_handshaker_shutdown() do nothing.
253   is_shutdown_ = true;
254 }
255
256 void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error* error) {
257   RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
258       ->OnPeerCheckedInner(GRPC_ERROR_REF(error));
259 }
260
261 grpc_error* SecurityHandshaker::CheckPeerLocked() {
262   tsi_peer peer;
263   tsi_result result =
264       tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
265   if (result != TSI_OK) {
266     return grpc_set_tsi_error_result(
267         GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result);
268   }
269   connector_->check_peer(peer, args_->endpoint, &auth_context_,
270                          &on_peer_checked_);
271   return GRPC_ERROR_NONE;
272 }
273
274 grpc_error* SecurityHandshaker::OnHandshakeNextDoneLocked(
275     tsi_result result, const unsigned char* bytes_to_send,
276     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
277   grpc_error* error = GRPC_ERROR_NONE;
278   // Handshaker was shutdown.
279   if (is_shutdown_) {
280     return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
281   }
282   // Read more if we need to.
283   if (result == TSI_INCOMPLETE_DATA) {
284     GPR_ASSERT(bytes_to_send_size == 0);
285     grpc_endpoint_read(args_->endpoint, args_->read_buffer,
286                        &on_handshake_data_received_from_peer_, /*urgent=*/true);
287     return error;
288   }
289   if (result != TSI_OK) {
290     return grpc_set_tsi_error_result(
291         GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake failed"), result);
292   }
293   // Update handshaker result.
294   if (handshaker_result != nullptr) {
295     GPR_ASSERT(handshaker_result_ == nullptr);
296     handshaker_result_ = handshaker_result;
297   }
298   if (bytes_to_send_size > 0) {
299     // Send data to peer, if needed.
300     grpc_slice to_send = grpc_slice_from_copied_buffer(
301         reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size);
302     grpc_slice_buffer_reset_and_unref_internal(&outgoing_);
303     grpc_slice_buffer_add(&outgoing_, to_send);
304     grpc_endpoint_write(args_->endpoint, &outgoing_,
305                         &on_handshake_data_sent_to_peer_, nullptr);
306   } else if (handshaker_result == nullptr) {
307     // There is nothing to send, but need to read from peer.
308     grpc_endpoint_read(args_->endpoint, args_->read_buffer,
309                        &on_handshake_data_received_from_peer_, /*urgent=*/true);
310   } else {
311     // Handshake has finished, check peer and so on.
312     error = CheckPeerLocked();
313   }
314   return error;
315 }
316
317 void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
318     tsi_result result, void* user_data, const unsigned char* bytes_to_send,
319     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
320   RefCountedPtr<SecurityHandshaker> h(
321       static_cast<SecurityHandshaker*>(user_data));
322   MutexLock lock(&h->mu_);
323   grpc_error* error = h->OnHandshakeNextDoneLocked(
324       result, bytes_to_send, bytes_to_send_size, handshaker_result);
325   if (error != GRPC_ERROR_NONE) {
326     h->HandshakeFailedLocked(error);
327   } else {
328     h.release();  // Avoid unref
329   }
330 }
331
332 grpc_error* SecurityHandshaker::DoHandshakerNextLocked(
333     const unsigned char* bytes_received, size_t bytes_received_size) {
334   // Invoke TSI handshaker.
335   const unsigned char* bytes_to_send = nullptr;
336   size_t bytes_to_send_size = 0;
337   tsi_handshaker_result* hs_result = nullptr;
338   tsi_result result = tsi_handshaker_next(
339       handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
340       &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this);
341   if (result == TSI_ASYNC) {
342     // Handshaker operating asynchronously. Nothing else to do here;
343     // callback will be invoked in a TSI thread.
344     return GRPC_ERROR_NONE;
345   }
346   // Handshaker returned synchronously. Invoke callback directly in
347   // this thread with our existing exec_ctx.
348   return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
349                                    hs_result);
350 }
351
352 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(void* arg,
353                                                            grpc_error* error) {
354   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
355   MutexLock lock(&h->mu_);
356   if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
357     h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
358         "Handshake read failed", &error, 1));
359     return;
360   }
361   // Copy all slices received.
362   size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer();
363   // Call TSI handshaker.
364   error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size);
365
366   if (error != GRPC_ERROR_NONE) {
367     h->HandshakeFailedLocked(error);
368   } else {
369     h.release();  // Avoid unref
370   }
371 }
372
373 void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg,
374                                                      grpc_error* error) {
375   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
376   MutexLock lock(&h->mu_);
377   if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
378     h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
379         "Handshake write failed", &error, 1));
380     return;
381   }
382   // We may be done.
383   if (h->handshaker_result_ == nullptr) {
384     grpc_endpoint_read(h->args_->endpoint, h->args_->read_buffer,
385                        &h->on_handshake_data_received_from_peer_,
386                        /*urgent=*/true);
387   } else {
388     error = h->CheckPeerLocked();
389     if (error != GRPC_ERROR_NONE) {
390       h->HandshakeFailedLocked(error);
391       return;
392     }
393   }
394   h.release();  // Avoid unref
395 }
396
397 //
398 // public handshaker API
399 //
400
401 void SecurityHandshaker::Shutdown(grpc_error* why) {
402   MutexLock lock(&mu_);
403   if (!is_shutdown_) {
404     is_shutdown_ = true;
405     tsi_handshaker_shutdown(handshaker_);
406     grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why));
407     CleanupArgsForFailureLocked();
408   }
409   GRPC_ERROR_UNREF(why);
410 }
411
412 void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* acceptor,
413                                      grpc_closure* on_handshake_done,
414                                      HandshakerArgs* args) {
415   auto ref = Ref();
416   MutexLock lock(&mu_);
417   args_ = args;
418   on_handshake_done_ = on_handshake_done;
419   size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
420   grpc_error* error =
421       DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
422   if (error != GRPC_ERROR_NONE) {
423     HandshakeFailedLocked(error);
424   } else {
425     ref.release();  // Avoid unref
426   }
427 }
428
429 //
430 // FailHandshaker
431 //
432
433 class FailHandshaker : public Handshaker {
434  public:
435   const char* name() const override { return "security_fail"; }
436   void Shutdown(grpc_error* why) override { GRPC_ERROR_UNREF(why); }
437   void DoHandshake(grpc_tcp_server_acceptor* acceptor,
438                    grpc_closure* on_handshake_done,
439                    HandshakerArgs* args) override {
440     GRPC_CLOSURE_SCHED(on_handshake_done,
441                        GRPC_ERROR_CREATE_FROM_STATIC_STRING(
442                            "Failed to create security handshaker"));
443   }
444
445  private:
446   virtual ~FailHandshaker() = default;
447 };
448
449 //
450 // handshaker factories
451 //
452
453 class ClientSecurityHandshakerFactory : public HandshakerFactory {
454  public:
455   void AddHandshakers(const grpc_channel_args* args,
456                       grpc_pollset_set* interested_parties,
457                       HandshakeManager* handshake_mgr) override {
458     auto* security_connector =
459         reinterpret_cast<grpc_channel_security_connector*>(
460             grpc_security_connector_find_in_args(args));
461     if (security_connector) {
462       security_connector->add_handshakers(interested_parties, handshake_mgr);
463     }
464   }
465   ~ClientSecurityHandshakerFactory() override = default;
466 };
467
468 class ServerSecurityHandshakerFactory : public HandshakerFactory {
469  public:
470   void AddHandshakers(const grpc_channel_args* args,
471                       grpc_pollset_set* interested_parties,
472                       HandshakeManager* handshake_mgr) override {
473     auto* security_connector =
474         reinterpret_cast<grpc_server_security_connector*>(
475             grpc_security_connector_find_in_args(args));
476     if (security_connector) {
477       security_connector->add_handshakers(interested_parties, handshake_mgr);
478     }
479   }
480   ~ServerSecurityHandshakerFactory() override = default;
481 };
482
483 }  // namespace
484
485 //
486 // exported functions
487 //
488
489 RefCountedPtr<Handshaker> SecurityHandshakerCreate(
490     tsi_handshaker* handshaker, grpc_security_connector* connector) {
491   // If no TSI handshaker was created, return a handshaker that always fails.
492   // Otherwise, return a real security handshaker.
493   if (handshaker == nullptr) {
494     return MakeRefCounted<FailHandshaker>();
495   } else {
496     return MakeRefCounted<SecurityHandshaker>(handshaker, connector);
497   }
498 }
499
500 void SecurityRegisterHandshakerFactories() {
501   HandshakerRegistry::RegisterHandshakerFactory(
502       false /* at_start */, HANDSHAKER_CLIENT,
503       UniquePtr<HandshakerFactory>(New<ClientSecurityHandshakerFactory>()));
504   HandshakerRegistry::RegisterHandshakerFactory(
505       false /* at_start */, HANDSHAKER_SERVER,
506       UniquePtr<HandshakerFactory>(New<ServerSecurityHandshakerFactory>()));
507 }
508
509 }  // namespace grpc_core
510
511 grpc_handshaker* grpc_security_handshaker_create(
512     tsi_handshaker* handshaker, grpc_security_connector* connector) {
513   return SecurityHandshakerCreate(handshaker, connector).release();
514 }