def load_tensors()

in aws_sagemaker_studio/sagemaker_debugger/mnist_tensor_plot/tensor_plot.py [0:0]


    def load_tensors(self):
        available_steps = self.trial.steps()
        for step in available_steps[0 : self.steps]:
            self.tensors[step] = []

            # input image into the neural network
            if self.label is not None:
                for tname in self.trial.tensor_names(regex=self.label):
                    tensor = self.trial.tensor(tname).value(step)
                    if self.color_channel == 1:
                        self.input[step] = tensor[0, 0, :, :]
                    elif self.color_channel == 3:
                        self.input[step] = tensor[0, :, :, 3]

            # iterate over tensors that match the regex
            for tname in self.trial.tensor_names(regex=self.regex):
                tensor = self.trial.tensor(tname).value(step)
                # get max value of tensors to set axis dimension accordingly
                for dim in tensor.shape:
                    if dim > self.max_dim:
                        self.max_dim = dim

                # layer inputs/outputs have as first dimension batch size
                if self.batch_sample_id != None:
                    # sum over batch dimension
                    if self.batch_sample_id == -1:
                        tensor = np.sum(tensor, axis=0) / tensor.shape[0]
                    # plot item from batch
                    elif self.batch_sample_id >= 0 and self.batch_sample_id <= tensor.shape[0]:
                        tensor = tensor[self.batch_sample_id]
                    # plot first item from batch
                    else:
                        tensor = tensor[0]

                    # normalize tensor values between 0 and 1 so that all tensors have same colorscheme
                    tensor = tensor - np.min(tensor)
                    if np.max(tensor) != 0:
                        tensor = tensor / np.max(tensor)
                    if len(tensor.shape) == 3:
                        for l in range(tensor.shape[self.color_channel - 1]):
                            if self.color_channel == 1:
                                self.tensors[step].append([tname, tensor[l, :, :]])
                            elif self.color_channel == 3:
                                self.tensors[step].append([tname, tensor[:, :, l]])
                    elif len(tensor.shape) == 1:
                        self.tensors[step].append([tname, tensor])
                else:
                    # normalize tensor values between 0 and 1 so that all tensors have same colorscheme
                    tensor = tensor - np.min(tensor)
                    if np.max(tensor) != 0:
                        tensor = tensor / np.max(tensor)
                    if len(tensor.shape) == 4:
                        for i in range(tensor.shape[0]):
                            for l in range(tensor.shape[1]):
                                if self.color_channel == 1:
                                    self.tensors[step].append([tname, tensor[i, l, :, :]])
                                elif self.color_channel == 3:
                                    self.tensors[step].append([tname, tensor[i, :, :, l]])
                    elif len(tensor.shape) == 2:
                        self.tensors[step].append([tname, tensor])

            # model output
            if self.prediction is not None:
                for tname in self.trial.tensor_names(regex=self.prediction):
                    tensor = self.trial.tensor(tname).value(step)
                    # predicted class (batch size, propabilities per clas)
                    if len(tensor.shape) == 2:
                        self.output[step] = np.array([np.argmax(tensor, axis=1)[0]])
                    # predict an image (batch size, color channel, weidth, height)
                    elif len(tensor.shape) == 4:
                        # MXNet has color channel in dim1
                        if self.color_channel == 1:
                            self.output[step] = tensor[0, 0, :, :]
                        # TF has color channel in dim 3
                        elif self.color_channel == 3:
                            self.output[step] = tensor[0, :, :, 0]