def download_tfrecord()

in easycv/datasets/utils/tfrecord_util.py [0:0]


def download_tfrecord(file_list_or_path,
                      target_path,
                      slice_count=1,
                      slice_id=0,
                      force=False):
    """Download data from oss.
    Use the processes on the gpus to slice download, each gpu process downloads part of the data.
    The number of slices is the same as the number of gpu processes.
    Support tfrecord of ImageNet style.
    tfrecord_dir
        |---train1
        |---train1.idx
        |---train2
        |---train2.idx
        |---...

    Args:
        file_list_or_path:  A list of absolute data path or a path str
                    type(file_list) == list means this is the list
                    type(file_list) == str means open(file_list).readlines()
        target_path: A str, download path
        slice_count: Download worker num
        slice_id : Download worker ID
        force: If false, skip download if the file already exists in the target path.
            If true, recopy and replace the original file.

    Returns:
        path: list of str,  download tfrecord path
        index_path: list of str, download tfrecord idx path
    """
    with dist_utils.dist_zero_exec():
        if not os.path.exists(target_path):
            os.makedirs(target_path)

    logging.info(f'num gpu(slice_count): {slice_count}')

    if isinstance(file_list_or_path, list):
        all_file_list = file_list_or_path
    else:
        with io.open(file_list_or_path, 'r') as f:
            lines = f.readlines()
        all_file_list = [i.strip() for i in lines]

    all_data_list = [
        all_file_list[i] for i in range(len(all_file_list))
        if not all_file_list[i].endswith('.idx')
        and not all_file_list[i].endswith('.info')
    ]
    all_index_list = [
        all_file_list[i] for i in range(len(all_file_list))
        if all_file_list[i].endswith('.idx')
    ]
    if not all_index_list:
        all_index_list = [i + '.idx' for i in all_data_list]

    idx = 0
    for data_path in all_data_list:
        # split data list to target worker
        if idx % slice_count == slice_id:
            target_file = os.path.join(target_path,
                                       os.path.split(data_path)[-1])
            if not force and io.exists(target_file):
                logging.info('%s already exists, skip download!' % target_file)
                continue
            io.copy(data_path, target_file)
            logging.info('Finished download file: %s' % data_path)
        idx += 1

    idx = 0
    for idx_path in all_index_list:
        # split data list to target worker
        if idx % slice_count == slice_id:
            target_file = os.path.join(target_path,
                                       os.path.split(idx_path)[-1])
            if not force and io.exists(target_file):
                logging.info('%s already exists, skip download!' % target_file)
                continue
            io.copy(idx_path, target_file)
            logging.info('Finished download file: %s' % idx_path)
        idx += 1

    logging.info('rank %s finish downloads!' % slice_id)

    dist_utils.barrier()

    # return all data list
    new_path = []
    for data_path in all_data_list:
        target_file = os.path.join(target_path, os.path.split(data_path)[-1])
        new_path.append(target_file)
    all_data_list = new_path

    new_index_path = []
    for idx_path in all_index_list:
        target_file = os.path.join(target_path, os.path.split(idx_path)[-1])
        new_index_path.append(target_file)
    all_index_list = new_index_path

    return all_data_list, all_index_list