func MultiSAR()

in tc-health-client/sar/multisar.go [52:223]


func MultiSAR(log llog.Log, hosts []HostPort, timeout time.Duration) ([]SARResult, error) {
	log = llog.LibInit(log)

	localAddrStr, err := GetLocalAddr()
	if err != nil {
		return nil, errors.New("getting local address: " + err.Error())
	}

	localAddr := net.ParseIP(localAddrStr)
	if localAddr == nil {
		return nil, errors.New("failed to parse local addr '" + localAddrStr + "' as IP")
	}
	if v4 := localAddr.To4(); v4 != nil {
		localAddr = v4
	}

	ephemeralPortHolder, err := GetAndHoldEphemeralPort(localAddrStr)
	if err != nil {
		return nil, errors.New("failed to listen on ephemeral port: " + err.Error())
	}
	defer ephemeralPortHolder.Close()

	srcPort := ephemeralPortHolder.Port()

	// pre-construct all the packets, so we listen for as little time as possible

	// TODO implement Initial Sequence Number ISN per RFC9293§3.4.1? It might be faster to use the same seq num for all packets
	seqNum := uint32(42)

	type HostPortPacket struct {
		HostPort
		TCPHdr TCPHdr
	}

	packets := []HostPortPacket{}

	hostAddr := map[string]string{}    // map[host]addr - note hosts may be IPs, in which case host and addr will be the same
	addrHosts := map[string][]string{} // map[addr][]host - note multiple FQDNs may have the same IP

	results := []SARResult{}

	for _, host := range hosts {
		makeHostErrResult := func(err error) SARResult {
			return SARResult{
				Host: host.Host,
				Port: host.Port,
				RTT:  0,
				Err:  err,
			}
		}

		remoteAddr := net.ParseIP(host.Host)
		if remoteAddr != nil {
			// host is IP
			hostAddr[host.Host] = host.Host
			addrHosts[host.Host] = append(addrHosts[host.Host], host.Host)
		} else {
			// host isn't an IP, assume FQDN
			addrs, err := net.LookupHost(host.Host)
			if err != nil {
				results = append(results, makeHostErrResult(errors.New("lookup up host '"+host.Host+"': "+err.Error())))
				continue
			}
			if len(addrs) == 0 {
				results = append(results, makeHostErrResult(errors.New("looking up host '"+host.Host+"' succeeded, but no addresses were found.")))
				continue
			}
			remoteAddr = net.ParseIP(addrs[0])
			if remoteAddr == nil {
				results = append(results, makeHostErrResult(errors.New("failed to parse addr '"+host.Host+"' ip '"+addrs[0]+"' as IP")))
				continue
			}

			hostAddr[host.Host] = addrs[0]
			addrHosts[addrs[0]] = append(addrHosts[addrs[0]], host.Host)
		}

		if v4 := remoteAddr.To4(); v4 != nil {
			remoteAddr = v4
		}

		// TODO handle IPv6

		window := 256 * 10
		destPort := host.Port
		dataOffset := 5 // because we have no options?
		native := TCPHdrNative{
			SrcPort:    uint16(srcPort),
			DestPort:   uint16(destPort),
			SeqNum:     seqNum,
			DataOffset: uint8(dataOffset), // 4 bits
			SYN:        true,
			Window:     uint16(window),
		}
		hdrBts, err := TCPHdrFromNative(native)
		if err != nil {
			return nil, errors.New("converting native header to byte: " + err.Error())
		}
		hdrBts.SetChecksum(MakeTCPChecksum(hdrBts, localAddr, remoteAddr))
		packets = append(packets, HostPortPacket{HostPort: host, TCPHdr: hdrBts})
	}

	remoteInf := []sarRemoteInf{}
	// Note we need to iterate over hostAddr, *not* hosts.
	// The hosts has all, but we may have failed to resolve some, in which case they won't be sent and we must not listen for their responses.
	for _, host := range hosts {
		addr, ok := hostAddr[host.Host]
		if !ok {
			continue // if it's not in hostAddr, we already added an error result and we won't send a packet, so don't listen for it
		}
		remoteInf = append(remoteInf, sarRemoteInf{
			Host:   host.Host,
			Addr:   addr,
			Port:   host.Port,
			AckNum: seqNum + 1,
		})
	}

	sarListenerResp := []sarListenerResp{}
	sarListenerErr := error(nil)
	wg := sync.WaitGroup{}
	wg.Add(1)
	doneSending := make(chan struct{}, 1) // we don't want to start the timeout until after we send the last packet. Buffer 1, the main sending thread doesn't block
	go func() {
		sarListenerResp, sarListenerErr = sarListener(log, localAddrStr, localAddr, remoteInf, timeout, srcPort, doneSending)
		ephemeralPortHolder.Close() // this isn't strictly necessary, the defer will close this shortly. But this closes it ASAP
		wg.Done()
	}()

	sendTimes := map[HostPort]time.Time{}

	packetSendStart := time.Now()
	for _, packet := range packets {
		sendTime, err := SendPacket(packet.TCPHdr, packet.HostPort.Host)
		if err != nil {
			return nil, errors.New("sending packet: " + err.Error())
		}
		sendTimes[packet.HostPort] = sendTime
	}

	log.Infof("multisar main thread sent %v packets in %vms\n", len(packets), time.Since(packetSendStart)/time.Millisecond)

	doneSending <- struct{}{} // send the doneSending message to the listener, so it sets the timeout
	wg.Wait()                 // wait for the listener to return and set sarListenerResp and sarListenerErr

	if sarListenerErr != nil {
		return nil, errors.New("listening for ACKs: " + sarListenerErr.Error())
	}

	for _, listenResp := range sarListenerResp {
		hosts := addrHosts[listenResp.Addr] // multiple FQDNs may have the same IP, but at L4 we only care about the behavior of the IP
		for _, host := range hosts {
			sendTime, ok := sendTimes[HostPort{Host: host, Port: listenResp.Port}]
			if !ok {
				log.Errorf("SAR listener got packet that was never sent somehow! Should never happen! Response: %+v\n", listenResp)
				continue
			}
			roundTripTime := time.Duration(0)
			if listenResp.Err == nil {
				// this check shouldn't be necessary, the caller should never look at duration if err!=nil. This just makes it easier to debug if they do
				roundTripTime = listenResp.RespTime.Sub(sendTime)
			}
			results = append(results, SARResult{
				Host: host,
				Port: listenResp.Port,
				RTT:  roundTripTime,
				Err:  listenResp.Err,
			})
		}
	}
	return results, nil
}