in training/train_unsupervised.py [0:0]
def main():
# Other args
nodata_value = args.nodata_val
num_cluster_samples_per_file = args.num_cluster_samples_per_file
num_clusters = args.num_clusters
r = args.radius
num_train_samples_per_file = args.num_train_samples_per_file
num_val_samples_per_file = args.num_val_samples_per_file
sample_size = args.patch_size
num_epochs = args.num_epochs
batch_size = args.batch_size
initial_lr = 0.01
assert sample_size % (2*r+1) == 0
assert sample_size % (2*r+1) == 0
input_fns = glob.glob(args.input_fn)
num_files = len(input_fns)
num_cluster_samples = num_cluster_samples_per_file * num_files
num_train_samples = num_train_samples_per_file * num_files
num_val_samples = num_val_samples_per_file * num_files
output_dir = os.path.dirname(args.output_fn)
#--------------------------------------------------
print("Starting unsupervised pre-training script with %d inputs" % (len(input_fns)))
start_time = float(time.time())
#--------------------------------------------------
print("Loading data")
tic = float(time.time())
all_data = []
for fn in input_fns:
with rasterio.open(fn) as f:
all_data.append(np.rollaxis(f.read(), 0, 3))
_, __, num_channels = all_data[0].shape
all_masks = []
for data in all_data:
assert data.shape[2] == num_channels
all_masks.append(np.sum(data == nodata_value, axis=2) == num_channels) # We assume that if the `nodata_value` is present across all channels then the pixel is actually nodata
print("Finished loading %d files in %0.4f seconds" % (len(all_data), time.time() - tic))
#--------------------------------------------------
# We first randomly sample (2*r+1, 2*r+1) patches to use to fit a KMeans model.
# We will later sample (sample_size, sample_size) patches, then apply this KMeans
# model to every pixel within those patches to get corresponding target labels.
print("Sampling dataset for KMeans and fitting model")
tic = float(time.time())
x_all = np.zeros((num_cluster_samples, 2*r+1, 2*r+1, num_channels), dtype=float)
idx = 0
for data, mask in zip(all_data, all_masks):
height, width, _ = data.shape
for i in range(num_cluster_samples_per_file):
x = np.random.randint(r, width-r)
y = np.random.randint(r, height-r)
while mask[y,x]: # we hit a no data
x = np.random.randint(r, width-r)
y = np.random.randint(r, height-r)
x_all[idx] = data[y-r:y+r+1, x-r:x+r+1].copy()
idx += 1
x_all_flat = x_all.reshape((num_cluster_samples, -1))
kmeans = KMeans(n_clusters=num_clusters, verbose=0, n_init=20)
kmeans = kmeans.fit(x_all_flat)
print("Finished fitting KMeans in %0.4f seconds" % (time.time() - tic))
#--------------------------------------------------
print("Sampling training dataset")
tic = float(time.time())
x_train = np.zeros((num_train_samples,sample_size,sample_size,num_channels), dtype=float)
x_val = np.zeros((num_val_samples,sample_size,sample_size,num_channels), dtype=float)
y_train = np.zeros((num_train_samples,sample_size,sample_size), dtype=int)
y_val = np.zeros((num_val_samples,sample_size,sample_size), dtype=int)
idx = 0
for data, mask in zip(all_data, all_masks):
height, width, _ = data.shape
for i in range(num_train_samples_per_file):
if idx % 1000 == 0:
print("%d/%d" % (idx, num_train_samples))
x = np.random.randint(r, width-sample_size-r)
y = np.random.randint(r, height-sample_size-r)
while mask[y,x]:
x = np.random.randint(r, width-sample_size-r)
y = np.random.randint(r, height-sample_size-r)
window = data[y-r:y+sample_size+r, x-r:x+sample_size+r]
labels = apply_model_to_data(window, r, kmeans)
x_train[idx] = data[y:y+sample_size,x:x+sample_size].copy()
y_train[idx] = labels
idx += 1
print("Finished sampling training dataset in %0.4f seconds" % (time.time() - tic))
#--------------------------------------------------
print("Sampling validation dataset")
tic = float(time.time())
idx = 0
for data, mask in zip(all_data, all_masks):
height, width, _ = data.shape
for i in range(num_val_samples_per_file):
if idx % 1000 == 0:
print("%d/%d" % (idx, num_val_samples))
x = np.random.randint(r, width-sample_size-r)
y = np.random.randint(r, height-sample_size-r)
while mask[y,x]:
x = np.random.randint(r, width-sample_size-r)
y = np.random.randint(r, height-sample_size-r)
window = data[y-r:y+sample_size+r, x-r:x+sample_size+r]
labels = apply_model_to_data(window, r, kmeans)
x_val[idx] = data[y:y+sample_size,x:x+sample_size].copy()
y_val[idx] = labels
idx += 1
print("Finished sampling validation dataset in %0.4f seconds" % (time.time() - tic))
#--------------------------------------------------
print("Normalizing sampled imagery")
means = 0
stds = 1
if (args.normalization_means is not None) and (args.normalization_stds is not None):
means = np.array(list(map(float,args.normalization_means.split(","))))
stds = np.array(list(map(float,args.normalization_stds.split(","))))
assert means.shape[0] == 0 or means.shape[0] == num_channels
assert stds.shape[0] == 0 or stds.shape[0] == num_channels
x_train = (x_train - means) / stds
x_val = (x_val - means) / stds
#--------------------------------------------------
print("Converting labels to categorical")
tic = float(time.time())
y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes=num_clusters)
y_val = tensorflow.keras.utils.to_categorical(y_val, num_classes=num_clusters)
print("Finished converting labels to categorical in %0.4f seconds" % (time.time() - tic))
#--------------------------------------------------
print("Creating and fitting model")
tic = float(time.time())
model = basic_model((sample_size, sample_size, num_channels), num_clusters, lr=initial_lr)
if args.verbose:
model.summary()
train_datagen = ImageDataGenerator(
rotation_range=0,
width_shift_range=0,
height_shift_range=0,
channel_shift_range=0.0,
horizontal_flip=False,
vertical_flip=False,
preprocessing_function=image_cutout_builder(mask_size=(5,20), replacement_val=means),
dtype=np.float32
)
val_datagen = ImageDataGenerator()
model_checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(args.output_fn, monitor="val_loss", verbose=1, save_best_only=True, save_weights_only=False, mode="min", save_freq="epoch")
lr_reducer = tensorflow.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=2, verbose=1)
early_stopper = tensorflow.keras.callbacks.EarlyStopping(monitor="val_loss", mode="min", verbose=1, patience=8)
history = model.fit(
train_datagen.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
steps_per_epoch=x_train.shape[0] // batch_size - 1,
epochs=num_epochs,
callbacks=[model_checkpoint, lr_reducer, early_stopper],
validation_data=val_datagen.flow(x_val, y_val, batch_size=batch_size),
validation_steps=x_val.shape[0] // batch_size - 1
)
print("Finished fitting model in %0.4f seconds" % (time.time() - tic))
with open(os.path.join(output_dir, "model_fit_history.p"), "wb") as f:
pickle.dump(
history.history,
f
)
print("Finished in %0.4f seconds" % (time.time() - start_time))