in tensorflow_similarity/samplers/tfdataset_samplers.py [0:0]
def __init__(self,
dataset_name: str,
classes_per_batch: int,
x_key: str = "image",
y_key: str = "label",
splits: Union[str, Sequence[str]] = ["train", "test"],
examples_per_class_per_batch: int = 2,
steps_per_epoch: int = 1000,
class_list: Sequence[int] = None,
total_examples_per_class: int = None,
preprocess_fn: Optional[PreProcessFn] = None,
augmenter: Optional[Augmenter] = None,
warmup: int = -1):
"""Create a Multishot in memory sampler from a dataset downloaded from
the [TensorFlow datasets catalogue](https://www.tensorflow.org/datasets/catalog/)
The sampler ensures that each batch is well balanced by ensure that
each batch aims to contains `example_per_class` examples of
`classes_per_batch` classes.
The `batch_size` used during training will be equal to:
`classes_per_batch * example_per_class` unless an `augmenter` that
alters the number of examples returned is used. Then the batch_size is
a function of how many augmented examples are returned by the
`augmenter`.
Multishot samplers are to be used when you have multiple examples for
the same class. If this is not the case, then see the
[SingleShotMemorySampler()](single_memory.md) for using single example
with augmentation.
Memory samplers are good for datasets that fit in memory. If you have
larger ones that needs to sample from disk then use the
[TFRecordDatasetSampler()](tfdataset_sampler.md)
Args:
dataset_name: the name of the dataset to download and uses as
referenced in the TensorFlow catalog dataset page.
x_key: name of the dictonary key that contains the data to feed as
model input as referenced in the TensorFlow catalog dataset page.
Defaults to "image".
y_key: name of the dictonary key that contains the labels as
referenced in the TensorFlow catalog dataset page.
Defaults to "label".
splits: which dataset split(s) to use. Default
is ["train", "train"] Refersto the catalog page for
the list of available splits.
examples_per_class_per_batch: How many example of each class to
use per batch. Defaults to 2.
steps_per_epoch: How many steps/batches per epoch.
Defaults to 1000.
class_list: Filter the list of examples to only keep those who
belong to the supplied class list.
total_examples_per_class: Restrict the number of examples for EACH
class to total_examples_per_class if set. If not set, all the
available examples are selected. Defaults to None - no selection.
preprocess_fn: Preprocess function to apply to the dataset after
download e.g to resize images. Takes an x and a y.
Defaults to None.
augmenter: A function that takes a batch in and return a batch out.
Can alters the number of examples returned which in turn change the
batch_size used. Defaults to None.
warmup: Keep track of warmup epochs and let the augmenter knows
when the warmup is over by passing along with each batch data a
boolean `is_warmup`. See `self._get_examples()` Defaults to 0.
"""
# dealing with users passing a single split e.g "train"
# instead of ["train"]
if isinstance(splits, str):
splits = [splits]
# we are reusing the memory sampler so "all we need to do" is convert
# the splits into memory arrays and call the super.
x = []
y = []
for split in splits:
ds, ds_info = tfds.load(dataset_name, split=split, with_info=True)
if x_key not in ds_info.features:
raise ValueError("x_key not found - available features are:",
str(ds_info.features.keys()))
if y_key not in ds_info.features:
raise ValueError("y_key not found - available features are:",
str(ds_info.features.keys()))
pb = tqdm(total=ds_info.splits[split].num_examples,
desc="converting %s" % split)
for e in ds:
x.append(e[x_key])
y.append(e[y_key])
pb.update()
pb.close()
# apply preprocess if needed.
if preprocess_fn:
x_pre = []
y_pre = []
for idx in tqdm(range(len(x)), desc="Preprocessing data"):
xb, yb = preprocess_fn(x[idx], y[idx])
x_pre.append(xb)
y_pre.append(yb)
x = x_pre
y = y_pre
# delegate to the base memorysample
super().__init__(
x,
y,
classes_per_batch=classes_per_batch,
examples_per_class_per_batch=examples_per_class_per_batch,
steps_per_epoch=steps_per_epoch,
class_list=class_list,
total_examples_per_class=total_examples_per_class,
augmenter=augmenter,
warmup=warmup)