pkg/testutils/kustainer/kustainer.go (297 lines of code) (raw):
//go:build !disableDocker
package kustainer
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Azure/adx-mon/pkg/testutils"
"github.com/Azure/azure-kusto-go/kusto"
"github.com/Azure/azure-kusto-go/kusto/kql"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/k3s"
"github.com/testcontainers/testcontainers-go/wait"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
kwait "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
)
type KustainerContainer struct {
testcontainers.Container
endpoint string
stop chan struct{}
}
func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*KustainerContainer, error) {
req := testcontainers.ContainerRequest{
Image: img,
ExposedPorts: []string{"8080/tcp"},
Env: map[string]string{
"ACCEPT_EULA": "Y",
},
WaitingFor: wait.ForAll(
wait.ForListeningPort("8080/tcp"),
wait.ForLog(".*Hit 'CTRL-C' or 'CTRL-BREAK' to quit.*").AsRegexp(),
),
}
genericContainerReq := testcontainers.GenericContainerRequest{
ContainerRequest: req,
}
for _, opt := range opts {
if err := opt.Customize(&genericContainerReq); err != nil {
return nil, err
}
}
container, err := testcontainers.GenericContainer(ctx, genericContainerReq)
var c *KustainerContainer
if container != nil {
c = &KustainerContainer{Container: container}
}
if err != nil {
return c, fmt.Errorf("generic container: %w", err)
}
return c, nil
}
func (c *KustainerContainer) PortForward(ctx context.Context, config *rest.Config) error {
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}
var podName string
// Wait for pod to exist and be running/ready
err = kwait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Minute, true, func(ctx context.Context) (bool, error) {
pods, err := clientset.CoreV1().Pods("default").List(ctx, metav1.ListOptions{
LabelSelector: "app=kustainer",
})
if err != nil || len(pods.Items) == 0 {
return false, nil
}
pod := pods.Items[0]
if pod.Status.Phase != corev1.PodRunning {
return false, nil
}
for _, cs := range pod.Status.ContainerStatuses {
if !cs.Ready {
return false, nil
}
}
podName = pod.Name
return true, nil
})
if err != nil {
return fmt.Errorf("failed to get ready pod: %w", err)
}
err = kwait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Minute, true, func(ctx context.Context) (bool, error) {
if err := c.waitForLog(ctx, clientset, podName, "Hit 'CTRL-C' or 'CTRL-BREAK' to quit"); err != nil {
return false, nil
}
return true, nil
})
if err != nil {
return fmt.Errorf("failed create container: %w", err)
}
// Retry port-forward on failure, with backoff
var lastErr error
for i := range 5 {
err = c.connect(ctx, config, podName)
if err == nil {
return nil
}
lastErr = err
// Exponential backoff
time.Sleep(time.Second * time.Duration(2<<i))
}
return fmt.Errorf("failed to connect to kustainer after retries: %w", lastErr)
}
func (c *KustainerContainer) Close() error {
if c.stop != nil {
close(c.stop)
}
return nil
}
func WithCluster(ctx context.Context, k *k3s.K3sContainer) testcontainers.CustomizeRequestOption {
return func(req *testcontainers.GenericContainerRequest) error {
req.LifecycleHooks = append(req.LifecycleHooks, testcontainers.ContainerLifecycleHooks{
PreCreates: []testcontainers.ContainerRequestHook{
func(ctx context.Context, req testcontainers.ContainerRequest) error {
rootDir, err := testutils.GetGitRootDir()
if err != nil {
return fmt.Errorf("failed to get git root dir: %w", err)
}
lfp := filepath.Join(rootDir, "pkg/testutils/kustainer/k8s.yaml")
rfp := filepath.Join(testutils.K3sManifests, "kustainer.yaml")
if err := k.CopyFileToContainer(ctx, lfp, rfp, 0644); err != nil {
return fmt.Errorf("failed to copy file to container: %w", err)
}
return nil
},
},
})
return nil
}
}
// WithStarted will start the container when it is created.
// You don't want to do this if you want to load the container into a k8s cluster.
func WithStarted() testcontainers.CustomizeRequestOption {
return func(req *testcontainers.GenericContainerRequest) error {
req.Started = true
return nil
}
}
func (c *KustainerContainer) waitForLog(ctx context.Context, client *kubernetes.Clientset, podName, logMsg string) error {
req := client.CoreV1().Pods("default").GetLogs(podName, &corev1.PodLogOptions{
Follow: true,
})
stream, err := req.Stream(ctx)
if err != nil {
return fmt.Errorf("failed to stream logs: %w", err)
}
defer stream.Close()
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, logMsg) {
return nil
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading log stream: %w", err)
}
return nil
}
func (c *KustainerContainer) connect(ctx context.Context, config *rest.Config, podName string) error {
// Create port forward
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
return fmt.Errorf("failed to create round tripper: %w", err)
}
path := fmt.Sprintf("/api/v1/namespaces/default/pods/%s/portforward", podName)
hostIP := strings.TrimLeft(config.Host, "htps:/")
serverURL := url.URL{Scheme: "https", Path: path, Host: hostIP}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, &serverURL)
ports := []string{"0:8080"}
stopChan, readyChan := make(chan struct{}, 1), make(chan struct{})
out, errOut := threadsafeBuffer{}, threadsafeBuffer{}
pf, err := portforward.New(dialer, ports, stopChan, readyChan, &out, &errOut)
if err != nil {
return fmt.Errorf("failed to create port forward: %w", err)
}
go func() {
if err := pf.ForwardPorts(); err != nil {
fmt.Fprintf(os.Stderr, "port forward error: %v\n", err)
}
}()
select {
case <-readyChan:
// Port forward is ready, get the local port
localPorts, err := pf.GetPorts()
if err != nil {
return fmt.Errorf("failed to get local port: %w", err)
}
endpoint := fmt.Sprintf("http://localhost:%d", localPorts[0].Local)
c.stop = stopChan
c.endpoint = endpoint
return nil
case <-time.After(10 * time.Second):
close(stopChan)
return fmt.Errorf("timeout waiting for port forward")
case <-ctx.Done():
close(stopChan)
return fmt.Errorf("timeout waiting for port forward")
}
}
func (c *KustainerContainer) ConnectionUrl() string {
if c.endpoint == "" {
// This means we're running out-of-cluster
port, err := c.MappedPort(context.Background(), "8080")
if err != nil {
return ""
}
return "http://localhost:" + port.Port()
}
return c.endpoint
}
func (c *KustainerContainer) CreateDatabase(ctx context.Context, dbName string) error {
cb := kusto.NewConnectionStringBuilder(c.endpoint)
client, err := kusto.New(cb)
if err != nil {
return fmt.Errorf("new kusto client: %w", err)
}
defer client.Close()
stmt := kql.New(".create database ").AddUnsafe(dbName).AddLiteral(" volatile")
_, err = client.Mgmt(ctx, "", stmt)
if err != nil && !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("create database %s: %w", dbName, err)
}
return nil
}
type IngestionBatchingPolicy struct {
MaximumBatchingTimeSpan time.Duration `json:"MaximumBatchingTimeSpan,omitempty"`
MaximumNumberOfItems int `json:"MaximumNumberOfItems,omitempty"`
MaximumRawDataSizeMB int `json:"MaximumRawDataSizeMB,omitempty"`
}
// MarshalJSON customizes the JSON representation of IngestionBatchingPolicy
func (p IngestionBatchingPolicy) MarshalJSON() ([]byte, error) {
type Alias IngestionBatchingPolicy
return json.Marshal(&struct {
MaximumBatchingTimeSpan string `json:"MaximumBatchingTimeSpan"`
*Alias
}{
MaximumBatchingTimeSpan: fmt.Sprintf("%02d:%02d:%02d", int(p.MaximumBatchingTimeSpan.Hours()), int(p.MaximumBatchingTimeSpan.Minutes())%60, int(p.MaximumBatchingTimeSpan.Seconds())%60),
Alias: (*Alias)(&p),
})
}
// UnmarshalJSON customizes the JSON unmarshalling of IngestionBatchingPolicy
func (p *IngestionBatchingPolicy) UnmarshalJSON(data []byte) error {
type Alias IngestionBatchingPolicy
aux := &struct {
MaximumBatchingTimeSpan string `json:"MaximumBatchingTimeSpan"`
*Alias
}{
Alias: (*Alias)(p),
}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
duration, err := time.ParseDuration(aux.MaximumBatchingTimeSpan)
if err != nil {
return err
}
p.MaximumBatchingTimeSpan = duration
return nil
}
// Custom String method to format duration as "00:00:01"
func (p IngestionBatchingPolicy) String() string {
return fmt.Sprintf("%02d:%02d:%02d", int(p.MaximumBatchingTimeSpan.Hours()), int(p.MaximumBatchingTimeSpan.Minutes())%60, int(p.MaximumBatchingTimeSpan.Seconds())%60)
}
func (c *KustainerContainer) SetIngestionBatchingPolicy(ctx context.Context, dbName string, p IngestionBatchingPolicy) error {
cb := kusto.NewConnectionStringBuilder(c.endpoint)
client, err := kusto.New(cb)
if err != nil {
return fmt.Errorf("new kusto client: %w", err)
}
defer client.Close()
policy, err := json.Marshal(p)
if err != nil {
return fmt.Errorf("marshal policy: %w", err)
}
stmt := kql.New(".alter database ").AddUnsafe(dbName).AddLiteral(" policy ingestionbatching").
AddLiteral("```").AddUnsafe(string(policy)).AddLiteral("```")
_, err = client.Mgmt(ctx, "", stmt)
if err != nil {
return fmt.Errorf("create database %s: %w", dbName, err)
}
return nil
}
type threadsafeBuffer struct {
sync.Mutex
buffer bytes.Buffer
}
func (b *threadsafeBuffer) Write(p []byte) (n int, err error) {
b.Lock()
defer b.Unlock()
return b.buffer.Write(p)
}
func (b *threadsafeBuffer) Read(p []byte) (n int, err error) {
b.Lock()
defer b.Unlock()
return b.buffer.Read(p)
}