cmd/core_plugin/snapshot/snapshot_linux.go (204 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build linux
// Package snapshot is responsible for running scripts for guest flush snapshots.
package snapshot
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/GoogleCloudPlatform/galog"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/GoogleCloudPlatform/google-guest-agent/cmd/core_plugin/manager"
sspb "github.com/GoogleCloudPlatform/google-guest-agent/cmd/core_plugin/snapshot/proto/cloud_vmm"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/cfg"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/lru"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/retry"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/run"
"github.com/GoogleCloudPlatform/google-guest-agent/internal/utils/file"
)
const (
// defaultScriptsDir is the directory with snapshot pre/post scripts to be
// executed on request.
defaultScriptsDir = "/etc/google/snapshots/"
// maxIDCacheSize is the maximum size of the operation ID cache.
maxIDCacheSize = 128
// responseMaxAttempts is the maximum number of attempts to send the response
// to the snapshot service - we are considering re trying for 10 seconds.
responseMaxAttempts = 10
// snapshotModuleID is the ID of the snapshot module.
snapshotModuleID = "snapshot"
)
// clientOptions contains the options for the snapshot handler.
type clientOptions struct {
// protocol is the protocol of the snapshot service.
protocol string
// address is the address of the snapshot service.
address string
// timeoutInSeconds is the timeout for the snapshot service.
timeoutInSeconds time.Duration
// scriptDir is the directory with snapshot pre/post scripts to be executed
// on request.
scriptDir string
}
// snapshotClient is the snapshot handler implementation for linux.
type snapshotClient struct {
// seenPreOperationIDS is the cache of operation IDs that have been seen
// for pre snapshot operations.
seenPreOperationIDS *lru.Handle[int32]
// seenPostOperationIDS is the cache of operation IDs that have been seen
// for post snapshot operations.
seenPostOperationIDS *lru.Handle[int32]
// options are the options for the snapshot handler.
options clientOptions
}
// NewModule returns the snapshot module for late stage registration.
func NewModule(context.Context) *manager.Module {
return &manager.Module{
ID: snapshotModuleID,
Enabled: &cfg.Retrieve().Snapshots.Enabled,
Setup: moduleSetup,
Description: "Handles snapshot service requests and triggers pre/post snapshot scripts",
}
}
// moduleSetup runs the actual snapshot handler for linux.
func moduleSetup(ctx context.Context, _ any) error {
config := cfg.Retrieve().Snapshots
opts := clientOptions{
protocol: "tcp",
address: fmt.Sprintf("%s:%d", config.SnapshotServiceIP, config.SnapshotServicePort),
timeoutInSeconds: time.Duration(config.TimeoutInSeconds) * time.Second,
scriptDir: defaultScriptsDir,
}
handler, err := newClient(opts)
if err != nil {
return fmt.Errorf("failed to create snapshot handler: %w", err)
}
// If we don't trigger a new go routine here the snapshot service will block
// the module manager as it Wait()s for all modules to finish their work - and
// this module will keep running forever.
go func() { handler.run(ctx) }()
return nil
}
// newClient creates a new snapshot handler.
func newClient(options clientOptions) (*snapshotClient, error) {
return &snapshotClient{
seenPreOperationIDS: lru.New[int32](maxIDCacheSize),
seenPostOperationIDS: lru.New[int32](maxIDCacheSize),
options: options,
}, nil
}
// fullAddress returns the full address of the snapshot service.
func (op clientOptions) fullAddress() string {
// In case of unit tests force the unix domain socket scheme. Let the grpc
// library decide the other cases.
if op.protocol == "unix" {
return fmt.Sprintf("%s:///%s", op.protocol, op.address)
}
return op.address
}
// run runs the snapshot handler.
func (s *snapshotClient) run(ctx context.Context) error {
if !file.Exists(s.options.scriptDir, file.TypeDir) {
if err := os.MkdirAll(s.options.scriptDir, 0700); err != nil {
return fmt.Errorf("failed to create scripts directory %q: %w", s.options.scriptDir, err)
}
}
if err := s.listen(ctx); err != nil {
return fmt.Errorf("failed to listen for snapshot requests: %w", err)
}
return nil
}
// listen listens for snapshot requests from the snapshot service.
func (s *snapshotClient) listen(ctx context.Context) error {
galog.Infof("Starting to listen for snapshot requests.")
for context.Cause(ctx) == nil {
galog.Debugf("Attempting to connect to snapshot service at %q via %q.", s.options.address, s.options.protocol)
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
conn, err := grpc.NewClient(s.options.fullAddress(), creds)
if err != nil {
return fmt.Errorf("failed to connect to snapshot service: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
galog.Errorf("Failed to close main connection to snapshot service: %v.", err)
}
}()
c := sspb.NewSnapshotServiceClient(conn)
guestReady := sspb.GuestReady{
RequestServerInfo: false,
}
r, err := c.CreateConnection(ctx, &guestReady)
if err != nil {
if !errors.Is(err, context.Canceled) {
galog.Errorf("Error creating connection with snapshot service: %v.", err)
}
continue
}
for {
request, err := r.Recv()
if err != nil {
galog.Errorf("Error reading snapshot request: %v.", err)
break
}
go func() {
if err := s.handleRequest(ctx, request.GetSnapshotRequest()); err != nil {
galog.Errorf("Failed to handle snapshot request: %v.", err)
}
}()
}
}
return nil
}
// handleRequest handles a single snapshot request. It runs the appropriate
// script and sends the response back to the snapshot service.
func (s *snapshotClient) handleRequest(ctx context.Context, request *sspb.SnapshotRequest) error {
type snapshotOperation struct {
cache *lru.Handle[int32]
scriptFileName string
name string
}
operationConfigs := map[sspb.OperationType]*snapshotOperation{
sspb.OperationType_PRE_SNAPSHOT: &snapshotOperation{
cache: s.seenPreOperationIDS,
scriptFileName: "pre.sh",
name: "pre",
},
sspb.OperationType_POST_SNAPSHOT: &snapshotOperation{
cache: s.seenPostOperationIDS,
scriptFileName: "post.sh",
name: "post",
},
}
// Determine if we know how to handle the operation type.
config, found := operationConfigs[request.GetType()]
if !found {
return fmt.Errorf("unhandled operation type %q", request.GetType())
}
// Have we seen this operation ID before?
if _, found := config.cache.Get(request.GetOperationId()); found {
return fmt.Errorf("duplicate %s snapshot request operation id %d", config.name, request.GetOperationId())
}
galog.Infof("Handling snapshot request type: %q, operation id: %d.", config.name, request.GetOperationId())
// Mark the operation ID as seen and avoid repeated execution.
config.cache.Put(request.GetOperationId(), true)
scriptPath := filepath.Join(s.options.scriptDir, config.scriptFileName)
// Trigger the execution of the script.
exitCode, errCode := s.runScript(ctx, scriptPath, request.GetDiskList())
response := &sspb.SnapshotResponse{
OperationId: request.GetOperationId(),
Type: request.GetType(),
ScriptsReturnCode: int32(exitCode),
AgentReturnCode: errCode,
}
// Send the response back to the snapshot service.
if err := s.sendResponse(ctx, response); err != nil {
return fmt.Errorf("failed to send snapshot response: %w", err)
}
galog.Debugf("Successfully handled snapshot request.")
return nil
}
// sendResponse sends the given response to the snapshot service.
func (s *snapshotClient) sendResponse(ctx context.Context, response *sspb.SnapshotResponse) error {
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
conn, err := grpc.NewClient(s.options.fullAddress(), creds)
if err != nil {
return fmt.Errorf("failed to connect to snapshot service to send response: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
galog.Errorf("Failed to close snapshot response connection: %v.", err)
}
}()
c := sspb.NewSnapshotServiceClient(conn)
// retryCb is the the actual response sending function.
retryCb := func() error {
_, err = c.HandleResponsesFromGuest(ctx, response)
return err
}
// Retry sending the response to the snapshot service.
policy := retry.Policy{MaxAttempts: responseMaxAttempts, BackoffFactor: 1, Jitter: time.Second}
if err := retry.Run(ctx, policy, retryCb); err != nil {
return fmt.Errorf("failed to send snapshot response: %w", err)
}
galog.Debugf("Successfully sent snapshot response for operation id %d.", response.GetOperationId())
return nil
}
// runScript runs the script at the given path with the given disks as
// arguments and returns the process' exit code and the snapshot service error
// code.
func (s *snapshotClient) runScript(ctx context.Context, scriptPath string, disks string) (int, sspb.AgentErrorCode) {
galog.Infof("Running guest consistent snapshot script: %s, disks: %s.", scriptPath, disks)
if !file.Exists(scriptPath, file.TypeFile) {
return -1, sspb.AgentErrorCode_SCRIPT_NOT_FOUND
}
cmd := []string{scriptPath, disks}
opts := run.Options{Name: cmd[0], Args: cmd[1:], OutputType: run.OutputNone, Timeout: s.options.timeoutInSeconds}
_, err := run.WithContext(ctx, opts)
// Handle timeout error.
if _, ok := run.AsTimeoutError(err); ok {
return -1, sspb.AgentErrorCode_SCRIPT_TIMED_OUT
}
// Handle "unknown" exit error.
if xerr, ok := run.AsExitError(err); ok {
return xerr.ExitCode(), sspb.AgentErrorCode_UNHANDLED_SCRIPT_ERROR
}
galog.Infof("Snpashot script %q succeeded.", scriptPath)
return 0, sspb.AgentErrorCode_NO_ERROR
}