thrift/lib/cpp2/protocol/BinaryProtocol-inl.h (504 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THRIFT2_PROTOCOL_TBINARYPROTOCOL_TCC_ #define THRIFT2_PROTOCOL_TBINARYPROTOCOL_TCC_ 1 #include <thrift/lib/cpp2/protocol/BinaryProtocol.h> #include <limits> #include <string> namespace apache { namespace thrift { uint32_t BinaryProtocolWriter::writeMessageBegin( folly::StringPiece name, MessageType messageType, int32_t seqid) { int32_t version = (VERSION_1) | ((int32_t)messageType); uint32_t wsize = 0; wsize += writeI32(version); wsize += writeString(name); wsize += writeI32(seqid); return wsize; } uint32_t BinaryProtocolWriter::writeMessageEnd() { return 0; } uint32_t BinaryProtocolWriter::writeStructBegin(const char* /*name*/) { return 0; } uint32_t BinaryProtocolWriter::writeStructEnd() { return 0; } uint32_t BinaryProtocolWriter::writeFieldBegin( const char* /*name*/, TType fieldType, int16_t fieldId) { uint32_t wsize = 0; wsize += writeByte((int8_t)fieldType); wsize += writeI16(fieldId); return wsize; } uint32_t BinaryProtocolWriter::writeFieldEnd() { return 0; } uint32_t BinaryProtocolWriter::writeFieldStop() { return writeByte((int8_t)TType::T_STOP); } uint32_t BinaryProtocolWriter::writeMapBegin( const TType keyType, TType valType, uint32_t size) { uint32_t wsize = 0; wsize += writeByte((int8_t)keyType); wsize += writeByte((int8_t)valType); wsize += writeI32((int32_t)size); return wsize; } uint32_t BinaryProtocolWriter::writeMapEnd() { return 0; } uint32_t BinaryProtocolWriter::writeListBegin(TType elemType, uint32_t size) { uint32_t wsize = 0; wsize += writeByte((int8_t)elemType); wsize += writeI32((int32_t)size); return wsize; } uint32_t BinaryProtocolWriter::writeListEnd() { return 0; } uint32_t BinaryProtocolWriter::writeSetBegin(TType elemType, uint32_t size) { uint32_t wsize = 0; wsize += writeByte((int8_t)elemType); wsize += writeI32((int32_t)size); return wsize; } uint32_t BinaryProtocolWriter::writeSetEnd() { return 0; } uint32_t BinaryProtocolWriter::writeBool(bool value) { out_.write(detail::validate_bool(value)); return sizeof(value); } uint32_t BinaryProtocolWriter::writeByte(int8_t byte) { out_.write(byte); return sizeof(byte); } uint32_t BinaryProtocolWriter::writeI16(int16_t i16) { out_.writeBE(i16); return sizeof(i16); } uint32_t BinaryProtocolWriter::writeI32(int32_t i32) { out_.writeBE(i32); return sizeof(i32); } uint32_t BinaryProtocolWriter::writeI64(int64_t i64) { out_.writeBE(i64); return sizeof(i64); } uint32_t BinaryProtocolWriter::writeDouble(double dub) { static_assert(sizeof(double) == sizeof(uint64_t), ""); static_assert(std::numeric_limits<double>::is_iec559, ""); uint64_t bits = folly::bit_cast<uint64_t>(dub); out_.writeBE(bits); return sizeof(bits); } uint32_t BinaryProtocolWriter::writeFloat(float flt) { static_assert(sizeof(float) == sizeof(uint32_t), ""); static_assert(std::numeric_limits<float>::is_iec559, ""); uint32_t bits = folly::bit_cast<uint32_t>(flt); out_.writeBE(bits); return sizeof(bits); } uint32_t BinaryProtocolWriter::writeString(folly::StringPiece str) { return writeBinary(str); } uint32_t BinaryProtocolWriter::writeBinary(folly::StringPiece str) { return writeBinary(folly::ByteRange(str)); } uint32_t BinaryProtocolWriter::writeBinary(folly::ByteRange v) { uint32_t size = folly::to_narrow(v.size()); uint32_t result = writeI32((int32_t)size); out_.push(v.data(), size); return result + size; } uint32_t BinaryProtocolWriter::writeBinary( const std::unique_ptr<folly::IOBuf>& str) { if (!str) { return writeI32(0); } return writeBinary(*str); } uint32_t BinaryProtocolWriter::writeBinary(const folly::IOBuf& str) { return writeBinaryImpl<true>(str); } uint32_t BinaryProtocolWriter::writeRaw(const folly::IOBuf& str) { return writeBinaryImpl<false>(str); } template <bool kWriteSize> uint32_t BinaryProtocolWriter::writeBinaryImpl(const folly::IOBuf& str) { size_t size = str.computeChainDataLength(); // leave room for size uint32_t limit = std::numeric_limits<uint32_t>::max() - serializedSizeI32(); if (size > limit) { TProtocolException::throwExceededSizeLimit(size, limit); } uint32_t result = kWriteSize ? writeI32((int32_t)size) : 0; if (sharing_ != SHARE_EXTERNAL_BUFFER && !str.isManaged()) { const auto growth = size - out_.length(); for (folly::ByteRange buf : str) { const auto tailroom = out_.length(); if (tailroom < buf.size()) { out_.push(buf.uncheckedSubpiece(0, tailroom)); buf.uncheckedAdvance(tailroom); out_.ensure(growth); } out_.push(buf); } } else { out_.insert(str); } return result + static_cast<uint32_t>(size); } void BinaryProtocolWriter::rewriteDouble(double dub, int64_t offset) { auto cursor = RWCursor(out_); cursor.advanceToEnd(); cursor -= offset; cursor.writeBE(folly::bit_cast<uint64_t>(dub)); } folly::io::Cursor BinaryProtocolWriter::tail(size_t n) { auto cursor = RWCursor(out_); cursor.advanceToEnd(); return {cursor - n, n}; } /** * Functions that return the serialized size */ uint32_t BinaryProtocolWriter::serializedMessageSize( folly::StringPiece name) const { // I32{version} + String{name} + I32{seqid} return 2 * serializedSizeI32() + serializedSizeString(name); } uint32_t BinaryProtocolWriter::serializedFieldSize( const char* /*name*/, TType /*fieldType*/, int16_t /*fieldId*/) const { // byte + I16 return serializedSizeByte() + serializedSizeI16(); } uint32_t BinaryProtocolWriter::serializedStructSize(const char* /*name*/ ) const { return 0; } uint32_t BinaryProtocolWriter::serializedSizeMapBegin( TType /*keyType*/, TType /*valType*/, uint32_t /*size*/) const { return serializedSizeByte() + serializedSizeByte() + serializedSizeI32(); } uint32_t BinaryProtocolWriter::serializedSizeMapEnd() const { return 0; } uint32_t BinaryProtocolWriter::serializedSizeListBegin( TType /*elemType*/, uint32_t /*size*/ ) const { return serializedSizeByte() + serializedSizeI32(); } uint32_t BinaryProtocolWriter::serializedSizeListEnd() const { return 0; } uint32_t BinaryProtocolWriter::serializedSizeSetBegin( TType /*elemType*/, uint32_t /*size*/) const { return serializedSizeByte() + serializedSizeI32(); } uint32_t BinaryProtocolWriter::serializedSizeSetEnd() const { return 0; } uint32_t BinaryProtocolWriter::serializedSizeStop() const { return 1; } uint32_t BinaryProtocolWriter::serializedSizeBool(bool /*val*/) const { return 1; } uint32_t BinaryProtocolWriter::serializedSizeByte(int8_t /*val*/) const { return 1; } uint32_t BinaryProtocolWriter::serializedSizeI16(int16_t /*val*/) const { return 2; } uint32_t BinaryProtocolWriter::serializedSizeI32(int32_t /*val*/) const { return 4; } uint32_t BinaryProtocolWriter::serializedSizeI64(int64_t /*val*/) const { return 8; } uint32_t BinaryProtocolWriter::serializedSizeDouble(double /*val*/) const { return 8; } uint32_t BinaryProtocolWriter::serializedSizeFloat(float /*val*/) const { return 4; } uint32_t BinaryProtocolWriter::serializedSizeString( folly::StringPiece str) const { return serializedSizeBinary(str); } uint32_t BinaryProtocolWriter::serializedSizeBinary( folly::StringPiece str) const { return serializedSizeBinary(folly::ByteRange(str)); } uint32_t BinaryProtocolWriter::serializedSizeBinary( folly::ByteRange str) const { // I32{length of string} + binary{string contents} return serializedSizeI32() + static_cast<uint32_t>(str.size()); } uint32_t BinaryProtocolWriter::serializedSizeBinary( std::unique_ptr<folly::IOBuf> const& v) const { return v ? serializedSizeBinary(*v) : 0; } uint32_t BinaryProtocolWriter::serializedSizeBinary( folly::IOBuf const& v) const { size_t size = v.computeChainDataLength(); uint32_t limit = std::numeric_limits<uint32_t>::max() - serializedSizeI32(); if (size > limit) { TProtocolException::throwExceededSizeLimit(size, limit); } return serializedSizeI32() + static_cast<uint32_t>(size); } uint32_t BinaryProtocolWriter::serializedSizeZCBinary( folly::StringPiece str) const { return serializedSizeZCBinary(folly::ByteRange(str)); } uint32_t BinaryProtocolWriter::serializedSizeZCBinary( folly::ByteRange v) const { return serializedSizeBinary(v); } uint32_t BinaryProtocolWriter::serializedSizeZCBinary( std::unique_ptr<folly::IOBuf> const& v) const { return v ? serializedSizeZCBinary(*v) : 0; } uint32_t BinaryProtocolWriter::serializedSizeZCBinary( folly::IOBuf const& v) const { size_t size = v.computeChainDataLength(); return (size > folly::IOBufQueue::kMaxPackCopy) ? serializedSizeI32() // too big to pack: size only : static_cast<uint32_t>(size) + serializedSizeI32(); // size + packed data } /** * Reading functions */ void BinaryProtocolReader::readMessageBegin( std::string& name, MessageType& messageType, int32_t& seqid) { int32_t sz; readI32(sz); if (sz < 0) { // Check for correct version number int32_t version = sz & VERSION_MASK; if (version != VERSION_1) { throwBadVersionIdentifier(sz); } messageType = (MessageType)(sz & 0x000000ff); readString(name); readI32(seqid); } else { if (this->strict_read_) { throwMissingVersionIdentifier(sz); } else { // Handle pre-versioned input int8_t type; readStringBody(name, sz); readByte(type); messageType = (MessageType)type; readI32(seqid); } } } void BinaryProtocolReader::readMessageEnd() {} void BinaryProtocolReader::readStructBegin(std::string& name) { name = ""; } void BinaryProtocolReader::readStructEnd() {} void BinaryProtocolReader::readFieldBegin( std::string& /*name*/, TType& fieldType, int16_t& fieldId) { int8_t type; readByte(type); fieldType = (TType)type; if (fieldType == TType::T_STOP) { fieldId = 0; return; } readI16(fieldId); } void BinaryProtocolReader::readFieldEnd() {} void BinaryProtocolReader::readMapBegin( TType& keyType, TType& valType, uint32_t& size) { int8_t k, v; int32_t sizei; readByte(k); keyType = (TType)k; readByte(v); valType = (TType)v; readI32(sizei); checkContainerSize(sizei); size = (uint32_t)sizei; } void BinaryProtocolReader::readMapEnd() {} void BinaryProtocolReader::readListBegin(TType& elemType, uint32_t& size) { int8_t e; int32_t sizei; readByte(e); elemType = (TType)e; readI32(sizei); checkContainerSize(sizei); size = (uint32_t)sizei; } void BinaryProtocolReader::readListEnd() {} void BinaryProtocolReader::readSetBegin(TType& elemType, uint32_t& size) { int8_t e; int32_t sizei; readByte(e); elemType = (TType)e; readI32(sizei); checkContainerSize(sizei); size = (uint32_t)sizei; } void BinaryProtocolReader::readSetEnd() {} void BinaryProtocolReader::readBool(bool& value) { auto byte = in_.read<uint8_t>(); if (byte >= 2) { TProtocolException::throwBoolValueOutOfRange(byte); } value = static_cast<bool>(byte); } void BinaryProtocolReader::readBool(std::vector<bool>::reference value) { bool ret = false; readBool(ret); value = ret; } void BinaryProtocolReader::readByte(int8_t& byte) { byte = in_.read<int8_t>(); } void BinaryProtocolReader::readI16(int16_t& i16) { i16 = in_.readBE<int16_t>(); } void BinaryProtocolReader::readI32(int32_t& i32) { i32 = in_.readBE<int32_t>(); } void BinaryProtocolReader::readI64(int64_t& i64) { i64 = in_.readBE<int64_t>(); } void BinaryProtocolReader::readDouble(double& dub) { static_assert(sizeof(double) == sizeof(uint64_t), ""); static_assert(std::numeric_limits<double>::is_iec559, ""); uint64_t bits = in_.readBE<int64_t>(); dub = folly::bit_cast<double>(bits); } void BinaryProtocolReader::readFloat(float& flt) { static_assert(sizeof(float) == sizeof(uint32_t), ""); static_assert(std::numeric_limits<double>::is_iec559, ""); uint32_t bits = in_.readBE<int32_t>(); flt = folly::bit_cast<float>(bits); } void BinaryProtocolReader::checkStringSize(int32_t size) { // Catch error cases if (size < 0) { TProtocolException::throwNegativeSize(); } if (string_limit_ > 0 && size > string_limit_) { TProtocolException::throwExceededSizeLimit(size, string_limit_); } } void BinaryProtocolReader::checkContainerSize(int32_t size) { if (size < 0) { TProtocolException::throwNegativeSize(); } else if (container_limit_ && size > container_limit_) { TProtocolException::throwExceededSizeLimit(size, container_limit_); } } template <typename StrType> void BinaryProtocolReader::readString(StrType& str) { int32_t size; readI32(size); readStringBody(str, size); } template <typename StrType> void BinaryProtocolReader::readBinary(StrType& str) { readString(str); } void BinaryProtocolReader::readBinary(std::unique_ptr<folly::IOBuf>& str) { if (!str) { str = std::make_unique<folly::IOBuf>(); } readBinary(*str); } void BinaryProtocolReader::readBinary(folly::IOBuf& str) { int32_t size; readI32(size); checkStringSize(size); in_.clone(str, size); if (sharing_ != SHARE_EXTERNAL_BUFFER && !str.isManaged()) { str = str.cloneCoalescedAsValueWithHeadroomTailroom(0, 0); str.makeManaged(); } } template <typename StrType> void BinaryProtocolReader::readStringBody(StrType& str, int32_t size) { checkStringSize(size); // Catch empty string case if (size == 0) { str.clear(); return; } if (static_cast<int32_t>(in_.length()) < size) { if (!in_.canAdvance(size)) { protocol::TProtocolException::throwTruncatedData(); } str.reserve(size); // only reserve for multi iter case below } str.clear(); size_t size_left = size; while (size_left > 0) { auto data = in_.peekBytes(); auto data_avail = std::min(data.size(), size_left); if (data.empty()) { TProtocolException::throwTruncatedData(); } str.append((const char*)data.data(), data_avail); size_left -= data_avail; in_.skipNoAdvance(data_avail); } } bool BinaryProtocolReader::advanceToNextField( int16_t nextFieldId, TType nextFieldType, StructReadState& state) { if (nextFieldType == TType::T_STOP) { if (in_.length() && *in_.data() == TType::T_STOP) { in_.skipNoAdvance(1); return true; } } else { if (in_.length() >= 3) { uint8_t type = *in_.data(); if (nextFieldType == type) { int16_t fieldId = folly::Endian::big(folly::loadUnaligned<int16_t>(in_.data() + 1)); in_.skipNoAdvance(3); if (nextFieldId == fieldId) { return true; } state.fieldType = (TType)type; state.fieldId = fieldId; return false; } state.fieldType = (TType)type; if (type != TType::T_STOP) { state.fieldId = folly::Endian::big(folly::loadUnaligned<int16_t>(in_.data() + 1)); in_.skipNoAdvance(3); } else { in_.skipNoAdvance(1); } return false; } } state.readFieldBeginNoInline(this); return false; } void BinaryProtocolReader::readFieldBeginWithState(StructReadState& state) { int8_t type; readByte(type); state.fieldType = (TType)type; if (state.fieldType == TType::T_STOP) { return; } readI16(state.fieldId); } constexpr std::size_t BinaryProtocolReader::fixedSizeInContainer(TType type) { switch (type) { case TType::T_BOOL: case TType::T_BYTE: return 1; case TType::T_I16: return 2; case TType::T_I32: case TType::T_FLOAT: return 4; case TType::T_I64: case TType::T_DOUBLE: return 8; default: return 0; } } } // namespace thrift } // namespace apache #endif // #ifndef THRIFT2_PROTOCOL_TBINARYPROTOCOL_TCC_