def augment_dataframe()

in augmentation/augment_bert.py [0:0]


def augment_dataframe(original, augmenter, input_key, batch_size=1):
	augmented = original.copy()
	augmented.rename({'origin': 'obtained_from'}, axis=1)
	if batch_size == 1:
		for (_, orig_row), (_, aug_row) in tqdm(zip(original.iterrows(), augmented.iterrows())):
			aug_row[input_key] = augmenter.augment(orig_row[input_key])
	else:
		for i in tqdm(range(0, len(original), batch_size)):
			batch = list(original[input_key].iloc[i:i+batch_size])
			for j, a in enumerate(augmenter.augment(batch)):
				augmented[input_key].iloc[i+j] = a
	augmented['origin'] = 'augmented'
	return augmented