awscli/customizations/emr/emrutils.py (182 lines of code) (raw):
# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json
import logging
import os
from awscli.clidriver import CLIOperationCaller
from awscli.customizations.emr import constants
from awscli.customizations.emr import exceptions
from botocore.exceptions import WaiterError, NoCredentialsError
from botocore import xform_name
LOG = logging.getLogger(__name__)
def parse_tags(raw_tags_list):
tags_dict_list = []
if raw_tags_list:
for tag in raw_tags_list:
if tag.find('=') == -1:
key, value = tag, ''
else:
key, value = tag.split('=', 1)
tags_dict_list.append({'Key': key, 'Value': value})
return tags_dict_list
def parse_key_value_string(key_value_string):
# raw_key_value_string is a list of key value pairs separated by comma.
# Examples: "k1=v1,k2='v 2',k3,k4"
key_value_list = []
if key_value_string is not None:
raw_key_value_list = key_value_string.split(',')
for kv in raw_key_value_list:
if kv.find('=') == -1:
key, value = kv, ''
else:
key, value = kv.split('=', 1)
key_value_list.append({'Key': key, 'Value': value})
return key_value_list
else:
return None
def apply_boolean_options(
true_option, true_option_name, false_option, false_option_name):
if true_option and false_option:
error_message = \
'aws: error: cannot use both ' + true_option_name + \
' and ' + false_option_name + ' options together.'
raise ValueError(error_message)
elif true_option:
return True
else:
return False
# Deprecate. Rename to apply_dict
def apply(params, key, value):
if value:
params[key] = value
return params
def apply_dict(params, key, value):
if value:
params[key] = value
return params
def apply_params(src_params, src_key, dest_params, dest_key):
if src_key in src_params.keys() and src_params[src_key]:
dest_params[dest_key] = src_params[src_key]
return dest_params
def build_step(
jar, name='Step',
action_on_failure=constants.DEFAULT_FAILURE_ACTION,
args=None,
main_class=None,
properties=None):
check_required_field(
structure='HadoopJarStep', name='Jar', value=jar)
step = {}
apply_dict(step, 'Name', name)
apply_dict(step, 'ActionOnFailure', action_on_failure)
jar_config = {}
jar_config['Jar'] = jar
apply_dict(jar_config, 'Args', args)
apply_dict(jar_config, 'MainClass', main_class)
apply_dict(jar_config, 'Properties', properties)
step['HadoopJarStep'] = jar_config
return step
def build_bootstrap_action(
path,
name='Bootstrap Action',
args=None):
if path is None:
raise exceptions.MissingParametersError(
object_name='ScriptBootstrapActionConfig', missing='Path')
ba_config = {}
apply_dict(ba_config, 'Name', name)
script_config = {}
apply_dict(script_config, 'Args', args)
script_config['Path'] = path
apply_dict(ba_config, 'ScriptBootstrapAction', script_config)
return ba_config
def build_s3_link(relative_path='', region='us-east-1'):
if region is None:
region = 'us-east-1'
return 's3://{0}.elasticmapreduce{1}'.format(region, relative_path)
def get_script_runner(region='us-east-1'):
if region is None:
region = 'us-east-1'
return build_s3_link(
relative_path=constants.SCRIPT_RUNNER_PATH, region=region)
def check_required_field(structure, name, value):
if not value:
raise exceptions.MissingParametersError(
object_name=structure, missing=name)
def check_empty_string_list(name, value):
if not value or (len(value) == 1 and value[0].strip() == ""):
raise exceptions.EmptyListError(param=name)
def call(session, operation_name, parameters, region_name=None,
endpoint_url=None, verify=None):
# We could get an error from get_endpoint() about not having
# a region configured. Before this happens we want to check
# for credentials so we can give a good error message.
if session.get_credentials() is None:
raise NoCredentialsError()
client = session.create_client(
'emr', region_name=region_name, endpoint_url=endpoint_url,
verify=verify)
LOG.debug('Calling ' + str(operation_name))
return getattr(client, operation_name)(**parameters)
def get_example_file(command):
return open('awscli/examples/emr/' + command + '.rst')
def dict_to_string(dict, indent=2):
return json.dumps(dict, indent=indent)
def get_client(session, parsed_globals):
return session.create_client(
'emr',
region_name=get_region(session, parsed_globals),
endpoint_url=parsed_globals.endpoint_url,
verify=parsed_globals.verify_ssl)
def get_cluster_state(session, parsed_globals, cluster_id):
client = get_client(session, parsed_globals)
data = client.describe_cluster(ClusterId=cluster_id)
return data['Cluster']['Status']['State']
def find_master_dns(session, parsed_globals, cluster_id):
"""
Returns the master_instance's 'PublicDnsName'.
"""
client = get_client(session, parsed_globals)
data = client.describe_cluster(ClusterId=cluster_id)
return data['Cluster']['MasterPublicDnsName']
def which(program):
for path in os.environ["PATH"].split(os.pathsep):
path = path.strip('"')
exe_file = os.path.join(path, program)
if os.path.isfile(exe_file) and os.access(exe_file, os.X_OK):
return exe_file
return None
def call_and_display_response(session, operation_name, parameters,
parsed_globals):
cli_operation_caller = CLIOperationCaller(session)
cli_operation_caller.invoke(
'emr', operation_name,
parameters, parsed_globals)
def display_response(session, operation_name, result, parsed_globals):
cli_operation_caller = CLIOperationCaller(session)
# Calling a private method. Should be changed after the functionality
# is moved outside CliOperationCaller.
cli_operation_caller._display_response(
operation_name, result, parsed_globals)
def get_region(session, parsed_globals):
region = parsed_globals.region
if region is None:
region = session.get_config_variable('region')
return region
def join(values, separator=',', lastSeparator='and'):
"""
Helper method to print a list of values
[1,2,3] -> '1, 2 and 3'
"""
values = [str(x) for x in values]
if len(values) < 1:
return ""
elif len(values) == 1:
return values[0]
else:
separator = '%s ' % separator
return ' '.join([separator.join(values[:-1]),
lastSeparator, values[-1]])
def split_to_key_value(string):
if string.find('=') == -1:
return string, ''
else:
return string.split('=', 1)
def get_cluster(cluster_id, session, region,
endpoint_url, verify_ssl):
describe_cluster_params = {'ClusterId': cluster_id}
describe_cluster_response = call(
session, 'describe_cluster', describe_cluster_params,
region, endpoint_url,
verify_ssl)
if describe_cluster_response is not None:
return describe_cluster_response.get('Cluster')
def get_release_label(cluster_id, session, region,
endpoint_url, verify_ssl):
cluster = get_cluster(cluster_id, session, region,
endpoint_url, verify_ssl)
if cluster is not None:
return cluster.get('ReleaseLabel')