tunnel/tcp_writer.go (192 lines of code) (raw):

package tunnel import ( "bytes" "encoding/binary" "fmt" "io" "net" "time" utils "github.com/alibaba/MongoShake/v2/common" nimo "github.com/gugemichael/nimo4go" LOG "github.com/vinllen/log4go" ) // Network packet structure // // [ Big-edian ] // Header (12 Bytes) // Body (n Bytes) // // [ Header structure ] // ----------------------------------------------------------------------------------- // | magic(2B) | version(1B) | type(1B) | crc32(4B) | length(4B) | // ----------------------------------------------------------------------------------- // | 0x00201314 | 0x01 | 0x01 | 0xFFFFF | 4096 | // ----------------------------------------------------------------------------------- // // [ PacketWrite payload ] // ------------------------------------------------------------------------------------------------------------------------------------------------- // | cksum(4B) | tag(4B) | shard(4B) | compress(4B) | number(4B) | len(4B) | log([]byte) | len(4B) | log([]byte) | // ------------------------------------------------------------------------------------------------------------------------------------------------- // // [ PacketGetACK payload ] // --------------| // | (zero) | // --------------| // // [ PacketReturnACK payload ] // ------------------ // | ack(4B) | // ------------------ // const ( MagicNumber = 0xCAFE CurrentVersion = 0x01 HeaderLen = 12 ) const ( PacketIncomplete uint8 = 0x00 PacketGetACK uint8 = 0x01 PacketWrite uint8 = 0x02 PacketReturnACK uint8 = 0x3 UndefinedPacketType uint8 = 0x4 ) const ( TransferChannel = iota RecvAckChannel TotalQueueNum ) const NetworkDefaultTimeout = 60 * time.Second type Packet struct { magic uint16 version uint8 typeOf uint8 crc32 uint32 length uint32 payload []byte } func NewPacketV1(packetType uint8, payload []byte) *Packet { return &Packet{magic: MagicNumber, version: CurrentVersion, typeOf: packetType, length: uint32(len(payload)), payload: payload} } func (packet *Packet) setPayload(payload []byte) { packet.payload = payload packet.length = uint32(len(payload)) } func (packet *Packet) encode() []byte { buffer := bytes.Buffer{} binary.Write(&buffer, binary.BigEndian, packet.magic) binary.Write(&buffer, binary.BigEndian, packet.version) binary.Write(&buffer, binary.BigEndian, packet.typeOf) // TODO: now crc32 is marked zero binary.Write(&buffer, binary.BigEndian, packet.crc32) binary.Write(&buffer, binary.BigEndian, packet.length) buffer.Write(packet.payload) nimo.AssertTrue(buffer.Len() == (HeaderLen+len(packet.payload)), "write packet header length is bad") return buffer.Bytes() } func (packet *Packet) decodeHeader(buffer []byte) bool { nimo.AssertTrue(len(buffer) == HeaderLen, "read packet header length is bad") buf := bytes.NewBuffer(buffer) binary.Read(buf, binary.BigEndian, &packet.magic) binary.Read(buf, binary.BigEndian, &packet.version) binary.Read(buf, binary.BigEndian, &packet.typeOf) binary.Read(buf, binary.BigEndian, &packet.crc32) binary.Read(buf, binary.BigEndian, &packet.length) return packet.valid() } func (packet *Packet) valid() bool { return packet.magic == MagicNumber && packet.version == CurrentVersion && packet.typeOf < UndefinedPacketType } func (packet *Packet) String() string { return fmt.Sprintf("[magic:%d, ver:%d, type:%d, crc:%d, len:%d]", packet.magic, packet.version, packet.typeOf, packet.crc32, packet.length) } type TCPWriter struct { RemoteAddr string // for tcp stream channel channel [2]*TcpSocket ack int64 } type TcpSocket struct { addr *net.TCPAddr socket *net.TCPConn } func (tcp *TcpSocket) ensureNetwork() error { if tcp.socket == nil { var err error tcp.socket, err = net.DialTCP("tcp4", nil, tcp.addr) if err != nil { LOG.Critical("channel connect to %s error %s", tcp.addr.String(), err.Error()) return err } tcp.socket.SetNoDelay(false) // linger policy is not required. our data kept in sender util acked tcp.socket.SetLinger(0) // default 16K. we set 16MB tcp.socket.SetWriteBuffer(1024 * 1024 * 16) } return nil } func (tcp *TcpSocket) release() { tcp.socket.Close() tcp.socket = nil } func (tunnel *TCPWriter) Name() string { return "rpc" } func (writer *TCPWriter) pollRemoteAckValue() { queryAck := NewPacketV1(PacketGetACK, nil).encode() header := [HeaderLen]byte{} tcp := writer.channel[RecvAckChannel] nimo.GoRoutineInLoop(func() { defer utils.DelayFor(1000) if tcp.ensureNetwork() != nil { return } // send get ack request socketTimeout(tcp.socket, NetworkDefaultTimeout) tcp.socket.Write(queryAck) // read util we got a entire header if _, err := io.ReadAtLeast(tcp.socket, header[:], HeaderLen); err != nil { tcpErrorAndRelease(tcp, err.Error()) return } result := NewPacketV1(PacketIncomplete, nil) if !result.decodeHeader(header[:]) { tcpErrorAndRelease(tcp, "decode header failed") return } nimo.AssertTrue(result.typeOf == PacketReturnACK && result.length != 4, "acker receive bad type queryAck") // it's bad response if length < 4 payload := make([]byte, result.length) if _, err := io.ReadAtLeast(tcp.socket, payload, int(result.length)); err != nil { tcpErrorAndRelease(tcp, err.Error()) return } result.setPayload(payload) nimo.AssertTrue(result.length == uint32(len(result.payload)) && len(result.payload) != 0, "acker receive bad payload queryAck") binary.Read(bytes.NewBuffer(result.payload), binary.BigEndian, &writer.ack) }) } func (writer *TCPWriter) Send(message *WMessage) int64 { tcp := writer.channel[TransferChannel] var err error if err = tcp.ensureNetwork(); err != nil { return ReplyNetworkOpFail } message.Tag |= MsgResident packet := NewPacketV1(PacketWrite, message.ToBytes(binary.BigEndian)) // TODO: no timeout ?? socketTimeout(tcp.socket, 0) if _, err = tcp.socket.Write(packet.encode()); err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { LOG.Warn("Tcp writer send data packet timeout") return ReplyNetworkTimeout } tcp.release() return ReplyNetworkOpFail } return writer.ack } func (writer *TCPWriter) Prepare() bool { var err error writer.channel = [2]*TcpSocket{new(TcpSocket), new(TcpSocket)} for i := 0; i != TotalQueueNum; i++ { writer.channel[i].addr, err = net.ResolveTCPAddr("tcp4", writer.RemoteAddr) if err != nil { LOG.Critical("Resolve channel listenAddress error: %s", err.Error()) return false } } writer.channel[RecvAckChannel].addr.Port = writer.channel[TransferChannel].addr.Port + 1 // continuously update the ACK value via separate socket writer.pollRemoteAckValue() if !InitialStageChecking { return true } for _, ch := range writer.channel { if err = ch.ensureNetwork(); err != nil { return false } } return true } func (writer *TCPWriter) AckRequired() bool { return true } func (writer *TCPWriter) ParsedLogsRequired() bool { return false } func socketTimeout(socket *net.TCPConn, duration time.Duration) { if duration != 0 { socket.SetWriteDeadline(time.Now().Add(duration)) } } func tcpErrorAndRelease(socket *TcpSocket, err string) { LOG.Critical("tcp operation error and release, %s", err) socket.release() }