in data/ami/utils.py [0:0]
def create_limited_sup(list_dir):
random.seed(0)
train_file = os.path.join(list_dir, "train.lst")
assert os.path.exists(train_file)
speakers = get_speakers(train_file)
print("Found speakers", len(speakers))
write_records = {}
chosen_records = {}
fid2length = get_fid2length(train_file)
all_records = full_records(speakers, fid2length)
for gender in ["M", "F"]:
print(f"Selecting from gender {gender}")
records = [rec for rec in all_records if rec.speaker.gender == gender]
speaker2time = get_speaker2time(
records, lambda_key=lambda r: r.speaker.id, lambda_value=lambda r: r.length
)
# select 15 random speakers
min_minutes_per_speaker = 15
speakers_10hr = {
r.speaker.id
for r in records
if speaker2time[r.speaker.id] >= min_minutes_per_speaker * 60
}
speakers_10hr = sorted(speakers_10hr)
random.shuffle(speakers_10hr)
speakers_10hr = speakers_10hr[:15]
print(f"Selected speakers from gender {gender} ", speakers_10hr)
cur_records = {}
for speaker in speakers_10hr:
cur_records[speaker] = [r for r in records if r.speaker.id == speaker]
random.shuffle(cur_records[speaker])
# 1 hr as 6 x 10min splits
key = "10min_" + gender
write_records[key] = {}
for i in range(6):
speakers_10min = random.sample(set(speakers_10hr), 3)
write_records[key][i], _ = do_split(
cur_records, speakers_10min, 10 * 60 / 2, chosen_records
)
for kk in write_records[key][i]:
chosen_records[kk.fid] = 1
# 9 hr
key = "9hr_" + gender
write_records[key], _ = do_split(
cur_records, speakers_10hr, (9 * 60 * 60) / 2, chosen_records
)
train_lines = {}
with open(train_file) as f:
for line in f:
train_lines[line.split()[0]] = line.strip()
print("Writing 6 x 10min list files...")
for i in range(6):
with open(os.path.join(list_dir, f"train_10min_{i}.lst"), "w") as fo:
for record in write_records["10min_M"][i] + write_records["10min_F"][i]:
fo.write(train_lines[record.fid])
print("Writing 9hr list file...")
with open(os.path.join(list_dir, "train_9hr.lst"), "w") as fo:
for record in write_records["9hr_M"] + write_records["9hr_F"]:
fo.write(train_lines[record.fid])