def main()

in training/train_unsupervised.py [0:0]


def main():
    # Other args
    nodata_value = args.nodata_val
    num_cluster_samples_per_file = args.num_cluster_samples_per_file
    num_clusters = args.num_clusters
    r = args.radius
    num_train_samples_per_file = args.num_train_samples_per_file
    num_val_samples_per_file = args.num_val_samples_per_file
    sample_size = args.patch_size
    num_epochs = args.num_epochs
    batch_size = args.batch_size

    initial_lr = 0.01
    assert sample_size % (2*r+1) == 0
    assert sample_size % (2*r+1) == 0

    input_fns = glob.glob(args.input_fn)
    num_files = len(input_fns)
    num_cluster_samples = num_cluster_samples_per_file * num_files
    num_train_samples = num_train_samples_per_file * num_files
    num_val_samples = num_val_samples_per_file * num_files
    output_dir = os.path.dirname(args.output_fn)

    #--------------------------------------------------
    print("Starting unsupervised pre-training script with %d inputs" % (len(input_fns)))
    start_time = float(time.time())

    #--------------------------------------------------
    print("Loading data")
    tic = float(time.time())

    all_data = []
    for fn in input_fns:
        with rasterio.open(fn) as f:
            all_data.append(np.rollaxis(f.read(), 0, 3))

    _, __, num_channels = all_data[0].shape
    all_masks = []
    for data in all_data:
        assert data.shape[2] == num_channels
        all_masks.append(np.sum(data == nodata_value, axis=2) == num_channels) # We assume that if the `nodata_value` is present across all channels then the pixel is actually nodata
    print("Finished loading %d files in %0.4f seconds" % (len(all_data), time.time() - tic))


    #--------------------------------------------------
    # We first randomly sample (2*r+1, 2*r+1) patches to use to fit a KMeans model.
    # We will later sample (sample_size, sample_size) patches, then apply this KMeans
    # model to every pixel within those patches to get corresponding target labels.
    print("Sampling dataset for KMeans and fitting model")
    tic = float(time.time())
    x_all = np.zeros((num_cluster_samples, 2*r+1, 2*r+1, num_channels), dtype=float)
    idx = 0
    for data, mask in zip(all_data, all_masks):
        height, width, _ = data.shape
        for i in range(num_cluster_samples_per_file):

            x = np.random.randint(r, width-r)
            y = np.random.randint(r, height-r)
            while mask[y,x]: # we hit a no data
                x = np.random.randint(r, width-r)
                y = np.random.randint(r, height-r)

            x_all[idx] = data[y-r:y+r+1, x-r:x+r+1].copy()
            idx += 1
    x_all_flat = x_all.reshape((num_cluster_samples, -1))

    kmeans = KMeans(n_clusters=num_clusters, verbose=0, n_init=20)
    kmeans = kmeans.fit(x_all_flat)
    print("Finished fitting KMeans in %0.4f seconds" % (time.time() - tic))


    #--------------------------------------------------
    print("Sampling training dataset")
    tic = float(time.time())

    x_train = np.zeros((num_train_samples,sample_size,sample_size,num_channels), dtype=float)
    x_val = np.zeros((num_val_samples,sample_size,sample_size,num_channels), dtype=float)

    y_train = np.zeros((num_train_samples,sample_size,sample_size), dtype=int)
    y_val = np.zeros((num_val_samples,sample_size,sample_size), dtype=int)
    idx = 0
    for data, mask in zip(all_data, all_masks):
        height, width, _ = data.shape
        for i in range(num_train_samples_per_file):
            if idx % 1000 == 0:
                print("%d/%d" % (idx, num_train_samples))
            x = np.random.randint(r, width-sample_size-r)
            y = np.random.randint(r, height-sample_size-r)

            while mask[y,x]:
                x = np.random.randint(r, width-sample_size-r)
                y = np.random.randint(r, height-sample_size-r)

            window = data[y-r:y+sample_size+r, x-r:x+sample_size+r]
            labels = apply_model_to_data(window, r, kmeans)
            x_train[idx] = data[y:y+sample_size,x:x+sample_size].copy()
            y_train[idx] = labels
            idx += 1
    print("Finished sampling training dataset in %0.4f seconds" % (time.time() - tic))


    #--------------------------------------------------
    print("Sampling validation dataset")
    tic = float(time.time())
    idx = 0
    for data, mask in zip(all_data, all_masks):
        height, width, _ = data.shape
        for i in range(num_val_samples_per_file):
            if idx % 1000 == 0:
                print("%d/%d" % (idx, num_val_samples))
            x = np.random.randint(r, width-sample_size-r)
            y = np.random.randint(r, height-sample_size-r)

            while mask[y,x]:
                x = np.random.randint(r, width-sample_size-r)
                y = np.random.randint(r, height-sample_size-r)

            window = data[y-r:y+sample_size+r, x-r:x+sample_size+r]
            labels = apply_model_to_data(window, r, kmeans)
            x_val[idx] = data[y:y+sample_size,x:x+sample_size].copy()
            y_val[idx] = labels
            idx += 1
    print("Finished sampling validation dataset in %0.4f seconds" % (time.time() - tic))


    #--------------------------------------------------
    print("Normalizing sampled imagery")
    means = 0
    stds = 1
    if (args.normalization_means is not None) and (args.normalization_stds is not None):
        means = np.array(list(map(float,args.normalization_means.split(","))))
        stds = np.array(list(map(float,args.normalization_stds.split(","))))
        assert means.shape[0] == 0 or means.shape[0] == num_channels
        assert stds.shape[0] == 0 or stds.shape[0] == num_channels
    x_train = (x_train - means) / stds
    x_val = (x_val - means) / stds


    #--------------------------------------------------
    print("Converting labels to categorical")
    tic = float(time.time())
    y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes=num_clusters)
    y_val = tensorflow.keras.utils.to_categorical(y_val, num_classes=num_clusters)
    print("Finished converting labels to categorical in %0.4f seconds" % (time.time() - tic))

    #--------------------------------------------------
    print("Creating and fitting model")
    tic = float(time.time())
    model = basic_model((sample_size, sample_size, num_channels), num_clusters, lr=initial_lr)
    if args.verbose:
        model.summary()

    train_datagen = ImageDataGenerator(
        rotation_range=0,
        width_shift_range=0,
        height_shift_range=0,
        channel_shift_range=0.0,
        horizontal_flip=False,
        vertical_flip=False,
        preprocessing_function=image_cutout_builder(mask_size=(5,20), replacement_val=means),
        dtype=np.float32
    )
    val_datagen = ImageDataGenerator()

    model_checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(args.output_fn, monitor="val_loss", verbose=1, save_best_only=True, save_weights_only=False, mode="min", save_freq="epoch")
    lr_reducer = tensorflow.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=2, verbose=1)
    early_stopper = tensorflow.keras.callbacks.EarlyStopping(monitor="val_loss", mode="min", verbose=1, patience=8)

    history = model.fit(
        train_datagen.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
        steps_per_epoch=x_train.shape[0] // batch_size - 1,
        epochs=num_epochs,
        callbacks=[model_checkpoint, lr_reducer, early_stopper],
        validation_data=val_datagen.flow(x_val, y_val, batch_size=batch_size),
        validation_steps=x_val.shape[0] // batch_size - 1
    )

    print("Finished fitting model in %0.4f seconds" % (time.time() - tic))

    with open(os.path.join(output_dir, "model_fit_history.p"), "wb") as f:
        pickle.dump(
            history.history,
            f
        )

    print("Finished in %0.4f seconds" % (time.time() - start_time))