ssiog/launch/main.tf (104 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
terraform {
required_providers {
# This is used to create Google Cloud Platform resources.
google = {
source = "hashicorp/google"
version = ">= 5.44.1"
}
# This is used to create the k8s resources within the cluster created by
# this configuration file.
kubernetes = {
source = "hashicorp/kubernetes"
version = "2.31.0"
}
}
}
provider "google" {
project = var.project
}
# Retrieve an access token as the Terraform runner
data "google_client_config" "provider" {}
data "google_project" "project" {}
data "google_container_cluster" "cluster" {
name = var.cluster_name
location = var.location
}
provider "kubernetes" {
host = "https://${data.google_container_cluster.cluster.endpoint}"
token = data.google_client_config.provider.access_token
cluster_ca_certificate = base64decode(
data.google_container_cluster.cluster.master_auth[0].cluster_ca_certificate,
)
exec {
api_version = "client.authentication.k8s.io/v1beta1"
command = "gke-gcloud-auth-plugin"
}
}
# Create a random string to uniquely name per-project or global resources.
resource "random_id" "uniq" {
byte_length = 8
}
locals {
k8s_sa_name = var.k8s_sa_name
# The full name of the k8s service account when used in GCP IAM bindings.
k8s_sa_full = "//iam.googleapis.com/projects/${data.google_project.project.number}/locations/global/workloadIdentityPools/${data.google_project.project.project_id}.svc.id.goog/subject/ns/default/sa/${local.k8s_sa_name}"
}
resource "kubernetes_service_account" "ksa" {
metadata {
name = local.k8s_sa_name
}
}
resource "google_storage_bucket_iam_member" "grant-ksa-permissions-on-metrics-bucket" {
bucket = var.metrics_bucket_name
role = "roles/storage.objectUser"
member = "principal:${local.k8s_sa_full}"
}
resource "google_storage_bucket_iam_member" "grant-ksa-permissions-on-data-bucket" {
bucket = var.data_bucket_name
role = "roles/storage.objectUser"
member = "principal:${local.k8s_sa_full}"
}
# Find the latest SHA of the image, this forces a pull if the image changes.
data "google_artifact_registry_docker_image" "image" {
location = "us-west1"
repository_id = var.repository_id
image_name = var.image_name
}
locals {
parallelism = var.parallelism
epochs = var.epochs
prefixes = var.prefixes # gs://${var.data_bucket_name}"
background_threads = var.background_threads
# This is large enough to make the L-SSD cache irrelevant
# object_count_limit = 1024
# This is good for troubleshooting, basically one file per thread
object_count_limit = local.parallelism * local.background_threads
# file_size_gib = 2
memory = -1 # local.file_size_gib * local.background_threads
label = var.label
}
# Generate the data loader job-set (nodes talks to each other) benchmark definition.
resource "local_file" "ssiog-training-jobset" {
filename = "job-set.yaml"
content = templatefile("./templates/job-set.tfpl.yaml", {
image = data.google_artifact_registry_docker_image.image.self_link,
prefixes = local.prefixes
k8s_sa_name = local.k8s_sa_name,
metrics_bucket_name = var.metrics_bucket_name,
data_bucket_name = var.data_bucket_name,
parallelism = local.parallelism,
epochs = local.epochs,
background_threads = local.background_threads,
object_count_limit = local.object_count_limit,
memory = local.memory,
label = local.label,
steps = var.steps,
batch_size = var.batch_size,
})
}
# Generate the data loader job (every job is independent) benchmark definition.
resource "local_file" "ssiog-training-job" {
filename = "job.yaml"
content = templatefile("./templates/job.tfpl.yaml", {
image = data.google_artifact_registry_docker_image.image.self_link,
prefixes = local.prefixes
k8s_sa_name = local.k8s_sa_name,
metrics_bucket_name = var.metrics_bucket_name,
data_bucket_name = var.data_bucket_name,
parallelism = local.parallelism,
epochs = local.epochs,
background_threads = local.background_threads,
object_count_limit = local.object_count_limit,
memory = local.memory,
label = local.label,
steps = var.steps,
batch_size = var.batch_size,
})
}