absl::Status DatetimeBucket()

in sql_utils/public/functions/date_time_util.cc [4275:4414]


absl::Status DatetimeBucket(const DatetimeValue& input,
                            bigquery_ml_utils::IntervalValue bucket_width,
                            const DatetimeValue& origin, TimestampScale scale,
                            DatetimeValue* output) {
  SQL_RET_CHECK(scale == kMicroseconds || scale == kNanoseconds)
      << "Only kMicroseconds and kNanoseconds are acceptable values for scale";
  if (scale == kMicroseconds && bucket_width.get_nano_fractions() != 0) {
    return MakeEvalError() << "DATETIME_BUCKET doesn't support bucket width "
                              "INTERVAL with nanoseconds precision";
  }
  // Nano fractions can't be negative, so only checking months, days and micros
  // here.
  if (bucket_width.get_months() < 0 || bucket_width.get_days() < 0 ||
      bucket_width.get_micros() < 0) {
    return MakeEvalError() << "DATETIME_BUCKET doesn't support negative "
                              "bucket width INTERVAL";
  }
  // We count micros and nano_fractions as one field since they are logically
  // represent one field - nanoseconds.
  int fields_set =
      (bucket_width.get_months() > 0 ? 1 : 0) +
      (bucket_width.get_days() > 0 ? 1 : 0) +
      ((bucket_width.get_micros() > 0 || bucket_width.get_nano_fractions() > 0)
           ? 1
           : 0);
  if (fields_set != 1) {
    return MakeEvalError() << "DATETIME_BUCKET requires exactly one non-zero "
                              "INTERVAL part in bucket width";
  }

  // Here we branch out into handling MONTHs and other interval types.
  // MONTH is special since it's a non-fixed interval, therefore it requires
  // use of civil time library to perform all arithmetic on dates.
  if (bucket_width.get_months() > 0) {
    absl::CivilMonth input_civil =
        absl::CivilMonth(input.ConvertToCivilSecond());
    absl::CivilMonth origin_civil =
        absl::CivilMonth(origin.ConvertToCivilSecond());
    int64_t rem = (input_civil - origin_civil) % bucket_width.get_months();
    absl::CivilMonth result = input_civil - rem;

    // We consider input and origin day to be equal when they both are the
    // ends of the month, so when that happens we just set input day to the
    // origin day, which we only use for comparison purposes.
    int input_day = input.Day();
    int origin_day = origin.Day();
    if (input_day < origin_day &&
        IsLastDayOfTheMonth(origin.Year(), origin.Month(), origin.Day()) &&
        IsLastDayOfTheMonth(input.Year(), input.Month(), input.Day())) {
      input_day = origin_day;
    }

    auto to_nanos = [](int day, int hour, int minute, int second,
                       int nanosecond) -> int64_t {
      return kNumNanosPerDay * day + kNumNanosPerHour * hour +
             kNumNanosPerMinute * minute + kNumNanosPerSecond * second +
             nanosecond;
    };
    int64_t input_sub_month_parts_nanos =
        to_nanos(input_day, input.Hour(), input.Minute(), input.Second(),
                 input.Nanoseconds());
    int64_t origin_sub_month_parts_nanos =
        to_nanos(origin_day, origin.Hour(), origin.Minute(), origin.Second(),
                 origin.Nanoseconds());

    // Negative remainder indicates that input < origin. When input precedes
    // origin we need shift the result backwards by one bucket.
    //
    // We also shift the result to the previous bucket when the input's
    // sub-month parts are less than origin's. This compensates for the fact
    // that when we did the math on CivilMonth we completely discarded
    // sub-month parts.
    if (rem < 0 || (rem == 0 && input_sub_month_parts_nanos <
                                    origin_sub_month_parts_nanos)) {
      result -= bucket_width.get_months();
    }

    // cast is safe, given method contract.
    int year = static_cast<int32_t>(result.year());
    int month = result.month();
    int day = origin.Day();
    // AdjustYearMonthDay takes care of handling last day of the month case: if
    // the resulting month has fewer days than the origin's month, then the
    // result day is the last day of the result month.
    AdjustYearMonthDay(&year, &month, &day);
    *output = DatetimeValue::FromYMDHMSAndNanos(
        year, month, day, origin.Hour(), origin.Minute(), origin.Second(),
        origin.Nanoseconds());
  } else {
    // In this branch we either have an interval with days part or nanos part.
    // In DATETIME we can assume that a day is equal to 24 hours, therefore
    // we just convert a day part to nanoseconds.
    absl::int128 bucket_width_nanos =
        bucket_width.get_days() > 0
            ? bucket_width.get_days() * IntervalValue::kNanosInDay
            : bucket_width.get_nanos();
    absl::CivilSecond input_civil = input.ConvertToCivilSecond();
    absl::CivilSecond origin_civil = origin.ConvertToCivilSecond();
    // We use year -10,000 as an epoch to avoid handling negative number in the
    // calculations.
    static constexpr absl::CivilSecond kEpochCivil(-10000, 0, 0, 0, 0, 0);

    // Note that since we do all calculation in int128 we don't have to check
    // for overflows. Representing 10,000 years (max value of DATETIME) in
    // nanoseconds only requires 69 bits.
    absl::int128 input_nanos = absl::int128(input_civil - kEpochCivil) *
                                   IntervalValue::kNanosInSecond +
                               input.Nanoseconds();
    absl::int128 origin_nanos = absl::int128(origin_civil - kEpochCivil) *
                                    IntervalValue::kNanosInSecond +
                                origin.Nanoseconds();
    absl::int128 rem = (input_nanos - origin_nanos) % bucket_width_nanos;
    absl::int128 result = input_nanos - rem;
    if (rem < 0) {
      // Negative remainder indicates that input < origin. When input precedes
      // origin we need shift the result backwards by one bucket.
      result -= bucket_width_nanos;
    }
    // 40 bits are required to represent number of seconds in 10,000 years.
    absl::CivilSecond result_seconds_part =
        kEpochCivil +
        static_cast<int64_t>(result / IntervalValue::kNanosInSecond);
    // result is guaranteed to be a positive number due to a choice of
    // kEpochCivil, thefore we don't need handle a case when result_nanos_part
    // is negative.
    int32_t result_nanos_part =
        static_cast<int32_t>(result % IntervalValue::kNanosInSecond);
    *output = DatetimeValue::FromYMDHMSAndNanos(
        static_cast<int32_t>(result_seconds_part.year()),
        result_seconds_part.month(), result_seconds_part.day(),
        result_seconds_part.hour(), result_seconds_part.minute(),
        result_seconds_part.second(), result_nanos_part);
  }

  if (!output->IsValid()) {
    return MakeEvalError() << "Bucket for " << input.DebugString()
                           << " is outside of datetime range";
  }
  return absl::OkStatus();
}