in tensorflow_estimator/python/estimator/canned/timeseries/math_utils.py [0:0]
def _update_statistics_from_mini_batch(self, statistics, auxiliary_variables,
times, values):
"""Given mini-batch input, update `statistics` and `auxiliary_variables`."""
values = tf.cast(values, self._dtype)
# The density (measured in times per observation) that we see in each part
# of the mini-batch.
batch_inter_observation_duration = (
tf.cast(
tf.math.reduce_max(times, axis=1) -
tf.math.reduce_min(times, axis=1), self._dtype) /
tf.cast(tf.compat.v1.shape(times)[1] - 1, self._dtype))
# Co-locate updates with their variables to minimize race conditions when
# updating statistics.
with tf.compat.v1.device(auxiliary_variables.max_time_seen.device):
# There is a race condition if this value is being updated from multiple
# workers. However, it should eventually reach the correct value if the
# last chunk is presented enough times.
latest_time = tf.cast(tf.math.reduce_max(times), tf.dtypes.int64)
max_time_seen = tf.math.maximum(auxiliary_variables.max_time_seen,
latest_time)
max_time_seen_assign = tf.compat.v1.assign(
auxiliary_variables.max_time_seen, max_time_seen)
with tf.compat.v1.device(auxiliary_variables.chunk_count.device):
chunk_count_assign = tf.compat.v1.assign_add(
auxiliary_variables.chunk_count,
tf.compat.v1.shape(times, out_type=tf.dtypes.int64)[0])
with tf.compat.v1.device(
auxiliary_variables.inter_observation_duration_sum.device):
inter_observation_duration_assign = tf.compat.v1.assign_add(
auxiliary_variables.inter_observation_duration_sum,
tf.math.reduce_sum(batch_inter_observation_duration))
with tf.compat.v1.device(auxiliary_variables.example_count.device):
example_count_assign = tf.compat.v1.assign_add(
auxiliary_variables.example_count,
tf.compat.v1.size(times, out_type=tf.dtypes.int64))
# Note: These mean/variance updates assume that all points are equally
# likely, which is not true if _chunks_ are sampled uniformly from the space
# of all possible contiguous chunks, since points at the start and end of
# the series are then members of fewer chunks. For series which are much
# longer than the chunk size (the usual/expected case), this effect becomes
# irrelevant.
with tf.compat.v1.device(auxiliary_variables.overall_feature_sum.device):
overall_feature_sum_assign = tf.compat.v1.assign_add(
auxiliary_variables.overall_feature_sum,
tf.math.reduce_sum(values, axis=[0, 1]))
with tf.compat.v1.device(
auxiliary_variables.overall_feature_sum_of_squares.device):
overall_feature_sum_of_squares_assign = tf.compat.v1.assign_add(
auxiliary_variables.overall_feature_sum_of_squares,
tf.math.reduce_sum(values**2, axis=[0, 1]))
per_chunk_aux_updates = tf.group(max_time_seen_assign, chunk_count_assign,
inter_observation_duration_assign,
example_count_assign,
overall_feature_sum_assign,
overall_feature_sum_of_squares_assign)
with tf.control_dependencies([per_chunk_aux_updates]):
example_count_float = tf.cast(auxiliary_variables.example_count,
self._dtype)
new_feature_mean = (
auxiliary_variables.overall_feature_sum / example_count_float)
overall_feature_mean_update = tf.compat.v1.assign(
statistics.overall_feature_moments.mean, new_feature_mean)
overall_feature_var_update = tf.compat.v1.assign(
statistics.overall_feature_moments.variance,
# De-biased n / (n - 1) variance correction
example_count_float / (example_count_float - 1.) *
(auxiliary_variables.overall_feature_sum_of_squares /
example_count_float - new_feature_mean**2))
# TODO(b/35675805): Remove this cast
min_time_batch = tf.cast(
tf.compat.v1.math.argmin(times[:, 0]), tf.dtypes.int32)
def series_start_updates():
# If this is the lowest-time chunk that we have seen so far, update
# series start moments to reflect that. Note that these statistics are
# "best effort", as there are race conditions in the update (however,
# they should eventually converge if the start of the series is
# presented enough times).
mean, variance = tf.compat.v1.nn.moments(
values[min_time_batch, :self._starting_variance_window_size],
axes=[0])
return tf.group(
tf.compat.v1.assign(statistics.series_start_moments.mean, mean),
tf.compat.v1.assign(statistics.series_start_moments.variance,
variance))
with tf.compat.v1.device(statistics.start_time.device):
series_start_update = tf.compat.v1.cond(
# Update moments whenever we even match the lowest time seen so far,
# to ensure that series start statistics are eventually updated to
# their correct values, despite race conditions (i.e. eventually
# statistics.start_time will reflect the global lowest time, and
# given that we will eventually update the series start moments to
# their correct values).
tf.math.less_equal(times[min_time_batch, 0],
tf.cast(statistics.start_time, times.dtype)),
series_start_updates,
tf.no_op)
with tf.control_dependencies([series_start_update]):
# There is a race condition if this update is performed in parallel on
# multiple workers. Since models may be sensitive to being presented
# with times before the putative start time, the value of this
# variable is post-processed above to guarantee that each worker is
# presented with a start time which is at least as low as the lowest
# time in its current mini-batch.
min_time = tf.cast(tf.math.reduce_min(times), tf.dtypes.int64)
start_time = tf.math.minimum(statistics.start_time, min_time)
start_time_update = tf.compat.v1.assign(statistics.start_time,
start_time)
inter_observation_duration_estimate = (
auxiliary_variables.inter_observation_duration_sum /
tf.cast(auxiliary_variables.chunk_count, self._dtype))
# Estimate the total number of observations as:
# (end time - start time + 1) * average intra-chunk time density
total_observation_count_update = tf.compat.v1.assign(
statistics.total_observation_count,
tf.cast(
gen_math_ops.round(
tf.cast(max_time_seen_assign - start_time_update + 1,
self._dtype) / inter_observation_duration_estimate),
tf.dtypes.int64))
per_chunk_stat_updates = tf.group(overall_feature_mean_update,
overall_feature_var_update,
series_start_update, start_time_update,
total_observation_count_update)
return per_chunk_stat_updates