cfg/optimize.go (205 lines of code) (raw):
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
package cfg
import (
"fmt"
"io"
"net/http"
"reflect"
"slices"
"strings"
"time"
"unicode"
)
////////////////////////////////////////////////////////////////////////
// Constants
////////////////////////////////////////////////////////////////////////
const (
maxRetries = 2
httpTimeout = 50 * time.Millisecond
machineTypeFlg = "machine-type"
)
////////////////////////////////////////////////////////////////////////
// Types
////////////////////////////////////////////////////////////////////////
type isValueSet interface {
IsSet(string) bool
GetString(string) string
GetBool(string) bool
}
// flagOverride represents a flag override with its new value.
type flagOverride struct {
newValue interface{}
}
// flagOverrideSet represents a named set of flag overrides.
type flagOverrideSet struct {
name string
overrides map[string]flagOverride
}
// machineType represents a specific machine type with associated flag overrides.
type machineType struct {
names []string
flagOverrideSetName string
}
// optimizationConfig holds the configuration for machine-specific optimizations.
type optimizationConfig struct {
flagOverrideSets []flagOverrideSet
machineTypes []machineType
}
////////////////////////////////////////////////////////////////////////
// Variables
////////////////////////////////////////////////////////////////////////
var (
// defaultOptimizationConfig provides a default configuration for optimizations.
defaultOptimizationConfig = optimizationConfig{
flagOverrideSets: []flagOverrideSet{
{
name: "high-performance",
overrides: map[string]flagOverride{
"metadata-cache.negative-ttl-secs": {newValue: 0},
"metadata-cache.ttl-secs": {newValue: -1},
"metadata-cache.stat-cache-max-size-mb": {newValue: 1024},
"metadata-cache.type-cache-max-size-mb": {newValue: 128},
"implicit-dirs": {newValue: true},
"file-system.rename-dir-limit": {newValue: 200000},
},
},
},
machineTypes: []machineType{
{
names: []string{
"a2-megagpu-16g", "a2-ultragpu-8g", "a3-edgegpu-8g", "a3-highgpu-8g", "a3-megagpu-8g", "a3-ultragpu-8g", "a4-highgpu-8g-lowmem",
"ct5l-hightpu-8t", "ct5lp-hightpu-8t", "ct5p-hightpu-4t", "ct5p-hightpu-4t-tpu", "ct6e-standard-4t", "ct6e-standard-4t-tpu", "ct6e-standard-8t", "ct6e-standard-8t-tpu"},
flagOverrideSetName: "high-performance",
},
// Add more machine types here as needed.
},
}
// metadataEndpoints are the endpoints to try for fetching metadata.
// Use an array to make provision for https endpoint in the future: https://cloud.google.com/compute/docs/metadata/querying-metadata#metadata_server_endpoints
metadataEndpoints = []string{
"http://metadata.google.internal/computeMetadata/v1/instance/machine-type",
}
)
////////////////////////////////////////////////////////////////////////
// Helper Functions
////////////////////////////////////////////////////////////////////////
// getMetadata fetches metadata from a given endpoint.
func getMetadata(client *http.Client, endpoint string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request for %s: %w", endpoint, err)
}
req.Header.Add("Metadata-Flavor", "Google")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request to %s failed: %w", endpoint, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("request to %s returned non-OK status: %d", endpoint, resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body from %s: %w", endpoint, err)
}
return body, nil
}
// getMachineType fetches the machine type from the metadata server.
func getMachineType(isSet isValueSet) (string, error) {
// Check if the machine-type flag is set and not empty.
if isSet.IsSet(machineTypeFlg) {
if currentMachineType := isSet.GetString(machineTypeFlg); currentMachineType != "" {
return currentMachineType, nil
}
}
client := http.Client{Timeout: httpTimeout}
for retry := 0; retry < maxRetries; retry++ {
for _, endpoint := range metadataEndpoints {
body, err := getMetadata(&client, endpoint)
if err != nil {
continue
}
currentMachineType := string(body)
parts := strings.Split(currentMachineType, "/")
return parts[len(parts)-1], nil
}
}
return "", fmt.Errorf("failed to get machine type from any metadata endpoint after retries")
}
func applyMachineTypeOptimizations(config *optimizationConfig, cfg *Config, isSet isValueSet) ([]string, error) {
currentMachineType, err := getMachineType(isSet)
if err != nil {
return nil, nil // Non-fatal error, continue with default settings.
}
var optimizedFlags []string
// Find the matching machine type.
mtIndex := slices.IndexFunc(config.machineTypes, func(mt machineType) bool {
return slices.ContainsFunc(mt.names, func(name string) bool {
return strings.HasPrefix(currentMachineType, name)
})
})
// If no matching machine type is found, return.
if mtIndex == -1 {
return optimizedFlags, nil
}
mt := &config.machineTypes[mtIndex]
// Find the corresponding flag override set.
flgOverrideSetIndex := slices.IndexFunc(config.flagOverrideSets, func(fos flagOverrideSet) bool {
return fos.name == mt.flagOverrideSetName
})
// If no matching flag override set is found, return.
if flgOverrideSetIndex == -1 {
return optimizedFlags, nil
}
flgOverrideSet := &config.flagOverrideSets[flgOverrideSetIndex]
// Apply all overrides from the set.
for flag, override := range flgOverrideSet.overrides {
err := setFlagValue(cfg, flag, override, isSet)
if err == nil {
optimizedFlags = append(optimizedFlags, flag)
}
}
return optimizedFlags, nil
}
// Optimize applies machine-type specific optimizations.
func Optimize(cfg *Config, isSet isValueSet) ([]string, error) {
// Check if disable-autoconfig is set to true.
if isSet.GetBool("disable-autoconfig") {
return nil, nil
}
optimizedFlags, err := applyMachineTypeOptimizations(&defaultOptimizationConfig, cfg, isSet)
return optimizedFlags, err
}
// convertToCamelCase converts a string from snake-case to CamelCase.
func convertToCamelCase(input string) string {
if input == "" {
return ""
}
// Split the string by hyphen.
parts := strings.Split(input, "-")
// Capitalize each part and join them together.
for i, part := range parts {
if len(part) > 0 {
runes := []rune(part)
runes[0] = unicode.ToUpper(runes[0])
parts[i] = string(runes)
}
}
return strings.Join(parts, "")
}
// setFlagValue uses reflection to set the value of a flag in ServerConfig.
func setFlagValue(cfg *Config, flag string, override flagOverride, isSet isValueSet) error {
// Split the flag name into parts to traverse nested structs.
parts := strings.Split(flag, ".")
if len(parts) == 0 {
return fmt.Errorf("invalid flag name: %s", flag)
}
// Start with the Config.
v := reflect.ValueOf(cfg).Elem()
var field reflect.Value
// Traverse nested structs.
for _, part := range parts {
field = v.FieldByName(convertToCamelCase(part))
if !field.IsValid() {
return fmt.Errorf("invalid flag name: %s", flag)
}
v = field
}
// Check if the field exists.
if !field.IsValid() {
return fmt.Errorf("invalid flag name: %s", flag)
}
// Check if the field is settable.
if !field.CanSet() {
return fmt.Errorf("cannot set flag: %s", flag)
}
// Construct the full flag name for IsSet check.
fullFlagName := strings.ToLower(flag)
// Only override if the user hasn't set it.
if !isSet.IsSet(fullFlagName) {
// Set the value based on the field type.
switch field.Kind() {
case reflect.Bool:
boolValue, ok := override.newValue.(bool)
if !ok {
return fmt.Errorf("invalid boolean value for flag %s: %v", flag, override.newValue)
}
field.SetBool(boolValue)
case reflect.Int, reflect.Int64:
intValue, ok := override.newValue.(int)
if !ok {
return fmt.Errorf("invalid integer value for flag %s: %v", flag, override.newValue)
}
field.SetInt(int64(intValue))
case reflect.String:
stringValue, ok := override.newValue.(string)
if !ok {
return fmt.Errorf("invalid string value for flag %s: %v", flag, override.newValue)
}
field.SetString(stringValue)
default:
return fmt.Errorf("unsupported flag type for flag %s", flag)
}
}
return nil
}
func isFlagPresent(flags []string, flag string) bool {
return slices.Contains(flags, flag)
}