in core/src/autogluon/core/utils/serialization.py [0:0]
def _load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects = {}
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
elif isinstance(map_location, _string_classes):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
def _check_container_source(container_type, source_file, original_source):
try:
current_source = inspect.getsource(container_type)
except Exception: # saving the source is optional, so we can ignore any errors
warnings.warn("Couldn't retrieve source code for container of "
"type " + container_type.__name__ + ". It won't be checked "
"for correctness upon loading.")
return
if original_source != current_source:
if container_type.dump_patches:
file_name = container_type.__name__ + '.patch'
diff = difflib.unified_diff(current_source.split('\n'),
original_source.split('\n'),
source_file,
source_file, lineterm="")
lines = '\n'.join(diff)
try:
with open(file_name, 'a+') as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise IOError
msg = ("Saved a reverse patch to " + file_name + ". "
"Run `patch -p0 < " + file_name + "` to revert your "
"changes.")
except IOError:
msg = ("Tried to save a patch, but couldn't create a "
"writable file " + file_name + ". Make sure it "
"doesn't exist and your working directory is "
"writable.")
else:
msg = ("you can retrieve the original source code by "
"accessing the object's source attribute")
msg = ("source code of class has changed. {}"
.format(msg))
warnings.warn(msg, SourceChangeWarning)
deserialized_objects = {}
def maybe_decode_ascii(bytes_str):
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.
#
# NOTE: This should only be used on internal keys (e.g., `typename` and
# `location` in `persistent_load` below!
if isinstance(bytes_str, bytes):
return bytes_str.decode('ascii')
return bytes_str
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data)
return data[0]
elif typename == 'storage':
data_type, root_key, location, size, view_metadata = data
location = maybe_decode_ascii(location)
if root_key not in deserialized_objects:
obj = data_type(size)
deserialized_objects[root_key] = restore_location(obj, location)
storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
if view_key not in deserialized_objects:
deserialized_objects[view_key] = storage[offset:offset + view_size]
return deserialized_objects[view_key]
else:
return storage
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
def legacy_load(f):
deserialized_objects = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
# Ignore containers that don't have any sources saved
if all(saved_id[1:]):
_check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f, **pickle_load_args)
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
obj = storage_type._new_with_file(f)
obj = restore_location(obj, location)
deserialized_objects[key] = obj
storage_views = pickle_module.load(f, **pickle_load_args)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
deserialized_objects[target_cdata] = root[offset:offset + size]
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
tensor_type = storage_to_tensor_type(storage)
ndim, = struct.unpack('<i', f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
f.read(4)
size = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
stride = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
tensor = tensor_type().set_(storage, storage_offset, size, stride)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
_check_seekable(f)
f_should_read_directly = _should_read_directly(f)
if f_should_read_directly and f.tell() == 0:
try:
return legacy_load(f)
except tarfile.TarError:
if zipfile.is_zipfile(f):
raise RuntimeError("Please uncompress the file.")
# if not a tarfile, reset file offset and proceed
f.seek(0)
magic_number = pickle_module.load(f, **pickle_load_args)
if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f, **pickle_load_args)
if protocol_version != PROTOCOL_VERSION:
raise RuntimeError("Invalid protocol version: %s" % protocol_version)
_sys_info = pickle_module.load(f, **pickle_load_args)
unpickler = pickle_module.Unpickler(f, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
offset = None
return result