def _train()

in training/built-in-frameworks/fastai_oxford_pets/source/pets.py [0:0]


def _train(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info("Device Type: {}".format(device))

    logger.info("Loading Pets dataset")
    print(f"Batch size: {args.batch_size}")
    path = Path(args.data_dir)
    print(f"Data path is: {path}")
    path_anno = path / "annotations"
    path_img = path / "images"
    fnames = get_image_files(path_img)

    # get the pattern to select the training/validation data
    np.random.seed(2)
    pat = re.compile(r"/([^/]+)_\d+.jpg$")
    print("Creating DataBunch object")
    data = ImageDataBunch.from_name_re(
        path_img, fnames, pat, ds_tfms=get_transforms(), size=args.image_size, bs=args.batch_size
    ).normalize(imagenet_stats)

    # create the CNN model
    print("Create CNN model from model zoo")
    print(f"Model architecture is {args.model_arch}")
    arch = getattr(models, args.model_arch)
    print("Creating pretrained conv net")
    learn = create_cnn(data, arch, metrics=error_rate)
    print("Fit for 4 cycles")
    learn.fit_one_cycle(4)
    learn.unfreeze()
    print("Unfreeze and fit for another 2 cycles")
    learn.fit_one_cycle(2, max_lr=slice(1e-6, 1e-4))
    print("Finished Training")

    logger.info("Saving the model.")
    model_path = Path(args.model_dir)
    print(f"Export data object")
    data.export(model_path / "export.pkl")
    # create empty models dir
    os.mkdir(model_path / "models")
    print(f"Saving model weights")
    return learn.save(model_path / f"{args.model_arch}")