in src/envs/ode.py [0:0]
def __init__(self, env, task, train, params, path, size=None):
super(EnvDataset).__init__()
self.env = env
self.train = train
self.task = task
self.batch_size = params.batch_size
self.env_base_seed = params.env_base_seed
self.path = path
self.global_rank = params.global_rank
self.count = 0
assert task in ODEEnvironment.TRAINING_TASKS
assert size is None or not self.train
# batching
self.num_workers = params.num_workers
self.batch_size = params.batch_size
# generation, or reloading from file
if path is not None:
assert os.path.isfile(path)
logger.info(f"Loading data from {path} ...")
with io.open(path, mode="r", encoding="utf-8") as f:
# either reload the entire file, or the first N lines
# (for the training set)
if not train:
lines = [line.rstrip().split("|") for line in f]
else:
lines = []
for i, line in enumerate(f):
if i == params.reload_size:
break
if i % params.n_gpu_per_node == params.local_rank:
lines.append(line.rstrip().split("|"))
self.data = [xy.split("\t") for _, xy in lines]
self.data = [xy for xy in self.data if len(xy) == 2]
logger.info(f"Loaded {len(self.data)} equations from the disk.")
if task == "ode_control" and params.reversed_eval and not self.train:
self.data = [
(x, "INT+ 1" if y == "INT+ 0" else "INT+ 0") for (x, y) in self.data
]
if task == "ode_convergence_speed" and params.qualitative:
self.data = [
(x, "INT+ 1" if y[:7] == "FLOAT- " else "INT+ 0")
for (x, y) in self.data
]
if (
task == "fourier_cond_init" and not params.predict_bounds
): # "INT+ X <SPECIAL_3> INT+ X"
self.data = [(x, y[:25]) for (x, y) in self.data]
# if we are not predicting the Jacobian, remove it
if task == "ode_convergence_speed" and not params.predict_jacobian:
self.data = [
(x, y[y.index(env.mtrx_separator) + len(env.mtrx_separator) + 1 :])
if env.mtrx_separator in y
else (x, y)
for (x, y) in self.data
]
# dataset size: infinite iterator for train,
# finite for valid / test (default of 5000 if no file provided)
if self.train:
self.size = 1 << 60
elif size is None:
self.size = 5000 if path is None else len(self.data)
else:
assert size > 0
self.size = size