velox/experimental/codegen/ast/ASTNode.h (304 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include <memory> #include <optional> #include <shared_mutex> #include <sstream> #include <string> #include <vector> #include "velox/experimental/codegen/CodegenExceptions.h" #include "velox/experimental/codegen/ast/CodegenCtx.h" #include "velox/experimental/codegen/compiler_utils/LibraryDescriptor.h" #include "velox/experimental/codegen/udf_manager/ExpressionNullMode.h" #include "velox/experimental/codegen/udf_manager/UDFManager.h" #include "velox/type/Type.h" namespace facebook { namespace velox { namespace codegen { using ASTNodePtr = std::shared_ptr<class ASTNode>; // A structure that represent the intermediate partial code generated by the // generateCode. struct CodeSnippet { public: CodeSnippet( const std::string& outputVarName = "", const std::string& code = "") : outputVarName_(outputVarName), code_(code) {} const std::string& code() const { return code_; } // Return a lambda that executes the code and returns the result value stored // in the output variable. std::string getAsLambda(const std::string& lambdaName) const { return fmt::format( "auto {name} = [](){{ {code} return {outputVar}; }}", fmt::arg("code", code_), fmt::arg("outputVar", outputVarName_), fmt::arg("name", lambdaName)); } private: // The name of the output variable that holds the expression result value std::string outputVarName_; // The code that need to be executed std::string code_; }; /// Abstract class for all expression supported in codegen class ASTNode { public: explicit ASTNode(const TypePtr& type) : type_(std::const_pointer_cast<Type>(type)) {} virtual ~ASTNode() = 0; // Validate if the expression node is complete and ready for codeGen virtual void validate() const = 0; // Returns a list of all the children expressions virtual const std::vector<ASTNodePtr> children() const = 0; // Uses default propagation, can be overridden virtual void propagateNullability() { defaultNullabilityPropagation(); } // List of include files required for the expression to be executed virtual std::vector<std::string> getHeaderFiles() const { return {}; } // List of libs required for the expression to be executed virtual std::vector<compiler_utils::LibraryDescriptor> getLibs() const { return {}; } template <typename T> const T* as() const { return dynamic_cast<const T*>(this); } template <typename T> T* as() { return dynamic_cast<T*>(this); } virtual bool isConstantExpression() const { return false; } // Generate code with nullable info and return the name of the variable // storing the results. Result should be written to output. For complex types // output should not be null before written. virtual CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const = 0; // Return the null mode of the expression, used in default // propagateNullability. virtual ExpressionNullMode getNullMode() const = 0; // Return the sql data type of the expression const velox::Type& type() const { return *type_.get(); } /// Return the sql data type of the expression const velox::TypePtr typePtr() const { return type_; } /// Check weather the expression is typed bool typed() const { return type_ != nullptr && type_->kind() != TypeKind::INVALID; } void validateTyped() const { if (!typed()) { throw ASTValidationException("ast node not typed"); } } // Returns the nullability of the node bool maybeNull() const { return maybeNull_; } void markAllInputsNotNullable(); protected: // Set the nullability of the node void setMaybeNull(bool maybeNull) { maybeNull_ = maybeNull; } private: // The default function that propagates nullability for function call, which // depends on FunctionNullMode. void defaultNullabilityPropagation() { for (auto& child : children()) { child->propagateNullability(); } switch (getNullMode()) { case ExpressionNullMode::NullInNullOut: case ExpressionNullMode::NullableInNullableOut: for (auto& child : children()) { if (child->maybeNull()) { setMaybeNull(true); return; } } case ExpressionNullMode::NotNull: setMaybeNull(false); return; case ExpressionNullMode::Custom: throw CodegenNotSupported( "Error in AST Design, propagateNullability must be overridden"); } } // Stores the type of the expression std::shared_ptr<velox::Type> type_ = nullptr; // Stores the nullability of the node bool maybeNull_ = true; }; /// Node represents input value reference (reference to value/col in the input /// row) class InputRefExpr final : public ASTNode { public: InputRefExpr(const TypePtr& type, const std::string& name, size_t index) : ASTNode(type), name_(name), index_(index) {} CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; const std::vector<ASTNodePtr> children() const override { return {}; } ExpressionNullMode getNullMode() const override { return ExpressionNullMode::Custom; } void propagateNullability() override { // set in constructor } /// Set the nullability of the input void setMaybeNull(bool maybeNull) { ASTNode::setMaybeNull(maybeNull); } void validate() const override { validateTyped(); if (index_ < 0) { throw ASTValidationException( "input reference expression expect a positive index"); } } /// Return the index of referenced column size_t index() const { return index_; } /// Return the column name const std::string& name() const { return name_; } private: /// Name of the referenced column std::string name_; // The index of referenced column size_t index_; }; // An expression that combines the children expressions into a row (tuple) class MakeRowExpression final : public ASTNode { public: MakeRowExpression( const TypePtr& type, const std::vector<ASTNodePtr>& children) : ASTNode(type), children_(children) {} CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; const std::vector<ASTNodePtr> children() const override { return children_; } ExpressionNullMode getNullMode() const override { return ExpressionNullMode::NotNull; } void validate() const override { validateTyped(); if (children_.size() == 0) { throw ASTValidationException( "output expression should have at least one child"); } for (auto& child : children_) { child->validate(); } } /// Return the size of the output tuple size_t width() { return children_.size(); } private: /// The output expression, a child at index X write to output[X] std::vector<ASTNodePtr> children_; }; // If expression AST node class IfExpression final : public ASTNode { public: IfExpression( const TypePtr& type, ASTNodePtr condition, ASTNodePtr thenPart, ASTNodePtr elsePart, bool isEager = false) : ASTNode(type), condition_(condition), thenPart_(thenPart), elsePart_(elsePart), isEager_(isEager) {} IfExpression( const TypePtr& type, ASTNodePtr condition, ASTNodePtr thenPart, bool isEager = false) : ASTNode(type), condition_(condition), thenPart_(thenPart), elsePart_(nullptr), isEager_(isEager) {} void propagateNullability() override; CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; void validate() const override; const std::vector<ASTNodePtr> children() const override { return {condition_, thenPart_, elsePart_}; } ExpressionNullMode getNullMode() const override { return ExpressionNullMode::Custom; } private: // AST node representing the condition ASTNodePtr condition_; // AST node representing then part ASTNodePtr thenPart_; // AST node representing else part ASTNodePtr elsePart_; // If isEager_ then and else always executed for side effects but only one // value returned bool isEager_; }; // Switch expression AST node class SwitchExpression final : public ASTNode { public: SwitchExpression(const TypePtr& type, std::vector<ASTNodePtr>& inputs) : ASTNode(type), inputs_(inputs) {} void propagateNullability() override; CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; void validate() const override; const std::vector<ASTNodePtr> children() const override { return inputs_; } ExpressionNullMode getNullMode() const override { return ExpressionNullMode::Custom; } private: // A sequence of WHEN ... THEN ... and maybe a trailing ELSE std::vector<ASTNodePtr> inputs_; }; /// AST node represent a general function call with no code-gen specific /// optimizations and handling. class UDFCallExpr final : public ASTNode { public: UDFCallExpr( const TypePtr& type, const UDFInformation& udfInformation, const std::vector<ASTNodePtr>& children) : ASTNode(type), udfInformation_(udfInformation), children_(children) { udfInformation.validate(false /*veloxNamesMustBeSet*/); } void validate() const override; const std::vector<ASTNodePtr> children() const override { return children_; } std::vector<std::string> getHeaderFiles() const override { if (udfInformation_.hasHeaderFiles()) { return udfInformation_.getHeaderFiles(); } return {}; } std::vector<compiler_utils::LibraryDescriptor> getLibs() const override { if (udfInformation_.hasLibs()) { return udfInformation_.getLibs(); } return {}; } CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; ExpressionNullMode getNullMode() const override { return udfInformation_.getNullMode(); } bool isNullableOutput() const { return udfInformation_.isOptionalOutput(); } protected: // Return the function name used in the generated code const std::string getFunctionName() const { return udfInformation_.getCalledFunctionName(); } private: // Information about the called udf are stored in this structure const UDFInformation udfInformation_; // Function call input arguments std::vector<ASTNodePtr> children_; }; class CoalesceExpr final : public ASTNode { public: CoalesceExpr(const TypePtr& type, const std::vector<ASTNodePtr>& children) : ASTNode(type), children_(children) {} void validate() const override { validateTyped(); } const std::vector<ASTNodePtr> children() const override { return children_; } void propagateNullability() override; CodeSnippet generateCode( CodegenCtx& exprCodegenCtx, const std::string& outputVarName) const override; virtual ExpressionNullMode getNullMode() const override { return ExpressionNullMode::Custom; } private: std::vector<ASTNodePtr> children_; }; } // namespace codegen } // namespace velox } // namespace facebook