luckmatter/model_gen.py [158:184]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def reset_parameters(self):
        for w in self.ws_linear:
            w.reset_parameters()
        for w in self.ws_bn:
            w.reset_parameters()
        self.final_w.reset_parameters()

    def normalize(self):
        for w in self.ws_linear:
            normalize_layer(w)
        normalize_layer(self.final_w)

    def from_bottom_linear(self, j):
        if j < len(self.ws_linear):
            return self.ws_linear[j].weight.data
        elif j == len(self.ws_linear):
            return self.final_w.weight.data
        else:
            raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws)))

    def from_bottom_aug_w(self, j):
        if j < len(self.ws_linear):
            return get_aug_w(self.ws_linear[j])
        elif j == len(self.ws_linear):
            return get_aug_w(self.final_w)
        else:
            raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws)))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



student_specialization/model_gen.py [177:203]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def reset_parameters(self):
        for w in self.ws_linear:
            w.reset_parameters()
        for w in self.ws_bn:
            w.reset_parameters()
        self.final_w.reset_parameters()

    def normalize(self):
        for w in self.ws_linear:
            normalize_layer(w)
        normalize_layer(self.final_w)

    def from_bottom_linear(self, j):
        if j < len(self.ws_linear):
            return self.ws_linear[j].weight.data
        elif j == len(self.ws_linear):
            return self.final_w.weight.data
        else:
            raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws)))

    def from_bottom_aug_w(self, j):
        if j < len(self.ws_linear):
            return get_aug_w(self.ws_linear[j])
        elif j == len(self.ws_linear):
            return get_aug_w(self.final_w)
        else:
            raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws)))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



