def getColumnInterval()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala [417:770]


  def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on TableAggregates.
   *
   * @param aggregate
   *   TableAggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on TableAggregate
   */
  def getColumnInterval(
      aggregate: TableAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval =
    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on batch group aggregate.
   *
   * @param aggregate
   *   batch group aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on batch group aggregate
   */
  def getColumnInterval(
      aggregate: BatchPhysicalGroupAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream group aggregate.
   *
   * @param aggregate
   *   stream group aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream group table aggregate.
   *
   * @param aggregate
   *   stream group table aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream group TableAggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGroupTableAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream local group aggregate.
   *
   * @param aggregate
   *   stream local group aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream local group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalLocalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream global group aggregate.
   *
   * @param aggregate
   *   stream global group aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream global group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGlobalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on window aggregate.
   *
   * @param agg
   *   window aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on window Aggregate
   */
  def getColumnInterval(agg: WindowAggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
    estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on batch window aggregate.
   *
   * @param agg
   *   batch window aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on batch window Aggregate
   */
  def getColumnInterval(
      agg: BatchPhysicalWindowAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on stream window aggregate.
   *
   * @param agg
   *   stream window aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream window Aggregate
   */
  def getColumnInterval(
      agg: StreamPhysicalGroupWindowAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on stream window table aggregate.
   *
   * @param agg
   *   stream window table aggregate RelNode
   * @param mq
   *   RelMetadataQuery instance
   * @param index
   *   the index of the given column
   * @return
   *   interval of the given column on stream window Aggregate
   */
  def getColumnInterval(
      agg: StreamPhysicalGroupWindowTableAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  private def estimateColumnIntervalOfAggregate(
      aggregate: SingleRel,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val input = aggregate.getInput
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val groupSet = aggregate match {
      case agg: StreamPhysicalGroupAggregate => agg.grouping
      case agg: StreamPhysicalLocalGroupAggregate => agg.grouping
      case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping
      case agg: StreamPhysicalIncrementalGroupAggregate => agg.partialAggGrouping
      case agg: StreamPhysicalGroupWindowAggregate => agg.grouping
      case agg: BatchPhysicalGroupAggregateBase => agg.grouping ++ agg.auxGrouping
      case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
      case agg: BatchPhysicalLocalSortWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
      case agg: BatchPhysicalLocalHashWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
      case agg: BatchPhysicalWindowAggregateBase => agg.grouping ++ agg.auxGrouping
      case agg: TableAggregate => agg.getGroupSet.toArray
      case agg: StreamPhysicalGroupTableAggregate => agg.grouping
      case agg: StreamPhysicalGroupWindowTableAggregate => agg.grouping
    }

    if (index < groupSet.length) {
      // estimates group keys according to the input relNodes.
      val sourceFieldIndex = groupSet(index)
      fmq.getColumnInterval(input, sourceFieldIndex)
    } else {
      def getAggCallFromLocalAgg(
          index: Int,
          aggCalls: Seq[AggregateCall],
          inputType: RelDataType,
          isBounded: Boolean): AggregateCall = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          unwrapTypeFactory(input),
          aggCalls,
          inputType,
          isBounded)
        if (outputIndexToAggCallIndexMap.containsKey(index)) {
          val realIndex = outputIndexToAggCallIndexMap.get(index)
          aggCalls(realIndex)
        } else {
          null
        }
      }

      def getAggCallIndexInLocalAgg(
          index: Int,
          globalAggCalls: Seq[AggregateCall],
          inputRowType: RelDataType,
          isBounded: Boolean): Integer = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          unwrapTypeFactory(input),
          globalAggCalls,
          inputRowType,
          isBounded)

        outputIndexToAggCallIndexMap.foreach {
          case (k, v) =>
            if (v == index) {
              return k
            }
        }
        null.asInstanceOf[Integer]
      }

      if (index < groupSet.length) {
        // estimates group keys according to the input relNodes.
        val sourceFieldIndex = groupSet(index)
        fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
      } else {
        val aggCallIndex = index - groupSet.length
        val aggCall = aggregate match {
          case agg: StreamPhysicalGroupAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: StreamPhysicalGlobalGroupAggregate if agg.aggCalls.length > aggCallIndex =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex,
              agg.aggCalls,
              agg.localAggInputRowType,
              isBounded = false)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: StreamPhysicalLocalGroupAggregate =>
            getAggCallFromLocalAgg(
              aggCallIndex,
              agg.aggCalls,
              agg.getInput.getRowType,
              isBounded = false)
          case agg: StreamPhysicalIncrementalGroupAggregate
              if agg.partialAggCalls.length > aggCallIndex =>
            agg.partialAggCalls(aggCallIndex)
          case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: BatchPhysicalLocalHashAggregate =>
            getAggCallFromLocalAgg(
              aggCallIndex,
              agg.getAggCallList,
              agg.getInput.getRowType,
              isBounded = true)
          case agg: BatchPhysicalHashAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex,
              agg.getAggCallList,
              agg.aggInputRowType,
              isBounded = true)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchPhysicalLocalSortAggregate =>
            getAggCallFromLocalAgg(
              aggCallIndex,
              agg.getAggCallList,
              agg.getInput.getRowType,
              isBounded = true)
          case agg: BatchPhysicalSortAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex,
              agg.getAggCallList,
              agg.aggInputRowType,
              isBounded = true)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchPhysicalGroupAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case agg: Aggregate =>
            val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
            if (aggCalls.length > aggCallIndex) {
              aggCalls(aggCallIndex)
            } else {
              null
            }
          case agg: BatchPhysicalWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case _ => null
        }

        if (aggCall != null) {
          aggCall.getAggregation.getKind match {
            case SUM | SUM0 =>
              val inputInterval = fmq.getColumnInterval(input, aggCall.getArgList.get(0))
              if (inputInterval != null) {
                inputInterval match {
                  case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
                    if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
                      RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
                    } else {
                      null.asInstanceOf[ValueInterval]
                    }
                  case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
                    if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
                      LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
                    } else {
                      null
                    }
                  case _ => null
                }
              } else {
                null
              }
            case COUNT =>
              RightSemiInfiniteValueInterval(JBigDecimal.valueOf(0), includeLower = true)
            // TODO add more built-in agg functions
            case _ => null
          }
        } else {
          null
        }
      }
    }
  }