def _get_easyrec()

in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]


def _get_easyrec(pkg_path=None):
    """Get easyrec whl and extract."""
    local_cache_dir = tempfile.mkdtemp(prefix="tzrec_tmp")
    if pkg_path is None:
        pkg_path = (
            f"https://easyrec.oss-cn-beijing.aliyuncs.com/release/whls/"
            f"easy_rec-{EASYREC_VERSION}-py2.py3-none-any.whl"
        )
    if pkg_path.startswith("http"):
        logger.info(f"downloading easyrec from {pkg_path}")
        r = requests.get(pkg_path)
        content = r.content
    else:
        with open(pkg_path, "rb") as f:
            content = f.read()
    if ".tar" in pkg_path:
        try:
            with tarfile.open(fileobj=io.BytesIO(content)) as tar:
                tar.extractall(path=local_cache_dir)
            local_package_dir = local_cache_dir
        except Exception as e:
            raise RuntimeError(f"invalid {pkg_path} tar.") from e
    else:
        try:
            with zipfile.ZipFile(io.BytesIO(content)) as f:
                f.extractall(local_cache_dir)
            local_package_dir = local_cache_dir
        except zipfile.BadZipfile as e:
            raise RuntimeError(f"invalid {pkg_path} whl.") from e

    with open(os.path.join(local_package_dir, "easy_rec/__init__.py"), "w") as f:
        f.write("")
    sys.path.append(local_package_dir)
    _sym = symbol_database.Default()
    _sym.pool = descriptor_pool.DescriptorPool()
    from easy_rec.python.protos import feature_config_pb2 as _feature_config_pb2
    from easy_rec.python.protos import loss_pb2 as _loss_pb2
    from easy_rec.python.protos import pipeline_pb2 as _pipeline_pb2

    globals()["easyrec_pipeline_pb2"] = _pipeline_pb2
    globals()["easyrec_feature_config_pb2"] = _feature_config_pb2
    globals()["easyrec_loss_pb2"] = _loss_pb2