in training/built-in-frameworks/fastai_oxford_pets/source/pets.py [0:0]
def _train(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Device Type: {}".format(device))
logger.info("Loading Pets dataset")
print(f"Batch size: {args.batch_size}")
path = Path(args.data_dir)
print(f"Data path is: {path}")
path_anno = path / "annotations"
path_img = path / "images"
fnames = get_image_files(path_img)
# get the pattern to select the training/validation data
np.random.seed(2)
pat = re.compile(r"/([^/]+)_\d+.jpg$")
print("Creating DataBunch object")
data = ImageDataBunch.from_name_re(
path_img, fnames, pat, ds_tfms=get_transforms(), size=args.image_size, bs=args.batch_size
).normalize(imagenet_stats)
# create the CNN model
print("Create CNN model from model zoo")
print(f"Model architecture is {args.model_arch}")
arch = getattr(models, args.model_arch)
print("Creating pretrained conv net")
learn = create_cnn(data, arch, metrics=error_rate)
print("Fit for 4 cycles")
learn.fit_one_cycle(4)
learn.unfreeze()
print("Unfreeze and fit for another 2 cycles")
learn.fit_one_cycle(2, max_lr=slice(1e-6, 1e-4))
print("Finished Training")
logger.info("Saving the model.")
model_path = Path(args.model_dir)
print(f"Export data object")
data.export(model_path / "export.pkl")
# create empty models dir
os.mkdir(model_path / "models")
print(f"Saving model weights")
return learn.save(model_path / f"{args.model_arch}")