gen_simple_module.py (88 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import base64
import io
import logging
import os
import zipfile
from collections import namedtuple
try:
import lzma # noqa(F401)
import tarfile
has_tar_lzma = True
except ImportError:
has_tar_lzma = False
Args = namedtuple("Args", ["inputs", "output", "tarxz"])
def parse_args():
parser = argparse.ArgumentParser(
description="Generate simple module with the given file"
)
parser.add_argument("args", nargs="+", help="name=filename list")
parser.add_argument(
"-o",
"--out",
nargs=1,
type=os.path.realpath,
help="Generated python wrapper",
)
parser.add_argument(
"--force-zip",
action="store_true",
help="Force the use of zip, even when tar and lzma are available",
)
args = parser.parse_args()
global has_tar_lzma
return Args(args.args, args.out[0], has_tar_lzma and not args.force_zip)
def compress_zip(inputs):
logging.info("Compressing as zip")
buf = io.BytesIO(b"")
with zipfile.ZipFile(buf, "w") as zf:
for input in inputs:
logging.info("Adding %s", input)
with open(input, "rb") as f:
zf.writestr(os.path.basename(input), f)
buf.seek(0)
return buf
def compress_tar_xz(inputs):
logging.info("Compressing as tar.xz")
buf = io.BytesIO(b"")
tar = tarfile.open(fileobj=buf, mode="w:xz")
for input in inputs:
logging.info("Adding %s", input)
info = tar.gettarinfo(input)
info.name = os.path.basename(input)
with open(input, "rb") as f:
tar.addfile(info, fileobj=f)
tar.close()
buf.seek(0)
return buf
def compress_and_base_64(inputs, tar_xz):
with compress_tar_xz(inputs) if tar_xz else compress_zip(inputs) as buf:
return base64.b64encode(buf.getbuffer())
_FILE_TEMPLATE = """
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import base64
import io
import re
{extra_imports}
_BASE64BLOB = "{base64_blob}"
{compression_specific_api}
"""
# def get_api_level_file(level):
# name = f"framework_classes_api_{{level}}.txt"
# return _load(name)
# def get_api_levels():
# name_re = re.compile(r"^framework_classes_api_(\\d+)\\.txt$")
# return {{
# int(match.group(1)) for name in _all() for match in [name_re.match(name)] if match
# }}
_TAR_XZ_IMPORTS = """
import lzma # noqa(F401)
import tarfile
"""
_TAR_XZ_API = """
_TAR = tarfile.open(mode="r:xz", fileobj=io.BytesIO(base64.b64decode(_BASE64BLOB)))
def _load(name):
global _TAR
return _TAR.extractfile(name).read()
def _all():
global _TAR
return _TAR.getnames()
"""
_ZIP_IMPORTS = "import zipfile"
_ZIP_API = """
_ZIP = zipfile.ZipFile(io.BytesIO(base64.b64decode(_BASE64BLOB)), "r")
def _load(name):
global _ZIP
return _ZIP.read(names)
def _all():
global _ZIP
return _ZIP.namelist()
"""
def write_py_wrapper(base_64_bytes_blob, files, filename, tar_xz):
base64_str = base_64_bytes_blob.decode("ascii")
with open(filename, "w") as f:
f.write(
_FILE_TEMPLATE.format(
extra_imports=_TAR_XZ_IMPORTS if tar_xz else _ZIP_IMPORTS,
base64_blob=base64_str,
compression_specific_api=_TAR_XZ_API if tar_xz else _ZIP_API,
)
)
for key, val in files.items():
f.write(f'{key} = _load("{os.path.basename(val)}")\n')
def main():
args = parse_args()
files = {
key_val[: key_val.find("=")]: key_val[key_val.find("=") + 1 :]
for key_val in args.inputs
}
base64_blob = compress_and_base_64(files.values(), args.tarxz)
write_py_wrapper(base64_blob, files, args.output, args.tarxz)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()