tzrec/ops/mm.py (13 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the mm ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. import torch from tzrec.ops import Kernel from tzrec.ops.triton.triton_addmm import triton_addmm def addmm( input: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: if kernel == Kernel.TRITON: return triton_addmm(input, mat1, mat2) else: return torch.addmm(input, mat1, mat2)