in python/singa/model.py [0:0]
def buffer_operation(func):
def remove_creator(tensors):
if not tensors:
return
if isinstance(tensors, Iterable):
for item in tensors:
if isinstance(item, Iterable):
remove_creator(item)
elif isinstance(item, tensor.Tensor):
item.creator = None
elif isinstance(tensors, tensor.Tensor):
tensors.creator = None
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.graph_mode and self.training:
if len(args) == 0:
raise ValueError('expect at least one input tensor')
if isinstance(args[0], list):
assert isinstance(
args[0][0],
Tensor), ('function expects PlaceHolders or Tensors')
dev = args[0][0].device
else:
assert isinstance(
args[0],
Tensor), ('function expects PlaceHolders or Tensors')
dev = args[0].device
if not self._buffered:
# buffer operations
dev.EnableGraph(True)
self._results = func(self, *args, **kwargs)
dev.Sync()
dev.EnableGraph(False)
self._buffered = True
# deconstruct Operations before running the entire graph
remove_creator(self._results)
# make sure all Operations are deallocated
gc.collect()
# run graph
dev.RunGraph(self.sequential)
return self._results
else:
return func(self, *args, **kwargs)
return wrapper