vmextension/vmextension.go (277 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package vmextension
import (
"fmt"
"os"
"strings"
"github.com/Azure/azure-extension-platform/pkg/environmentmanager"
"github.com/Azure/azure-extension-platform/pkg/exithelper"
"github.com/Azure/azure-extension-platform/pkg/extensionerrors"
"github.com/Azure/azure-extension-platform/pkg/extensionevents"
"github.com/Azure/azure-extension-platform/pkg/handlerenv"
"github.com/Azure/azure-extension-platform/pkg/logging"
"github.com/Azure/azure-extension-platform/pkg/seqno"
"github.com/Azure/azure-extension-platform/pkg/settings"
"github.com/Azure/azure-extension-platform/pkg/status"
"github.com/pkg/errors"
)
// HandlerEnvFileName is the file name of the Handler Environment as placed by the
// Azure Guest Agent.
const handlerEnvFileName = "HandlerEnvironment.json"
type OperationName string
const (
InstallOperation OperationName = "install"
UninstallOperation OperationName = "uninstall"
EnableOperation OperationName = "enable"
UpdateOperation OperationName = "update"
DisableOperation OperationName = "disable"
ResetStateOperation OperationName = "resetstate"
invalid OperationName = "invalid"
)
func (operationName OperationName) ToString() string {
return string(operationName)
}
func (operationName OperationName) ToStatusName() string {
return strings.Title(string(operationName))
}
type cmdFunc func(ext *VMExtension) (msg string, err error)
func OperationNameFromString(operation string) (OperationName, error) {
switch operation {
case InstallOperation.ToString():
return InstallOperation, nil
case UninstallOperation.ToString():
return UninstallOperation, nil
case EnableOperation.ToString():
return EnableOperation, nil
case UpdateOperation.ToString():
return UpdateOperation, nil
case DisableOperation.ToString():
return DisableOperation, nil
case ResetStateOperation.ToString():
return ResetStateOperation, nil
default:
return invalid, extensionerrors.ErrInvalidOperationName
}
}
// cmd is an internal structure that specifies how an operation should run
type cmd struct {
f cmdFunc // associated function
operation OperationName // human readable string
shouldReportStatus bool // determines if running this should log to a .status file
failExitCode int // exitCode to use when commands fail
}
// executionInfo contains internal information necessary for the extension to execute
type executionInfo struct {
cmds map[OperationName]cmd // Execution commands keyed by operation
requiresSeqNoChange bool // True if Enable will only execute if the sequence number changes
supportsDisable bool // Whether to run extension agnostic disable code
supportsResetState bool // Whether to run the extension agnostic ResetState code
enableCallback EnableCallbackFunc // A method provided by the extension for Enable
updateCallback CallbackFunc // A method provided by the extension for Update
disableCallback CallbackFunc // A method provided by the extension for Disable
resetStateCallBack CallbackFunc // A method provided by the extension for ResetState
installCallback CallbackFunc // A method provided by the extension for Update
uninstallCallback CallbackFunc // A method provided by the extension for Uninstall
manager environmentmanager.IGetVMExtensionEnvironmentManager // Used by tests to mock the environment
}
// VMExtension is an abstraction for standard extension operations in an OS agnostic manner
type VMExtension struct {
Name string // The name of the extension. This will contain 'Windows' or 'Linux'
Version string // The version of the extension
GetRequestedSequenceNumber func() (uint, error) // Function to get the requested sequence number to run
CurrentSequenceNumber *uint // The last run sequence number, null means no existing sequence number was found
HandlerEnv *handlerenv.HandlerEnvironment // Contains information about the folders necessary for the extension
GetSettings func() (*settings.HandlerSettings, error) // Function to get settings passed to the extension
ExtensionEvents *extensionevents.ExtensionEventManager // Allows extensions to raise events
ExtensionLogger *logging.ExtensionLogger // Automatically logs to the log directory
exec *executionInfo // Internal information necessary for the extension to run
statusFormatter status.StatusMessageFormatter // Custom status message formatter from initialization info
}
type prodGetVMExtensionEnvironmentManager struct {
}
func (*prodGetVMExtensionEnvironmentManager) GetHandlerEnvironment(name string, version string) (*handlerenv.HandlerEnvironment, error) {
return handlerenv.GetHandlerEnvironment(name, version)
}
func (*prodGetVMExtensionEnvironmentManager) FindSeqNum(el logging.ILogger, configFolder string) (uint, error) {
return seqno.FindSeqNum(el, configFolder)
}
func (*prodGetVMExtensionEnvironmentManager) GetCurrentSequenceNumber(el logging.ILogger, retriever seqno.ISequenceNumberRetriever, name, version string) (uint, error) {
return seqno.GetCurrentSequenceNumber(el, retriever, name, version)
}
func (em *prodGetVMExtensionEnvironmentManager) GetHandlerSettings(el logging.ILogger, he *handlerenv.HandlerEnvironment) (*settings.HandlerSettings, error) {
seqNo, err := em.FindSeqNum(el, he.ConfigFolder)
if err != nil {
return nil, err
}
return settings.GetHandlerSettings(el, he, seqNo)
}
func (*prodGetVMExtensionEnvironmentManager) SetSequenceNumberInternal(extensionName, extensionVersion string, seqNo uint) error {
return seqno.SetSequenceNumber(extensionName, extensionVersion, seqNo)
}
// GetVMExtension returns a new VMExtension object
func GetVMExtension(initInfo *InitializationInfo) (ext *VMExtension, _ error) {
return getVMExtensionInternal(initInfo, &prodGetVMExtensionEnvironmentManager{})
}
// GetVMExtensionForTesting mocks out the environment part of the VM extension for use with your extension
func GetVMExtensionForTesting(initInfo *InitializationInfo, manager environmentmanager.IGetVMExtensionEnvironmentManager) (ext *VMExtension, _ error) {
return getVMExtensionInternal(initInfo, manager)
}
// Internal method that allows mocking for unit tests
func getVMExtensionInternal(initInfo *InitializationInfo, manager environmentmanager.IGetVMExtensionEnvironmentManager) (ext *VMExtension, _ error) {
if initInfo == nil {
return nil, extensionerrors.ErrArgCannotBeNull
}
if len(initInfo.Name) < 1 || len(initInfo.Version) < 1 {
return nil, extensionerrors.ErrArgCannotBeNullOrEmpty
}
if initInfo.EnableCallback == nil {
return nil, extensionerrors.ErrArgCannotBeNull
}
handlerEnv, err := manager.GetHandlerEnvironment(initInfo.Name, initInfo.Version)
if err != nil {
return nil, err
}
extensionLogger := logging.NewWithName(handlerEnv, initInfo.LogFileNamePattern)
// Create our event manager. This will be disabled if no eventsFolder exists
extensionEvents := extensionevents.New(extensionLogger, handlerEnv)
// Determine the sequence number requested
newSeqNo := func() (uint, error) { return manager.FindSeqNum(extensionLogger, handlerEnv.ConfigFolder) }
// Determine the current sequence number
retriever := seqno.ProdSequenceNumberRetriever{}
var currentSeqNo = new(uint)
retrievedSequenceNumber, err := manager.GetCurrentSequenceNumber(extensionLogger, &retriever, initInfo.Name, initInfo.Version)
if err != nil {
if err == extensionerrors.ErrNoSettingsFiles || err == extensionerrors.ErrNoMrseqFile {
// current sequence number could not be found, this is a special error
currentSeqNo = nil
} else {
return nil, fmt.Errorf("failed to read the current sequence number due to '%v'", err)
}
} else {
*currentSeqNo = retrievedSequenceNumber
}
cmdInstall := cmd{install, InstallOperation, false, initInfo.InstallExitCode}
cmdEnable := cmd{enable, EnableOperation, true, initInfo.OtherExitCode}
cmdUninstall := cmd{uninstall, UninstallOperation, false, initInfo.OtherExitCode}
// Only support Update and Disable if we need to
var cmdDisable cmd
var cmdUpdate cmd
var cmdResetState cmd
if initInfo.UpdateCallback != nil {
cmdUpdate = cmd{update, UpdateOperation, false, 3}
} else {
cmdUpdate = cmd{noop, UpdateOperation, false, 3}
}
if initInfo.SupportsDisable || initInfo.DisableCallback != nil {
cmdDisable = cmd{disable, DisableOperation, true, 3}
} else {
cmdDisable = cmd{noop, DisableOperation, true, 3}
}
if initInfo.SupportsResetState || initInfo.ResetStateCallback != nil {
cmdResetState = cmd{resetState, ResetStateOperation, false, 3}
} else {
cmdResetState = cmd{noop, ResetStateOperation, false, 3}
}
settings := func() (*settings.HandlerSettings, error) {
return manager.GetHandlerSettings(extensionLogger, handlerEnv)
}
var statusFormatter status.StatusMessageFormatter
if initInfo.CustomStatusFormatter != nil {
statusFormatter = initInfo.CustomStatusFormatter
} else {
statusFormatter = status.StatusMsg
}
ext = &VMExtension{
Name: initInfo.Name,
Version: initInfo.Version,
GetRequestedSequenceNumber: newSeqNo,
CurrentSequenceNumber: currentSeqNo,
HandlerEnv: handlerEnv,
GetSettings: settings,
ExtensionEvents: extensionEvents,
ExtensionLogger: extensionLogger,
statusFormatter: statusFormatter,
exec: &executionInfo{
manager: manager,
requiresSeqNoChange: initInfo.RequiresSeqNoChange,
supportsDisable: initInfo.SupportsDisable,
supportsResetState: initInfo.SupportsResetState,
enableCallback: initInfo.EnableCallback,
disableCallback: initInfo.DisableCallback,
updateCallback: initInfo.UpdateCallback,
resetStateCallBack: initInfo.ResetStateCallback,
installCallback: initInfo.InstallCallback,
uninstallCallback: initInfo.UninstallCallback,
cmds: map[OperationName]cmd{
InstallOperation: cmdInstall,
UninstallOperation: cmdUninstall,
EnableOperation: cmdEnable,
UpdateOperation: cmdUpdate,
DisableOperation: cmdDisable,
ResetStateOperation: cmdResetState,
},
},
}
return ext, nil
}
// Do is the main worker method of the extension and determines which operation
// to run, if necessary
func (ve *VMExtension) Do() {
// parse command line arguments
eh := exithelper.Exiter
cmd := ve.parseCmd(os.Args, eh)
_, err := cmd.f(ve)
if err != nil {
ve.ExtensionLogger.Error("failed to handle: %v", err)
eh.Exit(cmd.failExitCode)
}
}
// reportStatus saves operation status to the status file for the extension
// handler with the optional given message, if the given cmd requires reporting
// status.
//
// If an error occurs reporting the status, it will be logged and returned.
func reportStatus(ve *VMExtension, t status.StatusType, c cmd, msg string) error {
if !c.shouldReportStatus {
ve.ExtensionLogger.Info("status not reported for operation (by design)")
return nil
}
requestedSequenceNumber, err := ve.GetRequestedSequenceNumber()
if err != nil {
return err
}
s := status.New(t, c.operation.ToStatusName(), ve.statusFormatter(c.operation.ToStatusName(), t, msg))
if err := s.Save(ve.HandlerEnv.StatusFolder, requestedSequenceNumber); err != nil {
ve.ExtensionLogger.Error("Failed to save handler status: %v", err)
return errors.Wrap(err, "failed to save handler status")
}
return nil
}
func reportErrorWithClarification(ve *VMExtension, c cmd, errorCode int, msg string) error {
if !c.shouldReportStatus {
ve.ExtensionLogger.Info("status not reported for operation (by design)")
return nil
}
requestedSequenceNumber, err := ve.GetRequestedSequenceNumber()
if err != nil {
return err
}
s := status.NewError(c.operation.ToStatusName(), status.ErrorClarification{Code: errorCode, Message: msg})
if err := s.Save(ve.HandlerEnv.StatusFolder, requestedSequenceNumber); err != nil {
ve.ExtensionLogger.Error("Failed to save handler status: %v", err)
return errors.Wrap(err, "failed to save handler status")
}
return nil
}
// parseCmd looks at os.Args and parses the subcommand. If it is invalid,
// it prints the usage string and an error message and exits with code 0.
func (ve *VMExtension) parseCmd(args []string, eh exithelper.IExitHelper) cmd {
if len(args) != 2 {
ve.printUsage(args)
fmt.Println("Incorrect usage.")
eh.Exit(2)
return cmd{}
}
op := args[1]
operation, _ := OperationNameFromString(op)
cmd, ok := ve.exec.cmds[operation]
if !ok {
ve.printUsage(args)
fmt.Printf("Incorrect command: %q\n", op)
eh.Exit(2)
}
return cmd
}
// printUsage prints the help string and version of the program to stdout with a
// trailing new line.
func (ve *VMExtension) printUsage(args []string) {
fmt.Printf("Usage: %s ", os.Args[0])
i := 0
for k := range ve.exec.cmds {
fmt.Print(k.ToString())
if i != len(ve.exec.cmds)-1 {
fmt.Printf("|")
}
i++
}
fmt.Println()
fmt.Println(ve.Version)
}
func noop(ext *VMExtension) (string, error) {
ext.ExtensionLogger.Info("noop")
return "", nil
}