tensorflow_ops/time_ops_kernel.cc (441 lines of code) (raw):
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <string>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "sql_utils/public/functions/cast_date_time.h"
#include "sql_utils/public/functions/date_time_util.h"
#include "sql_utils/public/functions/datetime.pb.h"
#include "sql_utils/public/functions/parse_date_time.h"
#include "tensorflow_ops/constants.h"
#include "tensorflow_ops/utils.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
using ::tensorflow::DEVICE_CPU;
using ::tensorflow::OpKernel;
using ::tensorflow::OpKernelConstruction;
using ::tensorflow::OpKernelContext;
using ::tensorflow::Tensor;
using ::tensorflow::tstring;
using ::tensorflow::errors::InvalidArgument;
namespace bigquery_ml_utils {
class TimeFromComponents : public OpKernel {
public:
explicit TimeFromComponents(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the hour tensor
const Tensor& hour_tensor = context->input(0);
auto hour = hour_tensor.flat<int64_t>();
// Grab the minute tensor
const Tensor& minute_tensor = context->input(1);
auto minute = minute_tensor.flat<int64_t>();
// Grab the second tensor
const Tensor& second_tensor = context->input(2);
auto second = second_tensor.flat<int64_t>();
OP_REQUIRES(
context, hour.size() == minute.size() && hour.size() == second.size(),
InvalidArgument(absl::Substitute("Errors in $0: Inputs must have the "
"same shape, but are: $1, $2, $3",
name(), hour.size(), minute.size(),
second.size())));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, hour_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = hour.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time;
OP_REQUIRES_OK(context, ToTslStatus(name(), functions::ConstructTime(
hour(i), minute(i),
second(i), &time)));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class TimeFromTimestamp : public OpKernel {
public:
explicit TimeFromTimestamp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the timestamp tensor
const Tensor& timestamp_tensor = context->input(0);
auto timestamp = timestamp_tensor.flat<tstring>();
// Grab the time zone tensor
const Tensor& time_zone_tensor = context->input(1);
std::string time_zone = time_zone_tensor.flat<tstring>()(0);
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0, timestamp_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = timestamp.size();
for (int i = 0; i < N; i++) {
// Parse the timestamp.
int64_t ts;
OP_REQUIRES_OK(
context,
ParseInputTimestamp(timestamp(i), absl::UTCTimeZone(), name(), &ts));
// Extract time from timestamp.
TimeValue time;
OP_REQUIRES_OK(
context,
ToTslStatus(name(), functions::ConvertTimestampToTime(
absl::FromUnixMicros(ts), time_zone, &time)));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class TimeFromDatetime : public OpKernel {
public:
explicit TimeFromDatetime(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the datetime tensor
const Tensor& datetime_tensor = context->input(0);
auto datetime = datetime_tensor.flat<tstring>();
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, datetime_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = datetime.size();
for (int i = 0; i < N; i++) {
// Parse the datetime.
DatetimeValue dt;
OP_REQUIRES_OK(context, ParseInputDatetime(datetime(i), name(), &dt));
// Extract time from datetime.
TimeValue time;
OP_REQUIRES_OK(
context,
ToTslStatus(name(), functions::ExtractTimeFromDatetime(dt, &time)));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class CastToTimeFromString : public OpKernel {
public:
explicit CastToTimeFromString(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time_string tensor
const Tensor& time_string_tensor = context->input(0);
auto time_string = time_string_tensor.flat<tstring>();
// Grab the format tensor
const Tensor& format_tensor = context->input(1);
std::string format = format_tensor.flat<tstring>()(0);
// Grab the with_format tensor
const Tensor& with_format_tensor = context->input(2);
bool with_format = with_format_tensor.flat<bool>()(0);
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0, time_string_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time_string.size();
for (int i = 0; i < N; i++) {
// Convert string to time.
TimeValue time;
if (with_format) {
// Convert string with format
OP_REQUIRES_OK(
context, ToTslStatus(name(), functions::CastStringToTime(
format, time_string(i),
functions::kMicroseconds, &time)));
} else {
// Convert string without format
OP_REQUIRES_OK(
context, ToTslStatus(name(), functions::ConvertStringToTime(
time_string(i),
functions::kMicroseconds, &time)));
}
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
::tsl::Status TimeAddOperator(TimeValue& time, int64_t interval,
functions::DateTimestampPart& time_part,
absl::string_view function_name, TimeValue* out) {
return ToTslStatus(function_name,
functions::AddTime(time, time_part, interval, out));
}
class TimeAdd : public OpKernel {
public:
explicit TimeAdd(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time tensor
const Tensor& time_tensor = context->input(0);
auto time = time_tensor.flat<tstring>();
// Grab the interval tensor
const Tensor& diff_tensor = context->input(1);
auto interval_int = diff_tensor.flat<int64_t>();
OP_REQUIRES(context, time.size() == interval_int.size(),
InvalidArgument(absl::Substitute(
"Error in $0: time and interval must have the same shape, "
"but are $1, $2",
name(), time.size(), interval_int.size())));
// Grab the part tensor
const Tensor& part_tensor = context->input(2);
std::string part = part_tensor.flat<tstring>()(0);
functions::DateTimestampPart part_enum;
static auto* supported_parts =
new absl::flat_hash_set<functions::DateTimestampPart>(
{functions::MICROSECOND, functions::MILLISECOND, functions::SECOND,
functions::MINUTE, functions::HOUR});
OP_REQUIRES_OK(context, ParseInputDateTimestampPart(
part, name(), &part_enum, *supported_parts));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, time_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_value;
OP_REQUIRES_OK(context, ParseInputTime(time(i), name(), &time_value));
// Extract time from datetime.
TimeValue out_time;
OP_REQUIRES_OK(context, TimeAddOperator(time_value, interval_int(i),
part_enum, name(), &out_time));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(out_time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class TimeSub : public OpKernel {
public:
explicit TimeSub(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time tensor
const Tensor& time_tensor = context->input(0);
auto time = time_tensor.flat<tstring>();
// Grab the interval tensor
const Tensor& diff_tensor = context->input(1);
auto interval_int = diff_tensor.flat<int64_t>();
OP_REQUIRES(context, interval_int.size() == time.size(),
InvalidArgument(absl::Substitute(
"Error in $0: time and interval must have the same shape, "
"but are $1, $2",
name(), time.size(), interval_int.size())));
// Grab the part tensor
const Tensor& part_tensor = context->input(2);
std::string part = part_tensor.flat<tstring>()(0);
functions::DateTimestampPart part_enum;
static auto* supported_parts =
new absl::flat_hash_set<functions::DateTimestampPart>(
{functions::MICROSECOND, functions::MILLISECOND, functions::SECOND,
functions::MINUTE, functions::HOUR});
OP_REQUIRES_OK(context, ParseInputDateTimestampPart(
part, name(), &part_enum, *supported_parts));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, time_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_value;
OP_REQUIRES_OK(context, ParseInputTime(time(i), name(), &time_value));
// Extract time from datetime.
TimeValue out_time;
OP_REQUIRES_OK(context, TimeAddOperator(time_value, -interval_int(i),
part_enum, name(), &out_time));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(out_time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class TimeDiff : public OpKernel {
public:
explicit TimeDiff(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time_a tensor
const Tensor& time_a_tensor = context->input(0);
auto time_a = time_a_tensor.flat<tstring>();
// Grab the time_b tensor
const Tensor& time_b_tensor = context->input(1);
auto time_b = time_b_tensor.flat<tstring>();
OP_REQUIRES(context, time_a.size() == time_b.size(),
InvalidArgument(absl::Substitute(
"Error in $0: time_a and time_b must have the same shape, "
"but are $1, $2",
name(), time_a.size(), time_b.size())));
// Grab the part tensor
const Tensor& part_tensor = context->input(2);
std::string part = part_tensor.flat<tstring>()(0);
functions::DateTimestampPart part_enum;
static auto* supported_parts =
new absl::flat_hash_set<functions::DateTimestampPart>(
{functions::MICROSECOND, functions::MILLISECOND, functions::SECOND,
functions::MINUTE, functions::HOUR});
OP_REQUIRES_OK(context, ParseInputDateTimestampPart(
part, name(), &part_enum, *supported_parts));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, time_a_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int64_t>();
const int N = time_a.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_a_value;
OP_REQUIRES_OK(context, ParseInputTime(time_a(i), name(), &time_a_value));
TimeValue time_b_value;
OP_REQUIRES_OK(context, ParseInputTime(time_b(i), name(), &time_b_value));
// Compute diff.
int64_t out;
OP_REQUIRES_OK(
context,
ToTslStatus(name(), functions::DiffTimes(time_a_value, time_b_value,
part_enum, &out)));
// Set the output value.
output_flat(i) = out;
}
}
};
class TimeTrunc : public OpKernel {
public:
explicit TimeTrunc(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time tensor
const Tensor& time_tensor = context->input(0);
auto time = time_tensor.flat<tstring>();
// Grab the part tensor
const Tensor& part_tensor = context->input(1);
std::string part = part_tensor.flat<tstring>()(0);
functions::DateTimestampPart part_enum;
static auto* supported_parts =
new absl::flat_hash_set<functions::DateTimestampPart>(
{functions::MICROSECOND, functions::MILLISECOND, functions::SECOND,
functions::MINUTE, functions::HOUR});
OP_REQUIRES_OK(context, ParseInputDateTimestampPart(
part, name(), &part_enum, *supported_parts));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, time_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_value;
OP_REQUIRES_OK(context, ParseInputTime(time(i), name(), &time_value));
// Extract time from datetime.
TimeValue out_time;
OP_REQUIRES_OK(
context, ToTslStatus(name(), functions::TruncateTime(
time_value, part_enum, &out_time)));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(out_time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class ExtractFromTime : public OpKernel {
public:
explicit ExtractFromTime(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the time tensor
const Tensor& time_tensor = context->input(0);
auto time = time_tensor.flat<tstring>();
// Grab the part tensor
const Tensor& part_tensor = context->input(1);
std::string part = part_tensor.flat<tstring>()(0);
functions::DateTimestampPart part_enum;
static auto* supported_parts =
new absl::flat_hash_set<functions::DateTimestampPart>(
{functions::MICROSECOND, functions::MILLISECOND, functions::SECOND,
functions::MINUTE, functions::HOUR});
OP_REQUIRES_OK(context, ParseInputDateTimestampPart(
part, name(), &part_enum, *supported_parts));
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, time_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int64_t>();
const int N = time.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_value;
OP_REQUIRES_OK(context, ParseInputTime(time(i), name(), &time_value));
// Extract time from datetime.
int32_t out;
OP_REQUIRES_OK(context,
ToTslStatus(name(), functions::ExtractFromTime(
part_enum, time_value, &out)));
// Set the output value.
// Currently, BQML util inference only supports int64.
output_flat(i) = static_cast<int64_t>(out);
}
}
};
class ParseTime : public OpKernel {
public:
explicit ParseTime(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the format tensor
const Tensor& format_tensor = context->input(0);
std::string format = format_tensor.flat<tstring>()(0);
// Grab the time tensor
const Tensor& time_string_tensor = context->input(1);
auto time_string = time_string_tensor.flat<tstring>();
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0, time_string_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time_string.size();
for (int i = 0; i < N; i++) {
// Parse time.
TimeValue out_time;
OP_REQUIRES_OK(context, ToTslStatus(name(), functions::ParseStringToTime(
format, time_string(i),
functions::kMicroseconds,
&out_time)));
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(out_time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class SafeParseTime : public OpKernel {
public:
explicit SafeParseTime(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the format tensor
const Tensor& format_tensor = context->input(0);
std::string format = format_tensor.flat<tstring>()(0);
// Grab the time tensor
const Tensor& time_string_tensor = context->input(1);
auto time_string = time_string_tensor.flat<tstring>();
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0, time_string_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time_string.size();
for (int i = 0; i < N; i++) {
// Parse time.
TimeValue out_time;
if (!functions::ParseStringToTime(format, time_string(i),
functions::kMicroseconds, &out_time)
.ok()) {
// Set the NULL-equivalent output value for unsuccessful parsing.
OP_REQUIRES_OK(
context,
ToTslStatus(name(), functions::ParseStringToTime(
kTimeFormatString, kNullTime,
functions::kMicroseconds, &out_time)));
}
// Format time to string.
std::string out;
OP_REQUIRES_OK(context, FormatOutputTime(out_time, name(), &out));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
class FormatTime : public OpKernel {
public:
explicit FormatTime(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the format tensor
const Tensor& format_tensor = context->input(0);
std::string format = format_tensor.flat<tstring>()(0);
// Grab the time tensor
const Tensor& time_string_tensor = context->input(1);
auto time = time_string_tensor.flat<tstring>();
// Create an output tensor with the shape of the time tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0, time_string_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<tstring>();
const int N = time.size();
for (int i = 0; i < N; i++) {
// Parse the time.
TimeValue time_value;
OP_REQUIRES_OK(context, ParseInputTime(time(i), name(), &time_value));
// Format time.
std::string out;
OP_REQUIRES_OK(context,
ToTslStatus(name(), functions::FormatTimeToString(
format, time_value, &out)));
// Set the output value.
output_flat(i).reserve(out.size());
output_flat(i) = std::move(out);
}
}
};
// Register the kernels.
REGISTER_KERNEL_BUILDER(Name("TimeFromComponents").Device(DEVICE_CPU),
TimeFromComponents);
REGISTER_KERNEL_BUILDER(Name("TimeFromTimestamp").Device(DEVICE_CPU),
TimeFromTimestamp);
REGISTER_KERNEL_BUILDER(Name("TimeFromDatetime").Device(DEVICE_CPU),
TimeFromDatetime);
REGISTER_KERNEL_BUILDER(Name("CastToTimeFromString").Device(DEVICE_CPU),
CastToTimeFromString);
REGISTER_KERNEL_BUILDER(Name("TimeAdd").Device(DEVICE_CPU), TimeAdd);
REGISTER_KERNEL_BUILDER(Name("TimeSub").Device(DEVICE_CPU), TimeSub);
REGISTER_KERNEL_BUILDER(Name("TimeDiff").Device(DEVICE_CPU), TimeDiff);
REGISTER_KERNEL_BUILDER(Name("TimeTrunc").Device(DEVICE_CPU), TimeTrunc);
REGISTER_KERNEL_BUILDER(Name("ExtractFromTime").Device(DEVICE_CPU),
ExtractFromTime);
REGISTER_KERNEL_BUILDER(Name("ParseTime").Device(DEVICE_CPU), ParseTime);
REGISTER_KERNEL_BUILDER(Name("SafeParseTime").Device(DEVICE_CPU),
SafeParseTime);
REGISTER_KERNEL_BUILDER(Name("FormatTime").Device(DEVICE_CPU), FormatTime);
} // namespace bigquery_ml_utils