in econml/sklearn_extensions/linear_model.py [0:0]
def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
"""Check dimensions and other assertions."""
if X is None:
X = np.empty((y.shape[0], 0))
if self.fit_intercept:
X = add_constant(X, has_constant='add')
# set default values for None
if sample_weight is None:
sample_weight = np.ones(y.shape[0])
if freq_weight is None:
freq_weight = np.ones(y.shape[0])
if sample_var is None:
sample_var = np.zeros(y.shape)
# check freq_weight should be integer and should be accompanied by sample_var
if np.any(np.not_equal(np.mod(freq_weight, 1), 0)):
raise AttributeError("Frequency weights must all be integers for inference to be valid!")
if sample_var.ndim < 2:
if np.any(np.equal(freq_weight, 1) & np.not_equal(sample_var, 0)):
warnings.warn(
"Variance was set to non-zero for an observation with freq_weight=1! "
"sample_var represents the variance of the original observations that are "
"summarized in this sample. Hence, cannot have a non-zero variance if only "
"one observations was summarized. Inference will be invalid!")
elif np.any(np.not_equal(freq_weight, 1) & np.equal(sample_var, 0)):
warnings.warn(
"Variance was set to zero for an observation with freq_weight>1! "
"sample_var represents the variance of the original observations that are "
"summarized in this sample. If it's zero, please use sample_wegiht instead "
"to reflect the weight for each individual sample!")
else:
if np.any(np.equal(freq_weight, 1) & np.not_equal(np.sum(sample_var, axis=1), 0)):
warnings.warn(
"Variance was set to non-zero for an observation with freq_weight=1! "
"sample_var represents the variance of the original observations that are "
"summarized in this sample. Hence, cannot have a non-zero variance if only "
"one observations was summarized. Inference will be invalid!")
elif np.any(np.not_equal(freq_weight, 1) & np.equal(np.sum(sample_var, axis=1), 0)):
warnings.warn(
"Variance was set to zero for an observation with freq_weight>1! "
"sample_var represents the variance of the original observations that are "
"summarized in this sample. If it's zero, please use sample_wegiht instead "
"to reflect the weight for each individual sample!")
# check array shape
assert (X.shape[0] == y.shape[0] == sample_weight.shape[0] ==
freq_weight.shape[0] == sample_var.shape[0]), "Input lengths not compatible!"
if y.ndim >= 2:
assert (y.ndim == sample_var.ndim and
y.shape[1] == sample_var.shape[1]), "Input shapes not compatible: {}, {}!".format(
y.shape, sample_var.shape)
# weight X and y and sample_var
weighted_X = X * np.sqrt(sample_weight).reshape(-1, 1)
if y.ndim < 2:
weighted_y = y * np.sqrt(sample_weight)
sample_var = sample_var * sample_weight
else:
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
sample_var = sample_var * (sample_weight.reshape(-1, 1))
return weighted_X, weighted_y, freq_weight, sample_var