#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import jinja2

args: argparse.Namespace
_: List[str]
TENSOR: int
INT_TENSOR: int
LONG_TENSOR: int
INT: int
FLOAT: int


parser = argparse.ArgumentParser()
# By default the source template files are in the same folder as
# embedding_backward_code_generator.py;
# The install dir is by default the same as the current folder.
parser.add_argument("--install_dir", default=".", help="where to put generated file")
parser.add_argument("--opensource", action="store_false", dest="is_fbcode")
args, _ = parser.parse_known_args()


env = jinja2.Environment(
    loader=jinja2.FileSystemLoader(os.path.dirname(os.path.abspath(__file__)))
)
# Upper Limit of "max_embedding_dim (max_D)":
# BT_block_size * sizeof(float) * 4 * kWarpSize * {{ kMaxVecsPerThread }}
# needs to be smaller than the allocated shared memory size (2/3 of 96 KB
# on V100 and 160 KB on A100.
# BT_block_size * 4 * 4 * 32 * (max_D // 128) <= 64 * 1024 (V100) or 96 * 1024 (A100)
# Since BT_block_size >= 1, max_D <= 16K (V100) or 24K (A100).
# Note that if we increase max_D, it will increase the compilation time significantly.
env.globals["max_embedding_dim"] = 1024
env.globals["dense"] = False


def write(filename: str, s: str) -> None:
    with open(os.path.join(args.install_dir, filename), "w") as f:
        f.write(s)


def _arg_constructor(
    type: str, name: str, gpu: bool = True, precision: int = 32
) -> str:
    return (
        f"{name}.packed_accessor{precision}<{type}, 1, at::RestrictPtrTraits>()"
        if gpu
        else f"{name}.accessor<{type}, 1>()"
    )


def _arg(type: str, name: str, gpu: bool = True, precision: int = 32) -> str:
    return (
        f"at::PackedTensorAccessor{precision}<{type}, 1, at::RestrictPtrTraits> {name}"
        if gpu
        else f"at::TensorAccessor<{type}, 1> {name}"
    )


def acc_cache_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
    return _arg_constructor(
        "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>",
        name,
        gpu=gpu,
        precision=64,
    )


def acc_cache_tensor_arg(name: str, gpu: bool = True) -> str:
    return _arg(
        "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>",
        name,
        gpu=gpu,
        precision=64,
    )


def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
    return _arg_constructor("int64_t", name, gpu=gpu)


def long_tensor_arg(name: str, gpu: bool = True) -> str:
    return _arg("int64_t", name, gpu=gpu)


def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
    return _arg_constructor("int32_t", name, gpu=gpu)


def int_tensor_arg(name: str, gpu: bool = True) -> str:
    return _arg("int32_t", name, gpu=gpu)


def tensor_arg(name: str) -> str:
    return f"Tensor {name}"


def double_arg(name: str) -> str:
    return f"double {name}"


def float_arg(name: str) -> str:
    return f"float {name}"


def int64_arg(name: str) -> str:
    return f"int64_t {name}"


def int_arg(name: str) -> str:
    return f"int {name}"


def generate(**kwargs: Any) -> None:
    gen_args = kwargs["args"]

    # Generates CUDA variants.
    kwargs["args"] = gen_args["cuda"]

    template = env.get_template("embedding_backward_split_template.cu")
    src_cu = template.render(weighted=False, **kwargs)
    write(
        f"gen_embedding_backward_{kwargs.get('optimizer')}_split_unweighted_cuda.cu",
        src_cu,
    )
    src_cu = template.render(weighted=True, **kwargs)
    write(
        f"gen_embedding_backward_{kwargs.get('optimizer')}_split_weighted_cuda.cu",
        src_cu,
    )
    if not kwargs.get("dense"):
        template = env.get_template("embedding_backward_split_host_template.cpp")
        src_cpp = template.render(**kwargs)
        write(f"gen_embedding_backward_split_{kwargs.get('optimizer')}.cpp", src_cpp)

        # Generates Python invoker for CUDA + CPU
        template = env.get_template("split_embedding_codegen_lookup_invoker.template")
        src_py = template.render(is_fbcode=args.is_fbcode, **kwargs)
        write(f"lookup_{kwargs.get('optimizer')}.py", src_py)

    # Generates CPU variants.
    kwargs["args"] = gen_args["cpu"]

    is_approx = "approx" in kwargs.get("optimizer")
    template = (
        env.get_template("embedding_backward_split_cpu_approx_template.cpp")
        if is_approx
        else env.get_template("embedding_backward_split_cpu_template.cpp")
    )

    src_cpp = template.render(**kwargs)
    write(
        f"gen_embedding_backward_{kwargs.get('optimizer')}_split_cpu.cpp",
        src_cpp,
    )

    if not kwargs.get("dense"):
        template = env.get_template("embedding_backward_split_host_cpu_template.cpp")
        src_cpp = template.render(**kwargs)
        write(
            f"gen_embedding_backward_split_{kwargs.get('optimizer')}_cpu.cpp", src_cpp
        )


@dataclass
class Args:
    split_kernel_args: List[str]
    split_kernel_arg_constructors: List[str]
    split_cpu_kernel_args: List[str]
    split_cpu_kernel_arg_constructors: List[str]
    split_function_args: List[str]
    split_saved_tensors: List[str]
    split_tensors: List[str]
    saved_data: List[Tuple[str, str]]
    split_function_arg_names: List[str]
    split_function_schemas: List[str]
    split_variables: List[str]


TENSOR, INT_TENSOR, LONG_TENSOR, INT, FLOAT = range(5)


def make_args(arg_spec: List[Tuple[int, str]]) -> Dict[str, Any]:
    def make_kernel_arg(ty: int, name: str) -> str:
        return {
            TENSOR: acc_cache_tensor_arg,
            INT_TENSOR: int_tensor_arg,
            LONG_TENSOR: long_tensor_arg,
            INT: int64_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_kernel_arg_constructor(ty: int, name: str) -> str:
        return {
            TENSOR: acc_cache_tensor_arg_constructor,
            INT_TENSOR: int_tensor_arg_constructor,
            LONG_TENSOR: long_tensor_arg_constructor,
            INT: lambda x: x,
            FLOAT: lambda x: x,
        }[ty](name)

    def make_cpu_kernel_arg(ty: int, name: str) -> str:
        return {
            TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False),
            INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False),
            LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False),
            INT: int64_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str:
        return {
            TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False),
            INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False),
            LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False),
            INT: lambda x: x,
            FLOAT: lambda x: x,
        }[ty](name)

    def make_function_arg(ty: int, name: str) -> str:
        return {
            TENSOR: tensor_arg,
            INT_TENSOR: tensor_arg,
            LONG_TENSOR: tensor_arg,
            INT: int64_arg,
            FLOAT: double_arg,
        }[ty](name)

    def make_function_schema_arg(ty: int, name: str) -> str:
        return {
            TENSOR: tensor_arg,
            INT_TENSOR: tensor_arg,
            LONG_TENSOR: tensor_arg,
            INT: int_arg,
            FLOAT: float_arg,
        }[ty](name)

    def make_ivalue_cast(ty: int) -> str:
        return {INT: "toInt", FLOAT: "toDouble"}[ty]

    def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args:
        return Args(
            split_kernel_args=[
                make_kernel_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_kernel_arg_constructors=[
                make_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec
            ],
            split_cpu_kernel_args=[
                make_cpu_kernel_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_cpu_kernel_arg_constructors=[
                make_cpu_kernel_arg_constructor(ty, name)
                for (ty, name) in split_arg_spec
            ],
            split_function_args=[
                make_function_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_tensors=[name for (ty, name) in arg_spec if ty == TENSOR],
            split_saved_tensors=[
                name
                for (ty, name) in split_arg_spec
                if ty in (TENSOR, INT_TENSOR, LONG_TENSOR)
            ],
            saved_data=[
                (name, make_ivalue_cast(ty)) for (ty, name) in arg_spec if ty != TENSOR
            ],
            split_function_arg_names=[name for (ty, name) in split_arg_spec],
            split_function_schemas=[
                make_function_schema_arg(ty, name) for (ty, name) in split_arg_spec
            ],
            split_variables=["Variable()" for _ in split_arg_spec],
        )

    split_arg_spec = []
    for (ty, arg) in arg_spec:
        if ty in (FLOAT, INT):
            split_arg_spec.append((ty, arg))
        else:
            assert ty == TENSOR
            split_arg_spec.extend(
                [
                    (TENSOR, f"{arg}_host"),
                    (INT_TENSOR, f"{arg}_placements"),
                    (LONG_TENSOR, f"{arg}_offsets"),
                ]
            )
    cpu = make_args_for_compute_device(split_arg_spec)

    split_arg_spec = []
    for (ty, arg) in arg_spec:
        if ty in (FLOAT, INT):
            split_arg_spec.append((ty, arg))
        else:
            assert ty == TENSOR
            split_arg_spec.extend(
                [
                    (TENSOR, f"{arg}_dev"),
                    (TENSOR, f"{arg}_uvm"),
                    (INT_TENSOR, f"{arg}_placements"),
                    (LONG_TENSOR, f"{arg}_offsets"),
                ]
            )
    cuda = make_args_for_compute_device(split_arg_spec)

    return {"cpu": cpu, "cuda": cuda}


def adagrad() -> None:
    split_weight_update = """
      Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
      m_t.acc.x += grad.acc.x * grad.acc.x;
      m_t.acc.y += grad.acc.y * grad.acc.y;
      m_t.acc.z += grad.acc.z * grad.acc.z;
      m_t.acc.w += grad.acc.w * grad.acc.w;
      m_t.store(&momentum1[idx * D + d]);

      weight_new.acc.x -= learning_rate * grad.acc.x / (sqrtf(m_t.acc.x) + eps);
      weight_new.acc.y -= learning_rate * grad.acc.y / (sqrtf(m_t.acc.y) + eps);
      weight_new.acc.z -= learning_rate * grad.acc.z / (sqrtf(m_t.acc.z) + eps);
      weight_new.acc.w -= learning_rate * grad.acc.w / (sqrtf(m_t.acc.w) + eps);
    """
    split_weight_update_cpu = """
      for (int64_t d = 0; d < D; ++d) {
        momentum1_host[embedding_begin + d] +=
            grad_buffer[d] * grad_buffer[d];
        host_weights_data[embedding_begin + d] -=
            learning_rate * grad_buffer[d] /
            (sqrt(momentum1_host[embedding_begin + d]) + eps);
      }
    """

    generate(
        optimizer="adagrad",
        args=make_args(
            [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")]
        ),
        split_precomputation="",
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def table_info_precomputation(momentum_prefix: str = "momentum1") -> str:
    template = """
      // table_begin -> (E, D, {momentum_prefix}_row_begin).
      std::map<int64_t, std::tuple<int64_t, int64_t, int64_t>> table_info_map;
      for (int64_t t = 0; t < T; ++t) {
        const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
        const auto table_begin = weights_offsets_data[t];
        const auto {momentum_prefix}_row_begin = {momentum_prefix}_offsets_data[t];
        table_info_map[table_begin] = std::make_tuple(0, D, {momentum_prefix}_row_begin);
      }
      int64_t previous_table_begin = host_weights.numel();
      // NOTE: table_info_map is sorted by table_begin!
      for (auto it = table_info_map.rbegin(); it != table_info_map.rend(); ++it) {
        const auto D = std::get<1>(it->second);
        // Calculates number of rows of each table.
        std::get<0>(it->second) = (previous_table_begin - it->first) / D;
        previous_table_begin = it->first;
      }
    """
    return template.replace("{momentum_prefix}", momentum_prefix)


def rowwise_adagrad() -> None:
    split_weight_update = """
      weight_new.fma_(grad, -multiplier);
    """
    split_precomputation = """
    at::acc_type<cache_t, true> g_local_sum_square = 0.0;
    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
    g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x +
        grad_sum[i].acc.y * grad_sum[i].acc.y +
        grad_sum[i].acc.z * grad_sum[i].acc.z +
        grad_sum[i].acc.w * grad_sum[i].acc.w;
    }
    const at::acc_type<cache_t, true> g_avg_square =
        warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D;

    at::acc_type<cache_t, true> multiplier;
    if (threadIdx.x == 0) {
        at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
        momentum1[idx] = new_sum_square_grads;
        multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
    }
    multiplier = shfl_sync(multiplier, 0);
    """
    split_weight_update_cpu = """
        at::acc_type<grad_t, true> g_local_sum_square = 0.0;
        for (int64_t d = 0; d < D; ++d) {
            g_local_sum_square += grad_buffer[d] * grad_buffer[d];
        }
        auto g_avg_square = g_local_sum_square / D;
        at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
        momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
        at::acc_type<grad_t, true> multiplier;
        multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
        for (int64_t d = 0; d < D; ++d) {
            host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier;
        }
    """

    generate(
        optimizer="rowwise_adagrad",
        args=make_args(
            [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )

    approx_split_weight_update = """
      // dummy computation to avoid unused variable warning
      weight_new.fma_(grad, -multiplier);
      assert(false); // approx rowwise AdaGrad is not supported on GPU
    """

    generate(
        optimizer="approx_rowwise_adagrad",
        args=make_args(
            [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=approx_split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def rowwise_adagrad_with_weight_decay() -> None:
    split_weight_update = """
        weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x;
        weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y;
        weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z;
        weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w;
    """
    split_precomputation = """
    at::acc_type<cache_t, true> g_local_sum_square = 0.0;
    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
        auto gx = grad_sum[i].acc.x;
        auto gy = grad_sum[i].acc.y;
        auto gz = grad_sum[i].acc.z;
        auto gw = grad_sum[i].acc.w;
        if (weight_decay_mode == 0) {
            // L2 regularization
            int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
            Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
            gx += weight_decay * weight.acc.x;
            gy += weight_decay * weight.acc.y;
            gz += weight_decay * weight.acc.z;
            gw += weight_decay * weight.acc.w;
        }
        g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
    }
    const at::acc_type<cache_t, true> g_avg_square =
        warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D;

    at::acc_type<cache_t, true> multiplier;
    at::acc_type<cache_t, true> correction;
    if (threadIdx.x == 0) {
        at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
        momentum1[idx] = new_sum_square_grads;
        multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
        if (weight_decay_mode == 0) {
            // L2 regularization
            correction = 1.0 - multiplier * weight_decay;
        } else if (weight_decay_mode == 1) {
            // Decoupled weight decay
            correction = 1.0 - learning_rate * weight_decay;
        }
    }
    multiplier = shfl_sync(multiplier, 0);
    correction = shfl_sync(correction, 0);
    """
    split_weight_update_cpu = """
        at::acc_type<grad_t, true> g_local_sum_square = 0.0;
        for (int64_t d = 0; d < D; ++d) {
            auto grad = grad_buffer[d];
            if (weight_decay_mode == 0) {
                // L2 regularization
                grad += weight_decay * host_weights_data[embedding_begin + d];
            }
            g_local_sum_square += grad * grad;
        }
        auto g_avg_square = g_local_sum_square / D;
        at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
        momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
        at::acc_type<grad_t, true> multiplier;
        multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
        at::acc_type<scalar_t, true> correction;
        if (weight_decay_mode == 0) {
            // L2 regularization
            correction = 1.0 - multiplier * weight_decay;
        } else if (weight_decay_mode == 1) {
            // Decoupled weight decay
            correction = 1.0 - learning_rate * weight_decay;
        }
        for (int64_t d = 0; d < D; ++d) {
            host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier;
        }
    """

    generate(
        optimizer="rowwise_adagrad_with_weight_decay",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (FLOAT, "eps"),
                (FLOAT, "learning_rate"),
                (FLOAT, "weight_decay"),
                (INT, "weight_decay_mode"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )

    approx_split_weight_update = """
      // dummy computation to avoid unused variable warning
      weight_new.fma_(grad, -multiplier);
      assert(false); // approx rowwise AdaGrad is not supported on GPU
    """

    generate(
        optimizer="approx_rowwise_adagrad_with_weight_decay",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (FLOAT, "eps"),
                (FLOAT, "learning_rate"),
                (FLOAT, "weight_decay"),
                (INT, "weight_decay_mode"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=approx_split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def rowwise_weighted_adagrad() -> None:
    split_weight_update = """
      weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x;
      weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y;
      weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z;
      weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w;
    """
    split_precomputation = """
    at::acc_type<cache_t, true> g_local_sum_square = 0.0;
    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
        int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
        Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
        auto gx = grad_sum[i].acc.x + weight_decay * weight.acc.x;
        auto gy = grad_sum[i].acc.y + weight_decay * weight.acc.y;
        auto gz = grad_sum[i].acc.z + weight_decay * weight.acc.z;
        auto gw = grad_sum[i].acc.w + weight_decay * weight.acc.w;
        g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
    }
    const at::acc_type<cache_t, true> g_avg_square =
        warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D;

    at::acc_type<cache_t, true> multiplier;
    at::acc_type<cache_t, true> correction;
    if (threadIdx.x == 0) {
        at::acc_type<cache_t, true> lambda = sqrtf(iter + 1);
        at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + lambda * g_avg_square;
        momentum1[idx] = new_sum_square_grads;
        multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps);
        correction = 1.0 - multiplier * weight_decay;
    }
    multiplier = shfl_sync(multiplier, 0);
    correction = shfl_sync(correction, 0);
    """
    split_weight_update_cpu = """
        // weight_decay not supported for cpu version
        at::acc_type<scalar_t, true> g_local_sum_square = 0.0;
        for (int64_t d = 0; d < D; ++d) {
            g_local_sum_square += grad_buffer[d] * grad_buffer[d];
        }
        auto g_avg_square = g_local_sum_square / D;
        at::acc_type<scalar_t, true> lambda = sqrtf(iter + 1);
        at::acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + lambda * g_avg_square;
        momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
        at::acc_type<scalar_t, true> multiplier;
        multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps);
        for (int64_t d = 0; d < D; ++d) {
            host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier;
        }
    """

    generate(
        optimizer="rowwise_weighted_adagrad",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (FLOAT, "eps"),
                (FLOAT, "learning_rate"),
                (FLOAT, "weight_decay"),
                (INT, "iter"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def sgd() -> None:
    split_weight_update = """
      weight_new.fma_(grad, -learning_rate);
    """
    split_weight_update_cpu = """
      for (int64_t d = 0; d < D; ++d) {
        host_weights_data[embedding_begin + d] -= learning_rate * grad_buffer[d];
      }
    """

    generate(
        optimizer="sgd",
        args=make_args([(FLOAT, "learning_rate")]),
        split_precomputation="",
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )

    approx_split_weight_update = """
      // approx_sgd not supported for GPU.
      // Just do the same thing as exact sgd to avoid unused variable warning.
      weight_new.fma_(grad, -learning_rate);
      assert(false); // approx SGD is not supported on GPU
    """

    generate(
        optimizer="approx_sgd",
        args=make_args([(FLOAT, "learning_rate")]),
        split_precomputation="",
        split_weight_update=approx_split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def lamb() -> None:
    split_precomputation = """
  at::acc_type<cache_t, true> weight_sum_sq = 0.0;
  at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
  auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
  float2 qparams;
  if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
    qparams = weight_row.load_qparams();
  }
#pragma unroll 1
  for (int32_t i = 0;
      i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
      ++i) {
    int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
    Vec4T<at::acc_type<cache_t, true>> weight = weight_row.load(d, qparams);
    Vec4T<at::acc_type<cache_t, true>> m1(&momentum1[idx * D + d]);

    m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad_sum[i].acc.x;
    m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad_sum[i].acc.y;
    m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad_sum[i].acc.z;
    m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad_sum[i].acc.w;
    m1.store(&momentum1[idx * D + d]);

    Vec4T<at::acc_type<cache_t, true>> m2(&momentum2[idx * D + d]);
    m2.acc.x = beta2 * m2.acc.x + (1.0 - beta2) * grad_sum[i].acc.x * grad_sum[i].acc.x;
    m2.acc.y = beta2 * m2.acc.y + (1.0 - beta2) * grad_sum[i].acc.y * grad_sum[i].acc.y;
    m2.acc.z = beta2 * m2.acc.z + (1.0 - beta2) * grad_sum[i].acc.z * grad_sum[i].acc.z;
    m2.acc.w = beta2 * m2.acc.w + (1.0 - beta2) * grad_sum[i].acc.w * grad_sum[i].acc.w;
    m2.store(&momentum2[idx * D + d]);

    // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight;
    grad_sum[i].acc.x = (m1.acc.x / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.x;
    grad_sum[i].acc.y = (m1.acc.y / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.y;
    grad_sum[i].acc.z = (m1.acc.z / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.z;
    grad_sum[i].acc.w = (m1.acc.w / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.w;

    weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w;
    rtw_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w;
  }
  const auto weight_norm =
      sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq));
  const auto rtw_norm =
      sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(rtw_sum_sq));
   const auto true_ratio = weight_norm / rtw_norm;
"""
    split_weight_update = """
      weight_new.fma_(grad, -learning_rate * true_ratio);
    """
    split_weight_update_cpu = ""

    generate(
        optimizer="lamb",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (TENSOR, "momentum2"),
                (FLOAT, "learning_rate"),
                (FLOAT, "eps"),
                (FLOAT, "beta1"),
                (FLOAT, "beta2"),
                (FLOAT, "weight_decay"),
                (INT, "iter"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def partial_rowwise_lamb() -> None:
    split_precomputation = """
    at::acc_type<cache_t, true> g_local_sum_square = 0.0;

    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
    g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x +
        grad_sum[i].acc.y * grad_sum[i].acc.y +
        grad_sum[i].acc.z * grad_sum[i].acc.z +
        grad_sum[i].acc.w * grad_sum[i].acc.w;
    }
    const at::acc_type<cache_t, true> g_avg_square =
        warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D;

    at::acc_type<cache_t, true> m2;
    if (threadIdx.x == 0) {
        m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square;
        momentum2[idx] = m2;
    }
    m2 = shfl_sync(m2, 0);
    at::acc_type<cache_t, true> m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps);

    at::acc_type<cache_t, true> weight_sum_sq = 0.0;
    at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
    auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
    float2 qparams;
    if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
        qparams = weight_row.load_qparams();
    }
    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
        int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;

        Vec4T<at::acc_type<cache_t, true>> m1(&momentum1[idx * D + d]);
        m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad_sum[i].acc.x;
        m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad_sum[i].acc.y;
        m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad_sum[i].acc.z;
        m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad_sum[i].acc.w;
        m1.store(&momentum1[idx * D + d]);

        // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight;
        Vec4T<at::acc_type<cache_t, true>> weight = weight_row.load(d, qparams);
        grad_sum[i].acc.x = (m1.acc.x / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.x;
        grad_sum[i].acc.y = (m1.acc.y / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.y;
        grad_sum[i].acc.z = (m1.acc.z / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.z;
        grad_sum[i].acc.w = (m1.acc.w / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.w;

        weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w;
        rtw_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w;
    }
    const auto weight_norm =
        sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq));
    const auto rtw_norm =
        sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(rtw_sum_sq));
    const auto true_ratio = weight_norm / rtw_norm;
    """

    split_weight_update = """
      weight_new.fma_(grad, -learning_rate * true_ratio);
    """
    split_weight_update_cpu = ""  # TODO

    generate(
        optimizer="partial_rowwise_lamb",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (TENSOR, "momentum2"),
                (FLOAT, "learning_rate"),
                (FLOAT, "eps"),
                (FLOAT, "beta1"),
                (FLOAT, "beta2"),
                (FLOAT, "weight_decay"),
                (INT, "iter"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def adam() -> None:
    split_weight_update = """
      Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
      m_t.acc.x *= beta1;
      m_t.acc.y *= beta1;
      m_t.acc.z *= beta1;
      m_t.acc.w *= beta1;
      m_t.fma_(grad, 1.0 - beta1);
      m_t.store(&momentum1[idx * D + d]);

      Vec4T<cache_t> v_t(&momentum2[idx * D + d]);
      v_t.acc.x *= beta2;
      v_t.acc.y *= beta2;
      v_t.acc.z *= beta2;
      v_t.acc.w *= beta2;

      grad.acc.x *= grad.acc.x;
      grad.acc.y *= grad.acc.y;
      grad.acc.z *= grad.acc.z;
      grad.acc.w *= grad.acc.w;
      v_t.fma_(grad, 1.0 - beta2);
      v_t.store(&momentum2[idx * D + d]);

      weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x);
      weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y);
      weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z);
      weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w);
    """
    split_weight_update_cpu = ""  # TODO

    generate(
        optimizer="adam",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (TENSOR, "momentum2"),
                (FLOAT, "learning_rate"),
                (FLOAT, "eps"),
                (FLOAT, "beta1"),
                (FLOAT, "beta2"),
                (FLOAT, "weight_decay"),
                (INT, "iter"),
            ]
        ),
        split_precomputation="",
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def partial_rowwise_adam() -> None:
    split_precomputation = """
    at::acc_type<cache_t, true> g_local_sum_square = 0.0;
    #pragma unroll kMaxVecsPerThread
    for (int32_t i = 0;
        i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
        ++i) {
    g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x +
        grad_sum[i].acc.y * grad_sum[i].acc.y +
        grad_sum[i].acc.z * grad_sum[i].acc.z +
        grad_sum[i].acc.w * grad_sum[i].acc.w;
    }
    const at::acc_type<cache_t, true> g_avg_square =
        warpReduceAllSum<at::acc_type<cache_t, true>>(g_local_sum_square) / D;

    at::acc_type<cache_t, true> v_hat_t;
    if (threadIdx.x == 0) {
        at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
        momentum2[idx] = v_t;
        v_hat_t = v_t / (1.0 - powf(beta2, iter));
    }
    v_hat_t = shfl_sync(v_hat_t, 0);
    """

    split_weight_update = """
      Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
      m_t.acc.x *= beta1;
      m_t.acc.y *= beta1;
      m_t.acc.z *= beta1;
      m_t.acc.w *= beta1;
      m_t.fma_(grad, 1.0 - beta1);
      m_t.store(&momentum1[idx * D + d]);

      weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
      weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
      weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.z);
      weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.w);
    """
    split_weight_update_cpu = ""  # TODO

    generate(
        optimizer="partial_rowwise_adam",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (TENSOR, "momentum2"),
                (FLOAT, "learning_rate"),
                (FLOAT, "eps"),
                (FLOAT, "beta1"),
                (FLOAT, "beta2"),
                (FLOAT, "weight_decay"),
                (INT, "iter"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def lars_sgd() -> None:
    split_precomputation = """
  at::acc_type<cache_t, true> weight_sum_sq = 0.0;
  at::acc_type<cache_t, true> grad_sum_sq = 0.0;

  auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
  float2 qparams;
  if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
      qparams = weight_row.load_qparams();
  }
#pragma unroll kMaxVecsPerThread
  for (int32_t i = 0;
      i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
      ++i) {
    int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
    Vec4T<at::acc_type<cache_t,true>> weight = weight_row.load(d, qparams);
    weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w;
    grad_sum_sq += grad_sum[i].acc.x * grad_sum[i].acc.x + grad_sum[i].acc.y * grad_sum[i].acc.y + grad_sum[i].acc.z * grad_sum[i].acc.z + grad_sum[i].acc.w * grad_sum[i].acc.w;
  }
  const auto weight_norm =
      sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(weight_sum_sq));
  const auto grad_norm =
      sqrtf(warpReduceAllSum<at::acc_type<cache_t, true>>(grad_sum_sq));
   const at::acc_type<cache_t, true> adjusted_lr = learning_rate * eta * weight_norm / (grad_norm + weight_decay * weight_norm);
"""

    split_weight_update = """
      Vec4T<cache_t> m1(&momentum1[idx * D + d]);
      m1.acc.x = momentum * m1.acc.x + adjusted_lr * (grad.acc.x + weight_decay * weight_new.acc.x);
      m1.acc.y = momentum * m1.acc.y + adjusted_lr * (grad.acc.y + weight_decay * weight_new.acc.y);
      m1.acc.z = momentum * m1.acc.z + adjusted_lr * (grad.acc.z + weight_decay * weight_new.acc.z);
      m1.acc.w = momentum * m1.acc.w + adjusted_lr * (grad.acc.w + weight_decay * weight_new.acc.w);
      m1.store(&momentum1[idx * D + d]);

      weight_new.acc.x -= m1.acc.x;
      weight_new.acc.y -= m1.acc.y;
      weight_new.acc.z -= m1.acc.z;
      weight_new.acc.w -= m1.acc.w;
    """
    split_weight_update_cpu = ""  # TODO

    generate(
        optimizer="lars_sgd",
        args=make_args(
            [
                (TENSOR, "momentum1"),
                (FLOAT, "learning_rate"),
                (FLOAT, "eta"),
                (FLOAT, "momentum"),
                (FLOAT, "weight_decay"),
            ]
        ),
        split_precomputation=split_precomputation,
        split_weight_update=split_weight_update,
        split_weight_update_cpu=split_weight_update_cpu,
    )


def forward_split() -> None:
    template = env.get_template("embedding_forward_split_template.cu")

    src_cu = template.render(weighted=False)
    write("gen_embedding_forward_split_unweighted_codegen_cuda.cu", src_cu)
    src_cu = template.render(weighted=True)
    write("gen_embedding_forward_split_weighted_codegen_cuda.cu", src_cu)

    src_cu = template.render(weighted=False, dense=True)
    write("gen_embedding_forward_dense_unweighted_codegen_cuda.cu", src_cu)
    src_cu = template.render(weighted=True, dense=True)
    write("gen_embedding_forward_dense_weighted_codegen_cuda.cu", src_cu)


def forward_quantized() -> None:
    @dataclass
    class elem_type:
        enum_name: str
        cpp_type_name: str

    type_map = {
        32: elem_type("FP32", "float"),
        16: elem_type("FP16", "__half2"),
        8: elem_type("INT8", "uint32_t"),
        4: elem_type("INT4", "uint32_t"),
        2: elem_type("INT2", "uint32_t"),
    }

    template = env.get_template("embedding_forward_quantized_split_template.cu")
    src_cu = template.render(weighted=False, type_map=type_map)
    write("gen_embedding_forward_quantized_split_unweighted_codegen_cuda.cu", src_cu)
    src_cu = template.render(weighted=True, type_map=type_map)
    write("gen_embedding_forward_quantized_split_weighted_codegen_cuda.cu", src_cu)

    template = env.get_template("embedding_forward_quantized_cpu_template.cpp")
    src_cu = template.render(weighted=False, type_map=type_map)
    write("gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp", src_cu)
    src_cu = template.render(weighted=True, type_map=type_map)
    write("gen_embedding_forward_quantized_weighted_codegen_cpu.cpp", src_cu)


def backward_indices() -> None:
    template = env.get_template("embedding_backward_split_indice_weights_template.cu")
    src_cu = template.render()
    write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu)
    src_cu = template.render(dense=True)
    write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu)


def backward_dense() -> None:
    generate(
        optimizer="dense",
        dense=True,
        args=make_args(
            [
                (FLOAT, "unused"),
            ]
        ),
    )


def gen__init__py() -> None:
    template = env.get_template("__init__.template")
    src_py = template.render()
    write("__init__.py", src_py)


def emb_codegen(
    install_dir: Optional[str] = None, is_fbcode: Optional[bool] = None
) -> None:
    if install_dir is not None and len(install_dir) != 0:
        args.install_dir = install_dir
    if is_fbcode is not None:
        args.is_fbcode = is_fbcode
    adagrad()
    adam()
    backward_indices()
    backward_dense()
    forward_quantized()
    forward_split()
    lamb()
    lars_sgd()
    partial_rowwise_adam()
    partial_rowwise_lamb()
    rowwise_adagrad()
    rowwise_adagrad_with_weight_decay()
    rowwise_weighted_adagrad()
    sgd()

    gen__init__py()


def main() -> None:
    emb_codegen()


if __name__ == "__main__":
    main()
