dags/multipod/legacy_tests/gpt1-like.py (152 lines of code) (raw):
#!/usr/bin/python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax.experimental.pjit import pjit
from jax._src.partition_spec import PartitionSpec
import numpy as np
from jax._src.mesh import Mesh
import datetime
import os
# os.environ["TPU_STDERR_LOG_LEVEL"] = "0"
# os.environ["TPU_MIN_LOG_LEVEL"] = "0"
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["JAX_USE_PJRT_C_API_ON_TPU"] = "1"
def simple_timeit(f, tries=10, verbose=True):
outcomes = []
f() # warm it up!
for i in range(tries):
s = datetime.datetime.now()
r = f()
e = datetime.datetime.now()
outcomes.append((e - s).total_seconds())
average_time = sum(outcomes) / len(outcomes)
if verbose:
print(f"average time: {average_time}, timings (seconds) {outcomes}")
return average_time
# GPT1
BATCH = len(jax.devices()) * 128
SEQUENCE_LENGTH = 512
D_MODEL = 768
D_HIDDEN = 3072
NUM_LAYERS = 12
parameter_bytes = 2 * (4 * D_MODEL * D_HIDDEN * NUM_LAYERS)
ACTIVATIONS_PER_LAYER = 2
activation_bytes = (
2 * (BATCH * SEQUENCE_LENGTH * D_MODEL) * NUM_LAYERS * ACTIVATIONS_PER_LAYER
)
memory_bytes = parameter_bytes + activation_bytes
print(
f"total {memory_bytes/10**9} GB, parameters {parameter_bytes/10**9} GB, all layers of activations {activation_bytes/10**9} GB",
flush=True,
)
def gen_layer(random_key):
keys = jax.random.split(random_key, num=4)
return {
"WQ": 1e-4
* jax.random.normal(
keys[0], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"WK": 1e-4
* jax.random.normal(
keys[1], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"WV": 1e-4
* jax.random.normal(
keys[2], (D_MODEL, D_HIDDEN), dtype=jax.numpy.bfloat16
),
"FF": 1e-4
* jax.random.normal(
keys[3], (D_HIDDEN, D_MODEL), dtype=jax.numpy.bfloat16
),
}
def gen_layers(random_key):
layers = []
for _ in range(NUM_LAYERS):
random_key, sub_key = jax.random.split(random_key)
layers.append(gen_layer(sub_key))
return tuple(layers)
def gen_data(random_key):
return jax.random.uniform(
random_key, (BATCH, SEQUENCE_LENGTH, D_MODEL), dtype=jax.numpy.bfloat16
)
def multiply_layer(in_act, in_layer):
Q = (
in_act @ in_layer["WQ"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
K = (
in_act @ in_layer["WK"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
V = (
in_act @ in_layer["WV"]
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2* BATCH * SEQUENCE_LENGTH * D_MODEL * D_HIDDEN
A = jax.numpy.einsum(
"bsd,btd->bst", Q, K
) # BATCH x SEQUENCE_LENGTH x SEQUENCE_LENGTH, flops : 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN
A = jax.nn.relu(A) # TODO(correct low arithmetic intensity manips)
post_attention = (
A @ V
) # BATCH x SEQUENCE_LENGTH x D_HIDDEN, flops: 2 * BATCH * SEQUENCE_LENGTH^2 * D_HIDDEN
right_shape = (
post_attention @ in_layer["FF"]
) # BATCH x SEQUENCE_LENGTH x D_MODEL, flops: 2 * BATCH * SEQUENCE_LENGTH * D_HIDDEN * D_MODEL
right_shape = jax.nn.relu(
right_shape
) # TODO(correct low arithmetic intensity manips)
return right_shape + 1 + in_act
def multiply_layers(in_act, in_layers):
x = in_act
for i in range(len(in_layers)):
x = multiply_layer(x, in_layers[i])
return x, in_layers
def multiply_layers_with_loss(in_act, in_layers):
x, _ = multiply_layers(in_act, in_layers)
return jax.numpy.sum(x)
def calculate_tflops(f, *args, **kwargs):
print(
"Not calculating TFLOPS since MXLA is enabled -- for now just have a stored value for this test"
)
return 50
multiply_layers_and_grad = jax.value_and_grad(
multiply_layers_with_loss, argnums=[1]
)
def training_loop(in_act, in_layers):
_, grad_layers = multiply_layers_and_grad(in_act, in_layers)
out_layers = jax.tree_map(
lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]
)
return out_layers
print(f"finished includes ", flush=True)
# pjit NN
devices = jax.devices()
try:
num_slices = 1 + max([d.slice_index for d in devices])
except:
num_slices = 1
mesh_shape = [num_slices, len(jax.devices()) // num_slices]
devices_array = np.asarray(jax.devices()).reshape(*mesh_shape)
print(f"mesh shape {mesh_shape}", flush=True)
print(f"device layout {devices_array}", flush=True)
mesh = Mesh(devices_array, ("slices", "tpus"))
pjit_func = pjit(
training_loop,
in_shardings=(PartitionSpec(("slices", "tpus")), PartitionSpec("tpus")),
out_shardings=PartitionSpec("tpus"),
)
pjit_gen_data = pjit(
gen_data, in_shardings=None, out_shardings=PartitionSpec(("slices", "tpus"))
)
pjit_gen_layers = pjit(
gen_layers, in_shardings=None, out_shardings=PartitionSpec("tpus")
)
print("compiles completed")
with Mesh(mesh.devices, mesh.axis_names):
key = jax.random.PRNGKey(0)
presharded_X = jax.block_until_ready(pjit_gen_data(key))
presharded_layers = jax.block_until_ready(pjit_gen_layers(key))
TFLOPs = calculate_tflops(training_loop, presharded_X, presharded_layers)
with jax.profiler.trace("/tmp/tb12"):
time = simple_timeit(
lambda: jax.block_until_ready(
pjit_func(presharded_X, presharded_layers)
)
)
print(
f"time is {time} seconds, TFLOP is {TFLOPs}, memory usage is {memory_bytes/10**9} GB, TFLOP/s is {TFLOPs/time}",
flush=True,
)
assert (
TFLOPs / time > 275 / 2
), "make sure that we're hitting the performance target, 50% peakflops"