velox/experimental/codegen/ast/ASTNode.h (304 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <memory>
#include <optional>
#include <shared_mutex>
#include <sstream>
#include <string>
#include <vector>
#include "velox/experimental/codegen/CodegenExceptions.h"
#include "velox/experimental/codegen/ast/CodegenCtx.h"
#include "velox/experimental/codegen/compiler_utils/LibraryDescriptor.h"
#include "velox/experimental/codegen/udf_manager/ExpressionNullMode.h"
#include "velox/experimental/codegen/udf_manager/UDFManager.h"
#include "velox/type/Type.h"
namespace facebook {
namespace velox {
namespace codegen {
using ASTNodePtr = std::shared_ptr<class ASTNode>;
// A structure that represent the intermediate partial code generated by the
// generateCode.
struct CodeSnippet {
public:
CodeSnippet(
const std::string& outputVarName = "",
const std::string& code = "")
: outputVarName_(outputVarName), code_(code) {}
const std::string& code() const {
return code_;
}
// Return a lambda that executes the code and returns the result value stored
// in the output variable.
std::string getAsLambda(const std::string& lambdaName) const {
return fmt::format(
"auto {name} = [](){{ {code} return {outputVar}; }}",
fmt::arg("code", code_),
fmt::arg("outputVar", outputVarName_),
fmt::arg("name", lambdaName));
}
private:
// The name of the output variable that holds the expression result value
std::string outputVarName_;
// The code that need to be executed
std::string code_;
};
/// Abstract class for all expression supported in codegen
class ASTNode {
public:
explicit ASTNode(const TypePtr& type)
: type_(std::const_pointer_cast<Type>(type)) {}
virtual ~ASTNode() = 0;
// Validate if the expression node is complete and ready for codeGen
virtual void validate() const = 0;
// Returns a list of all the children expressions
virtual const std::vector<ASTNodePtr> children() const = 0;
// Uses default propagation, can be overridden
virtual void propagateNullability() {
defaultNullabilityPropagation();
}
// List of include files required for the expression to be executed
virtual std::vector<std::string> getHeaderFiles() const {
return {};
}
// List of libs required for the expression to be executed
virtual std::vector<compiler_utils::LibraryDescriptor> getLibs() const {
return {};
}
template <typename T>
const T* as() const {
return dynamic_cast<const T*>(this);
}
template <typename T>
T* as() {
return dynamic_cast<T*>(this);
}
virtual bool isConstantExpression() const {
return false;
}
// Generate code with nullable info and return the name of the variable
// storing the results. Result should be written to output. For complex types
// output should not be null before written.
virtual CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const = 0;
// Return the null mode of the expression, used in default
// propagateNullability.
virtual ExpressionNullMode getNullMode() const = 0;
// Return the sql data type of the expression
const velox::Type& type() const {
return *type_.get();
}
/// Return the sql data type of the expression
const velox::TypePtr typePtr() const {
return type_;
}
/// Check weather the expression is typed
bool typed() const {
return type_ != nullptr && type_->kind() != TypeKind::INVALID;
}
void validateTyped() const {
if (!typed()) {
throw ASTValidationException("ast node not typed");
}
}
// Returns the nullability of the node
bool maybeNull() const {
return maybeNull_;
}
void markAllInputsNotNullable();
protected:
// Set the nullability of the node
void setMaybeNull(bool maybeNull) {
maybeNull_ = maybeNull;
}
private:
// The default function that propagates nullability for function call, which
// depends on FunctionNullMode.
void defaultNullabilityPropagation() {
for (auto& child : children()) {
child->propagateNullability();
}
switch (getNullMode()) {
case ExpressionNullMode::NullInNullOut:
case ExpressionNullMode::NullableInNullableOut:
for (auto& child : children()) {
if (child->maybeNull()) {
setMaybeNull(true);
return;
}
}
case ExpressionNullMode::NotNull:
setMaybeNull(false);
return;
case ExpressionNullMode::Custom:
throw CodegenNotSupported(
"Error in AST Design, propagateNullability must be overridden");
}
}
// Stores the type of the expression
std::shared_ptr<velox::Type> type_ = nullptr;
// Stores the nullability of the node
bool maybeNull_ = true;
};
/// Node represents input value reference (reference to value/col in the input
/// row)
class InputRefExpr final : public ASTNode {
public:
InputRefExpr(const TypePtr& type, const std::string& name, size_t index)
: ASTNode(type), name_(name), index_(index) {}
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
const std::vector<ASTNodePtr> children() const override {
return {};
}
ExpressionNullMode getNullMode() const override {
return ExpressionNullMode::Custom;
}
void propagateNullability() override {
// set in constructor
}
/// Set the nullability of the input
void setMaybeNull(bool maybeNull) {
ASTNode::setMaybeNull(maybeNull);
}
void validate() const override {
validateTyped();
if (index_ < 0) {
throw ASTValidationException(
"input reference expression expect a positive index");
}
}
/// Return the index of referenced column
size_t index() const {
return index_;
}
/// Return the column name
const std::string& name() const {
return name_;
}
private:
/// Name of the referenced column
std::string name_;
// The index of referenced column
size_t index_;
};
// An expression that combines the children expressions into a row (tuple)
class MakeRowExpression final : public ASTNode {
public:
MakeRowExpression(
const TypePtr& type,
const std::vector<ASTNodePtr>& children)
: ASTNode(type), children_(children) {}
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
const std::vector<ASTNodePtr> children() const override {
return children_;
}
ExpressionNullMode getNullMode() const override {
return ExpressionNullMode::NotNull;
}
void validate() const override {
validateTyped();
if (children_.size() == 0) {
throw ASTValidationException(
"output expression should have at least one child");
}
for (auto& child : children_) {
child->validate();
}
}
/// Return the size of the output tuple
size_t width() {
return children_.size();
}
private:
/// The output expression, a child at index X write to output[X]
std::vector<ASTNodePtr> children_;
};
// If expression AST node
class IfExpression final : public ASTNode {
public:
IfExpression(
const TypePtr& type,
ASTNodePtr condition,
ASTNodePtr thenPart,
ASTNodePtr elsePart,
bool isEager = false)
: ASTNode(type),
condition_(condition),
thenPart_(thenPart),
elsePart_(elsePart),
isEager_(isEager) {}
IfExpression(
const TypePtr& type,
ASTNodePtr condition,
ASTNodePtr thenPart,
bool isEager = false)
: ASTNode(type),
condition_(condition),
thenPart_(thenPart),
elsePart_(nullptr),
isEager_(isEager) {}
void propagateNullability() override;
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
void validate() const override;
const std::vector<ASTNodePtr> children() const override {
return {condition_, thenPart_, elsePart_};
}
ExpressionNullMode getNullMode() const override {
return ExpressionNullMode::Custom;
}
private:
// AST node representing the condition
ASTNodePtr condition_;
// AST node representing then part
ASTNodePtr thenPart_;
// AST node representing else part
ASTNodePtr elsePart_;
// If isEager_ then and else always executed for side effects but only one
// value returned
bool isEager_;
};
// Switch expression AST node
class SwitchExpression final : public ASTNode {
public:
SwitchExpression(const TypePtr& type, std::vector<ASTNodePtr>& inputs)
: ASTNode(type), inputs_(inputs) {}
void propagateNullability() override;
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
void validate() const override;
const std::vector<ASTNodePtr> children() const override {
return inputs_;
}
ExpressionNullMode getNullMode() const override {
return ExpressionNullMode::Custom;
}
private:
// A sequence of WHEN ... THEN ... and maybe a trailing ELSE
std::vector<ASTNodePtr> inputs_;
};
/// AST node represent a general function call with no code-gen specific
/// optimizations and handling.
class UDFCallExpr final : public ASTNode {
public:
UDFCallExpr(
const TypePtr& type,
const UDFInformation& udfInformation,
const std::vector<ASTNodePtr>& children)
: ASTNode(type), udfInformation_(udfInformation), children_(children) {
udfInformation.validate(false /*veloxNamesMustBeSet*/);
}
void validate() const override;
const std::vector<ASTNodePtr> children() const override {
return children_;
}
std::vector<std::string> getHeaderFiles() const override {
if (udfInformation_.hasHeaderFiles()) {
return udfInformation_.getHeaderFiles();
}
return {};
}
std::vector<compiler_utils::LibraryDescriptor> getLibs() const override {
if (udfInformation_.hasLibs()) {
return udfInformation_.getLibs();
}
return {};
}
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
ExpressionNullMode getNullMode() const override {
return udfInformation_.getNullMode();
}
bool isNullableOutput() const {
return udfInformation_.isOptionalOutput();
}
protected:
// Return the function name used in the generated code
const std::string getFunctionName() const {
return udfInformation_.getCalledFunctionName();
}
private:
// Information about the called udf are stored in this structure
const UDFInformation udfInformation_;
// Function call input arguments
std::vector<ASTNodePtr> children_;
};
class CoalesceExpr final : public ASTNode {
public:
CoalesceExpr(const TypePtr& type, const std::vector<ASTNodePtr>& children)
: ASTNode(type), children_(children) {}
void validate() const override {
validateTyped();
}
const std::vector<ASTNodePtr> children() const override {
return children_;
}
void propagateNullability() override;
CodeSnippet generateCode(
CodegenCtx& exprCodegenCtx,
const std::string& outputVarName) const override;
virtual ExpressionNullMode getNullMode() const override {
return ExpressionNullMode::Custom;
}
private:
std::vector<ASTNodePtr> children_;
};
} // namespace codegen
} // namespace velox
} // namespace facebook