lib/handler.go (300 lines of code) (raw):

/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ package dhcplb import ( "errors" "fmt" "net" "runtime/debug" "time" "github.com/golang/glog" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/ztpv4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/ztpv6" ) // List of possible errors. const ( ErrUnknown = "E_UNKNOWN" ErrPanic = "E_PANIC" ErrRead = "E_READ" ErrConnect = "E_CONN" ErrWrite = "E_WRITE" ErrGi0 = "E_GI_0" ErrParse = "E_PARSE" ErrNoServer = "E_NO_SERVER" ErrConnRate = "E_CONN_RATE" ) func (s *Server) handleConnection() { buffer := s.bufPool.Get().([]byte) bytesRead, peer, err := s.conn.ReadFromUDP(buffer) if err != nil || bytesRead == 0 { s.bufPool.Put(buffer) msg := "error reading from %s: %v" glog.Errorf(msg, peer, err) s.logger.LogErr(time.Now(), nil, nil, peer, ErrRead, err) return } go func() { defer func() { // always release this routine's buffer back to the pool s.bufPool.Put(buffer) if r := recover(); r != nil { glog.Errorf("Panicked handling v%d packet from %s: %s", s.config.Version, peer, r) glog.Errorf("Offending packet: %x", buffer[:bytesRead]) err, _ := r.(error) s.logger.LogErr(time.Now(), nil, nil, peer, ErrPanic, err) glog.Errorf("%s: %s", r, debug.Stack()) } }() if s.config.Version == 4 { s.handleRawPacketV4(buffer[:bytesRead], peer) } else if s.config.Version == 6 { s.handleRawPacketV6(buffer[:bytesRead], peer) } }() } func selectDestinationServer(config *Config, message *DHCPMessage) (*DHCPServer, error) { server, err := handleOverride(config, message) if err != nil { glog.Errorf("Error handling override, drop due to: %s", err) return nil, err } if server == nil { server, err = config.Algorithm.SelectRatioBasedDhcpServer(message) } return server, err } func handleOverride(config *Config, message *DHCPMessage) (*DHCPServer, error) { if override, ok := config.Overrides[message.Mac.String()]; ok { // Checking if override is expired. If so, ignore it. Expiration field should // be a timestamp in the following format "2006/01/02 15:04 -0700". // For example, a timestamp in UTC would look as follows: "2017/05/06 14:00 +0000". var err error var expiration time.Time if override.Expiration != "" { expiration, err = time.Parse("2006/01/02 15:04 -0700", override.Expiration) if err != nil { glog.Errorf("Could not parse override expiration for MAC %s: %s", message.Mac.String(), err.Error()) return nil, nil } if time.Now().After(expiration) { glog.Errorf("Override rule for MAC %s expired on %s, ignoring", message.Mac.String(), expiration.Local()) return nil, nil } } if override.Expiration == "" { glog.Infof("Found override rule for %s without expiration", message.Mac.String()) } else { glog.Infof("Found override rule for %s, it will expire on %s", message.Mac.String(), expiration.Local()) } var server *DHCPServer if len(override.Host) > 0 { server, err = handleHostOverride(config, override.Host) } else if len(override.Tier) > 0 { server, err = handleTierOverride(config, override.Tier, message) } if err != nil { return nil, err } if server != nil { return server, nil } glog.Infof("Override didn't have host or tier, this shouldn't happen, proceeding with normal server selection") } return nil, nil } func handleHostOverride(config *Config, host string) (*DHCPServer, error) { addr := net.ParseIP(host) if addr == nil { return nil, fmt.Errorf("Failed to get IP for overridden host %s", host) } port := 67 if config.Version == 6 { port = 547 } server := NewDHCPServer(host, addr, port) return server, nil } func handleTierOverride(config *Config, tier string, message *DHCPMessage) (*DHCPServer, error) { servers, err := config.HostSourcer.GetServersFromTier(tier) if err != nil { return nil, fmt.Errorf("Failed to get servers from tier: %s", err) } if len(servers) == 0 { return nil, fmt.Errorf("Sourcer returned no servers") } // pick server according to the configured algorithm server, err := config.Algorithm.SelectServerFromList(servers, message) if err != nil { return nil, fmt.Errorf("Failed to select server: %s", err) } return server, nil } func (s *Server) sendToServer(start time.Time, server *DHCPServer, packet []byte, peer *net.UDPAddr) error { // Check for connection rate ok, err := s.throttle.OK(server.Address.String()) if !ok { glog.Errorf("Error writing to server %s, drop due to throttling", server.Hostname) s.logger.LogErr(time.Now(), server, packet, peer, ErrConnRate, err) return err } _, err = s.conn.WriteTo(packet, server.udpAddr()) if err != nil { glog.Errorf("Error writing to server %s, drop due to %s", server.Hostname, err) s.logger.LogErr(start, server, packet, peer, ErrWrite, err) return err } err = s.logger.LogSuccess(start, server, packet, peer) if err != nil { glog.Errorf("Failed to log request: %s", err) } return nil } func (s *Server) handleRawPacketV4(buffer []byte, peer *net.UDPAddr) { // runs in a separate go routine start := time.Now() var message DHCPMessage packet, err := dhcpv4.FromBytes(buffer) if err != nil { glog.Errorf("Error encoding DHCPv4 packet: %s", err) s.logger.LogErr(start, nil, nil, peer, ErrParse, err) return } if s.server { s.handleV4Server(start, packet, peer) return } message.XID = packet.TransactionID[:] message.Peer = peer message.ClientID = packet.ClientHWAddr message.Mac = packet.ClientHWAddr if vd, err := ztpv4.ParseVendorData(packet); err != nil { glog.V(2).Infof("error parsing vendor data: %s", err) } else { message.Serial = vd.Serial } packet.HopCount++ server, err := selectDestinationServer(s.config, &message) if err != nil { glog.Errorf("%s, Drop due to %s", packet.Summary(), err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrNoServer, err) return } s.sendToServer(start, server, packet.ToBytes(), peer) } func (s *Server) handleV4Server(start time.Time, packet *dhcpv4.DHCPv4, peer *net.UDPAddr) { reply, err := s.config.Handler.ServeDHCPv4(packet) logErr := s.logger.LogSuccess(start, nil, packet.ToBytes(), peer) if logErr != nil { glog.Errorf("Failed to log incoming packet: %s", logErr) } if err != nil { glog.Errorf("Error creating reply %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, fmt.Sprintf("%T", err), err) return } addr := &net.UDPAddr{ IP: packet.GatewayIPAddr, Port: dhcpv4.ServerPort, } s.conn.WriteTo(reply.ToBytes(), addr) err = s.logger.LogSuccess(start, nil, reply.ToBytes(), peer) if err != nil { glog.Errorf("Failed to log reply: %s", err) } return } func (s *Server) handleRawPacketV6(buffer []byte, peer *net.UDPAddr) { // runs in a separate go routine start := time.Now() packet, err := dhcpv6.FromBytes(buffer) if err != nil { glog.Errorf("Error encoding DHCPv6 packet: %s", err) s.logger.LogErr(start, nil, nil, peer, ErrParse, err) return } if s.server { s.handleV6Server(start, packet, peer) return } if packet.Type() == dhcpv6.MessageTypeRelayReply { s.handleV6RelayRepl(start, packet, peer) return } var message DHCPMessage msg, err := packet.GetInnerMessage() if err != nil { glog.Errorf("Error getting inner message: %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrParse, err) return } message.XID = msg.TransactionID[:] message.Peer = peer duid := msg.Options.ClientID() if duid == nil { errMsg := errors.New("failed to extract Client ID") glog.Errorf("%v", errMsg) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrParse, errMsg) return } message.ClientID = duid.ToBytes() mac, err := dhcpv6.ExtractMAC(packet) if err != nil { glog.Errorf("Failed to extract MAC, drop due to %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrParse, err) return } message.Mac = mac if vendorData, err := ztpv6.ParseVendorData(msg); err != nil { glog.V(2).Infof("Failed to extract vendor data: %s", err) } else { message.Serial = vendorData.Serial } server, err := selectDestinationServer(s.config, &message) if err != nil { glog.Errorf("%s, Drop due to %s", packet.Summary(), err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrNoServer, err) return } relayMsg, err := dhcpv6.EncapsulateRelay(packet, dhcpv6.MessageTypeRelayForward, net.IPv6zero, peer.IP) s.sendToServer(start, server, relayMsg.ToBytes(), peer) } func (s *Server) handleV6RelayRepl(start time.Time, packet dhcpv6.DHCPv6, peer *net.UDPAddr) { // when we get a relay-reply, we need to unwind the message, removing the top // relay-reply info and passing on the inner part of the message msg, err := dhcpv6.DecapsulateRelay(packet) if err != nil { glog.Errorf("Failed to decapsulate packet, drop due to %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrParse, err) return } peerAddr := packet.(*dhcpv6.RelayMessage).PeerAddr // send the packet to the peer addr addr := &net.UDPAddr{ IP: peerAddr, Port: dhcpv6.DefaultServerPort, Zone: "", } conn, err := net.DialUDP("udp", s.config.ReplyAddr, addr) if err != nil { glog.Errorf("Error creating udp connection %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, ErrConnect, err) return } conn.Write(msg.ToBytes()) err = s.logger.LogSuccess(start, nil, packet.ToBytes(), peer) if err != nil { glog.Errorf("Failed to log request: %s", err) } conn.Close() return } func (s *Server) handleV6Server(start time.Time, packet dhcpv6.DHCPv6, peer *net.UDPAddr) { reply, err := s.config.Handler.ServeDHCPv6(packet) logErr := s.logger.LogSuccess(start, nil, packet.ToBytes(), peer) if logErr != nil { glog.Errorf("Failed to log incoming packet: %s", logErr) } if err != nil { glog.Errorf("Error creating reply %s", err) s.logger.LogErr(start, nil, packet.ToBytes(), peer, fmt.Sprintf("%T", err), err) return } addr := &net.UDPAddr{ IP: peer.IP, Port: dhcpv6.DefaultServerPort, } s.conn.WriteTo(reply.ToBytes(), addr) err = s.logger.LogSuccess(start, nil, reply.ToBytes(), peer) if err != nil { glog.Errorf("Failed to log reply: %s", err) } return }