void AssignOpAttr()

in tfjs-node/binding/tfjs_backend.cc [502:649]


void AssignOpAttr(napi_env env, TFE_Op *tfe_op, napi_value attr_value) {
  napi_status nstatus;

  napi_value attr_name_value;
  nstatus = napi_get_named_property(env, attr_value, "name", &attr_name_value);
  ENSURE_NAPI_OK(env, nstatus);

  std::string attr_name_string;
  nstatus = GetStringParam(env, attr_name_value, attr_name_string);
  ENSURE_NAPI_OK(env, nstatus);

  // OpAttr will be used beyond the scope of this function call. Stash ops in
  // a set for re-use instead of dynamically reallocating strings for
  // operations.
  const char *attr_name =
      ATTR_NAME_SET.insert(attr_name_string.c_str()).first->c_str();

  napi_value attr_type_value;
  nstatus = napi_get_named_property(env, attr_value, "type", &attr_type_value);
  ENSURE_NAPI_OK(env, nstatus);

  TF_AttrType tf_attr_type;
  nstatus = napi_get_value_int32(env, attr_type_value,
                                 reinterpret_cast<int32_t *>(&tf_attr_type));
  ENSURE_NAPI_OK(env, nstatus);

  napi_value js_value;
  nstatus = napi_get_named_property(env, attr_value, "value", &js_value);
  ENSURE_NAPI_OK(env, nstatus);

  switch (tf_attr_type) {
    case TF_ATTR_STRING: {
      // NOTE: String attribute values do not have to be utf8 encoded strings
      // (could be arbitrary byte sequences).
      std::string str_value;
      nstatus = GetStringParam(env, js_value, str_value);
      ENSURE_NAPI_OK(env, nstatus);

      TFE_OpSetAttrString(tfe_op, attr_name, str_value.c_str(),
                          str_value.size());
      break;
    }

    case TF_ATTR_INT: {
      if (IsArray(env, nstatus, &js_value)) {
        uint32_t length;
        nstatus = napi_get_array_length(env, js_value, &length);
        ENSURE_NAPI_OK(env, nstatus);
        std::unique_ptr<int64_t[]> data(new int64_t[length]);
        for (uint32_t i = 0; i < length; ++i) {
          napi_value element;
          nstatus = napi_get_element(env, js_value, i, &element);
          ENSURE_NAPI_OK(env, nstatus);
          int32_t value;
          nstatus = napi_get_value_int32(env, element, &value);
          ENSURE_NAPI_OK(env, nstatus);
          data[i] = value;
        }
        TFE_OpSetAttrIntList(tfe_op, attr_name, data.get(),
                             static_cast<int>(length));
      } else {
        int64_t value;
        nstatus = napi_get_value_int64(env, js_value, &value);
        ENSURE_NAPI_OK(env, nstatus);

        TFE_OpSetAttrInt(tfe_op, attr_name, value);
      }
      break;
    }

    case TF_ATTR_FLOAT: {
      if (IsArray(env, nstatus, &js_value)) {
        uint32_t length;
        nstatus = napi_get_array_length(env, js_value, &length);
        ENSURE_NAPI_OK(env, nstatus);
        std::unique_ptr<float[]> data(new float[length]);
        for (uint32_t i = 0; i < length; ++i) {
          napi_value element;
          nstatus = napi_get_element(env, js_value, i, &element);
          ENSURE_NAPI_OK(env, nstatus);
          double value;
          nstatus = napi_get_value_double(env, element, &value);
          ENSURE_NAPI_OK(env, nstatus);
          data[i] = static_cast<float>(value);
        }
        TFE_OpSetAttrFloatList(tfe_op, attr_name, data.get(),
                               static_cast<int>(length));
      } else {
        double value;
        nstatus = napi_get_value_double(env, js_value, &value);
        ENSURE_NAPI_OK(env, nstatus);
        TFE_OpSetAttrFloat(tfe_op, attr_name, static_cast<float>(value));
      }
      break;
    }

    case TF_ATTR_BOOL: {
      if (IsArray(env, nstatus, &js_value)) {
        uint32_t length;
        nstatus = napi_get_array_length(env, js_value, &length);
        ENSURE_NAPI_OK(env, nstatus);
        std::unique_ptr<unsigned char[]> data(new unsigned char[length]);
        for (uint32_t i = 0; i < length; ++i) {
          napi_value element;
          nstatus = napi_get_element(env, js_value, i, &element);
          ENSURE_NAPI_OK(env, nstatus);
          bool value;
          nstatus = napi_get_value_bool(env, element, &value);
          ENSURE_NAPI_OK(env, nstatus);
          data[i] = value ? 1 : 0;
        }
        TFE_OpSetAttrBoolList(tfe_op, attr_name, data.get(),
                              static_cast<int>(length));
      } else {
        bool value;
        nstatus = napi_get_value_bool(env, js_value, &value);
        ENSURE_NAPI_OK(env, nstatus);
        TFE_OpSetAttrBool(tfe_op, attr_name, value ? 1 : 0);
      }
      break;
    }

    case TF_ATTR_TYPE: {
      TF_DataType tf_data_type;
      nstatus = napi_get_value_int32(
          env, js_value, reinterpret_cast<int32_t *>(&tf_data_type));
      ENSURE_NAPI_OK(env, nstatus);

      TFE_OpSetAttrType(tfe_op, attr_name, tf_data_type);
      break;
    }

    case TF_ATTR_SHAPE: {
      std::vector<int64_t> shape_vector;
      ExtractArrayShape(env, js_value, &shape_vector);

      TF_AutoStatus tf_status;
      TFE_OpSetAttrShape(tfe_op, attr_name, shape_vector.data(),
                         shape_vector.size(), tf_status.status);
      ENSURE_TF_OK(env, tf_status);
      break;
    }

    default:
      REPORT_UNKNOWN_TF_ATTR_TYPE(env, tf_attr_type);
      break;
  }
}