in redshift_connector/plugin/saml_credentials_provider.py [0:0]
def refresh(self: "SamlCredentialsProvider") -> None:
import boto3 # type: ignore
import bs4 # type: ignore
try:
# get SAML assertion from specific identity provider
saml_assertion = self.get_saml_assertion()
except Exception as e:
_logger.error("Get saml assertion failed: {}".format(str(e)))
raise InterfaceError(e)
# decode SAML assertion into xml format
doc: bytes = base64.b64decode(saml_assertion)
soup = bs4.BeautifulSoup(doc, "xml")
attrs = soup.findAll("Attribute")
# extract RoleArn adn PrincipleArn from SAML assertion
role_pattern = re.compile(r"arn:aws:iam::\d*:role/\S+")
provider_pattern = re.compile(r"arn:aws:iam::\d*:saml-provider/\S+")
roles: typing.Dict[str, str] = {}
for attr in attrs:
name: str = attr.attrs["Name"]
values: typing.Any = attr.findAll("AttributeValue")
if name == "https://aws.amazon.com/SAML/Attributes/Role":
for value in values:
arns = value.contents[0].split(",")
role: str = ""
provider: str = ""
for arn in arns:
arn = arn.strip() # remove trailing or leading whitespace
if role_pattern.match(arn):
role = arn
if provider_pattern.match(arn):
provider = arn
if role != "" and provider != "":
roles[role] = provider
if len(roles) == 0:
raise InterfaceError("No role found in SamlAssertion")
role_arn: str = ""
principle: str = ""
if self.preferred_role:
role_arn = self.preferred_role
if role_arn not in roles:
raise InterfaceError("Preferred role not found in SamlAssertion")
principle = roles[role_arn]
else:
role_arn = random.choice(list(roles))
principle = roles[role_arn]
client = boto3.client("sts")
try:
response = client.assume_role_with_saml(
RoleArn=role_arn, # self.preferred_role,
PrincipalArn=principle, # self.principal,
SAMLAssertion=saml_assertion,
)
stscred: typing.Dict[str, typing.Any] = response["Credentials"]
credentials: CredentialsHolder = CredentialsHolder(stscred)
# get metadata from SAML assertion
credentials.set_metadata(self.read_metadata(doc))
key: str = self.get_cache_key()
self.cache[key] = credentials
except AttributeError as e:
_logger.error("AttributeError: %s", e)
raise e
except KeyError as e:
_logger.error("KeyError: %s", e)
raise e
except client.exceptions.MalformedPolicyDocumentException as e:
_logger.error("MalformedPolicyDocumentException: %s", e)
raise e
except client.exceptions.PackedPolicyTooLargeException as e:
_logger.error("PackedPolicyTooLargeException: %s", e)
raise e
except client.exceptions.IDPRejectedClaimException as e:
_logger.error("IDPRejectedClaimException: %s", e)
raise e
except client.exceptions.InvalidIdentityTokenException as e:
_logger.error("InvalidIdentityTokenException: %s", e)
raise e
except client.exceptions.ExpiredTokenException as e:
_logger.error("ExpiredTokenException: %s", e)
raise e
except client.exceptions.RegionDisabledException as e:
_logger.error("RegionDisabledException: %s", e)
raise e
except Exception as e:
_logger.error("Other Exception: %s", e)
raise e