pkg/remoting/getty/readwriter.go (162 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package getty import ( "errors" "fmt" getty "github.com/apache/dubbo-getty" "seata.apache.org/seata-go/pkg/protocol/codec" "seata.apache.org/seata-go/pkg/protocol/message" "seata.apache.org/seata-go/pkg/util/bytes" ) // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+ // | magic |Proto | Full length | Head | Msg |Seria|Compr| RequestID | // | code |clVer | (head+body) | Length |Type |lizer|ess | | // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ // | | // | Head Map [Optional] | // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ // | | // | body | // | | // | ... ... | // +-----------------------------------------------------------------------------------------------+ // <li>Full Length: include all data </li> // <li>Head Length: include head data from magic code to head map. </li> // <li>Body Length: Full Length - Head Length</li> // </p> // https://github.com/seata/seata/issues/893 const ( Seatav1HeaderLength = 16 ) var ( magics = []uint8{0xda, 0xda} rpcPkgHandler = &RpcPackageHandler{} ) var ( ErrNotEnoughStream = errors.New("packet stream is not enough") ErrTooLargePackage = errors.New("package length is exceed the getty package's legal maximum length") ErrInvalidPackage = errors.New("invalid rpc package") ErrIllegalMagic = errors.New("package magic is not right") ) type RpcPackageHandler struct{} type SeataV1PackageHeader struct { Magic0 byte Magic1 byte Version byte TotalLength uint32 HeadLength uint16 MessageType message.GettyRequestType CodecType byte CompressType byte RequestID uint32 Meta map[string]string BodyLength uint32 Body interface{} } func (p *RpcPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) { in := bytes.NewByteBuffer(data) header := SeataV1PackageHeader{} magic0 := bytes.ReadByte(in) magic1 := bytes.ReadByte(in) if magic0 != magics[0] || magic1 != magics[1] { return nil, 0, fmt.Errorf("codec decode not found magic offset") } header.Magic0 = magic0 header.Magic1 = magic1 header.Version = bytes.ReadByte(in) // length of head and body header.TotalLength = bytes.ReadUInt32(in) header.HeadLength = bytes.ReadUInt16(in) header.MessageType = message.GettyRequestType(bytes.ReadByte(in)) header.CodecType = bytes.ReadByte(in) header.CompressType = bytes.ReadByte(in) header.RequestID = bytes.ReadUInt32(in) headMapLength := header.HeadLength - Seatav1HeaderLength header.Meta = decodeHeapMap(in, headMapLength) header.BodyLength = header.TotalLength - uint32(header.HeadLength) if uint32(len(data)) < header.TotalLength { return nil, int(header.TotalLength), nil } // r := byteio.BigEndianReader{Reader: bytes.NewReader(data)} rpcMessage := message.RpcMessage{ Codec: header.CodecType, ID: int32(header.RequestID), Compressor: header.CompressType, Type: header.MessageType, HeadMap: header.Meta, } if header.MessageType == message.GettyRequestTypeHeartbeatRequest { rpcMessage.Body = message.HeartBeatMessagePing } else if header.MessageType == message.GettyRequestTypeHeartbeatResponse { rpcMessage.Body = message.HeartBeatMessagePong } else { if header.BodyLength > 0 { msg := codec.GetCodecManager().Decode(codec.CodecType(header.CodecType), data[header.HeadLength:]) rpcMessage.Body = msg } } return rpcMessage, int(header.TotalLength), nil } // Write write rpc message to binary data func (p *RpcPackageHandler) Write(ss getty.Session, pkg interface{}) ([]byte, error) { msg, ok := pkg.(message.RpcMessage) if !ok { return nil, ErrInvalidPackage } totalLength := message.V1HeadLength headLength := message.V1HeadLength var headMapBytes []byte if len(msg.HeadMap) > 0 { hb, headMapLength := encodeHeapMap(msg.HeadMap) headMapBytes = hb headLength += headMapLength totalLength += headMapLength } var bodyBytes []byte if msg.Type != message.GettyRequestTypeHeartbeatRequest && msg.Type != message.GettyRequestTypeHeartbeatResponse { bodyBytes = codec.GetCodecManager().Encode(codec.CodecType(msg.Codec), msg.Body) totalLength += len(bodyBytes) } buf := bytes.NewByteBuffer([]byte{}) buf.WriteByte(message.MagicCodeBytes[0]) buf.WriteByte(message.MagicCodeBytes[1]) buf.WriteByte(message.VERSION) buf.WriteUint32(uint32(totalLength)) buf.WriteUint16(uint16(headLength)) buf.WriteByte(byte(msg.Type)) buf.WriteByte(msg.Codec) buf.WriteByte(msg.Compressor) buf.WriteUint32(uint32(msg.ID)) buf.Write(headMapBytes) buf.Write(bodyBytes) return buf.Bytes(), nil } func encodeHeapMap(data map[string]string) ([]byte, int) { buf := bytes.NewByteBuffer([]byte{}) for k, v := range data { if k == "" { buf.WriteUint16(uint16(0)) } else { buf.WriteUint16(uint16(len(k))) buf.WriteString(k) } if v == "" { buf.WriteUint16(uint16(0)) } else { buf.WriteUint16(uint16(len(v))) buf.WriteString(v) } } res := buf.Bytes() return res, len(res) } func decodeHeapMap(in *bytes.ByteBuffer, length uint16) map[string]string { res := make(map[string]string, 0) if length == 0 { return res } readedLength := uint16(0) for readedLength < length { var key, value string keyLength := bytes.ReadUInt16(in) if keyLength == 0 { key = "" } else { keyBytes := make([]byte, keyLength) in.Read(keyBytes) key = string(keyBytes) } valueLength := bytes.ReadUInt16(in) if valueLength == 0 { key = "" } else { valueBytes := make([]byte, valueLength) in.Read(valueBytes) value = string(valueBytes) } res[key] = value readedLength += 4 + keyLength + valueLength fmt.Sprintln("done") } return res }