export function assign()

in modules/mlvis-common/src/utils/kmeans.js [157:176]


export function assign(X, indices, values) {
  return tf.tidy(() => {
    // todo: implement broadcasting logic
    if (indices.size === 0) {
      return X.clone();
    }
    let _values;
    // when `values` is a scalar
    if (indices.shape !== values.shape && values.rank === 0) {
      _values = tf.mul(tf.ones(indices.shape), values);
    } else {
      _values = values;
    }
    const correctVals = tf.scatterND(indices, _values, X.shape);
    const mask = tf
      .scatterND(indices, tf.ones(indices.shape), X.shape)
      .toBool();
    return tf.where(mask, correctVals, X);
  });
}