in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]
def _get_easyrec(pkg_path=None):
"""Get easyrec whl and extract."""
local_cache_dir = tempfile.mkdtemp(prefix="tzrec_tmp")
if pkg_path is None:
pkg_path = (
f"https://easyrec.oss-cn-beijing.aliyuncs.com/release/whls/"
f"easy_rec-{EASYREC_VERSION}-py2.py3-none-any.whl"
)
if pkg_path.startswith("http"):
logger.info(f"downloading easyrec from {pkg_path}")
r = requests.get(pkg_path)
content = r.content
else:
with open(pkg_path, "rb") as f:
content = f.read()
if ".tar" in pkg_path:
try:
with tarfile.open(fileobj=io.BytesIO(content)) as tar:
tar.extractall(path=local_cache_dir)
local_package_dir = local_cache_dir
except Exception as e:
raise RuntimeError(f"invalid {pkg_path} tar.") from e
else:
try:
with zipfile.ZipFile(io.BytesIO(content)) as f:
f.extractall(local_cache_dir)
local_package_dir = local_cache_dir
except zipfile.BadZipfile as e:
raise RuntimeError(f"invalid {pkg_path} whl.") from e
with open(os.path.join(local_package_dir, "easy_rec/__init__.py"), "w") as f:
f.write("")
sys.path.append(local_package_dir)
_sym = symbol_database.Default()
_sym.pool = descriptor_pool.DescriptorPool()
from easy_rec.python.protos import feature_config_pb2 as _feature_config_pb2
from easy_rec.python.protos import loss_pb2 as _loss_pb2
from easy_rec.python.protos import pipeline_pb2 as _pipeline_pb2
globals()["easyrec_pipeline_pb2"] = _pipeline_pb2
globals()["easyrec_feature_config_pb2"] = _feature_config_pb2
globals()["easyrec_loss_pb2"] = _loss_pb2