in src/scripts/gen_embeddings.py [0:0]
def get_embeddings(dataset_name: Dataset,
engine: ModelType,
entity: Optional[Entity],
indices_path: Optional[str]
):
results_folder = f'./data/{dataset_name}/{entity}/' if entity else f'./data/{dataset_name}/'
os.makedirs(results_folder, exist_ok=True)
successful_records = []
dataset = prepare_dataset(dataset_name, entity, indices_path)
records_names = list(dataset.keys())
num_records = len(records_names)
if num_records == 0:
return
num_samples = len(dataset[list(dataset.keys())[0]]) if num_records > 1 else 1
model_dims = resolve_model_dim_size(engine)
embeddings = np.zeros((num_records, num_samples, model_dims), dtype=np.float32)
with tqdm(total=num_records) as pbar:
for i, key in enumerate(dataset):
samples = [texti.replace("\n", " ") if isinstance(texti, str)
else ", ".join(texti).replace("\n", " ") if isinstance(texti, list)
else "none"
for texti in dataset[key]]
response = try_request(engine, samples)
if response is None:
print(
f'Failed request {i} ({records_names[i]}), '
f'{len(dataset[key])} chars, {len(dataset[key].split())} words')
pbar.update()
continue
embeddings[i] = [response[i]['embedding'] for i in range(len(response))]
successful_records.append(records_names[i])
pbar.update()
del dataset
failed_requests = list(set(records_names) - set(successful_records))
drop_rows_mask = np.array([x not in failed_requests for x in records_names])
embeddings = embeddings[drop_rows_mask]
write_to_file(results_folder, embeddings, successful_records, failed_requests)