in utils/masks.py [0:0]
def generate_masks(n_data, split_config):
assert type(split_config) is dict
assert "public" in split_config and "private" in split_config
assert type(split_config["private"]) is dict
permutation = np.random.permutation(n_data)
if type(split_config["public"]) is dict:
n_public=int(sum(split_config["public"].values())*n_data)
else:
n_public = int(split_config["public"] * n_data)
n_private = n_data - n_public
known_masks = {}
known_masks["public"] = to_mask(n_data, permutation[:n_public])
known_masks["private"] = to_mask(n_data, permutation[n_public:])
hidden_masks = {}
hidden_masks["private"] = {}
sizes = multiply_round(n_private, split_config["private"])
print(' Private', sizes)
offset = n_public
for name, size in sizes.items():
hidden_masks["private"][name] = to_mask(n_data, permutation[offset:offset+size])
offset += size
assert offset == n_data
if type(split_config["public"]) is dict:
hidden_masks["public"] = {}
public_sizes = multiply_round(n_public, split_config["public"])
print('Public', public_sizes)
public_offset = 0
for name, size in public_sizes.items():
hidden_masks["public"][name] = to_mask(n_data, permutation[public_offset:public_offset+size])
public_offset += size
assert public_offset == n_public
else:
hidden_masks["public"] = known_masks["public"]
return known_masks, hidden_masks