data/ami/utils.py (159 lines of code) (raw):
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import os
import random
from collections import namedtuple
import sox
Speaker = namedtuple("Speaker", ["id", "gender"])
FileRecord = namedtuple("FileRecord", ["fid", "length", "speaker"])
def split_audio(line):
apath, meetid, hset, spk, start, end, transcript = line.strip().split(" ", 6)
key = "_".join([meetid, hset, spk, start, end])
os.makedirs(os.path.join(apath, "segments", meetid), exist_ok=True)
idx = hset[-1]
fn = f"{meetid}.Headset-{idx}.wav"
infile = os.path.join(apath, meetid, fn)
assert os.path.exists(infile), f"{infile} doesn't exist"
new_path = os.path.join(apath, "segments", meetid, key + ".flac")
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="flac", encoding="signed-integer", bits=16, rate=16000
)
start = float(start)
end = float(end)
sox_tfm.trim(start, end)
sox_tfm.build(infile, new_path)
sx_dur = sox.file_info.duration(new_path)
if sx_dur is not None and abs(sx_dur - end + start) < 0.5:
return [meetid, key, new_path, str(round(sx_dur * 1000, 2)), transcript.lower()]
def do_split(all_records, spkrs, total_seconds, handles_chosen=None):
"""
Greedily selecting speakers, provided we don't go over budget
"""
time_taken = 0.0
records_filtered = []
idx = 0
speakers = copy.deepcopy(spkrs)
current_speaker_time = {spk: 0 for spk in speakers}
current_speaker_idx = {spk: 0 for spk in speakers}
while True:
if len(speakers) == 0:
break
speaker = speakers[idx % len(speakers)]
idx += 1
tocontinue = False
while True:
cur_spk_idx = current_speaker_idx[speaker]
if cur_spk_idx == len(all_records[speaker]):
speakers.remove(speaker)
tocontinue = True
break
cur_record = all_records[speaker][cur_spk_idx]
current_speaker_idx[speaker] += 1
if handles_chosen is None or cur_record.fid not in handles_chosen:
break
if tocontinue:
continue
records_filtered.append(cur_record)
time_taken += cur_record.length
current_speaker_time[speaker] += cur_record.length
if abs(time_taken - total_seconds) < 10:
break
return records_filtered, time_taken
def get_speakers(train_file):
cache = {}
all_speakers = []
with open(train_file) as f:
for line in f:
spl = line.split()
speaker_id = spl[0].split("_")[2]
gender = speaker_id[0]
if gender not in ["M", "F"]:
continue
if speaker_id not in cache:
cache[speaker_id] = 1
speaker = Speaker(id=speaker_id, gender=gender)
all_speakers.append(speaker)
return all_speakers
def get_fid2length(train_file):
fids = []
lengths = []
with open(train_file) as f:
for line in f:
spl = line.split()
fids.append(spl[0])
lengths.append(float(spl[2]) / 1000)
return list(zip(fids, lengths))
def full_records(speakers, fid2length, subset_name=None):
all_records = []
speakers = {(speaker.id, speaker) for speaker in speakers}
for fid, length in fid2length:
speaker = fid.split("_")[2]
assert speaker in speakers, f"Unknown speaker! {speaker}"
speaker = speakers[speaker]
if subset_name is not None:
assert subset_name == speaker.subset
frecord = FileRecord(speaker=speaker, length=length, fid=fid)
all_records.append(frecord)
return all_records
def get_speaker2time(records, lambda_key, lambda_value):
from collections import defaultdict
key_value = defaultdict(int)
for record in records:
key = lambda_key(record)
value = lambda_value(record)
key_value[key] += value
return key_value
def create_limited_sup(list_dir):
random.seed(0)
train_file = os.path.join(list_dir, "train.lst")
assert os.path.exists(train_file)
speakers = get_speakers(train_file)
print("Found speakers", len(speakers))
write_records = {}
chosen_records = {}
fid2length = get_fid2length(train_file)
all_records = full_records(speakers, fid2length)
for gender in ["M", "F"]:
print(f"Selecting from gender {gender}")
records = [rec for rec in all_records if rec.speaker.gender == gender]
speaker2time = get_speaker2time(
records, lambda_key=lambda r: r.speaker.id, lambda_value=lambda r: r.length
)
# select 15 random speakers
min_minutes_per_speaker = 15
speakers_10hr = {
r.speaker.id
for r in records
if speaker2time[r.speaker.id] >= min_minutes_per_speaker * 60
}
speakers_10hr = sorted(speakers_10hr)
random.shuffle(speakers_10hr)
speakers_10hr = speakers_10hr[:15]
print(f"Selected speakers from gender {gender} ", speakers_10hr)
cur_records = {}
for speaker in speakers_10hr:
cur_records[speaker] = [r for r in records if r.speaker.id == speaker]
random.shuffle(cur_records[speaker])
# 1 hr as 6 x 10min splits
key = "10min_" + gender
write_records[key] = {}
for i in range(6):
speakers_10min = random.sample(set(speakers_10hr), 3)
write_records[key][i], _ = do_split(
cur_records, speakers_10min, 10 * 60 / 2, chosen_records
)
for kk in write_records[key][i]:
chosen_records[kk.fid] = 1
# 9 hr
key = "9hr_" + gender
write_records[key], _ = do_split(
cur_records, speakers_10hr, (9 * 60 * 60) / 2, chosen_records
)
train_lines = {}
with open(train_file) as f:
for line in f:
train_lines[line.split()[0]] = line.strip()
print("Writing 6 x 10min list files...")
for i in range(6):
with open(os.path.join(list_dir, f"train_10min_{i}.lst"), "w") as fo:
for record in write_records["10min_M"][i] + write_records["10min_F"][i]:
fo.write(train_lines[record.fid])
print("Writing 9hr list file...")
with open(os.path.join(list_dir, "train_9hr.lst"), "w") as fo:
for record in write_records["9hr_M"] + write_records["9hr_F"]:
fo.write(train_lines[record.fid])