torchbenchmark/models/demucs/__init__.py (73 lines of code) (raw):

import json import torch import random import numpy as np from fractions import Fraction from .demucs.model import Demucs from .demucs.parser import get_name, get_parser from .demucs.augment import FlipChannels, FlipSign, Remix, Shift from .demucs.utils import capture_init, center_trim from ...util.model import BenchmarkModel from torchbenchmark.tasks import OTHER from torch import Tensor from torch.nn.modules.container import Sequential from torchbenchmark.models.demucs.demucs.model import Demucs from typing import Optional, Tuple torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True class DemucsWrapper(torch.nn.Module): def __init__(self, model: Demucs, augment: Sequential) -> None: super(DemucsWrapper, self).__init__() self.model = model self.augment = augment def forward(self, streams) -> Tuple[Tensor, Tensor]: sources = streams[:, 1:] sources = self.augment(sources) mix = sources.sum(dim=1) return sources, self.model(mix) class Model(BenchmarkModel): task = OTHER.OTHER_TASKS # Original train batch size: 64 # Source: https://github.com/facebookresearch/demucs/blob/3e5ea549ba921316c587e5f03c0afc0be47a0ced/conf/config.yaml#L37 DEFAULT_TRAIN_BSIZE = 64 DEFAULT_EVAL_BSIZE = 8 def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]) -> None: super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) self.parser = get_parser() self.args = self.parser.parse_args([]) args = self.args model = Demucs(channels=64) model.to(device) samples = 80000 self.duration = Fraction(samples + args.data_stride, args.samplerate) self.stride = Fraction(args.data_stride, args.samplerate) if args.mse: self.criterion = torch.nn.MSELoss() else: self.criterion = torch.nn.L1Loss() if args.augment: self.augment = torch.nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), Remix(group_size=args.remix_group_size)).to(device) else: self.augment = Shift(args.data_stride) self.model = DemucsWrapper(model, self.augment) if test == "train": self.model.train() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr) elif test == "eval": self.model.eval() self.example_inputs = (torch.rand([self.batch_size, 5, 2, 426888], device=device),) def get_module(self) -> Tuple[DemucsWrapper, Tuple[Tensor]]: return self.model, self.example_inputs def eval(self, niter=1) -> Tuple[torch.Tensor]: for _ in range(niter): sources, estimates = self.model(*self.example_inputs) sources = center_trim(sources, estimates) loss = self.criterion(estimates, sources) return (sources, estimates) def train(self, niter=1): for _ in range(niter): sources, estimates = self.model(*self.example_inputs) sources = center_trim(sources, estimates) loss = self.criterion(estimates, sources) loss.backward() self.optimizer.step() self.optimizer.zero_grad()