def process_time_features()

in covid19_spread/data/usa/convert.py [0:0]


def process_time_features(df, pth, shift=0, merge_nyc=False, input_resolution="county"):
    print(f"Processing {pth} at resolution: {input_resolution}")
    time_features = pd.read_csv(pth)
    if input_resolution == "county_state":
        # Expand state level time features to each county in `df`
        idx = df.rename_axis("county").reset_index()[["county"]]
        idx["region"] = idx["county"].apply(lambda x: x.split(", ")[-1])
        time_features = time_features.merge(idx, on="region").drop(columns="region")
        time_features = time_features.rename(columns={"county": "region"})
    time_feature_regions = time_features["region"].unique()
    ncommon = len(df.index.intersection(time_feature_regions))
    if ncommon != len(df):
        missing = set(df.index).difference(set(time_feature_regions))
        warnings.warn(
            f"{pth}: Missing time features for the following regions: {list(missing)}"
        )
    if ncommon != len(time_feature_regions):
        ignoring = set(time_feature_regions).difference(set(df.index))
        warnings.warn(
            f"{pth}: Ignoring time features for the following regions: {list(ignoring)}"
        )
        time_features = time_features[time_features["region"].isin(set(df.index))]
    if merge_nyc:
        time_features = merge_nyc_boroughs(
            time_features, len(time_features["type"].unique())
        )
    # Transpose to have two level columns (region, type) and dates as index
    time_features = time_features.set_index(["region", "type"]).transpose().sort_index()
    time_features.index = pd.to_datetime(time_features.index)
    # Trim prefix if it starts before the dates in `df`
    time_features = time_features.loc[time_features.index >= df.columns.min()]
    # Fill in dates that are missing in `time_features` that exist in `df`
    time_features = time_features.reindex(df.columns)
    # Shift time features UP by `shift` days
    time_features = time_features.shift(shift)
    # forward fill the missing values
    time_features = time_features.fillna(method="ffill")
    # Fill the beginning end with zeros if null
    time_features = time_features.fillna(0)
    time_features = time_features[time_features.columns.sort_values()]
    feature_tensors = {
        region: th.from_numpy(time_features[region].values)
        for region in time_features.columns.get_level_values(0).unique()
    }
    if input_resolution == "county_state":
        pth = pth.replace("state", "county_state")
    th.save(feature_tensors, pth.replace(".csv", ".pt"))