in src/train_data.py [0:0]
def initialize(self, config, load_data=True, log_path=None, training=True):
self.config_file = config
self.base_log_dir = self.config_file.logDir
if config.randomSeed != -1:
torch.manual_seed(config.randomSeed)
np.random.seed(config.randomSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.device = f"cuda:{config.device}" if torch.cuda.is_available() else "cpu"
print(f"device {torch.cuda.get_device_name(self.device)}")
self.f_in, self.f_out = FeatureSet.get_sets(config, self.device)
self.dataset_info = DatasetInfo(config, self)
self.initialize_features()
self.copy_to_gpu = not config.storeFullData
self.models = []
self.optimizers = []
self.encodings = []
self.enc_args = []
self.losses = []
self.loss_weights = []
# Initialize list defaults per feature. This is so that the exported config file correctly contains
# all the defaults which are otherwise initialized in the feature init code...
if config.rayMarchSamplingNoise is None:
config.rayMarchSamplingNoise = []
if config.zNear is None:
config.zNear = []
if config.zFar is None:
config.zFar = []
for i in range(len(self.f_in)):
model = ModelSelection.getModel(config, self.f_in[i].n_feat, self.f_out[i].n_feat, self.device, i)
self.models.append(model)
self.optimizers.append(torch.optim.Adam(model.parameters(), lr=config.lrate))
self.encodings.append(config.posEnc[i])
self.enc_args.append(config.posEncArgs[i])
self.losses.append(get_loss_by_name(config.losses[i], config=config, net_idx=i))
self.loss_weights.append(config.lossWeights[i])
# Initialize list defaults per feature. This is so that the exported config file correctly contains
# all the defaults which are otherwise initialized in the feature init code...
if len(config.rayMarchSamplingNoise) <= i:
config.rayMarchSamplingNoise.append(0.0)
if len(config.zNear) <= i:
config.zNear.append(0.001)
if len(config.zFar) <= i:
config.zFar.append(1.0)
if hasattr(self.losses[i], 'requires_alpha_beta'):
if len(config.lossAlpha) <= i:
config.lossAlpha.append(1.0)
if len(config.lossBeta) <= i:
config.lossBeta.append(0.0)
depth_transform = ""
if config.depthTransform and config.depthTransform != "linear":
depth_transform = config.depthTransform[0:2] + "_"
scale_interpolation = ""
if config.scaleInterpolation and config.scaleInterpolation != "median":
scale_interpolation = config.scaleInterpolation[0:2] + "_"
experiment_name = depth_transform + scale_interpolation + \
config_to_name(self.f_in, self.f_out, self.models, self.encodings, self.enc_args, config.lossAlpha, config.lossBeta)
dataset_name = os.path.basename(os.path.normpath(config.data)) + "/"
if log_path is None:
self.logDir = os.path.join(config.logDir, dataset_name, experiment_name) + "/"
self.dataset_name = dataset_name
self.experiment_name = experiment_name
else:
self.logDir = log_path
# just to prevent bugs
self.config_file.logDir = self.logDir
os.makedirs(config.logDir, exist_ok=True)
os.makedirs(f"{config.logDir}/test_opt/", exist_ok=True)
self.epochs = self.config_file.epochs
# load previous best validation loss (if any)
if os.path.exists(os.path.join(config.logDir, "opt.txt")):
with open(os.path.join(config.logDir, "opt.txt")) as f:
line = f.readline()
match = re.search(r'\d+\.\d+', line)
self.best_valid_loss = float(line[match.start():match.end()])
for i in range(len(self.models)):
if os.path.exists(os.path.join(config.logDir, f"opt_{i}.txt")):
with open(os.path.join(config.logDir, f"opt_{i}.txt")) as f:
line = f.readline()
match = re.search(r'\d+\.\d+', line)
self.best_valid_loss_pretrain.append(float(line[match.start():match.end()]))
if not os.path.exists(os.path.join(config.logDir, "config.ini")):
# Copy config params (including command line params) by serializing dict again
# This is used to replace the quotes
translation = {39: None}
with open(os.path.join(config.logDir, "config.ini"), 'w') as f:
for key in self.config_file.__dict__:
val = self.config_file.__dict__[key]
if val is not None:
# Skip empty lists
if isinstance(val, list) and len(val) == 0:
continue
f.write(f'{key} = {str(self.config_file.__dict__[key]).translate(translation)}\n')
if load_data:
self.pixel_idx_sequence_gen = getattr(import_module("sampled_sequence"), self.config_file.sampleGenerator)\
(dims=2, device='cpu', base_log_dir=self.base_log_dir, num_pregeneration=30000000)
if config.storeFullData:
from datasets import FullyLoadedViewCellDataset as Dataset
num_workers = 0
pin_memory = False
worker_init_fn = None
else:
from datasets import OnTheFlyViewCellDataset as Dataset
num_workers = config.numWorkers
pin_memory = True
worker_init_fn = worker_offset_sequence
if training:
self.train_dataset = Dataset(self.config_file, self, self.dataset_info, set_name="train",
num_samples=self.config_file.samples)
self.train_data_loader = DataLoader(self.train_dataset, batch_size=self.config_file.batchImages,
shuffle=True, num_workers=num_workers, pin_memory=pin_memory,
persistent_workers=pin_memory, worker_init_fn=worker_init_fn)
self.valid_dataset = Dataset(self.config_file, self, self.dataset_info, set_name="val",
num_samples=self.config_file.samples)
self.valid_data_loader = DataLoader(self.valid_dataset, batch_size=self.config_file.batchImages,
shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
persistent_workers=pin_memory, worker_init_fn=worker_init_fn)
# we create another DataLoader here, so we can use a different number of images for pretraining
# num_samples can simply be changed in the dataset, but batch size appears to be unchangeable after
# creation
if self.config_file.epochsPretrain is not None and len(self.config_file.epochsPretrain) != 0 and \
self.config_file.batchImagesPretrain != -1:
self.pretrain_data_loader = DataLoader(self.train_dataset,
batch_size=self.config_file.batchImagesPretrain,
shuffle=True, num_workers=num_workers, pin_memory=pin_memory,
persistent_workers=pin_memory, worker_init_fn=worker_init_fn)
self.test_dataset = Dataset(self.config_file, self, self.dataset_info, set_name="test",
num_samples=self.dataset_info.w * self.dataset_info.h)
self.test_data_loader = DataLoader(self.test_dataset, batch_size=1,
shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
persistent_workers=pin_memory, worker_init_fn=worker_init_fn)