daisy_workflows/linux_common/utils/common.py (421 lines of code) (raw):
#!/usr/bin/env python3
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for all VM scripts."""
import functools
import json
import logging
import os
import re
import stat
import subprocess
import sys
import time
import trace
import traceback
import typing
import urllib.error
import urllib.request
import uuid
from .guestfsprocess import run
SUCCESS_LEVELNO = logging.ERROR - 5
def RetryOnFailure(stop_after_seconds=15 * 60, initial_delay_seconds=3):
"""Function decorator to retry on an exception.
Performs linear backoff until stop_after_seconds is reached.
Args:
stop_after_seconds: Maximum amount of time (in seconds) to spend retrying.
initial_delay_seconds: The delay before the first retry, in seconds."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
ratio = 1.5
wait = initial_delay_seconds
ntries = 0
start_time = time.time()
# Stop after five minutes.
end_time = start_time + stop_after_seconds
exception = None
while time.time() < end_time:
# Don't sleep on first attempt.
if ntries > 0:
time.sleep(wait)
wait *= ratio
ntries += 1
try:
response = func(*args, **kwargs)
except Exception as e:
exception = e
logging.info(str(e))
logging.info(
'Function %s failed, waiting %d seconds, retrying %d ...',
str(func), wait, ntries)
else:
logging.info(
'Function %s executed in less then %d sec, with %d tentative(s)',
str(func), time.time() - start_time, ntries)
return response
raise exception
return wrapper
return decorator
@RetryOnFailure()
def YumInstall(package_list):
if YumInstall.first_run:
Execute(['yum', 'update'])
YumInstall.first_run = False
Execute(['yum', '-y', 'install'] + package_list)
YumInstall.first_run = True
@RetryOnFailure()
def AptGetInstall(package_list, suite=None):
# When `apt update` fails to update a repo, it returns 0.
# This check ensures that we retry running update until we've
# had one successful install.
# https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=778357
if not AptGetInstall.prior_success:
Execute(['apt', '-y', 'update'])
env = os.environ.copy()
env['DEBIAN_FRONTEND'] = 'noninteractive'
cmd = ['apt-get', '-q', '-y', 'install']
if suite:
cmd += ['-t', suite]
result = Execute(cmd + package_list, env=env)
AptGetInstall.prior_success = True
return result
AptGetInstall.prior_success = False
def PipInstall(package_list):
"""Install Python modules via pip. Assumes pip is already installed."""
return Execute(['pip', 'install', '-U'] + package_list)
def Execute(cmd, cwd=None, capture_output=False, env=None, raise_errors=True):
"""Execute an external command (wrapper for Python subprocess)."""
logging.info('Executing command: %s' % str(cmd))
stdout = subprocess.PIPE if capture_output else None
p = subprocess.Popen(cmd, cwd=cwd, env=env, stdout=stdout)
output = p.communicate()[0]
returncode = p.returncode
if returncode != 0:
# Error
if raise_errors:
raise subprocess.CalledProcessError(returncode, cmd)
else:
logging.info('Command returned error status %d' % returncode)
if output is not None:
output = output.decode()
logging.info(output)
return returncode, output
def ClearEtcResolv(g):
"""Clear /etc/resolv.conf to allow DNS settings to come from
GCP's DHCP server.
Args:
g (guestfs.GuestFS): A mounted GuestFS instance.
"""
_ClearImmutableAttr(g, '/etc/resolv.conf')
g.sh('echo "" > /etc/resolv.conf')
def _ClearImmutableAttr(g, fname):
"""Clears the immutable attr on the file associated with fname.
Args:
g (guestfs.GuestFS): A mounted GuestFS instance.
fname (str): File to have its immutable attr cleared.
"""
if g.exists(fname):
try:
g.set_e2attrs(fname, 'i', clear=True)
except BaseException:
# set_e2attrs will throw an error if the filesystem
# doesn't support chattr, in which case the file
# won't have the attr at all.
pass
def HttpGet(url, headers=None):
request = urllib.request.Request(url)
if headers:
for key in headers.keys():
request.add_unredirected_header(key, headers[key])
return urllib.request.urlopen(request).read()
def _GetMetadataParam(name, default_value=None, raise_on_not_found=None):
try:
url = 'http://metadata.google.internal/computeMetadata/v1/instance/%s' % \
name
return HttpGet(url, headers={'Metadata-Flavor': 'Google'}).decode()
except (urllib.error.HTTPError, urllib.error.URLError):
if raise_on_not_found:
raise ValueError('Metadata key "%s" not found' % name)
else:
return default_value
def GetMetadataAttribute(name, default_value=None, raise_on_not_found=False):
return _GetMetadataParam('attributes/%s' % name, default_value,
raise_on_not_found)
def GetCurrentLoginProfileUsername(user_lib, unique_id_user):
"""
Equivalent of calling the gcloud equivalent:
gcloud compute os-login describe-profile --format \
value\\(posixAccounts.username\\)
Parameter:
Args:
user_lib: object, from GetOslogin().users()
Returns:
string, username like 'sa_101330816214789148073'
"""
login_info = user_lib.getLoginProfile(name=unique_id_user).execute()
return login_info[u'posixAccounts'][0][u'username']
def GetServiceAccountUniqueIDUser():
"""
Retrieves unique ID for the user in format `users/{user}`.
Used for retrieving LoginProfile and oslogin ssh key's operations
Returns:
string, unique id for the user.
"""
s = _GetMetadataParam('service-accounts/default/?recursive=True')
service_info = json.loads(s)
return 'users/' + service_info['email']
def CommonRoutines(g):
# Remove udev file to force it to be re-generated
logging.info('Removing udev 70-persistent-net.rules.')
_ClearImmutableAttr(g, '/etc/udev/rules.d/70-persistent-net.rules')
g.rm_rf('/etc/udev/rules.d/70-persistent-net.rules')
# Remove SSH host keys.
logging.info('Removing SSH host keys.')
g.sh("rm -f /etc/ssh/ssh_host_*")
def RunTranslate(translate_func: typing.Callable,
run_with_tracing: bool = True):
"""Run `translate_func`, and communicate success or failure back to Daisy.
Args:
translate_func: Closure to execute
run_with_tracing: When enabled, the closure will be executed with
trace.Trace, resulting in executed lines being printed to stdout.
"""
exit_code = 0
try:
if run_with_tracing:
tracer = trace.Trace(
ignoredirs=[sys.prefix, sys.exec_prefix], trace=1, count=0)
tracer.runfunc(translate_func)
else:
translate_func()
logging.success('Translation finished.')
except Exception as e:
exit_code = 1
logging.debug(traceback.format_exc())
logging.error('error: %s', str(e))
logging.shutdown()
sys.exit(exit_code)
def MakeExecutable(file_path):
os.chmod(file_path, os.stat(file_path).st_mode | stat.S_IEXEC)
def ReadFile(file_path, strip=False):
content = open(file_path).read()
if strip:
return content.strip()
return content
def WriteFile(file_path, content, mode='w'):
with open(file_path, mode) as fp:
fp.write(content)
def GenSshKey(user):
"""Generate ssh key for user.
Args:
user: string, the user to create the ssh key for.
Returns:
ret, out if capture_output=True.
"""
key_name = 'daisy-test-key-' + str(uuid.uuid4())
Execute(
['ssh-keygen', '-t', 'rsa', '-N', '', '-f', key_name, '-C', key_name])
with open(key_name + '.pub', 'r') as original:
data = original.read().strip()
return "%s:%s" % (user, data), key_name
def ExecuteInSsh(
key, user, machine, cmds, expect_fail=False, capture_output=False):
"""Execute commands through ssh.
Args:
key: string, the path of the private key to use in the ssh connection.
user: string, the user used to connect through ssh.
machine: string, the hostname of the machine to connect.
cmds: list[string], the commands to be execute in the ssh session.
expect_fail: bool, indicates if the failure in the execution is expected.
capture_output: bool, indicates if the output of the command should be
captured.
Returns:
ret, out if capture_output=True.
"""
ssh_command = [
'ssh', '-i', key, '-o', 'IdentitiesOnly=yes', '-o', 'ConnectTimeout=10',
'-o', 'StrictHostKeyChecking=no', '-o', 'UserKnownHostsFile=/dev/null',
'%s@%s' % (user, machine),
]
ret, out = Execute(
ssh_command + cmds, raise_errors=False, capture_output=capture_output)
if expect_fail and ret == 0:
raise ValueError('SSH command succeeded when expected to fail')
elif not expect_fail and ret != 0:
raise ValueError('SSH command failed when expected to succeed')
else:
return ret, out
def GetCompute(discovery, credentials):
"""Get google compute api cli object.
Args:
discovery: object, from googleapiclient.
credentials: object, from google.auth.
Returns:
compute: object, the google compute api object.
"""
compute = discovery.build('compute', 'v1', credentials=credentials)
return compute
def GetOslogin(discovery, credentials):
"""Get google os-login api cli object.
Args:
discovery: object, from googleapiclient.
credentials: object, from google.auth.
Returns:
oslogin: object, the google oslogin api object.
"""
oslogin = discovery.build('oslogin', 'v1', credentials=credentials)
return oslogin
def RunTest(test_func):
"""Run main test function and print logging.success() or logging.error().
Args:
test_func: function, the function to be tested.
"""
try:
tracer = trace.Trace(
ignoredirs=[sys.prefix, sys.exec_prefix], trace=1, count=0)
tracer.runfunc(test_func)
logging.success('Test finished.')
except Exception as e:
logging.error('error: ' + str(e))
traceback.print_exc()
def DownloadFile(gcs_source_file, dest_file):
"""Downloads a file from GCS.
Expects a source file in GCS and a local destination path.
Args:
gcs_source_file: string, the path of a source file to download.
ex: gs://path/to/orig_file.tar.gz
dest_file: string, the path to the resulting file.
ex: /path/to/new/file.tar.gz
"""
# import 'google.cloud.storage' locally as 'google-cloud-storage' pip package
# is not a mandatory package for all utils users
from google.cloud import storage
bucket = r'(?P<bucket>[a-z0-9][-_.a-z0-9]*[a-z0-9])'
obj = r'(?P<obj>[^\*\?]+)'
prefix = r'gs://'
gs_regex = re.compile(r'{prefix}{bucket}/{obj}'.format(prefix=prefix,
bucket=bucket,
obj=obj))
match = gs_regex.match(gcs_source_file)
client = storage.Client()
bucket = client.get_bucket(match.group('bucket'))
blob = bucket.blob(match.group('obj'))
blob.download_to_filename(dest_file)
def UploadFile(source_file, gcs_dest_file):
"""Uploads a file to GCS.
Expects a local source file and a destination bucket and GCS path.
Args:
source_file: string, the path of a source file to upload.
ex: /path/to/local/orig_file.tar.gz
gcs_dest_file: string, the path to the resulting file in GCS
ex: gs://new/path/orig_file.tar.gz
"""
# import 'google.cloud.storage' locally as 'google-cloud-storage' pip package
# is not a mandatory package for all utils users
from google.cloud import storage
bucket = r'(?P<bucket>[a-z0-9][-_.a-z0-9]*[a-z0-9])'
obj = r'(?P<obj>[^\*\?]+)'
prefix = r'gs://'
gs_regex = re.compile(r'{prefix}{bucket}/{obj}'.format(prefix=prefix,
bucket=bucket,
obj=obj))
match = gs_regex.match(gcs_dest_file)
client = storage.Client()
bucket = client.get_bucket(match.group('bucket'))
blob = bucket.blob(match.group('obj'))
blob.upload_from_filename(source_file)
class LogFormatter(logging.Formatter):
default_formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
formatters = {}
def __init__(self):
prefix = GetMetadataAttribute('prefix', default_value='')
prefix_level = {
logging.DEBUG: '%sDebug: ' % prefix,
logging.INFO: '%sStatus: ' % prefix,
logging.WARNING: '%sWarn: ' % prefix,
logging.ERROR: '%sFailed: ' % prefix,
SUCCESS_LEVELNO: '%sSuccess: ' % prefix
}
for loglevel in prefix_level:
self.formatters[loglevel] = logging.Formatter(
prefix_level[loglevel] + '%(message)s')
def format(self, record):
formatter = self.formatters.get(record.levelno, self.default_formatter)
return formatter.format(record)
def SetupLogging():
"""Configure Logging system."""
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stdout = logging.StreamHandler(sys.stdout)
stdout.setLevel(logging.DEBUG)
formatter = LogFormatter()
stdout.setFormatter(formatter)
logger.addHandler(stdout)
logging.addLevelName(SUCCESS_LEVELNO, 'SUCCESS')
def success(self, message, *args, **kws):
self._log(SUCCESS_LEVELNO, message, args, **kws)
logger.success = success
logging.success = lambda *args: logging.log(SUCCESS_LEVELNO, *args)
SetupLogging()
class MetadataManager:
"""Utilities to manage metadata."""
SSH_KEYS = 'ssh-keys'
SSHKEYS_LEGACY = 'sshKeys'
INSTANCE_LEVEL = 1
PROJECT_LEVEL = 2
def __init__(self, compute, instance, ssh_user='tester'):
"""Constructor.
Args:
compute: object, from GetCompute.
instance: string, the instance to manage the metadata.
user: string, the user to create ssh keys and perform ssh tests.
"""
self.zone = self.FetchMetadataDefault('zone')
self.region = self.zone[:-2] # clears the "-[a-z]$" of the zone
self.project = self.FetchMetadataDefault('project')
self.compute = compute
self.instance = instance
self.last_fingerprint = None
self.ssh_user = ssh_user
self.md_items = {}
md_obj = self._FetchMetadata(self.INSTANCE_LEVEL)
self.md_items[self.INSTANCE_LEVEL] = (
md_obj['items'] if 'items' in md_obj else [])
md_obj = self._FetchMetadata(self.PROJECT_LEVEL)
self.md_items[self.PROJECT_LEVEL] = (
md_obj['items'] if 'items' in md_obj else [])
def _FetchMetadata(self, level):
"""Fetch metadata from the server.
Args:
level: enum, INSTANCE_LEVEL or PROJECT_LEVEL to fetch the metadata.
"""
if level == self.PROJECT_LEVEL:
request = self.compute.projects().get(project=self.project)
md_id = 'commonInstanceMetadata'
else:
request = self.compute.instances().get(
project=self.project, zone=self.zone, instance=self.instance)
md_id = 'metadata'
response = request.execute()
return response[md_id]
@RetryOnFailure()
def StoreMetadata(self, level):
"""Store Metadata.
Args:
level: enum, INSTANCE_LEVEL or PROJECT_LEVEL to store the metadata.
"""
md_obj = self._FetchMetadata(level)
md_obj['items'] = self.md_items[level]
if level == self.PROJECT_LEVEL:
request = self.compute.projects().setCommonInstanceMetadata(
project=self.project, body=md_obj)
else:
request = self.compute.instances().setMetadata(
project=self.project, zone=self.zone, instance=self.instance,
body=md_obj)
response = request.execute()
self.Wait(response)
def ExtractKeyItem(self, md_key, level):
"""Extract a given key value from the metadata.
Args:
md_key: string, the key of the metadata value to be searched.
level: enum, INSTANCE_LEVEL or PROJECT_LEVEL to fetch the metadata.
Returns:
md_item: dict, in the format {'key', md_key, 'value', md_value}.
None: if md_key was not found.
"""
for md_item in self.md_items[level]:
if md_item['key'] == md_key:
return md_item
def SetMetadata(self, md_key, md_value, level=None, store=True):
"""Add or update a metadata key with a new value in a given level.
Args:
md_key: string, the key of the metadata.
md_value: string, value to be added or updated.
level: enum, INSTANCE_LEVEL (default) or PROJECT_LEVEL to fetch the
metadata.
store: bool, if True, saves metadata to GCE server.
"""
if not level:
level = self.INSTANCE_LEVEL
md_item = self.ExtractKeyItem(md_key, level)
if md_item and md_value is None:
self.md_items[level].remove(md_item)
elif not md_item:
md_item = {'key': md_key, 'value': md_value}
self.md_items[level].append(md_item)
else:
md_item['value'] = md_value
if store:
self.StoreMetadata(level)
def AddSshKey(self, md_key, level=None, store=True):
"""Generate and add an ssh key to the metadata
Args:
md_key: string, SSH_KEYS or SSHKEYS_LEGACY, defines where to add the key.
level: enum, INSTANCE_LEVEL (default) or PROJECT_LEVEL to fetch the
metadata.
store: bool, if True, saves metadata to GCE server.
Returns:
key_name: string, the name of the file with the generated private key.
"""
if not level:
level = self.INSTANCE_LEVEL
key, key_name = GenSshKey(self.ssh_user)
md_item = self.ExtractKeyItem(md_key, level)
if not md_item:
md_item = {'key': md_key, 'value': key}
self.md_items[level].append(md_item)
else:
md_item['value'] = '\n'.join([md_item['value'], key])
if store:
self.StoreMetadata(level)
return key_name
def RemoveSshKey(self, key, md_key, level=None, store=True):
"""Remove an ssh key to the metadata
Args:
key: string, the key to be removed.
md_key: string, SSH_KEYS or SSHKEYS_LEGACY, defines where to add the key.
level: enum, INSTANCE_LEVEL (default) or PROJECT_LEVEL to fetch the
metadata.
store: bool, if True, saves metadata to GCE server.
"""
if not level:
level = self.INSTANCE_LEVEL
md_item = self.ExtractKeyItem(md_key, level)
# Clear the key (whole line), empty keys (if any) and the last break line.
md_item['value'] = re.sub('\n$', '',
re.sub('\n\n', '\n',
re.sub('.*%s.*' % key, '', md_item['value'])))
if not md_item['value']:
self.md_items[level].remove(md_item)
if store:
self.StoreMetadata(level)
@RetryOnFailure()
def TestSshLogin(self, key, as_root=False, expect_fail=False):
"""Try to login to self.instance using key.
Args:
key: string, the private key to be used in the ssh connection.
as_root: bool, indicates if the test is executed with root privileges.
expect_fail: bool, indicates if the failure in the execution is expected.
"""
command = ['echo', 'Logged']
if as_root:
command.insert(0, 'sudo')
ExecuteInSsh(
key, self.ssh_user, self.instance, command,
expect_fail=expect_fail)
@classmethod
def FetchMetadataDefault(cls, name):
"""Fetch Metadata from default metadata server (local machine).
Args:
name: string, the metadata key to be fetched.
Returns:
value: the metadata value.
"""
try:
url = 'http://metadata/computeMetadata/v1/instance/attributes/%s' % name
return HttpGet(url, headers={'Metadata-Flavor': 'Google'}).decode()
except urllib.error.HTTPError:
raise ValueError('Metadata key "%s" not found' % name)
def GetInstanceInfo(self, instance):
"""Get an instance information
Args:
instance: string, the name of the instance to fetch its state.
Returns:
value: dictionary: instance information
"""
request = self.compute.instances().get(
project=self.project, zone=self.zone, instance=instance)
return request.execute()
def GetInstanceIfaces(self, instance):
"""Get an instance network interfaces
Args:
instance: string, the name of the instance to fetch its state.
Returns:
value: list of dict, the network interfaces information.
"""
return self.GetInstanceInfo(instance)[u'networkInterfaces']
def GetInstanceState(self, instance):
"""Get an instance state (e.g: RUNNING, TERMINATED, STOPPING...)
Args:
instance: string, the name of the instance to fetch its state.
Returns:
value: string, the status string.
"""
return self.GetInstanceInfo(instance)[u'status']
def SetInstanceIface(self, instance, iface_info, iface_name='nic0'):
"""Update an instance's network interface information
Args:
instance: string, the name of the instance to fetch its state.
iface_info: dict, interface information to be set
iface_name: string, interface name, by default, nic0
Returns:
response: dict, the request's response.
"""
request = self.compute.instances().updateNetworkInterface(
project=self.project, zone=self.zone, instance=instance,
networkInterface=iface_name, body=iface_info)
return request.execute()
def StartInstance(self, instance):
"""Start an instance
Args:
instance: string, the name of the instance to be started.
"""
self.compute.instances().start(
project=self.project, zone=self.zone, instance=instance).execute()
def ResizeDiskGb(self, disk_name, new_size):
"""Resize a disk to a new size. Note: Only allows size growing.
Args:
disk_name: string, the name of the disks to be resized.
new_size: int, the new size in gigabytes to be resized
"""
body = {'sizeGb': "%d" % new_size}
request = self.compute.disks().resize(
project=self.project, zone=self.zone, disk=disk_name, body=body)
return request.execute()
def AttachDisk(self, instance, disk_name):
"""Attach disk on instance.
Args:
instance: string, the name of the instance to attach disk.
disk_name: string, the name of the disks to be attached.
Returns:
response: dict, the request's response.
"""
body = {'source': 'projects/%s/zones/%s/disks/%s' % (
self.project, self.zone, disk_name)}
request = self.compute.instances().attachDisk(
project=self.project, zone=self.zone, instance=instance, body=body)
return request.execute()
def GetDiskDeviceNameFromAttached(self, instance, disk_name):
"""Retrieve deviceName of an attached disk based on disk source name
Args:
instance: string, the name of the instance to detach disk.
disk_name: string, the disk name to be compared to.
"""
request = self.compute.instances().get(
project=self.project, zone=self.zone, instance=instance)
response = request.execute()
for disk in response[u'disks']:
if disk_name in disk[u'source']:
return disk[u'deviceName']
def DetachDisk(self, instance, device_name):
"""Detach disk on instance.
Args:
instance: string, the name of the instance to detach disk.
device_name: string, the device name of the disk to be detached.
Returns:
response: dictionary, the request's response.
"""
request = self.compute.instances().detachDisk(
project=self.project, zone=self.zone, instance=instance,
deviceName=device_name)
return request.execute()
def Wait(self, response):
"""Blocks until operation completes.
Code from GitHub's GoogleCloudPlatform/python-docs-samples
Args:
response: dict, a request's response
"""
def _OperationGetter(response):
operation = response[u'name']
if response.get(u'zone'):
return self.compute.zoneOperations().get(
project=self.project, zone=self.zone, operation=operation)
elif response.get(u'region'):
return self.compute.regionOperations().get(
project=self.project, region=self.region, operation=operation)
else:
return self.compute.globalOperations().get(
project=self.project, operation=operation)
while True:
result = _OperationGetter(response).execute()
if result['status'] == 'DONE':
if 'error' in result:
raise Exception(result['error'])
return result
time.sleep(1)
def GetForwardingRuleIP(self, name):
"""Retrieves a forwarding rule ip
Args:
name: string, the name of the forwarding rule.
Returns:
response: string, the forwarding rule ip.
"""
request = self.compute.forwardingRules().get(
project=self.project, region=self.region, forwardingRule=name)
response = request.execute()
return response[u'IPAddress']
@RetryOnFailure(stop_after_seconds=5 * 60, initial_delay_seconds=1)
def install_apt_packages(g, *pkgs):
cmd = 'DEBIAN_FRONTEND=noninteractive apt-get ' \
'install -y --no-install-recommends ' + ' '.join(pkgs)
run(g, cmd)
@RetryOnFailure(stop_after_seconds=5 * 60, initial_delay_seconds=1)
def update_apt(g):
"""Runs apt update in a guest.
Starting at apt 1.5, release info changes must be confirmed
explicitly with `--allow-releaseinfo-change`. That flag,
however, is not supported prior to 1.5.
Args:
g: guestfs.GuestFS, mounted guest.
https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=931566
"""
try:
run(g, 'apt update -y')
except RuntimeError:
run(g, 'apt update -y --allow-releaseinfo-change')