Built motion from commit 6a09e18b.|2.6.11
[motion2.git] / legacy-libs / grpc / deps / grpc / src / core / lib / security / transport / secure_endpoint.cc
diff --git a/legacy-libs/grpc/deps/grpc/src/core/lib/security/transport/secure_endpoint.cc b/legacy-libs/grpc/deps/grpc/src/core/lib/security/transport/secure_endpoint.cc
new file mode 100644 (file)
index 0000000..0aac7d8
--- /dev/null
@@ -0,0 +1,445 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+/* With the addition of a libuv endpoint, sockaddr.h now includes uv.h when
+   using that endpoint. Because of various transitive includes in uv.h,
+   including windows.h on Windows, uv.h must be included before other system
+   headers. Therefore, sockaddr.h must always be included first */
+#include <grpc/support/port_platform.h>
+
+#include <new>
+
+#include "src/core/lib/iomgr/sockaddr.h"
+
+#include <grpc/slice.h>
+#include <grpc/slice_buffer.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/sync.h>
+#include "src/core/lib/debug/trace.h"
+#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/memory.h"
+#include "src/core/lib/profiling/timers.h"
+#include "src/core/lib/security/transport/secure_endpoint.h"
+#include "src/core/lib/security/transport/tsi_error.h"
+#include "src/core/lib/slice/slice_internal.h"
+#include "src/core/lib/slice/slice_string_helpers.h"
+#include "src/core/tsi/transport_security_grpc.h"
+
+#define STAGING_BUFFER_SIZE 8192
+
+static void on_read(void* user_data, grpc_error* error);
+
+namespace {
+struct secure_endpoint {
+  secure_endpoint(const grpc_endpoint_vtable* vtable,
+                  tsi_frame_protector* protector,
+                  tsi_zero_copy_grpc_protector* zero_copy_protector,
+                  grpc_endpoint* transport, grpc_slice* leftover_slices,
+                  size_t leftover_nslices)
+      : wrapped_ep(transport),
+        protector(protector),
+        zero_copy_protector(zero_copy_protector) {
+    base.vtable = vtable;
+    gpr_mu_init(&protector_mu);
+    GRPC_CLOSURE_INIT(&on_read, ::on_read, this, grpc_schedule_on_exec_ctx);
+    grpc_slice_buffer_init(&source_buffer);
+    grpc_slice_buffer_init(&leftover_bytes);
+    for (size_t i = 0; i < leftover_nslices; i++) {
+      grpc_slice_buffer_add(&leftover_bytes,
+                            grpc_slice_ref_internal(leftover_slices[i]));
+    }
+    grpc_slice_buffer_init(&output_buffer);
+    gpr_ref_init(&ref, 1);
+  }
+
+  ~secure_endpoint() {
+    grpc_endpoint_destroy(wrapped_ep);
+    tsi_frame_protector_destroy(protector);
+    tsi_zero_copy_grpc_protector_destroy(zero_copy_protector);
+    grpc_slice_buffer_destroy_internal(&source_buffer);
+    grpc_slice_buffer_destroy_internal(&leftover_bytes);
+    grpc_slice_unref_internal(read_staging_buffer);
+    grpc_slice_unref_internal(write_staging_buffer);
+    grpc_slice_buffer_destroy_internal(&output_buffer);
+    gpr_mu_destroy(&protector_mu);
+  }
+
+  grpc_endpoint base;
+  grpc_endpoint* wrapped_ep;
+  struct tsi_frame_protector* protector;
+  struct tsi_zero_copy_grpc_protector* zero_copy_protector;
+  gpr_mu protector_mu;
+  /* saved upper level callbacks and user_data. */
+  grpc_closure* read_cb = nullptr;
+  grpc_closure* write_cb = nullptr;
+  grpc_closure on_read;
+  grpc_slice_buffer* read_buffer = nullptr;
+  grpc_slice_buffer source_buffer;
+  /* saved handshaker leftover data to unprotect. */
+  grpc_slice_buffer leftover_bytes;
+  /* buffers for read and write */
+  grpc_slice read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
+  grpc_slice write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
+  grpc_slice_buffer output_buffer;
+
+  gpr_refcount ref;
+};
+}  // namespace
+
+grpc_core::TraceFlag grpc_trace_secure_endpoint(false, "secure_endpoint");
+
+static void destroy(secure_endpoint* ep) { grpc_core::Delete(ep); }
+
+#ifndef NDEBUG
+#define SECURE_ENDPOINT_UNREF(ep, reason) \
+  secure_endpoint_unref((ep), (reason), __FILE__, __LINE__)
+#define SECURE_ENDPOINT_REF(ep, reason) \
+  secure_endpoint_ref((ep), (reason), __FILE__, __LINE__)
+static void secure_endpoint_unref(secure_endpoint* ep, const char* reason,
+                                  const char* file, int line) {
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
+    gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count);
+    gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
+            "SECENDP unref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val,
+            val - 1);
+  }
+  if (gpr_unref(&ep->ref)) {
+    destroy(ep);
+  }
+}
+
+static void secure_endpoint_ref(secure_endpoint* ep, const char* reason,
+                                const char* file, int line) {
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
+    gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count);
+    gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
+            "SECENDP   ref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val,
+            val + 1);
+  }
+  gpr_ref(&ep->ref);
+}
+#else
+#define SECURE_ENDPOINT_UNREF(ep, reason) secure_endpoint_unref((ep))
+#define SECURE_ENDPOINT_REF(ep, reason) secure_endpoint_ref((ep))
+static void secure_endpoint_unref(secure_endpoint* ep) {
+  if (gpr_unref(&ep->ref)) {
+    destroy(ep);
+  }
+}
+
+static void secure_endpoint_ref(secure_endpoint* ep) { gpr_ref(&ep->ref); }
+#endif
+
+static void flush_read_staging_buffer(secure_endpoint* ep, uint8_t** cur,
+                                      uint8_t** end) {
+  grpc_slice_buffer_add(ep->read_buffer, ep->read_staging_buffer);
+  ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
+  *cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
+  *end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
+}
+
+static void call_read_cb(secure_endpoint* ep, grpc_error* error) {
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
+    size_t i;
+    for (i = 0; i < ep->read_buffer->count; i++) {
+      char* data = grpc_dump_slice(ep->read_buffer->slices[i],
+                                   GPR_DUMP_HEX | GPR_DUMP_ASCII);
+      gpr_log(GPR_INFO, "READ %p: %s", ep, data);
+      gpr_free(data);
+    }
+  }
+  ep->read_buffer = nullptr;
+  GRPC_CLOSURE_SCHED(ep->read_cb, error);
+  SECURE_ENDPOINT_UNREF(ep, "read");
+}
+
+static void on_read(void* user_data, grpc_error* error) {
+  unsigned i;
+  uint8_t keep_looping = 0;
+  tsi_result result = TSI_OK;
+  secure_endpoint* ep = static_cast<secure_endpoint*>(user_data);
+  uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
+  uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
+
+  if (error != GRPC_ERROR_NONE) {
+    grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
+    call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
+                         "Secure read failed", &error, 1));
+    return;
+  }
+
+  if (ep->zero_copy_protector != nullptr) {
+    // Use zero-copy grpc protector to unprotect.
+    result = tsi_zero_copy_grpc_protector_unprotect(
+        ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer);
+  } else {
+    // Use frame protector to unprotect.
+    /* TODO(yangg) check error, maybe bail out early */
+    for (i = 0; i < ep->source_buffer.count; i++) {
+      grpc_slice encrypted = ep->source_buffer.slices[i];
+      uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted);
+      size_t message_size = GRPC_SLICE_LENGTH(encrypted);
+
+      while (message_size > 0 || keep_looping) {
+        size_t unprotected_buffer_size_written = static_cast<size_t>(end - cur);
+        size_t processed_message_size = message_size;
+        gpr_mu_lock(&ep->protector_mu);
+        result = tsi_frame_protector_unprotect(
+            ep->protector, message_bytes, &processed_message_size, cur,
+            &unprotected_buffer_size_written);
+        gpr_mu_unlock(&ep->protector_mu);
+        if (result != TSI_OK) {
+          gpr_log(GPR_ERROR, "Decryption error: %s",
+                  tsi_result_to_string(result));
+          break;
+        }
+        message_bytes += processed_message_size;
+        message_size -= processed_message_size;
+        cur += unprotected_buffer_size_written;
+
+        if (cur == end) {
+          flush_read_staging_buffer(ep, &cur, &end);
+          /* Force to enter the loop again to extract buffered bytes in
+             protector. The bytes could be buffered because of running out of
+             staging_buffer. If this happens at the end of all slices, doing
+             another unprotect avoids leaving data in the protector. */
+          keep_looping = 1;
+        } else if (unprotected_buffer_size_written > 0) {
+          keep_looping = 1;
+        } else {
+          keep_looping = 0;
+        }
+      }
+      if (result != TSI_OK) break;
+    }
+
+    if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) {
+      grpc_slice_buffer_add(
+          ep->read_buffer,
+          grpc_slice_split_head(
+              &ep->read_staging_buffer,
+              static_cast<size_t>(
+                  cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer))));
+    }
+  }
+
+  /* TODO(yangg) experiment with moving this block after read_cb to see if it
+     helps latency */
+  grpc_slice_buffer_reset_and_unref_internal(&ep->source_buffer);
+
+  if (result != TSI_OK) {
+    grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
+    call_read_cb(
+        ep, grpc_set_tsi_error_result(
+                GRPC_ERROR_CREATE_FROM_STATIC_STRING("Unwrap failed"), result));
+    return;
+  }
+
+  call_read_cb(ep, GRPC_ERROR_NONE);
+}
+
+static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
+                          grpc_closure* cb, bool urgent) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  ep->read_cb = cb;
+  ep->read_buffer = slices;
+  grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
+
+  SECURE_ENDPOINT_REF(ep, "read");
+  if (ep->leftover_bytes.count) {
+    grpc_slice_buffer_swap(&ep->leftover_bytes, &ep->source_buffer);
+    GPR_ASSERT(ep->leftover_bytes.count == 0);
+    on_read(ep, GRPC_ERROR_NONE);
+    return;
+  }
+
+  grpc_endpoint_read(ep->wrapped_ep, &ep->source_buffer, &ep->on_read, urgent);
+}
+
+static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur,
+                                       uint8_t** end) {
+  grpc_slice_buffer_add(&ep->output_buffer, ep->write_staging_buffer);
+  ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
+  *cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
+  *end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
+}
+
+static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
+                           grpc_closure* cb, void* arg) {
+  GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0);
+
+  unsigned i;
+  tsi_result result = TSI_OK;
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
+  uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
+
+  grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
+
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) {
+    for (i = 0; i < slices->count; i++) {
+      char* data =
+          grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII);
+      gpr_log(GPR_INFO, "WRITE %p: %s", ep, data);
+      gpr_free(data);
+    }
+  }
+
+  if (ep->zero_copy_protector != nullptr) {
+    // Use zero-copy grpc protector to protect.
+    result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
+                                                  slices, &ep->output_buffer);
+  } else {
+    // Use frame protector to protect.
+    for (i = 0; i < slices->count; i++) {
+      grpc_slice plain = slices->slices[i];
+      uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
+      size_t message_size = GRPC_SLICE_LENGTH(plain);
+      while (message_size > 0) {
+        size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
+        size_t processed_message_size = message_size;
+        gpr_mu_lock(&ep->protector_mu);
+        result = tsi_frame_protector_protect(ep->protector, message_bytes,
+                                             &processed_message_size, cur,
+                                             &protected_buffer_size_to_send);
+        gpr_mu_unlock(&ep->protector_mu);
+        if (result != TSI_OK) {
+          gpr_log(GPR_ERROR, "Encryption error: %s",
+                  tsi_result_to_string(result));
+          break;
+        }
+        message_bytes += processed_message_size;
+        message_size -= processed_message_size;
+        cur += protected_buffer_size_to_send;
+
+        if (cur == end) {
+          flush_write_staging_buffer(ep, &cur, &end);
+        }
+      }
+      if (result != TSI_OK) break;
+    }
+    if (result == TSI_OK) {
+      size_t still_pending_size;
+      do {
+        size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
+        gpr_mu_lock(&ep->protector_mu);
+        result = tsi_frame_protector_protect_flush(
+            ep->protector, cur, &protected_buffer_size_to_send,
+            &still_pending_size);
+        gpr_mu_unlock(&ep->protector_mu);
+        if (result != TSI_OK) break;
+        cur += protected_buffer_size_to_send;
+        if (cur == end) {
+          flush_write_staging_buffer(ep, &cur, &end);
+        }
+      } while (still_pending_size > 0);
+      if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) {
+        grpc_slice_buffer_add(
+            &ep->output_buffer,
+            grpc_slice_split_head(
+                &ep->write_staging_buffer,
+                static_cast<size_t>(
+                    cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer))));
+      }
+    }
+  }
+
+  if (result != TSI_OK) {
+    /* TODO(yangg) do different things according to the error type? */
+    grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
+    GRPC_CLOSURE_SCHED(
+        cb, grpc_set_tsi_error_result(
+                GRPC_ERROR_CREATE_FROM_STATIC_STRING("Wrap failed"), result));
+    return;
+  }
+
+  grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg);
+}
+
+static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error* why) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  grpc_endpoint_shutdown(ep->wrapped_ep, why);
+}
+
+static void endpoint_destroy(grpc_endpoint* secure_ep) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  SECURE_ENDPOINT_UNREF(ep, "destroy");
+}
+
+static void endpoint_add_to_pollset(grpc_endpoint* secure_ep,
+                                    grpc_pollset* pollset) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  grpc_endpoint_add_to_pollset(ep->wrapped_ep, pollset);
+}
+
+static void endpoint_add_to_pollset_set(grpc_endpoint* secure_ep,
+                                        grpc_pollset_set* pollset_set) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  grpc_endpoint_add_to_pollset_set(ep->wrapped_ep, pollset_set);
+}
+
+static void endpoint_delete_from_pollset_set(grpc_endpoint* secure_ep,
+                                             grpc_pollset_set* pollset_set) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  grpc_endpoint_delete_from_pollset_set(ep->wrapped_ep, pollset_set);
+}
+
+static char* endpoint_get_peer(grpc_endpoint* secure_ep) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  return grpc_endpoint_get_peer(ep->wrapped_ep);
+}
+
+static int endpoint_get_fd(grpc_endpoint* secure_ep) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  return grpc_endpoint_get_fd(ep->wrapped_ep);
+}
+
+static grpc_resource_user* endpoint_get_resource_user(
+    grpc_endpoint* secure_ep) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  return grpc_endpoint_get_resource_user(ep->wrapped_ep);
+}
+
+static bool endpoint_can_track_err(grpc_endpoint* secure_ep) {
+  secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
+  return grpc_endpoint_can_track_err(ep->wrapped_ep);
+}
+
+static const grpc_endpoint_vtable vtable = {endpoint_read,
+                                            endpoint_write,
+                                            endpoint_add_to_pollset,
+                                            endpoint_add_to_pollset_set,
+                                            endpoint_delete_from_pollset_set,
+                                            endpoint_shutdown,
+                                            endpoint_destroy,
+                                            endpoint_get_resource_user,
+                                            endpoint_get_peer,
+                                            endpoint_get_fd,
+                                            endpoint_can_track_err};
+
+grpc_endpoint* grpc_secure_endpoint_create(
+    struct tsi_frame_protector* protector,
+    struct tsi_zero_copy_grpc_protector* zero_copy_protector,
+    grpc_endpoint* transport, grpc_slice* leftover_slices,
+    size_t leftover_nslices) {
+  secure_endpoint* ep = grpc_core::New<secure_endpoint>(
+      &vtable, protector, zero_copy_protector, transport, leftover_slices,
+      leftover_nslices);
+  return &ep->base;
+}