in providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py [0:0]
def _get_conn_params(self) -> dict[str, str | None]:
"""
Fetch connection params as a dict.
This is used in ``get_uri()`` and ``get_connection()``.
"""
conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined]
extra_dict = conn.extra_dejson
account = self._get_field(extra_dict, "account") or ""
warehouse = self._get_field(extra_dict, "warehouse") or ""
database = self._get_field(extra_dict, "database") or ""
region = self._get_field(extra_dict, "region") or ""
role = self._get_field(extra_dict, "role") or ""
insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode"))
json_result_force_utf8_decoding = _try_to_boolean(
self._get_field(extra_dict, "json_result_force_utf8_decoding")
)
schema = conn.schema or ""
client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token"))
client_store_temporary_credential = _try_to_boolean(
self._get_field(extra_dict, "client_store_temporary_credential")
)
# authenticator and session_parameters never supported long name so we don't use _get_field
authenticator = extra_dict.get("authenticator", "snowflake")
session_parameters = extra_dict.get("session_parameters")
conn_config = {
"user": conn.login,
"password": conn.password or "",
"schema": self.schema or schema,
"database": self.database or database,
"account": self.account or account,
"warehouse": self.warehouse or warehouse,
"region": self.region or region,
"role": self.role or role,
"authenticator": self.authenticator or authenticator,
"session_parameters": self.session_parameters or session_parameters,
# application is used to track origin of the requests
"application": os.environ.get("AIRFLOW_SNOWFLAKE_PARTNER", "AIRFLOW"),
}
if insecure_mode:
conn_config["insecure_mode"] = insecure_mode
if json_result_force_utf8_decoding:
conn_config["json_result_force_utf8_decoding"] = json_result_force_utf8_decoding
if client_request_mfa_token:
conn_config["client_request_mfa_token"] = client_request_mfa_token
if client_store_temporary_credential:
conn_config["client_store_temporary_credential"] = client_store_temporary_credential
# If private_key_file is specified in the extra json, load the contents of the file as a private key.
# If private_key_content is specified in the extra json, use it as a private key.
# As a next step, specify this private key in the connection configuration.
# The connection password then becomes the passphrase for the private key.
# If your private key is not encrypted (not recommended), then leave the password empty.
private_key_file = self._get_field(extra_dict, "private_key_file")
private_key_content = self._get_field(extra_dict, "private_key_content")
private_key_pem = None
if private_key_content and private_key_file:
raise AirflowException(
"The private_key_file and private_key_content extra fields are mutually exclusive. "
"Please remove one."
)
if private_key_file:
private_key_file_path = Path(private_key_file)
if not private_key_file_path.is_file() or private_key_file_path.stat().st_size == 0:
raise ValueError("The private_key_file path points to an empty or invalid file.")
if private_key_file_path.stat().st_size > 4096:
raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.")
private_key_pem = Path(private_key_file_path).read_bytes()
elif private_key_content:
private_key_pem = base64.b64decode(private_key_content)
if private_key_pem:
passphrase = None
if conn.password:
passphrase = conn.password.strip().encode()
p_key = serialization.load_pem_private_key(
private_key_pem, password=passphrase, backend=default_backend()
)
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
conn_config["private_key"] = pkb
conn_config.pop("password", None)
refresh_token = self._get_field(extra_dict, "refresh_token") or ""
if refresh_token:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)
conn_config["token"] = self.get_oauth_token(conn_config=conn_config)
# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
if snowflake_host:
conn_config["host"] = snowflake_host
if snowflake_port:
conn_config["port"] = snowflake_port
# if a value for ocsp_fail_open is set, pass it along.
# Note the check is for `is not None` so that we can pass along `False` as a value.
ocsp_fail_open = extra_dict.get("ocsp_fail_open")
if ocsp_fail_open is not None:
conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)
return conn_config