pkg/helpers/ssh/scp.go (90 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
package ssh
import (
"bytes"
"context"
"fmt"
"io"
"os"
"github.com/Azure/aks-engine-azurestack/pkg/api"
"github.com/pkg/errors"
)
// CopyToRemote copies a file to a remote host.
//
// Context ctx is only enforced during the process that establishes
// the SSH connection and creates the SSH client.
func CopyToRemote(ctx context.Context, host *RemoteHost, file *RemoteFile) (combinedOutput string, err error) {
c, err := clientWithRetry(ctx, host)
if err != nil {
return "", errors.Wrap(err, "creating SSH client")
}
defer c.Close()
s, err := c.NewSession()
if err != nil {
return "", errors.Wrap(err, "creating SSH session")
}
defer s.Close()
// Make this configurable if we find that consumers need to update the command
cmd := getUploadCommand(host.OperatingSystem)(file)
s.Stdin = bytes.NewReader(file.Content)
if co, err := s.CombinedOutput(cmd); err != nil {
return string(co), errors.Wrap(err, "uploading to remote host")
}
return "", nil
}
// CopyFromRemote copies a remote file to the local host.
//
// Context ctx is only enforced during the process that establishes
// the SSH connection and creates the SSH client.
func CopyFromRemote(ctx context.Context, host *RemoteHost, remoteFile *RemoteFile, destinationPath string) (stderr string, err error) {
f, err := os.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return "", errors.Wrap(err, "opening destination file")
}
defer f.Close()
c, err := clientWithRetry(ctx, host)
if err != nil {
return "", errors.Wrap(err, "creating SSH client")
}
defer c.Close()
s, err := c.NewSession()
if err != nil {
return "", errors.Wrap(err, "creating SSH session")
}
defer s.Close()
stdout, err := s.StdoutPipe()
if err != nil {
return "", errors.Wrap(err, "opening SSH session stdout pipe")
}
// Make this configurable if we find that consumers need to update the command
cmd := getDownloadCommand(host.OperatingSystem)(remoteFile)
if err = s.Start(cmd); err != nil {
return fmt.Sprintf("%s", s.Stderr), errors.Wrap(err, "downloading logs from remote host")
}
_, err = io.Copy(f, stdout)
if err != nil {
return "", errors.Wrap(err, "downloading logs")
}
return "", nil
}
type uploadCommandBuilder func(file *RemoteFile) string
func getUploadCommand(os api.OSType) uploadCommandBuilder {
switch os {
case api.Linux:
return func(f *RemoteFile) string {
return fmt.Sprintf("sudo bash -c \"mkdir -p $(dirname %s); cat /dev/stdin > %s; chmod %s %s; chown %s %s\"",
f.Path, f.Path, f.Permissions, f.Path, f.Owner, f.Path)
}
case api.Windows:
return func(f *RemoteFile) string {
return fmt.Sprintf("powershell -noprofile -command \"$Input | Out-File -Encoding ASCII %s\"",
f.Path)
}
default:
return nil
}
}
type downloadCommandBuilder func(file *RemoteFile) string
func getDownloadCommand(os api.OSType) downloadCommandBuilder {
switch os {
case api.Linux:
return func(f *RemoteFile) string {
return fmt.Sprintf("bash -c \"cat %s > /dev/stdout\"", f.Path)
}
case api.Windows:
return func(f *RemoteFile) string {
return fmt.Sprintf("type %s", f.Path)
}
default:
return nil
}
}