awscli/customizations/emrcontainers/update_role_trust_policy.py (142 lines of code) (raw):
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json
import logging
from awscli.customizations.commands import BasicCommand
from awscli.customizations.emrcontainers.constants \
import TRUST_POLICY_STATEMENT_FORMAT, \
TRUST_POLICY_STATEMENT_ALREADY_EXISTS, \
TRUST_POLICY_UPDATE_SUCCESSFUL
from awscli.customizations.emrcontainers.base36 import Base36
from awscli.customizations.emrcontainers.eks import EKS
from awscli.customizations.emrcontainers.iam import IAM
from awscli.customizations.utils import uni_print, get_policy_arn_suffix
LOG = logging.getLogger(__name__)
# Method to parse the arguments to get the region value
def get_region(session, parsed_globals):
region = parsed_globals.region
if region is None:
region = session.get_config_variable('region')
return region
def check_if_statement_exists(expected_statement, actual_assume_role_document):
if actual_assume_role_document is None:
return False
existing_statements = actual_assume_role_document.get("Statement", [])
for existing_statement in existing_statements:
matches = check_if_dict_matches(expected_statement, existing_statement)
if matches:
return True
return False
def check_if_dict_matches(expected_dict, actual_dict):
if len(expected_dict) != len(actual_dict):
return False
for key in expected_dict:
key_str = str(key)
val = expected_dict[key_str]
if isinstance(val, dict):
if not check_if_dict_matches(val, actual_dict.get(key_str, {})):
return False
else:
if key_str not in actual_dict or actual_dict[key_str] != str(val):
return False
return True
class UpdateRoleTrustPolicyCommand(BasicCommand):
NAME = 'update-role-trust-policy'
DESCRIPTION = BasicCommand.FROM_FILE(
'emr-containers',
'update-role-trust-policy',
'_description.rst'
)
ARG_TABLE = [
{
'name': 'cluster-name',
'help_text': ("Specify the name of the Amazon EKS cluster with "
"which the IAM Role would be used."),
'required': True
},
{
'name': 'namespace',
'help_text': ("Specify the namespace from the Amazon EKS cluster "
"with which the IAM Role would be used."),
'required': True
},
{
'name': 'role-name',
'help_text': ("Specify the IAM Role name that you want to use"
"with Amazon EMR on EKS."),
'required': True
},
{
'name': 'iam-endpoint',
'no_paramfile': True,
'help_text': ("The IAM endpoint to call for updating the role "
"trust policy. This is optional and should only be"
"specified when a custom endpoint should be called"
"for IAM operations."),
'required': False
},
{
'name': 'dry-run',
'action': 'store_true',
'default': False,
'help_text': ("Print the merged trust policy document to"
"stdout instead of updating the role trust"
"policy directly."),
'required': False
}
]
def _run_main(self, parsed_args, parsed_globals):
"""Call to run the commands"""
self._cluster_name = parsed_args.cluster_name
self._namespace = parsed_args.namespace
self._role_name = parsed_args.role_name
self._region = get_region(self._session, parsed_globals)
self._endpoint_url = parsed_args.iam_endpoint
self._dry_run = parsed_args.dry_run
result = self._update_role_trust_policy(parsed_globals)
uni_print(result)
uni_print("\n")
return 0
def _update_role_trust_policy(self, parsed_globals):
"""Method to update trust policy if not done already"""
base36 = Base36()
eks_client = EKS(self._session.create_client(
'eks',
region_name=self._region,
verify=parsed_globals.verify_ssl
))
account_id = eks_client.get_account_id(self._cluster_name)
oidc_provider = eks_client.get_oidc_issuer_id(self._cluster_name)
base36_encoded_role_name = base36.encode(self._role_name)
LOG.debug('Base36 encoded role name: %s', base36_encoded_role_name)
trust_policy_statement = json.loads(TRUST_POLICY_STATEMENT_FORMAT % {
"AWS_ACCOUNT_ID": account_id,
"OIDC_PROVIDER": oidc_provider,
"NAMESPACE": self._namespace,
"BASE36_ENCODED_ROLE_NAME": base36_encoded_role_name,
"AWS_PARTITION": get_policy_arn_suffix(self._region)
})
LOG.debug('Computed Trust Policy Statement:\n%s', json.dumps(
trust_policy_statement, indent=2))
iam_client = IAM(self._session.create_client(
'iam',
region_name=self._region,
endpoint_url=self._endpoint_url,
verify=parsed_globals.verify_ssl
))
assume_role_document = iam_client.get_assume_role_policy(
self._role_name)
matches = check_if_statement_exists(trust_policy_statement,
assume_role_document)
if not matches:
LOG.debug('Role %s does not have the required trust policy ',
self._role_name)
existing_statements = assume_role_document.get("Statement")
if existing_statements is None:
assume_role_document["Statement"] = [trust_policy_statement]
else:
existing_statements.append(trust_policy_statement)
if self._dry_run:
return json.dumps(assume_role_document, indent=2)
else:
LOG.debug('Updating trust policy of role %s', self._role_name)
iam_client.update_assume_role_policy(self._role_name,
assume_role_document)
return TRUST_POLICY_UPDATE_SUCCESSFUL % self._role_name
else:
return TRUST_POLICY_STATEMENT_ALREADY_EXISTS % self._role_name