def main()

in Damage Assessment Visualizer/utils/inference.py [0:0]


def main(args):
    """Inference script logic.

    Args:
        args from command line
    """
    print('starting inference with variables:')
    
    # Setup
    pre_fn = args.pre_imagery
    post_fn = args.post_imagery
    output_fn = args.output_fn
    print(pre_fn,post_fn,output_fn, sep="\n")
    assert os.path.exists(pre_fn)
    assert os.path.exists(post_fn)
    assert not os.path.exists(output_fn)

    device = None
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
    else:
        print(
            "WARNING: GPU is not available -- defaulting to CPU -- inference will be slow"
        )
        device = torch.device("cpu")

    # Validating input data
    with rasterio.open(pre_fn) as f:
        assert f.profile["dtype"] == "uint8"
        assert f.count == 3
        input_height = f.height
        input_width = f.width
        input_crs = f.crs

    with rasterio.open(pre_fn) as f:
        assert f.profile["dtype"] == "uint8"
        assert f.count == 3
        assert f.crs == input_crs
        assert f.width == input_width
        assert f.height == input_height

    # Load data from the intersection
    print("Loading data")
    with rasterio.open(pre_fn) as f:
        pre_data = f.read()
        twod_nodata_mask = (pre_data == 0).sum(axis=0) == 3
        pre_data = pre_data / 255.0
        input_profile = f.profile.copy()
        pre_data = pre_data.reshape(pre_data.shape[0], -1).T.copy()

    with rasterio.open(post_fn) as f:
        post_data = f.read()
        post_data = post_data / 255.0
        post_data = post_data.reshape(post_data.shape[0], -1).T.copy()

    print("Computing data statistics")
    all_data = np.concatenate([pre_data, post_data], axis=0)
    nodata_mask = (all_data == 0).sum(axis=1) != 3
    all_means = np.mean(all_data[nodata_mask], axis=0, dtype=np.float64)
    all_stdevs = np.std(all_data[nodata_mask], axis=0, dtype=np.float64)

    # Create dataloaders
    print("Creating dataloaders")

    def transform_by_all(img):
        img = img / 255.0
        img = (img - all_means) / all_stdevs
        img = np.rollaxis(img, 2, 0)
        img = torch.from_numpy(img).float()
        return img

    ds1 = TileInferenceDataset(
        pre_fn,
        CHIP_SIZE,
        CHIP_STRIDE,
        transform=transform_by_all,
        windowed_sampling=False,
        verbose=False,
    )
    ds2 = TileInferenceDataset(
        post_fn,
        CHIP_SIZE,
        CHIP_STRIDE,
        transform=transform_by_all,
        windowed_sampling=False,
        verbose=False,
    )
    ds = ZipDataset(ds1, ds2)
    dataloader = DataLoader(
        ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=False
    )

    # Init model
    print("Initializing model")
    model = SiamUnet()
    state_dict = load_state_dict_from_url(
        "https://github.com/microsoft/building-damage-assessment-cnn-siamese/raw/main/models/model_best.pth.tar",
        map_location="cpu",
    )["state_dict"]
    model.load_state_dict(state_dict)
    model = model.eval()
    model = model.to(device)

    # Run model
    print("Running model inference")
    output = np.zeros((5, input_height, input_width), dtype=np.float64)
    kernel = np.ones((CHIP_SIZE, CHIP_SIZE), dtype=np.float32)
    kernel[HALF_PADDING:-HALF_PADDING, HALF_PADDING:-HALF_PADDING] = 5
    counts = np.zeros((input_height, input_width), dtype=np.float32)
    for i, (x1, x2, coords) in enumerate(dataloader):
        x1 = x1.to(device)
        x2 = x2.to(device)

        with torch.no_grad():
            y1, y2, damage = model.forward(x1, x2)

            y1 = y1.argmax(dim=1)

            damage = y1.unsqueeze(1) * damage
            damage = F.softmax(damage, dim=1).cpu().numpy()

        for j in range(damage.shape[0]):
            y, x = coords[j]
            output[:, y:y+CHIP_SIZE, x:x+CHIP_SIZE] += damage[j] * kernel
            counts[y : y + CHIP_SIZE, x : x + CHIP_SIZE] += kernel

    output = output / counts
    output = output.argmax(axis=0).astype(np.uint8)
    output[twod_nodata_mask] = 0
    output[output==1] = 0

    # Save results
    print("Saving output")
    input_profile["count"] = 1
    input_profile["nodata"] = 0
    input_profile["height"] = input_height
    input_profile["width"] = input_width
    input_profile["compress"] = "lzw"
    input_profile["predictor"] = 2
    print("writing output to file:")
    print(output_fn)
    with rasterio.open(output_fn, "w", **input_profile) as f:
        f.write(output, 1)
        f.write_colormap(
            1,
            {
                0: (0, 0, 0, 0),
                1: (0, 0, 0, 0),
                2: (252, 112, 80, 255),
                3: (212, 32, 32, 255),
                4: (103, 0, 13, 255),
            },
        )
    print("finished writing output")