def initialize()

in covid19_spread/bar.py [0:0]


    def initialize(self, args):
        device = th.device(
            "cuda" if th.cuda.is_available() and getattr(args, "cuda", True) else "cpu"
        )
        cases, regions, basedate = load.load_confirmed_csv(args.fdat)
        assert (cases == cases).all(), th.where(cases != cases)

        # Cumulative max across time
        cases = np.maximum.accumulate(cases, axis=1)

        new_cases = th.zeros_like(cases)
        new_cases.narrow(1, 1, cases.size(1) - 1).copy_(cases[:, 1:] - cases[:, :-1])

        assert (new_cases >= 0).all(), new_cases[th.where(new_cases < 0)]
        new_cases = new_cases.float().to(device)[:, args.t0 :]

        print("Number of Regions =", new_cases.size(0))
        print("Timeseries length =", new_cases.size(1))
        print(
            "Increase: max all = {}, max last = {}, min last = {}".format(
                new_cases.max().item(),
                new_cases[:, -1].max().item(),
                new_cases[:, -1].min().item(),
            )
        )
        tmax = new_cases.size(1) + 1

        # adjust max window size to available data
        args.window = min(args.window, new_cases.size(1) - 4)

        # setup optional features
        graph = (
            th.load(args.graph).to(device).float() if hasattr(args, "graph") else None
        )
        features = _get_arg(args, "features", device, regions)
        time_features = _get_dict(args, "time_features", device, regions)
        if time_features is not None:
            time_features = time_features.transpose(0, 1)
            time_features = time_features.narrow(0, args.t0, new_cases.size(1))
            print("Feature size = {} x {} x {}".format(*time_features.size()))
            print(time_features.min(), time_features.max())

        self.weight_decay = 0
        # setup beta function
        if args.decay.startswith("latent"):
            dim, layers = args.decay[6:].split("_")
            fbeta = lambda M, input_dim: BetaRNN(
                M,
                int(layers),
                int(dim),
                input_dim,
                dropout=getattr(args, "dropout", 0.0),
            )
            beta_net = BetaLatent(fbeta, regions, tmax, time_features)
            self.weight_decay = args.weight_decay
        elif args.decay.startswith("lstm"):
            dim, layers = args.decay[len("lstm") :].split("_")
            fbeta = lambda M, input_dim: BetaLSTM(
                M,
                int(layers),
                int(dim),
                input_dim,
                dropout=getattr(args, "dropout", 0.0),
            )
            beta_net = BetaLatent(fbeta, regions, tmax, time_features)
            self.weight_decay = args.weight_decay
        elif args.decay.startswith("gru"):
            dim, layers = args.decay[len("gru") :].split("_")
            fbeta = lambda M, input_dim: BetaGRU(
                M,
                int(layers),
                int(dim),
                input_dim,
                dropout=getattr(args, "dropout", 0.0),
            )
            beta_net = BetaLatent(fbeta, regions, tmax, time_features)
            self.weight_decay = args.weight_decay
        else:
            raise ValueError("Unknown beta function")

        self.func = BAR(
            regions,
            beta_net,
            args.window,
            args.loss,
            graph,
            features,
            self_correlation=getattr(args, "self_correlation", True),
            cross_correlation=not getattr(args, "no_cross_correlation", False),
            offset=cases[:, 0].unsqueeze(1).to(device).float(),
        ).to(device)

        return new_cases, regions, basedate, device