in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala [417:770]
def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on TableAggregates.
*
* @param aggregate
* TableAggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on TableAggregate
*/
def getColumnInterval(
aggregate: TableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on batch group aggregate.
*
* @param aggregate
* batch group aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on batch group aggregate
*/
def getColumnInterval(
aggregate: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream group aggregate.
*
* @param aggregate
* stream group aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream group Aggregate
*/
def getColumnInterval(
aggregate: StreamPhysicalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream group table aggregate.
*
* @param aggregate
* stream group table aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream group TableAggregate
*/
def getColumnInterval(
aggregate: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream local group aggregate.
*
* @param aggregate
* stream local group aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream local group Aggregate
*/
def getColumnInterval(
aggregate: StreamPhysicalLocalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream global group aggregate.
*
* @param aggregate
* stream global group aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream global group Aggregate
*/
def getColumnInterval(
aggregate: StreamPhysicalGlobalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on window aggregate.
*
* @param agg
* window aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on window Aggregate
*/
def getColumnInterval(agg: WindowAggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on batch window aggregate.
*
* @param agg
* batch window aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on batch window Aggregate
*/
def getColumnInterval(
agg: BatchPhysicalWindowAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on stream window aggregate.
*
* @param agg
* stream window aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream window Aggregate
*/
def getColumnInterval(
agg: StreamPhysicalGroupWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on stream window table aggregate.
*
* @param agg
* stream window table aggregate RelNode
* @param mq
* RelMetadataQuery instance
* @param index
* the index of the given column
* @return
* interval of the given column on stream window Aggregate
*/
def getColumnInterval(
agg: StreamPhysicalGroupWindowTableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
private def estimateColumnIntervalOfAggregate(
aggregate: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val input = aggregate.getInput
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val groupSet = aggregate match {
case agg: StreamPhysicalGroupAggregate => agg.grouping
case agg: StreamPhysicalLocalGroupAggregate => agg.grouping
case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping
case agg: StreamPhysicalIncrementalGroupAggregate => agg.partialAggGrouping
case agg: StreamPhysicalGroupWindowAggregate => agg.grouping
case agg: BatchPhysicalGroupAggregateBase => agg.grouping ++ agg.auxGrouping
case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
case agg: BatchPhysicalLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
case agg: BatchPhysicalLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
case agg: BatchPhysicalWindowAggregateBase => agg.grouping ++ agg.auxGrouping
case agg: TableAggregate => agg.getGroupSet.toArray
case agg: StreamPhysicalGroupTableAggregate => agg.grouping
case agg: StreamPhysicalGroupWindowTableAggregate => agg.grouping
}
if (index < groupSet.length) {
// estimates group keys according to the input relNodes.
val sourceFieldIndex = groupSet(index)
fmq.getColumnInterval(input, sourceFieldIndex)
} else {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
inputType: RelDataType,
isBounded: Boolean): AggregateCall = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
unwrapTypeFactory(input),
aggCalls,
inputType,
isBounded)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
} else {
null
}
}
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
inputRowType: RelDataType,
isBounded: Boolean): Integer = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
unwrapTypeFactory(input),
globalAggCalls,
inputRowType,
isBounded)
outputIndexToAggCallIndexMap.foreach {
case (k, v) =>
if (v == index) {
return k
}
}
null.asInstanceOf[Integer]
}
if (index < groupSet.length) {
// estimates group keys according to the input relNodes.
val sourceFieldIndex = groupSet(index)
fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
} else {
val aggCallIndex = index - groupSet.length
val aggCall = aggregate match {
case agg: StreamPhysicalGroupAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: StreamPhysicalGlobalGroupAggregate if agg.aggCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex,
agg.aggCalls,
agg.localAggInputRowType,
isBounded = false)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamPhysicalLocalGroupAggregate =>
getAggCallFromLocalAgg(
aggCallIndex,
agg.aggCalls,
agg.getInput.getRowType,
isBounded = false)
case agg: StreamPhysicalIncrementalGroupAggregate
if agg.partialAggCalls.length > aggCallIndex =>
agg.partialAggCalls(aggCallIndex)
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchPhysicalLocalHashAggregate =>
getAggCallFromLocalAgg(
aggCallIndex,
agg.getAggCallList,
agg.getInput.getRowType,
isBounded = true)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex,
agg.getAggCallList,
agg.aggInputRowType,
isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalLocalSortAggregate =>
getAggCallFromLocalAgg(
aggCallIndex,
agg.getAggCallList,
agg.getInput.getRowType,
isBounded = true)
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex,
agg.getAggCallList,
agg.aggInputRowType,
isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalGroupAggregateBase if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList(aggCallIndex)
case agg: Aggregate =>
val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
if (aggCalls.length > aggCallIndex) {
aggCalls(aggCallIndex)
} else {
null
}
case agg: BatchPhysicalWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList(aggCallIndex)
case _ => null
}
if (aggCall != null) {
aggCall.getAggregation.getKind match {
case SUM | SUM0 =>
val inputInterval = fmq.getColumnInterval(input, aggCall.getArgList.get(0))
if (inputInterval != null) {
inputInterval match {
case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
} else {
null.asInstanceOf[ValueInterval]
}
case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
} else {
null
}
case _ => null
}
} else {
null
}
case COUNT =>
RightSemiInfiniteValueInterval(JBigDecimal.valueOf(0), includeLower = true)
// TODO add more built-in agg functions
case _ => null
}
} else {
null
}
}
}
}