def _get_predicted_values()

in Synthesis_incorporation/value_search/value_search.py [0:0]


def _get_predicted_values(values, predicted_operation, constants_values, end_time, settings, statistics):
    if len(values) > 1 and predicted_operation.name in ['torch.cat(tensors, dim)', 'torch.stack(tensors)', 'torch.stack(tensors, dim)']:
        stacked_value = all_operations.find_operation_with_name('PairCreationOperation').apply(values, settings)
        if stacked_value is None:
            predicted_values = []
        else:
            predicted_values = predicted_operation.enumerate_values_with_values(
                given_values=[[stacked_value]],
                potential_value_list=constants_values,
                end_time=end_time,
                settings=settings,
                statistics=statistics
            )
    if len(values) == 2 and predicted_operation.name in ['torch.mul(input, other)']:
        new_values = []
        for value in values:
            if value.is_tensor and value.value.dtype == torch.bool:
                new_values.append(all_operations.find_operation_with_name('IntOperation').apply([value], settings))
            else:
                new_values.append(value)
        # values = [all_operations.find_operation_with_name('IntOperation').apply(value, settings) if value.is value.value.dtype == torch.bool and value is not None else value for value in values]
        predicted_values = predicted_operation.enumerate_values_with_values(
            given_values=[[value] for value in new_values],
            potential_value_list=constants_values,
            end_time=end_time,
            settings=settings,
            statistics=statistics
        )
    elif len(values) == 3 and predicted_operation.name in ['torch.where(condition, input, other)', 'torch.where(condition, self, other)']:
        if values[0].value.dtype != torch.bool:
            values[0] = all_operations.find_operation_with_name('BoolOperation').apply([values[0]], settings)
        predicted_values = predicted_operation.enumerate_values_with_values(
                given_values= [[value] for valule in values],
                potential_value_list=constants_values,
                end_time=end_time,
                settings=settings,
                statistics=statistics
            )
    elif len(values) == 1 and predicted_operation.name in ['torch.argmax(input)', 'torch.argmax(input, dim)']:
        if values[0].value.dtype != torch.int:
            values[0] = all_operations.find_operation_with_name('IntOperation').apply([values[0]], settings)
        predicted_values = predicted_operation.enumerate_values_with_values(
                given_values=[[values[0]]],
                potential_value_list=constants_values,
                end_time=end_time,
                settings=settings,
                statistics=statistics
            )
    else:
        predicted_values = predicted_operation.enumerate_values_with_values(
            given_values=[[value] for value in values],
            potential_value_list=constants_values,
            end_time=end_time,
            settings=settings,
            statistics=statistics
        )
    return predicted_values