in research/gam/gam/trainer/trainer_cotrain.py [0:0]
def __init__(self,
model_cls,
model_agr,
max_num_iter_cotrain,
min_num_iter_cls,
max_num_iter_cls,
num_iter_after_best_val_cls,
min_num_iter_agr,
max_num_iter_agr,
num_iter_after_best_val_agr,
num_samples_to_label,
min_confidence_new_label=0.0,
keep_label_proportions=False,
num_warm_up_iter_agr=1,
optimizer=tf.train.AdamOptimizer,
gradient_clip=None,
batch_size_agr=128,
batch_size_cls=128,
learning_rate_cls=1e-3,
learning_rate_agr=1e-3,
warm_start_cls=False,
warm_start_agr=False,
enable_summaries=True,
enable_summaries_per_model=False,
summary_dir=None,
summary_step_cls=1000,
summary_step_agr=1000,
logging_step_cls=1,
logging_step_agr=1,
eval_step_cls=1,
eval_step_agr=1,
checkpoints_step=None,
checkpoints_dir=None,
data_dir=None,
abs_loss_chg_tol=1e-10,
rel_loss_chg_tol=1e-7,
loss_chg_iter_below_tol=30,
use_perfect_agr=False,
use_perfect_cls=False,
ratio_valid_agr=0,
max_samples_valid_agr=None,
weight_decay_cls=None,
weight_decay_schedule_cls=None,
weight_decay_agr=None,
weight_decay_schedule_agr=None,
reg_weight_ll=0,
reg_weight_lu=0,
reg_weight_uu=0,
num_pairs_reg=100,
reg_weight_vat=0,
use_ent_min=False,
penalize_neg_agr=False,
use_l2_cls=True,
first_iter_original=True,
inductive=False,
seed=None,
eval_acc_pred_by_agr=False,
num_neighbors_pred_by_agr=20,
lr_decay_rate_cls=None,
lr_decay_steps_cls=None,
lr_decay_rate_agr=None,
lr_decay_steps_agr=None,
load_from_checkpoint=False,
use_graph=False,
always_agree=False,
add_negative_edges_agr=False):
assert not enable_summaries or (enable_summaries and
summary_dir is not None)
assert checkpoints_step is None or (checkpoints_step is not None and
checkpoints_dir is not None)
super(TrainerCotraining, self).__init__(
model=None,
abs_loss_chg_tol=abs_loss_chg_tol,
rel_loss_chg_tol=rel_loss_chg_tol,
loss_chg_iter_below_tol=loss_chg_iter_below_tol)
self.model_cls = model_cls
self.model_agr = model_agr
self.max_num_iter_cotrain = max_num_iter_cotrain
self.min_num_iter_cls = min_num_iter_cls
self.max_num_iter_cls = max_num_iter_cls
self.num_iter_after_best_val_cls = num_iter_after_best_val_cls
self.min_num_iter_agr = min_num_iter_agr
self.max_num_iter_agr = max_num_iter_agr
self.num_iter_after_best_val_agr = num_iter_after_best_val_agr
self.num_samples_to_label = num_samples_to_label
self.min_confidence_new_label = min_confidence_new_label
self.keep_label_proportions = keep_label_proportions
self.num_warm_up_iter_agr = num_warm_up_iter_agr
self.optimizer = optimizer
self.gradient_clip = gradient_clip
self.batch_size_agr = batch_size_agr
self.batch_size_cls = batch_size_cls
self.learning_rate_cls = learning_rate_cls
self.learning_rate_agr = learning_rate_agr
self.warm_start_cls = warm_start_cls
self.warm_start_agr = warm_start_agr
self.enable_summaries = enable_summaries
self.enable_summaries_per_model = enable_summaries_per_model
self.summary_step_cls = summary_step_cls
self.summary_step_agr = summary_step_agr
self.summary_dir = summary_dir
self.logging_step_cls = logging_step_cls
self.logging_step_agr = logging_step_agr
self.eval_step_cls = eval_step_cls
self.eval_step_agr = eval_step_agr
self.checkpoints_step = checkpoints_step
self.checkpoints_dir = checkpoints_dir
self.data_dir = data_dir
self.use_perfect_agr = use_perfect_agr
self.use_perfect_cls = use_perfect_cls
self.ratio_valid_agr = ratio_valid_agr
self.max_samples_valid_agr = max_samples_valid_agr
self.weight_decay_cls = weight_decay_cls
self.weight_decay_schedule_cls = weight_decay_schedule_cls
self.weight_decay_agr = weight_decay_agr
self.weight_decay_schedule_agr = weight_decay_schedule_agr
self.reg_weight_ll = reg_weight_ll
self.reg_weight_lu = reg_weight_lu
self.reg_weight_uu = reg_weight_uu
self.num_pairs_reg = num_pairs_reg
self.reg_weight_vat = reg_weight_vat
self.use_ent_min = use_ent_min
self.penalize_neg_agr = penalize_neg_agr
self.use_l2_classif = use_l2_cls
self.first_iter_original = first_iter_original
self.inductive = inductive
self.seed = seed
self.eval_acc_pred_by_agr = eval_acc_pred_by_agr
self.num_neighbors_pred_by_agr = num_neighbors_pred_by_agr
self.lr_decay_rate_cls = lr_decay_rate_cls
self.lr_decay_steps_cls = lr_decay_steps_cls
self.lr_decay_rate_agr = lr_decay_rate_agr
self.lr_decay_steps_agr = lr_decay_steps_agr
self.load_from_checkpoint = load_from_checkpoint
self.use_graph = use_graph
self.always_agree = always_agree
self.add_negative_edges_agr = add_negative_edges_agr