cmd/core_plugin/wsfchealthcheck/wsfchealthcheck.go (127 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package wsfchealthcheck implements an agent that is used to support Windows
// Server Failover Cluster (WSFC) in GCE. The agent will listen on a TCP port
// and respond to health check requests from the WSFC cluster. Agent checks if
// the IP address in the request is present on any of the interfaces and return
// a response accordingly.
package wsfchealthcheck
import (
"context"
"fmt"
"net"
"github.com/GoogleCloudPlatform/galog"
"github.com/GoogleCloudPlatform/google-guest-agent/cmd/core_plugin/manager"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/cfg"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/events"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/metadata"
)
type overrideContextKey string
const (
// wsfcModuleID is the ID of the WSFC health check module.
wsfcModuleID = "wsfc-health-check"
// wsfcDefaultAgentPort is the default port where agent listens on for health
// check requests.
wsfcDefaultAgentPort = "59998"
// tcpProtocol is the protocol used for health check connections.
tcpProtocol = "tcp"
// overrideIPExistCheck is the context key for overriding the IP check in unit
// tests.
overrideIPExistCheck overrideContextKey = "override-ip-check"
)
// NewModule returns a new WSFC health check module for late registration.
func NewModule(context.Context) *manager.Module {
m := newWsfcManager(connectOpts{protocol: tcpProtocol})
// Register the cert refresher module.
return &manager.Module{
ID: wsfcModuleID,
Setup: m.moduleSetup,
Quit: m.teardown,
}
}
// moduleSetup is the initialization function for wsfc module that registers
// itself to listen MDS events.
func (wm *wsfcManager) moduleSetup(ctx context.Context, _ any) error {
sub := events.EventSubscriber{Name: wsfcModuleID, Callback: wm.metadataSubscriber}
events.FetchManager().Subscribe(metadata.LongpollEvent, sub)
return nil
}
// teardown unsubscribes the wsfc module from listening any new MDS events.
func (wm *wsfcManager) teardown(ctx context.Context) {
events.FetchManager().Unsubscribe(metadata.LongpollEvent, wsfcModuleID)
if err := wm.agent.stop(ctx); err != nil {
galog.Errorf("Failed to stop wsfc agent: %v", err)
}
}
// wsfcManager is the handler for the health check agent.
type wsfcManager struct {
// agent is the health check agent implementation reference.
agent healthCheck
}
// isWsfcEnabled returns true if its set in instance config file or instance
// or project level metadata attributes. Order of precedence is instance config,
// instance metadata then project metadata. By default its disabled. Note that
// if its enabled via config file agent expects address to be set as well.
func isWsfcEnabled(desc *metadata.Descriptor) bool {
config := cfg.Retrieve()
if config.WSFC != nil && config.WSFC.Enable && config.WSFC.Addresses != "" {
return true
}
if desc.Instance().Attributes().EnableWSFC() != nil {
return *desc.Instance().Attributes().EnableWSFC()
}
if desc.Instance().Attributes().WSFCAddresses() != "" {
return true
}
if desc.Project().Attributes().EnableWSFC() != nil {
return *desc.Project().Attributes().EnableWSFC()
}
if desc.Project().Attributes().WSFCAddresses() != "" {
return true
}
return false
}
// listenerAddr returns the address where agent should listens on.
func listenerAddr(desc *metadata.Descriptor) string {
config := cfg.Retrieve()
if config.WSFC != nil && config.WSFC.Port != "" {
return config.WSFC.Port
}
if port := desc.Instance().Attributes().WSFCAgentPort(); port != "" {
return port
}
if port := desc.Project().Attributes().WSFCAgentPort(); port != "" {
return port
}
return wsfcDefaultAgentPort
}
// newWsfcManager returns a new wsfcManager instance.
func newWsfcManager(opts connectOpts) *wsfcManager {
return &wsfcManager{agent: newWSFCAgent(opts)}
}
// reset resets the wsfc agent state if required.
func (wm *wsfcManager) reset(ctx context.Context, desc *metadata.Descriptor) error {
newAddr := listenerAddr(desc)
newState := isWsfcEnabled(desc)
galog.Debugf("WSFC enabled: %t, on address: %s", newState, newAddr)
// If WSFC is disabled or listener address has changed stop the currently
// running agent.
if !newState || newAddr != wm.agent.address() {
if err := wm.agent.stop(ctx); err != nil {
return fmt.Errorf("failed to stop agent: %w", err)
}
}
if !newState {
return nil
}
// If WSFC is enabled or listener address has changed start the agent.
if newAddr != wm.agent.address() {
wm.agent.setAddress(newAddr)
}
if err := wm.agent.run(ctx); err != nil {
return fmt.Errorf("failed to run agent: %w", err)
}
return nil
}
// metadataSubscriber is the callback function for MDS events, any new MDS
// response will trigger it. Always return true to continue listening.
func (wm *wsfcManager) metadataSubscriber(ctx context.Context, evType string, data any, evData *events.EventData) bool {
// There could be transient errors with MDS, just log and continue.
if evData.Error != nil {
galog.Debugf("Metadata event watcher reported error: %s, skiping.", evData.Error)
return true
}
desc, ok := evData.Data.(*metadata.Descriptor)
// If the event manager is passing a non expected data type log it and
// don't renew the subscriber.
if !ok {
galog.Errorf("Metadata event watcher reported data type %T, expected *metadata.Descriptor", evData.Data)
return false
}
if err := wm.reset(ctx, desc); err != nil {
galog.Errorf("Failed to change wsfc agent state: %v", err)
}
return true
}
// checkIPExist returns 1 if IP exists on any of the interfaces otherwise 0.
// 0/1 is based off of the protocol and the values expected by the server.
func checkIPExist(ctx context.Context, ip string) (string, error) {
if got := ctx.Value(overrideIPExistCheck); got != nil {
return got.(string), nil
}
addrs, err := net.InterfaceAddrs()
if err != nil {
return "0", err
}
for _, address := range addrs {
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ip == ipnet.IP.String() {
return "1", nil
}
}
}
return "0", nil
}