DistilVitFlow.py (29 lines of code) (raw):
from metaflow import step, kubernetes, environment, Parameter
from GenAIFlow import GenAIFlow
from distilvit.train import train, parse_args, environ_dict, get_arg_parser
class DistilVitFlow(GenAIFlow):
"""
DistilVit Trainer
"""
GenAIFlow.import_argparse_to_params(get_arg_parser("./"))
@kubernetes(
image="us-docker.pkg.dev/moz-fx-mozsoc-ml-nonprod/metaflow-dockers/metaflow_gpu:rolf-distilvit-build-test",
gpu=1,
disk=100000
)
@environment(
vars=environ_dict
)
@step
def start(self):
self.load_remote_env()
args = self.params_to_args()
print(f"Parsed args are as follows:{args}")
train(parse_args(args))
self.next(self.end)
@kubernetes(
image="us-docker.pkg.dev/moz-fx-mozsoc-ml-nonprod/metaflow-dockers/metaflow_gpu:rolf-distilvit-build-test",
gpu=1
)
@step
def end(self):
pass
if __name__ == '__main__':
DistilVitFlow()