agentendpoint/patch_task.go (245 lines of code) (raw):
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agentendpoint
import (
"context"
"fmt"
"time"
"github.com/GoogleCloudPlatform/osconfig/agentconfig"
"github.com/GoogleCloudPlatform/osconfig/clog"
"github.com/GoogleCloudPlatform/osconfig/ospatch"
"google.golang.org/protobuf/encoding/protojson"
"cloud.google.com/go/osconfig/agentendpoint/apiv1/agentendpointpb"
)
func systemRebootRequired(ctx context.Context) (bool, error) {
return ospatch.SystemRebootRequired(ctx)
}
type patchStep string
const (
prePatch = "PrePatch"
patching = "Patching"
postPatch = "PostPatch"
totalRebootCountLimit = 5
)
type patchTask struct {
client *Client
lastProgressState map[agentendpointpb.ApplyPatchesTaskProgress_State]time.Time
state *taskState
TaskID string
Task *applyPatchesTask
StartedAt time.Time `json:",omitempty"`
PatchStep patchStep `json:",omitempty"`
PrePatchRebootCount int
PostPatchRebootCount int
// TODO: add Attempts and track number of retries with backoff, jitter, etc.
}
func (r *patchTask) saveState() error {
r.state.PatchTask = r
return r.state.save(taskStateFile)
}
func (r *patchTask) complete(ctx context.Context) {
if err := (&taskState{}).save(taskStateFile); err != nil {
clog.Errorf(ctx, "Error saving state: %v", err)
}
}
type applyPatchesTask struct {
*agentendpointpb.ApplyPatchesTask
}
// MarshalJSON marshals a patchConfig using protojson.
func (a *applyPatchesTask) MarshalJSON() ([]byte, error) {
m := &protojson.MarshalOptions{AllowPartial: true, EmitUnpopulated: false}
return m.Marshal(a)
}
// UnmarshalJSON unmarshals a patchConfig using protojson.
func (a *applyPatchesTask) UnmarshalJSON(b []byte) error {
a.ApplyPatchesTask = &agentendpointpb.ApplyPatchesTask{}
un := &protojson.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true}
return un.Unmarshal(b, a.ApplyPatchesTask)
}
func (r *patchTask) setStep(step patchStep) error {
r.PatchStep = step
if err := r.saveState(); err != nil {
return fmt.Errorf("error saving state: %v", err)
}
return nil
}
func (r *patchTask) handleErrorState(ctx context.Context, msg string, err error) error {
if err == errServerCancel {
return r.reportCanceled(ctx)
}
return r.reportFailed(ctx, msg)
}
func (r *patchTask) reportFailed(ctx context.Context, msg string) error {
clog.Errorf(ctx, "%v", msg)
return r.reportCompletedState(ctx, msg, &agentendpointpb.ReportTaskCompleteRequest_ApplyPatchesTaskOutput{
ApplyPatchesTaskOutput: &agentendpointpb.ApplyPatchesTaskOutput{State: agentendpointpb.ApplyPatchesTaskOutput_FAILED},
})
}
func (r *patchTask) reportCanceled(ctx context.Context) error {
clog.Infof(ctx, "Canceling patch execution")
return r.reportCompletedState(ctx, errServerCancel.Error(), &agentendpointpb.ReportTaskCompleteRequest_ApplyPatchesTaskOutput{
// Is this right? Maybe there should be a canceled state instead.
ApplyPatchesTaskOutput: &agentendpointpb.ApplyPatchesTaskOutput{State: agentendpointpb.ApplyPatchesTaskOutput_FAILED},
})
}
func (r *patchTask) reportCompletedState(ctx context.Context, errMsg string, output *agentendpointpb.ReportTaskCompleteRequest_ApplyPatchesTaskOutput) error {
req := &agentendpointpb.ReportTaskCompleteRequest{
TaskId: r.TaskID,
TaskType: agentendpointpb.TaskType_APPLY_PATCHES,
ErrorMessage: errMsg,
Output: output,
}
if err := r.client.reportTaskComplete(ctx, req); err != nil {
return fmt.Errorf("error reporting completed state: %v", err)
}
return nil
}
func (r *patchTask) reportContinuingState(ctx context.Context, patchState agentendpointpb.ApplyPatchesTaskProgress_State) error {
st, ok := r.lastProgressState[patchState]
if ok && st.After(time.Now().Add(sameStateTimeWindow)) {
// Don't resend the same state more than once every 5s.
return nil
}
req := &agentendpointpb.ReportTaskProgressRequest{
TaskId: r.TaskID,
TaskType: agentendpointpb.TaskType_APPLY_PATCHES,
Progress: &agentendpointpb.ReportTaskProgressRequest_ApplyPatchesTaskProgress{
ApplyPatchesTaskProgress: &agentendpointpb.ApplyPatchesTaskProgress{State: patchState},
},
}
res, err := r.client.reportTaskProgress(ctx, req)
if err != nil {
return fmt.Errorf("error reporting state %s: %v", patchState, err)
}
if res.GetTaskDirective() == agentendpointpb.TaskDirective_STOP {
return errServerCancel
}
if r.lastProgressState == nil {
r.lastProgressState = make(map[agentendpointpb.ApplyPatchesTaskProgress_State]time.Time)
}
r.lastProgressState[patchState] = time.Now()
return r.saveState()
}
func (r *patchTask) prePatchReboot(ctx context.Context) error {
return r.rebootIfNeeded(ctx, true)
}
func (r *patchTask) postPatchReboot(ctx context.Context) error {
return r.rebootIfNeeded(ctx, false)
}
func (r *patchTask) rebootIfNeeded(ctx context.Context, prePatch bool) error {
var reboot bool
var err error
if r.Task.GetPatchConfig().GetRebootConfig() == agentendpointpb.PatchConfig_ALWAYS && !prePatch && r.PostPatchRebootCount == 0 {
reboot = true
clog.Infof(ctx, "PatchConfig RebootConfig set to %s.", agentendpointpb.PatchConfig_ALWAYS)
} else {
reboot, err = systemRebootRequired(ctx)
if err != nil {
return fmt.Errorf("error checking if a system reboot is required: %v", err)
}
if reboot {
clog.Infof(ctx, "System indicates a reboot is required.")
totalRebootCount := r.PrePatchRebootCount + r.PostPatchRebootCount
if totalRebootCount >= totalRebootCountLimit {
clog.Infof(ctx, "Detected abnormal number of reboots for a single patch task (%d). Not rebooting to prevent a possible boot loop", totalRebootCount)
return nil
}
} else {
clog.Infof(ctx, "System indicates a reboot is not required.")
}
}
if !reboot {
return nil
}
if r.Task.GetPatchConfig().GetRebootConfig() == agentendpointpb.PatchConfig_NEVER {
clog.Infof(ctx, "Skipping reboot because of PatchConfig RebootConfig set to %s.", agentendpointpb.PatchConfig_NEVER)
return nil
}
if err := r.reportContinuingState(ctx, agentendpointpb.ApplyPatchesTaskProgress_REBOOTING); err != nil {
return err
}
if r.Task.GetDryRun() {
clog.Infof(ctx, "Dry run - not rebooting for ApplyPatchesTask")
return nil
}
if prePatch {
r.PrePatchRebootCount++
} else {
r.PostPatchRebootCount++
}
if err := r.saveState(); err != nil {
return fmt.Errorf("error saving state: %v", err)
}
if err := rebootSystem(); err != nil {
return fmt.Errorf("failed to reboot system: %v", err)
}
// Reboot can take a bit, pause here so other activities don't start.
for {
clog.Debugf(ctx, "Waiting for system reboot.")
time.Sleep(1 * time.Minute)
}
}
func (r *patchTask) run(ctx context.Context) (err error) {
ctx = clog.WithLabels(ctx, r.state.Labels)
clog.Infof(ctx, "Beginning ApplyPatchesTask")
defer func() {
// This should not happen but the WUA libraries are complicated and
// recovering with an error is better than crashing.
if rec := recover(); rec != nil {
err = fmt.Errorf("Recovered from panic: %v", rec)
r.reportFailed(ctx, err.Error())
return
}
r.complete(ctx)
if agentconfig.OSInventoryEnabled() {
go r.client.ReportInventory(ctx)
}
}()
for {
clog.Debugf(ctx, "Running PatchStep %q.", r.PatchStep)
switch r.PatchStep {
default:
return r.reportFailed(ctx, fmt.Sprintf("unknown step: %q", r.PatchStep))
case prePatch:
r.StartedAt = time.Now()
if err := r.setStep(patching); err != nil {
return r.reportFailed(ctx, fmt.Sprintf("Error saving agent step: %v", err))
}
if err := r.reportContinuingState(ctx, agentendpointpb.ApplyPatchesTaskProgress_STARTED); err != nil {
return r.handleErrorState(ctx, err.Error(), err)
}
if err := r.prePatchReboot(ctx); err != nil {
return r.handleErrorState(ctx, fmt.Sprintf("Error running prePatchReboot: %v", err), err)
}
case patching:
if err := r.reportContinuingState(ctx, agentendpointpb.ApplyPatchesTaskProgress_APPLYING_PATCHES); err != nil {
return r.handleErrorState(ctx, err.Error(), err)
}
if err := r.runUpdates(ctx); err != nil {
return r.handleErrorState(ctx, fmt.Sprintf("Failed to apply patches: %v", err), err)
}
if err := r.postPatchReboot(ctx); err != nil {
return r.handleErrorState(ctx, fmt.Sprintf("Error running postPatchReboot: %v", err), err)
}
// We have not rebooted so patching is complete.
if err := r.setStep(postPatch); err != nil {
return r.reportFailed(ctx, fmt.Sprintf("Error saving agent step: %v", err))
}
case postPatch:
isRebootRequired, err := systemRebootRequired(ctx)
if err != nil {
return r.reportFailed(ctx, fmt.Sprintf("Error checking if system reboot is required: %v", err))
}
finalState := agentendpointpb.ApplyPatchesTaskOutput_SUCCEEDED
if isRebootRequired {
finalState = agentendpointpb.ApplyPatchesTaskOutput_SUCCEEDED_REBOOT_REQUIRED
}
if err := r.reportCompletedState(ctx, "", &agentendpointpb.ReportTaskCompleteRequest_ApplyPatchesTaskOutput{
ApplyPatchesTaskOutput: &agentendpointpb.ApplyPatchesTaskOutput{State: finalState},
}); err != nil {
return fmt.Errorf("failed to report state %s: %v", finalState, err)
}
clog.Infof(ctx, "Successfully completed ApplyPatchesTask")
return nil
}
}
}
// RunApplyPatches runs an ApplyPatchesTask.
func (c *Client) RunApplyPatches(ctx context.Context, task *agentendpointpb.Task) error {
r := &patchTask{
state: &taskState{Labels: task.GetServiceLabels()},
TaskID: task.GetTaskId(),
client: c,
Task: &applyPatchesTask{task.GetApplyPatchesTask()},
}
r.setStep(prePatch)
return r.run(ctx)
}