in providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py [0:0]
def execute(self, context: Context) -> int:
# Check name again in case it was templated.
self._check_name(self.name)
self._ci_hook = AzureContainerInstanceHook(azure_conn_id=self.ci_conn_id)
if self.fail_if_exists:
self.log.info("Testing if container group already exists")
if self._ci_hook.exists(self.resource_group, self.name):
raise AirflowException("Container group exists")
if self.registry_conn_id:
registry_hook = AzureContainerRegistryHook(self.registry_conn_id)
image_registry_credentials: list | None = [
registry_hook.connection,
]
else:
image_registry_credentials = None
environment_variables = []
for key, value in self.environment_variables.items():
if key in self.secured_variables:
e = EnvironmentVariable(name=key, secure_value=value)
else:
e = EnvironmentVariable(name=key, value=value)
environment_variables.append(e)
volumes: list[_AzureVolume] = []
volume_mounts: list[VolumeMount | VolumeMount] = []
for conn_id, account_name, share_name, mount_path, read_only in self.volumes:
hook = AzureContainerVolumeHook(conn_id)
mount_name = f"mount-{len(volumes)}"
volumes.append(hook.get_file_volume(mount_name, share_name, account_name, read_only))
volume_mounts.append(VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only))
exit_code = 1
try:
self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb)
if self.gpu:
self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku)
resources = ResourceRequirements(
requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu)
)
if self.ip_address and not self.ports:
self.ports = [ContainerPort(port=80)]
self.log.info("Default port set. Container will listen on port 80")
container = Container(
name=self.name,
image=self.image,
resources=resources,
command=self.command,
environment_variables=environment_variables,
volume_mounts=volume_mounts,
ports=self.ports,
)
container_group = ContainerGroup(
location=self.region,
containers=[
container,
],
image_registry_credentials=image_registry_credentials,
volumes=volumes,
restart_policy=self.restart_policy,
os_type=self.os_type,
tags=self.tags,
ip_address=self.ip_address,
subnet_ids=self.subnet_ids,
dns_config=self.dns_config,
diagnostics=self.diagnostics,
priority=self.priority,
)
self._ci_hook.create_or_update(self.resource_group, self.name, container_group)
self.log.info("Container group started %s/%s", self.resource_group, self.name)
exit_code = self._monitor_logging(self.resource_group, self.name)
if self.xcom_all is not None:
logs = self._ci_hook.get_logs(self.resource_group, self.name)
if logs is None:
context["ti"].xcom_push(key="logs", value=[])
else:
if self.xcom_all:
context["ti"].xcom_push(key="logs", value=logs)
else:
# slice off the last entry in the list logs and return it as a list
context["ti"].xcom_push(key="logs", value=logs[-1:])
self.log.info("Container had exit code: %s", exit_code)
if exit_code != 0:
raise AirflowException(f"Container had a non-zero exit code, {exit_code}")
return exit_code
except CloudError:
self.log.exception("Could not start container group")
raise AirflowException("Could not start container group")
finally:
if exit_code == 0 or self.remove_on_error:
self.on_kill()