network/tools/wireguard-scale/collector/main.go (72 lines of code) (raw):
package main
import (
"flag"
"fmt"
"log"
"net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
wgPeersTotal = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "wireguard_peers_total",
Help: "Total number of WireGuard peers",
})
wgAllowedIPsTotal = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "wireguard_allowed_ips_total",
Help: "Total number of allowed IPs across all peers",
})
wgLastHandshake = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "wireguard_last_handshake",
Help: "Verifies the handshake status of a peer",
})
)
// Fetch WireGuard metrics and update Prometheus metrics
func collectMetrics(client *wgctrl.Client, ifname string) {
device, err := client.Device(ifname)
if err != nil {
log.Fatalf("Failed to fetch WireGuard devices: %v", err)
}
totalPeers := 0
totalAllowedIPs := 0
for _, peer := range device.Peers {
totalPeers++
totalAllowedIPs += len(peer.AllowedIPs)
duration := time.Since(peer.LastHandshakeTime)
if duration > 180*time.Second {
//Issue with handshake
wgLastHandshake.Set(0)
} else {
wgLastHandshake.Set(1)
}
}
wgPeersTotal.Set(float64(totalPeers))
wgAllowedIPsTotal.Set(float64(totalAllowedIPs))
}
func main() {
var port, poll int
var address, ifname string
flag.StringVar(&ifname, "ifname", "cilium_wg0", "name of the wireguard interface")
flag.StringVar(&address, "address", "0.0.0.0", "address for metrics endpoint")
flag.IntVar(&poll, "poll", 60, "polling interval in seconds")
flag.IntVar(&port, "port", 8080, "metrics port")
flag.Parse()
// Register Prometheus metrics
prometheus.MustRegister(wgPeersTotal)
prometheus.MustRegister(wgAllowedIPsTotal)
prometheus.MustRegister(wgLastHandshake)
// Start the metrics collection loop
go func() {
client, err := wgctrl.New()
if err != nil {
log.Fatalf("Failed to create WireGuard client: %v", err)
}
defer client.Close()
for {
fmt.Println("polling the wireguard interface")
collectMetrics(client, ifname)
time.Sleep(time.Duration(poll) * time.Second)
}
}()
// Serve Prometheus metrics
http.Handle("/metrics", promhttp.Handler())
fmt.Printf("WireGuard Prometheus Exporter running on :%d\n", port)
log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", address, port), nil))
}