in Synthesis_incorporation/value_search/value_search.py [0:0]
def _get_predicted_values(values, predicted_operation, constants_values, end_time, settings, statistics):
if len(values) > 1 and predicted_operation.name in ['torch.cat(tensors, dim)', 'torch.stack(tensors)', 'torch.stack(tensors, dim)']:
stacked_value = all_operations.find_operation_with_name('PairCreationOperation').apply(values, settings)
if stacked_value is None:
predicted_values = []
else:
predicted_values = predicted_operation.enumerate_values_with_values(
given_values=[[stacked_value]],
potential_value_list=constants_values,
end_time=end_time,
settings=settings,
statistics=statistics
)
if len(values) == 2 and predicted_operation.name in ['torch.mul(input, other)']:
new_values = []
for value in values:
if value.is_tensor and value.value.dtype == torch.bool:
new_values.append(all_operations.find_operation_with_name('IntOperation').apply([value], settings))
else:
new_values.append(value)
# values = [all_operations.find_operation_with_name('IntOperation').apply(value, settings) if value.is value.value.dtype == torch.bool and value is not None else value for value in values]
predicted_values = predicted_operation.enumerate_values_with_values(
given_values=[[value] for value in new_values],
potential_value_list=constants_values,
end_time=end_time,
settings=settings,
statistics=statistics
)
elif len(values) == 3 and predicted_operation.name in ['torch.where(condition, input, other)', 'torch.where(condition, self, other)']:
if values[0].value.dtype != torch.bool:
values[0] = all_operations.find_operation_with_name('BoolOperation').apply([values[0]], settings)
predicted_values = predicted_operation.enumerate_values_with_values(
given_values= [[value] for valule in values],
potential_value_list=constants_values,
end_time=end_time,
settings=settings,
statistics=statistics
)
elif len(values) == 1 and predicted_operation.name in ['torch.argmax(input)', 'torch.argmax(input, dim)']:
if values[0].value.dtype != torch.int:
values[0] = all_operations.find_operation_with_name('IntOperation').apply([values[0]], settings)
predicted_values = predicted_operation.enumerate_values_with_values(
given_values=[[values[0]]],
potential_value_list=constants_values,
end_time=end_time,
settings=settings,
statistics=statistics
)
else:
predicted_values = predicted_operation.enumerate_values_with_values(
given_values=[[value] for value in values],
potential_value_list=constants_values,
end_time=end_time,
settings=settings,
statistics=statistics
)
return predicted_values