internal/sqlservermetrics/guestoscollector/remote/remote.go (158 lines of code) (raw):

/* Copyright 2024 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package remote ssh'es into remote machines and runs a command package remote import ( "bufio" "bytes" "fmt" "net" "os" "path/filepath" "strconv" "strings" "golang.org/x/crypto/ssh/knownhosts" "golang.org/x/crypto/ssh" "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/log" ) // SSHClientInterface abstracts the client struct from ssh package type SSHClientInterface interface { ssh.Conn NewSession() (*ssh.Session, error) } // SSHSessionInterface abstracts the session struct from ssh package type SSHSessionInterface interface { Output(string) ([]byte, error) Close() error } // Executor interface for executing remote commands type Executor interface { SetupKeys(string) error CreateClient() error CreateSession(string) (SSHSessionInterface, error) Run(string, SSHSessionInterface) (string, error) Close() error } // remote contains the key for remote ssh'ing type remote struct { user string ip string port int32 key *key client SSHClientInterface } type key struct { PrivateKey ssh.Signer PublicKey ssh.PublicKey knownHostsPath string } // NewRemote attempts to find connect to remote ssh server with private key func NewRemote(ipaddr, user string, port int32) Executor { return &remote{ ip: ipaddr, port: port, user: user, key: &key{}, } } // SetupKeys load the key from given path and returns error if it failed to read the key file. func (r *remote) SetupKeys(privateKeyPath string) error { if err := r.privateKey(privateKeyPath); err != nil { return err } knownHostsPath := filepath.Join(filepath.Dir(privateKeyPath), "known_hosts") if err := r.publicKey(r.ip, knownHostsPath); err != nil { return err } return nil } func (r *remote) privateKey(privateKeyPath string) error { privateKeyBytes, err := os.ReadFile(privateKeyPath) if err != nil { return fmt.Errorf("an error occurred while reading the key file. %v", err) } privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) if err != nil { return fmt.Errorf("an error occurred while parsing the private key. %v", err) } r.key.PrivateKey = privateKey return nil } // publicKey scans the known hosts file and gets a public key for the valid host that we are trying to ssh into func (r *remote) publicKey(host, knownHostsPath string) error { // parse OpenSSH known_hosts file // ssh or use ssh-keyscan to get initial key fd, err := os.Open(knownHostsPath) if err != nil { return fmt.Errorf("an error occurred when opening known_hosts. %v", err) } defer fd.Close() // support -H parameter for ssh-keyscan hashhost := knownhosts.HashHostname(host) scanner := bufio.NewScanner(fd) for scanner.Scan() { _, hosts, key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes()) if err != nil { log.Logger.Errorf("failed to parse known_hosts: %s", scanner.Text()) continue } for _, h := range hosts { if h == host || h == hashhost { r.key.PublicKey = key return nil } } } return fmt.Errorf("known host file does not contain host %s; please SSH into host first to verify fingerprint", host) } // CreateClient creates ssh client based on private key and public key from Remote struct. func (r *remote) CreateClient() error { if r.key.PublicKey == nil { return fmt.Errorf("no public key found. please make sure SetupKeys() is called before calling CreateClient()") } if r.key.PrivateKey == nil { return fmt.Errorf("no private key found. please make sure SetupKeys() is called before calling CreateClient()") } c, err := ssh.Dial("tcp", net.JoinHostPort(r.ip, strconv.FormatInt(int64(r.port), 10)), &ssh.ClientConfig{ User: r.user, HostKeyCallback: ssh.FixedHostKey(r.key.PublicKey), Auth: []ssh.AuthMethod{ ssh.PublicKeys(r.key.PrivateKey), }, }) if err != nil { return fmt.Errorf("an error occurred while ssh dialing. %v", err) } r.client = c return nil } // CreateSession creates ssh session. func (r *remote) CreateSession(input string) (SSHSessionInterface, error) { if r.client == nil { return nil, fmt.Errorf("no client created. please make sure CreateClient() is called before calling CreateSession()") } session, err := r.client.NewSession() if err != nil { return nil, err } if input != "" { session.Stdin = bytes.NewBufferString(input) } return session, nil } func (r *remote) Close() error { return r.client.Close() } // Run runs a remote ssh command ex: output, err := remoteRun("root", "MY_IP", "privateKey", "22", "ls -l") func (r *remote) Run(cmd string, session SSHSessionInterface) (string, error) { output, err := session.Output(cmd) if err != nil { return "", fmt.Errorf("An error occurred while running the cmd %v, %v", cmd, err) } return strings.TrimSuffix(string(output), "\n"), nil } // RunCommandWithPipes runs consecutive remote commands that have | func RunCommandWithPipes(cmd string, e Executor) (string, error) { commands := strings.Split(cmd, "|") input := "" for _, command := range commands { f := func() error { s, err := e.CreateSession(input) if err != nil { return fmt.Errorf("Failed to create a session. %v", err) } defer s.Close() if input, err = e.Run(command, s); err != nil { return err } return nil } if err := f(); err != nil { return "", err } } // the last "input" from Run is the final return value return input, nil }