def get_embeddings()

in src/scripts/gen_embeddings.py [0:0]


def get_embeddings(dataset_name: Dataset,
                   engine: ModelType,
                   entity: Optional[Entity],
                   indices_path: Optional[str]
                   ):
    results_folder = f'./data/{dataset_name}/{entity}/' if entity else f'./data/{dataset_name}/'
    os.makedirs(results_folder, exist_ok=True)
    successful_records = []
    dataset = prepare_dataset(dataset_name, entity, indices_path)
    records_names = list(dataset.keys())
    num_records = len(records_names)
    if num_records == 0:
        return
    num_samples = len(dataset[list(dataset.keys())[0]]) if num_records > 1 else 1
    model_dims = resolve_model_dim_size(engine)
    embeddings = np.zeros((num_records, num_samples, model_dims), dtype=np.float32)
    with tqdm(total=num_records) as pbar:
        for i, key in enumerate(dataset):
            samples = [texti.replace("\n", " ") if isinstance(texti, str)
                       else ", ".join(texti).replace("\n", " ") if isinstance(texti, list)
                       else "none"
                       for texti in dataset[key]]
            response = try_request(engine, samples)
            if response is None:
                print(
                    f'Failed request {i} ({records_names[i]}), '
                    f'{len(dataset[key])} chars, {len(dataset[key].split())} words')
                pbar.update()
                continue
            embeddings[i] = [response[i]['embedding'] for i in range(len(response))]
            successful_records.append(records_names[i])
            pbar.update()
    del dataset
    failed_requests = list(set(records_names) - set(successful_records))
    drop_rows_mask = np.array([x not in failed_requests for x in records_names])
    embeddings = embeddings[drop_rows_mask]

    write_to_file(results_folder, embeddings, successful_records, failed_requests)