internal/remoting/handshake.go (82 lines of code) (raw):

/* * Copyright (c) 2023 Alibaba Group Holding Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package remoting import ( "context" "encoding/binary" "fmt" "io" "net" "time" "google.golang.org/protobuf/proto" "github.com/alibaba/schedulerx-worker-go/internal/constants" "github.com/alibaba/schedulerx-worker-go/internal/proto/akka" "github.com/alibaba/schedulerx-worker-go/internal/remoting/trans" "github.com/alibaba/schedulerx-worker-go/internal/utils" "github.com/alibaba/schedulerx-worker-go/logger" ) func Handshake(ctx context.Context, conn net.Conn) error { if err := sendHandshake(ctx, conn); err != nil { return fmt.Errorf("Write handshake to remote failed, err=%s. ", err.Error()) } waitRespTimeout := 5 * time.Second waitTimeout := time.After(waitRespTimeout) for { select { case <-waitTimeout: return fmt.Errorf("Wait handshake response timeout, timeout=%s ", waitRespTimeout.String()) default: var dataLen uint32 hdrBuf := make([]byte, constants.TransportHeaderSize) n, err := io.ReadFull(conn, hdrBuf) if err == io.EOF { continue } if n < constants.TransportHeaderSize { logger.Errorf("Read header from connection failed, read bytes=%d but expect bytes=%d", n, constants.TransportHeaderSize) continue } dataLen = binary.BigEndian.Uint32(hdrBuf) dataBuf := make([]byte, dataLen) n, err = io.ReadFull(conn, dataBuf) if err == io.EOF { continue } if n < int(dataLen) { logger.Errorf("Read payload from connection failed, read bytes=%d but expect bytes=%d", n, dataLen) continue } msg, err := trans.ReadAkkaMsg(dataBuf) if err != nil { return fmt.Errorf("handshake read akka msg err=%+v ", err) } if controlMsg := msg.Instruction; controlMsg != nil && controlMsg.CommandType != nil { if int32(*controlMsg.CommandType) == int32(akka.CommandType_ASSOCIATE) { logger.Infof("Receive handshake msg, msg=%+v ", controlMsg) return nil } } else { return fmt.Errorf("Receive unknown msg type when wait handshake response, msg=%+v ", msg) } } } } func sendHandshake(ctx context.Context, conn net.Conn) error { host, port, err := utils.ParseIPAddr(conn.LocalAddr().String()) if err != nil { return err } akkaMsg := &akka.AkkaProtocolMessage{ Instruction: &akka.AkkaControlMessage{ CommandType: akka.CommandType_ASSOCIATE.Enum(), HandshakeInfo: &akka.AkkaHandshakeInfo{ Origin: &akka.AddressData{ System: proto.String(utils.GetWorkerId()), Hostname: proto.String(host), Port: proto.Uint32(uint32(port)), Protocol: proto.String("tcp"), }, Uid: proto.Uint64(utils.GetHandshakeUid()), }, }, } return trans.WriteAkkaMsg(akkaMsg, conn) }