fbgemm_gpu/codegen/embedding_backward_code_generator.py (531 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import argparse import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import jinja2 args: argparse.Namespace _: List[str] TENSOR: int INT_TENSOR: int LONG_TENSOR: int INT: int FLOAT: int parser = argparse.ArgumentParser() # By default the source template files are in the same folder as # embedding_backward_code_generator.py; # The install dir is by default the same as the current folder. parser.add_argument("--install_dir", default=".", help="where to put generated file") parser.add_argument("--opensource", action="store_false", dest="is_fbcode") args, _ = parser.parse_known_args() env = jinja2.Environment( loader=jinja2.FileSystemLoader(os.path.dirname(os.path.abspath(__file__))) ) # Upper Limit of "max_embedding_dim (max_D)": # BT_block_size * sizeof(float) * 4 * kWarpSize * {{ kMaxVecsPerThread }} # needs to be smaller than the allocated shared memory size (2/3 of 96 KB # on V100 and 160 KB on A100. # BT_block_size * 4 * 4 * 32 * (max_D // 128) <= 64 * 1024 (V100) or 96 * 1024 (A100) # Since BT_block_size >= 1, max_D <= 16K (V100) or 24K (A100). # Note that if we increase max_D, it will increase the compilation time significantly. env.globals["max_embedding_dim"] = 1024 env.globals["dense"] = False def write(filename: str, s: str) -> None: with open(os.path.join(args.install_dir, filename), "w") as f: f.write(s) def _arg_constructor( type: str, name: str, gpu: bool = True, precision: int = 32 ) -> str: return ( f"{name}.packed_accessor{precision}<{type}, 1, at::RestrictPtrTraits>()" if gpu else f"{name}.accessor<{type}, 1>()" ) def _arg(type: str, name: str, gpu: bool = True, precision: int = 32) -> str: return ( f"at::PackedTensorAccessor{precision}<{type}, 1, at::RestrictPtrTraits> {name}" if gpu else f"at::TensorAccessor<{type}, 1> {name}" ) def acc_cache_tensor_arg_constructor(name: str, gpu: bool = True) -> str: return _arg_constructor( "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>", name, gpu=gpu, precision=64, ) def acc_cache_tensor_arg(name: str, gpu: bool = True) -> str: return _arg( "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>", name, gpu=gpu, precision=64, ) def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str: return _arg_constructor("int64_t", name, gpu=gpu) def long_tensor_arg(name: str, gpu: bool = True) -> str: return _arg("int64_t", name, gpu=gpu) def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str: return _arg_constructor("int32_t", name, gpu=gpu) def int_tensor_arg(name: str, gpu: bool = True) -> str: return _arg("int32_t", name, gpu=gpu) def tensor_arg(name: str) -> str: return f"Tensor {name}" def double_arg(name: str) -> str: return f"double {name}" def float_arg(name: str) -> str: return f"float {name}" def int64_arg(name: str) -> str: return f"int64_t {name}" def int_arg(name: str) -> str: return f"int {name}" def generate(**kwargs: Any) -> None: gen_args = kwargs["args"] # Generates CUDA variants. kwargs["args"] = gen_args["cuda"] template = env.get_template("embedding_backward_split_template.cu") src_cu = template.render(weighted=False, **kwargs) write( f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu", src_cu, ) src_cu = template.render(weighted=True, **kwargs) write( f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu", src_cu, ) if not kwargs.get("dense"): template = env.get_template("embedding_backward_split_host_template.cpp") src_cpp = template.render(**kwargs) write(f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp) # Generates Python invoker for CUDA + CPU template = env.get_template("split_embedding_codegen_lookup_invoker.template") src_py = template.render(is_fbcode=args.is_fbcode, **kwargs) write(f"lookup_{kwargs.get('optimizer')}.py", src_py) # Generates CPU variants. kwargs["args"] = gen_args["cpu"] is_approx = "approx" in kwargs.get("optimizer") template = ( env.get_template("embedding_backward_split_cpu_approx_template.cpp") if is_approx else env.get_template("embedding_backward_split_cpu_template.cpp") ) src_cpp = template.render(**kwargs) write( f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp", src_cpp, ) if not kwargs.get("dense"): template = env.get_template("embedding_backward_split_host_cpu_template.cpp") src_cpp = template.render(**kwargs) write( f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp", src_cpp ) @dataclass class Args: split_kernel_args: List[str] split_kernel_arg_constructors: List[str] split_cpu_kernel_args: List[str] split_cpu_kernel_arg_constructors: List[str] split_function_args: List[str] split_saved_tensors: List[str] split_tensors: List[str] saved_data: List[Tuple[str, str]] split_function_arg_names: List[str] split_function_schemas: List[str] split_variables: List[str] TENSOR, INT_TENSOR, LONG_TENSOR, INT, FLOAT = range(5) def make_args(arg_spec: List[Tuple[int, str]]) -> Dict[str, Any]: def make_kernel_arg(ty: int, name: str) -> str: return { TENSOR: acc_cache_tensor_arg, INT_TENSOR: int_tensor_arg, LONG_TENSOR: long_tensor_arg, INT: int64_arg, FLOAT: float_arg, }[ty](name) def make_kernel_arg_constructor(ty: int, name: str) -> str: return { TENSOR: acc_cache_tensor_arg_constructor, INT_TENSOR: int_tensor_arg_constructor, LONG_TENSOR: long_tensor_arg_constructor, INT: lambda x: x, FLOAT: lambda x: x, }[ty](name) def make_cpu_kernel_arg(ty: int, name: str) -> str: return { TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False), INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False), LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False), INT: int64_arg, FLOAT: float_arg, }[ty](name) def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str: return { TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False), INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False), LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False), INT: lambda x: x, FLOAT: lambda x: x, }[ty](name) def make_function_arg(ty: int, name: str) -> str: return { TENSOR: tensor_arg, INT_TENSOR: tensor_arg, LONG_TENSOR: tensor_arg, INT: int64_arg, FLOAT: double_arg, }[ty](name) def make_function_schema_arg(ty: int, name: str) -> str: return { TENSOR: tensor_arg, INT_TENSOR: tensor_arg, LONG_TENSOR: tensor_arg, INT: int_arg, FLOAT: float_arg, }[ty](name) def make_ivalue_cast(ty: int) -> str: return {INT: "toInt", FLOAT: "toDouble"}[ty] def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args: return Args( split_kernel_args=[ make_kernel_arg(ty, name) for (ty, name) in split_arg_spec ], split_kernel_arg_constructors=[ make_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec ], split_cpu_kernel_args=[ make_cpu_kernel_arg(ty, name) for (ty, name) in split_arg_spec ], split_cpu_kernel_arg_constructors=[ make_cpu_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec ], split_function_args=[ make_function_arg(ty, name) for (ty, name) in split_arg_spec ], split_tensors=[name for (ty, name) in arg_spec if ty == TENSOR], split_saved_tensors=[ name for (ty, name) in split_arg_spec if ty in (TENSOR, INT_TENSOR, LONG_TENSOR) ], saved_data=[ (name, make_ivalue_cast(ty)) for (ty, name) in arg_spec if ty != TENSOR ], split_function_arg_names=[name for (ty, name) in split_arg_spec], split_function_schemas=[ make_function_schema_arg(ty, name) for (ty, name) in split_arg_spec ], split_variables=["Variable()" for _ in split_arg_spec], ) split_arg_spec = [] for (ty, arg) in arg_spec: if ty in (FLOAT, INT): split_arg_spec.append((ty, arg)) else: assert ty == TENSOR split_arg_spec.extend( [ (TENSOR, f"{arg}_host"), (INT_TENSOR, f"{arg}_placements"), (LONG_TENSOR, f"{arg}_offsets"), ] ) cpu = make_args_for_compute_device(split_arg_spec) split_arg_spec = [] for (ty, arg) in arg_spec: if ty in (FLOAT, INT): split_arg_spec.append((ty, arg)) else: assert ty == TENSOR split_arg_spec.extend( [ (TENSOR, f"{arg}_dev"), (TENSOR, f"{arg}_uvm"), (INT_TENSOR, f"{arg}_placements"), (LONG_TENSOR, f"{arg}_offsets"), ] ) cuda = make_args_for_compute_device(split_arg_spec) return {"cpu": cpu, "cuda": cuda} def adagrad() -> None: split_weight_update = """ Vec4T<cache_t> m_t(&momentum1[idx * D + d]); m_t.acc.x += grad.acc.x * grad.acc.x; m_t.acc.y += grad.acc.y * grad.acc.y; m_t.acc.z += grad.acc.z * grad.acc.z; m_t.acc.w += grad.acc.w * grad.acc.w; m_t.store(&momentum1[idx * D + d]); weight_new.acc.x -= learning_rate * grad.acc.x / (sqrtf(m_t.acc.x) + eps); weight_new.acc.y -= learning_rate * grad.acc.y / (sqrtf(m_t.acc.y) + eps); weight_new.acc.z -= learning_rate * grad.acc.z / (sqrtf(m_t.acc.z) + eps); weight_new.acc.w -= learning_rate * grad.acc.w / (sqrtf(m_t.acc.w) + eps); """ split_weight_update_cpu = """ for (int64_t d = 0; d < D; ++d) { momentum1_host[embedding_begin + d] += grad_buffer[d] * grad_buffer[d]; host_weights_data[embedding_begin + d] -= learning_rate * grad_buffer[d] / (sqrt(momentum1_host[embedding_begin + d]) + eps); } """ generate( optimizer="adagrad", args=make_args( [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")] ), split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def table_info_precomputation(momentum_prefix: str = "momentum1") -> str: template = """ // table_begin -> (E, D, {momentum_prefix}_row_begin). std::map<int64_t, std::tuple<int64_t, int64_t, int64_t>> table_info_map; for (int64_t t = 0; t < T; ++t) { const auto D = D_offsets_data[t + 1] - D_offsets_data[t]; const auto table_begin = weights_offsets_data[t]; const auto {momentum_prefix}_row_begin = {momentum_prefix}_offsets_data[t]; table_info_map[table_begin] = std::make_tuple(0, D, {momentum_prefix}_row_begin); } int64_t previous_table_begin = host_weights.numel(); // NOTE: table_info_map is sorted by table_begin! for (auto it = table_info_map.rbegin(); it != table_info_map.rend(); ++it) { const auto D = std::get<1>(it->second); // Calculates number of rows of each table. std::get<0>(it->second) = (previous_table_begin - it->first) / D; previous_table_begin = it->first; } """ return template.replace("{momentum_prefix}", momentum_prefix) def rowwise_adagrad() -> None: split_weight_update = """ weight_new.fma_(grad, -multiplier); """ split_precomputation = """ at::acc_type<cache_t, true> g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const at::acc_type<cache_t, true> g_avg_square = warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D; at::acc_type<cache_t, true> multiplier; if (threadIdx.x == 0) { at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square; momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); } multiplier = shfl_sync(multiplier, 0); """ split_weight_update_cpu = """ at::acc_type<grad_t, true> g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { g_local_sum_square += grad_buffer[d] * grad_buffer[d]; } auto g_avg_square = g_local_sum_square / D; at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square; momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; at::acc_type<grad_t, true> multiplier; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier; } """ generate( optimizer="rowwise_adagrad", args=make_args( [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) approx_split_weight_update = """ // dummy computation to avoid unused variable warning weight_new.fma_(grad, -multiplier); assert(false); // approx rowwise AdaGrad is not supported on GPU """ generate( optimizer="approx_rowwise_adagrad", args=make_args( [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")] ), split_precomputation=split_precomputation, split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def rowwise_adagrad_with_weight_decay() -> None: split_weight_update = """ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z; weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w; """ split_precomputation = """ at::acc_type<cache_t, true> g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { auto gx = grad_sum[i].acc.x; auto gy = grad_sum[i].acc.y; auto gz = grad_sum[i].acc.z; auto gw = grad_sum[i].acc.w; if (weight_decay_mode == 0) { // L2 regularization int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template); gx += weight_decay * weight.acc.x; gy += weight_decay * weight.acc.y; gz += weight_decay * weight.acc.z; gw += weight_decay * weight.acc.w; } g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; } const at::acc_type<cache_t, true> g_avg_square = warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D; at::acc_type<cache_t, true> multiplier; at::acc_type<cache_t, true> correction; if (threadIdx.x == 0) { at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square; momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); if (weight_decay_mode == 0) { // L2 regularization correction = 1.0 - multiplier * weight_decay; } else if (weight_decay_mode == 1) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; } } multiplier = shfl_sync(multiplier, 0); correction = shfl_sync(correction, 0); """ split_weight_update_cpu = """ at::acc_type<grad_t, true> g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { auto grad = grad_buffer[d]; if (weight_decay_mode == 0) { // L2 regularization grad += weight_decay * host_weights_data[embedding_begin + d]; } g_local_sum_square += grad * grad; } auto g_avg_square = g_local_sum_square / D; at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square; momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; at::acc_type<grad_t, true> multiplier; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); at::acc_type<scalar_t, true> correction; if (weight_decay_mode == 0) { // L2 regularization correction = 1.0 - multiplier * weight_decay; } else if (weight_decay_mode == 1) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; } for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier; } """ generate( optimizer="rowwise_adagrad_with_weight_decay", args=make_args( [ (TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate"), (FLOAT, "weight_decay"), (INT, "weight_decay_mode"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) approx_split_weight_update = """ // dummy computation to avoid unused variable warning weight_new.fma_(grad, -multiplier); assert(false); // approx rowwise AdaGrad is not supported on GPU """ generate( optimizer="approx_rowwise_adagrad_with_weight_decay", args=make_args( [ (TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate"), (FLOAT, "weight_decay"), (INT, "weight_decay_mode"), ] ), split_precomputation=split_precomputation, split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def rowwise_weighted_adagrad() -> None: split_weight_update = """ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z; weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w; """ split_precomputation = """ at::acc_type<cache_t, true> g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template); auto gx = grad_sum[i].acc.x + weight_decay * weight.acc.x; auto gy = grad_sum[i].acc.y + weight_decay * weight.acc.y; auto gz = grad_sum[i].acc.z + weight_decay * weight.acc.z; auto gw = grad_sum[i].acc.w + weight_decay * weight.acc.w; g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; } const at::acc_type<cache_t, true> g_avg_square = warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D; at::acc_type<cache_t, true> multiplier; at::acc_type<cache_t, true> correction; if (threadIdx.x == 0) { at::acc_type<cache_t, true> lambda = sqrtf(iter + 1); at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + lambda * g_avg_square; momentum1[idx] = new_sum_square_grads; multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); correction = 1.0 - multiplier * weight_decay; } multiplier = shfl_sync(multiplier, 0); correction = shfl_sync(correction, 0); """ split_weight_update_cpu = """ // weight_decay not supported for cpu version at::acc_type<scalar_t, true> g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { g_local_sum_square += grad_buffer[d] * grad_buffer[d]; } auto g_avg_square = g_local_sum_square / D; at::acc_type<scalar_t, true> lambda = sqrtf(iter + 1); at::acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + lambda * g_avg_square; momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; at::acc_type<scalar_t, true> multiplier; multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier; } """ generate( optimizer="rowwise_weighted_adagrad", args=make_args( [ (TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate"), (FLOAT, "weight_decay"), (INT, "iter"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def sgd() -> None: split_weight_update = """ weight_new.fma_(grad, -learning_rate); """ split_weight_update_cpu = """ for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] -= learning_rate * grad_buffer[d]; } """ generate( optimizer="sgd", args=make_args([(FLOAT, "learning_rate")]), split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) approx_split_weight_update = """ // approx_sgd not supported for GPU. // Just do the same thing as exact sgd to avoid unused variable warning. weight_new.fma_(grad, -learning_rate); assert(false); // approx SGD is not supported on GPU """ generate( optimizer="approx_sgd", args=make_args([(FLOAT, "learning_rate")]), split_precomputation="", split_weight_update=approx_split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def lamb() -> None: split_precomputation = """ at::acc_type<cache_t, true> weight_sum_sq = 0.0; at::acc_type<cache_t, true> rtw_sum_sq = 0.0; auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr); float2 qparams; if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { qparams = weight_row.load_qparams(); } #pragma unroll 1 for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> weight = weight_row.load(d, qparams); Vec4T<at::acc_type<cache_t, true>> m1(&momentum1[idx * D + d]); m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad_sum[i].acc.x; m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad_sum[i].acc.y; m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad_sum[i].acc.z; m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad_sum[i].acc.w; m1.store(&momentum1[idx * D + d]); Vec4T<at::acc_type<cache_t, true>> m2(&momentum2[idx * D + d]); m2.acc.x = beta2 * m2.acc.x + (1.0 - beta2) * grad_sum[i].acc.x * grad_sum[i].acc.x; m2.acc.y = beta2 * m2.acc.y + (1.0 - beta2) * grad_sum[i].acc.y * grad_sum[i].acc.y; m2.acc.z = beta2 * m2.acc.z + (1.0 - beta2) * grad_sum[i].acc.z * grad_sum[i].acc.z; m2.acc.w = beta2 * m2.acc.w + (1.0 - beta2) * grad_sum[i].acc.w * grad_sum[i].acc.w; m2.store(&momentum2[idx * D + d]); // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight; grad_sum[i].acc.x = (m1.acc.x / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.x; grad_sum[i].acc.y = (m1.acc.y / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.y; grad_sum[i].acc.z = (m1.acc.z / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.z; grad_sum[i].acc.w = (m1.acc.w / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.w; weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; rtw_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const auto weight_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq)); const auto rtw_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(rtw_sum_sq)); const auto true_ratio = weight_norm / rtw_norm; """ split_weight_update = """ weight_new.fma_(grad, -learning_rate * true_ratio); """ split_weight_update_cpu = "" generate( optimizer="lamb", args=make_args( [ (TENSOR, "momentum1"), (TENSOR, "momentum2"), (FLOAT, "learning_rate"), (FLOAT, "eps"), (FLOAT, "beta1"), (FLOAT, "beta2"), (FLOAT, "weight_decay"), (INT, "iter"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def partial_rowwise_lamb() -> None: split_precomputation = """ at::acc_type<cache_t, true> g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const at::acc_type<cache_t, true> g_avg_square = warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D; at::acc_type<cache_t, true> m2; if (threadIdx.x == 0) { m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; momentum2[idx] = m2; } m2 = shfl_sync(m2, 0); at::acc_type<cache_t, true> m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); at::acc_type<cache_t, true> weight_sum_sq = 0.0; at::acc_type<cache_t, true> rtw_sum_sq = 0.0; auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr); float2 qparams; if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { qparams = weight_row.load_qparams(); } #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> m1(&momentum1[idx * D + d]); m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad_sum[i].acc.x; m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad_sum[i].acc.y; m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad_sum[i].acc.z; m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad_sum[i].acc.w; m1.store(&momentum1[idx * D + d]); // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight; Vec4T<at::acc_type<cache_t, true>> weight = weight_row.load(d, qparams); grad_sum[i].acc.x = (m1.acc.x / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.x; grad_sum[i].acc.y = (m1.acc.y / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.y; grad_sum[i].acc.z = (m1.acc.z / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.z; grad_sum[i].acc.w = (m1.acc.w / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.w; weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; rtw_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const auto weight_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq)); const auto rtw_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(rtw_sum_sq)); const auto true_ratio = weight_norm / rtw_norm; """ split_weight_update = """ weight_new.fma_(grad, -learning_rate * true_ratio); """ split_weight_update_cpu = "" # TODO generate( optimizer="partial_rowwise_lamb", args=make_args( [ (TENSOR, "momentum1"), (TENSOR, "momentum2"), (FLOAT, "learning_rate"), (FLOAT, "eps"), (FLOAT, "beta1"), (FLOAT, "beta2"), (FLOAT, "weight_decay"), (INT, "iter"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def adam() -> None: split_weight_update = """ Vec4T<cache_t> m_t(&momentum1[idx * D + d]); m_t.acc.x *= beta1; m_t.acc.y *= beta1; m_t.acc.z *= beta1; m_t.acc.w *= beta1; m_t.fma_(grad, 1.0 - beta1); m_t.store(&momentum1[idx * D + d]); Vec4T<cache_t> v_t(&momentum2[idx * D + d]); v_t.acc.x *= beta2; v_t.acc.y *= beta2; v_t.acc.z *= beta2; v_t.acc.w *= beta2; grad.acc.x *= grad.acc.x; grad.acc.y *= grad.acc.y; grad.acc.z *= grad.acc.z; grad.acc.w *= grad.acc.w; v_t.fma_(grad, 1.0 - beta2); v_t.store(&momentum2[idx * D + d]); weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x); weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y); weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z); weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w); """ split_weight_update_cpu = "" # TODO generate( optimizer="adam", args=make_args( [ (TENSOR, "momentum1"), (TENSOR, "momentum2"), (FLOAT, "learning_rate"), (FLOAT, "eps"), (FLOAT, "beta1"), (FLOAT, "beta2"), (FLOAT, "weight_decay"), (INT, "iter"), ] ), split_precomputation="", split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def partial_rowwise_adam() -> None: split_precomputation = """ at::acc_type<cache_t, true> g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const at::acc_type<cache_t, true> g_avg_square = warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D; at::acc_type<cache_t, true> v_hat_t; if (threadIdx.x == 0) { at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2); momentum2[idx] = v_t; v_hat_t = v_t / (1.0 - powf(beta2, iter)); } v_hat_t = shfl_sync(v_hat_t, 0); """ split_weight_update = """ Vec4T<cache_t> m_t(&momentum1[idx * D + d]); m_t.acc.x *= beta1; m_t.acc.y *= beta1; m_t.acc.z *= beta1; m_t.acc.w *= beta1; m_t.fma_(grad, 1.0 - beta1); m_t.store(&momentum1[idx * D + d]); weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x); weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y); weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.z); weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.w); """ split_weight_update_cpu = "" # TODO generate( optimizer="partial_rowwise_adam", args=make_args( [ (TENSOR, "momentum1"), (TENSOR, "momentum2"), (FLOAT, "learning_rate"), (FLOAT, "eps"), (FLOAT, "beta1"), (FLOAT, "beta2"), (FLOAT, "weight_decay"), (INT, "iter"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def lars_sgd() -> None: split_precomputation = """ at::acc_type<cache_t, true> weight_sum_sq = 0.0; at::acc_type<cache_t, true> grad_sum_sq = 0.0; auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr); float2 qparams; if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { qparams = weight_row.load_qparams(); } #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t,true>> weight = weight_row.load(d, qparams); weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; grad_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w; } const auto weight_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq)); const auto grad_norm = sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(grad_sum_sq)); const at::acc_type<cache_t, true> adjusted_lr = learning_rate * eta * weight_norm / (grad_norm + weight_decay * weight_norm); """ split_weight_update = """ Vec4T<cache_t> m1(&momentum1[idx * D + d]); m1.acc.x = momentum * m1.acc.x + adjusted_lr * (grad.acc.x + weight_decay * weight_new.acc.x); m1.acc.y = momentum * m1.acc.y + adjusted_lr * (grad.acc.y + weight_decay * weight_new.acc.y); m1.acc.z = momentum * m1.acc.z + adjusted_lr * (grad.acc.z + weight_decay * weight_new.acc.z); m1.acc.w = momentum * m1.acc.w + adjusted_lr * (grad.acc.w + weight_decay * weight_new.acc.w); m1.store(&momentum1[idx * D + d]); weight_new.acc.x -= m1.acc.x; weight_new.acc.y -= m1.acc.y; weight_new.acc.z -= m1.acc.z; weight_new.acc.w -= m1.acc.w; """ split_weight_update_cpu = "" # TODO generate( optimizer="lars_sgd", args=make_args( [ (TENSOR, "momentum1"), (FLOAT, "learning_rate"), (FLOAT, "eta"), (FLOAT, "momentum"), (FLOAT, "weight_decay"), ] ), split_precomputation=split_precomputation, split_weight_update=split_weight_update, split_weight_update_cpu=split_weight_update_cpu, ) def forward_split() -> None: template = env.get_template("embedding_forward_split_template.cu") src_cu = template.render(weighted=False) write("gen_embedding_forward_split_unweighted_codegen_cuda.cu", src_cu) src_cu = template.render(weighted=True) write("gen_embedding_forward_split_weighted_codegen_cuda.cu", src_cu) src_cu = template.render(weighted=False, dense=True) write("gen_embedding_forward_dense_unweighted_codegen_cuda.cu", src_cu) src_cu = template.render(weighted=True, dense=True) write("gen_embedding_forward_dense_weighted_codegen_cuda.cu", src_cu) def forward_quantized() -> None: @dataclass class elem_type: enum_name: str cpp_type_name: str type_map = { 32: elem_type("FP32", "float"), 16: elem_type("FP16", "__half2"), 8: elem_type("INT8", "uint32_t"), 4: elem_type("INT4", "uint32_t"), 2: elem_type("INT2", "uint32_t"), } template = env.get_template("embedding_forward_quantized_split_template.cu") src_cu = template.render(weighted=False, type_map=type_map) write("gen_embedding_forward_quantized_split_unweighted_codegen_cuda.cu", src_cu) src_cu = template.render(weighted=True, type_map=type_map) write("gen_embedding_forward_quantized_split_weighted_codegen_cuda.cu", src_cu) template = env.get_template("embedding_forward_quantized_cpu_template.cpp") src_cu = template.render(weighted=False, type_map=type_map) write("gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp", src_cu) src_cu = template.render(weighted=True, type_map=type_map) write("gen_embedding_forward_quantized_weighted_codegen_cpu.cpp", src_cu) def backward_indices() -> None: template = env.get_template("embedding_backward_split_indice_weights_template.cu") src_cu = template.render() write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu) src_cu = template.render(dense=True) write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu) def backward_dense() -> None: generate( optimizer="dense", dense=True, args=make_args( [ (FLOAT, "unused"), ] ), ) def gen__init__py() -> None: template = env.get_template("__init__.template") src_py = template.render() write("__init__.py", src_py) def emb_codegen( install_dir: Optional[str] = None, is_fbcode: Optional[bool] = None ) -> None: if install_dir is not None and len(install_dir) != 0: args.install_dir = install_dir if is_fbcode is not None: args.is_fbcode = is_fbcode adagrad() adam() backward_indices() backward_dense() forward_quantized() forward_split() lamb() lars_sgd() partial_rowwise_adam() partial_rowwise_lamb() rowwise_adagrad() rowwise_adagrad_with_weight_decay() rowwise_weighted_adagrad() sgd() gen__init__py() def main() -> None: emb_codegen() if __name__ == "__main__": main()