src/app/app.py [246:523]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if args.pdb_dir:
            if not os.path.exists(args.pdb_dir):
                os.makedirs(args.pdb_dir)
            pdb_filepath = os.path.join(args.pdb_dir, prot_id.replace("/", "_") + ".pdb")
            with open(pdb_filepath, "w") as wfp:
                wfp.write(pdb)
        c_alpha, c_beta = calc_distance_maps(pdb, args.chain, processed_seq)
        cmap = c_alpha[args.chain]['contact-map'] if args.cmap_type == "C_alpha" else c_beta[args.chain]['contact-map']
        # use the specific threshold to transform the float contact map into 0-1 contact map
        cmap = np.less_equal(cmap, args.cmap_thresh).astype(np.int32)
        struct_contact_map = cmap
        real_shape = struct_contact_map.shape
        if real_shape[0] > args.struct_max_length:
            if args.trunc_type == "left":
                struct_contact_map = struct_contact_map[-args.struct_max_length:, -args.struct_max_length:]
            else:
                struct_contact_map = struct_contact_map[:args.struct_max_length, :args.struct_max_length]
            contact_map_padding_length = 0
        else:
            contact_map_padding_length = args.struct_max_length - real_shape[0]
        assert contact_map_padding_length == padding_length

        if contact_map_padding_length > 0:
            if pad_on_left:
                struct_input_ids = [pad_token] * padding_length + struct_input_ids
                struct_contact_map = np.pad(struct_contact_map, [(contact_map_padding_length, 0), (contact_map_padding_length, 0)], mode='constant', constant_values=pad_token)
            else:
                struct_input_ids = struct_input_ids + ([pad_token] * padding_length)
                struct_contact_map = np.pad(struct_contact_map, [(0, contact_map_padding_length), (0, contact_map_padding_length)], mode='constant', constant_values=pad_token)

        assert len(struct_input_ids) == args.struct_max_length, "Error with input length {} vs {}".format(len(struct_input_ids), args.struct_max_length)
        assert struct_contact_map.shape[0] == args.struct_max_length, "Error with input length {}x{} vs {}x{}".format(struct_contact_map.shape[0], struct_contact_map.shape[1], args.struct_max_length, args.struct_max_length)
    else:
        struct_input_ids = None
        struct_contact_map = None
        real_struct_node_size = None

    if args.embedding_type:
        # for embedding
        embedding_info, processed_seq = predict_embedding(
            [prot_id, protein_seq],
            args.trunc_type,
            "representations" if args.embedding_type == "matrix" else args.embedding_type,
            repr_layers=[-1],
            truncation_seq_length=args.truncation_seq_length - 2,
            device=args.device
        )
        # failure on GPU, then using CPU for embedding
        if embedding_info is None:
            # 失败,则调用cpu进行embedding推理
            embedding_info, processed_seq = predict_embedding(
                [prot_id, protein_seq],
                args.trunc_type,
                "representations" if args.embedding_type == "matrix" else args.embedding_type,
                repr_layers=[-1],
                truncation_seq_length=args.truncation_seq_length - 2,
                device=torch.device("cpu")
            )
        if args.emb_dir:
            if not os.path.exists(args.emb_dir):
                os.makedirs(args.emb_dir)

            embedding_filepath = os.path.join(args.emb_dir, prot_id.replace("/", "_") + ".pt")
            torch.save(embedding_info, embedding_filepath)
        if args.embedding_type == "contacts":
            emb_l = embedding_info.shape[0]
            embedding_attention_mask = [1 if mask_padding_with_zero else 0] * emb_l
            if emb_l > args.embedding_max_length:
                if args.trunc_type == "left":
                    embedding_info = embedding_info[-args.embedding_max_length:, -args.embedding_max_length:]
                else:
                    embedding_info = embedding_info[:args.embedding_max_length, :args.embedding_max_length]
                embedding_attention_mask = [1 if mask_padding_with_zero else 0] * args.embedding_max_length
            else:
                embedding_padding_length = args.embedding_max_length - emb_l
                if embedding_padding_length > 0:
                    if pad_on_left:
                        embedding_attention_mask = [0 if mask_padding_with_zero else 1] * embedding_padding_length + embedding_attention_mask
                        embedding_info = np.pad(embedding_info, [(embedding_padding_length, 0), (embedding_padding_length, 0)], mode='constant', constant_values=pad_token)
                    else:
                        embedding_attention_mask = embedding_attention_mask + [0 if mask_padding_with_zero else 1] * embedding_padding_length
                        embedding_info = np.pad(embedding_info, [(0, embedding_padding_length), (0, embedding_padding_length)], mode='constant', constant_values=pad_token)
        elif args.embedding_type == "matrix":
            emb_l = embedding_info.shape[0]
            embedding_attention_mask = [1 if mask_padding_with_zero else 0] * emb_l
            if emb_l > args.embedding_max_length:
                if args.trunc_type == "left":
                    embedding_info = embedding_info[-args.embedding_max_length:, :]
                else:
                    embedding_info = embedding_info[:args.embedding_max_length, :]
                embedding_attention_mask = [1 if mask_padding_with_zero else 0] * args.embedding_max_length
            else:
                embedding_padding_length = args.embedding_max_length - emb_l
                if embedding_padding_length > 0:
                    if pad_on_left:
                        embedding_attention_mask = [0 if mask_padding_with_zero else 1] * embedding_padding_length + embedding_attention_mask
                        embedding_info = np.pad(embedding_info, [(embedding_padding_length, 0), (0, 0)], mode='constant', constant_values=pad_token)
                    else:
                        embedding_attention_mask = embedding_attention_mask + [0 if mask_padding_with_zero else 1] * embedding_padding_length
                        embedding_info = np.pad(embedding_info, [(0, embedding_padding_length), (0, 0)], mode='constant', constant_values=pad_token)
        elif args.embedding_type == "bos":
            embedding_attention_mask = None
        else:
            raise Exception("Not support arg: --embedding_type=%s" % args.embedding_type)
    else:
        embedding_info = None
        embedding_attention_mask = None
    features.append(
        InputFeatures(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            real_token_len=real_token_len,
            struct_input_ids=struct_input_ids,
            struct_contact_map=struct_contact_map,
            real_struct_node_size=real_struct_node_size,
            embedding_info=embedding_info,
            embedding_attention_mask=embedding_attention_mask,
            label=None
        )
    )
    batch_input = {}
    # "labels": torch.tensor([f.label for f in features], dtype=torch.long).to(args.device),
    if seq_tokenizer:
        batch_input.update(
            {
                "input_ids": torch.tensor([f.input_ids for f in features], dtype=torch.long).to(args.device),
                "attention_mask": torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(args.device),
                "token_type_ids": torch.tensor([f.token_type_ids for f in features], dtype=torch.long).to(args.device),
            }
        )
    if struct_tokenizer:
        batch_input.update(
            {
                "struct_input_ids": torch.tensor([f.struct_input_ids for f in features], dtype=torch.long).to(args.device),
                "struct_contact_map": torch.tensor([f.struct_contact_map for f in features], dtype=torch.long).to(args.device),
            }
        )
    if args.embedding_type:
        batch_input["embedding_info"] = torch.tensor(np.array([f.embedding_info for f in features], dtype=np.float32), dtype=torch.float32).to(args.device)
        if args.embedding_type != "bos":
            batch_input["embedding_attention_mask"] = torch.tensor([f.embedding_attention_mask for f in features], dtype=torch.long).to(args.device)

    return batch_info, batch_input


def predict_probs(
        args,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    prediction for one sample
    :param args:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer
    :param model:
    :param row: one sample
    :return:
    '''
    '''
    label_list = processor.get_labels(label_filepath=args.label_filepath)
    label_map = {label: i for i, label in enumerate(label_list)}
    '''
    # in order to be able to embed longer sequences
    model.to(torch.device("cpu"))
    batch_info, batch_input = transform_sample_2_feature(args, row, seq_tokenizer, subword, struct_tokenizer)
    model.to(args.device)
    if torch.cuda.is_available():
        probs = model(**batch_input)[1].detach().cpu().numpy()
    else:
        probs = model(**batch_input)[1].detach().numpy()
    return batch_info, probs


def predict_binary_class(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict positive or negative label for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer
    :param subword:
    :param struct_tokenizer
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = (probs >= args.threshold).astype(int).flatten()
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [info[0], info[1], float(probs[idx][0]), label_id_2_name[preds[idx]]]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res


def predict_multi_class(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict multi-labels for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer:
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = np.argmax(probs, axis=-1)
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [info[0], info[1], float(probs[idx][preds[idx]]), label_id_2_name[preds[idx]]]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res


def predict_multi_label(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict multi-labels for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer:
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = relevant_indexes((probs >= args.threshold).astype(int))
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [
            info[0],
            info[1],
            [float(probs[idx][label_index]) for label_index in preds[idx]],
            [label_id_2_name[label_index] for label_index in preds[idx]]
        ]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/predict_many_samples.py [262:539]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if args.pdb_dir:
            if not os.path.exists(args.pdb_dir):
                os.makedirs(args.pdb_dir)
            pdb_filepath = os.path.join(args.pdb_dir, prot_id.replace("/", "_") + ".pdb")
            with open(pdb_filepath, "w") as wfp:
                wfp.write(pdb)
        c_alpha, c_beta = calc_distance_maps(pdb, args.chain, processed_seq)
        cmap = c_alpha[args.chain]['contact-map'] if args.cmap_type == "C_alpha" else c_beta[args.chain]['contact-map']
        # use the specific threshold to transform the float contact map into 0-1 contact map
        cmap = np.less_equal(cmap, args.cmap_thresh).astype(np.int32)
        struct_contact_map = cmap
        real_shape = struct_contact_map.shape
        if real_shape[0] > args.struct_max_length:
            if args.trunc_type == "left":
                struct_contact_map = struct_contact_map[-args.struct_max_length:, -args.struct_max_length:]
            else:
                struct_contact_map = struct_contact_map[:args.struct_max_length, :args.struct_max_length]
            contact_map_padding_length = 0
        else:
            contact_map_padding_length = args.struct_max_length - real_shape[0]
        assert contact_map_padding_length == padding_length

        if contact_map_padding_length > 0:
            if pad_on_left:
                struct_input_ids = [pad_token] * padding_length + struct_input_ids
                struct_contact_map = np.pad(struct_contact_map, [(contact_map_padding_length, 0), (contact_map_padding_length, 0)], mode='constant', constant_values=pad_token)
            else:
                struct_input_ids = struct_input_ids + ([pad_token] * padding_length)
                struct_contact_map = np.pad(struct_contact_map, [(0, contact_map_padding_length), (0, contact_map_padding_length)], mode='constant', constant_values=pad_token)

        assert len(struct_input_ids) == args.struct_max_length, "Error with input length {} vs {}".format(len(struct_input_ids), args.struct_max_length)
        assert struct_contact_map.shape[0] == args.struct_max_length, "Error with input length {}x{} vs {}x{}".format(struct_contact_map.shape[0], struct_contact_map.shape[1], args.struct_max_length, args.struct_max_length)
    else:
        struct_input_ids = None
        struct_contact_map = None
        real_struct_node_size = None

    if args.embedding_type:
        # for embedding
        embedding_info, processed_seq = predict_embedding(
            [prot_id, protein_seq],
            args.trunc_type,
            "representations" if args.embedding_type == "matrix" else args.embedding_type,
            repr_layers=[-1],
            truncation_seq_length=args.truncation_seq_length - 2,
            device=args.device
        )
        # failure on GPU, then using CPU for embedding
        if embedding_info is None:
            # 失败,则调用cpu进行embedding推理
            embedding_info, processed_seq = predict_embedding(
                [prot_id, protein_seq],
                args.trunc_type,
                "representations" if args.embedding_type == "matrix" else args.embedding_type,
                repr_layers=[-1],
                truncation_seq_length=args.truncation_seq_length - 2,
                device=torch.device("cpu")
            )
        if args.emb_dir:
            if not os.path.exists(args.emb_dir):
                os.makedirs(args.emb_dir)

            embedding_filepath = os.path.join(args.emb_dir, prot_id.replace("/", "_") + ".pt")
            torch.save(embedding_info, embedding_filepath)
        if args.embedding_type == "contacts":
            emb_l = embedding_info.shape[0]
            embedding_attention_mask = [1 if mask_padding_with_zero else 0] * emb_l
            if emb_l > args.embedding_max_length:
                if args.trunc_type == "left":
                    embedding_info = embedding_info[-args.embedding_max_length:, -args.embedding_max_length:]
                else:
                    embedding_info = embedding_info[:args.embedding_max_length, :args.embedding_max_length]
                embedding_attention_mask = [1 if mask_padding_with_zero else 0] * args.embedding_max_length
            else:
                embedding_padding_length = args.embedding_max_length - emb_l
                if embedding_padding_length > 0:
                    if pad_on_left:
                        embedding_attention_mask = [0 if mask_padding_with_zero else 1] * embedding_padding_length + embedding_attention_mask
                        embedding_info = np.pad(embedding_info, [(embedding_padding_length, 0), (embedding_padding_length, 0)], mode='constant', constant_values=pad_token)
                    else:
                        embedding_attention_mask = embedding_attention_mask + [0 if mask_padding_with_zero else 1] * embedding_padding_length
                        embedding_info = np.pad(embedding_info, [(0, embedding_padding_length), (0, embedding_padding_length)], mode='constant', constant_values=pad_token)
        elif args.embedding_type == "matrix":
            emb_l = embedding_info.shape[0]
            embedding_attention_mask = [1 if mask_padding_with_zero else 0] * emb_l
            if emb_l > args.embedding_max_length:
                if args.trunc_type == "left":
                    embedding_info = embedding_info[-args.embedding_max_length:, :]
                else:
                    embedding_info = embedding_info[:args.embedding_max_length, :]
                embedding_attention_mask = [1 if mask_padding_with_zero else 0] * args.embedding_max_length
            else:
                embedding_padding_length = args.embedding_max_length - emb_l
                if embedding_padding_length > 0:
                    if pad_on_left:
                        embedding_attention_mask = [0 if mask_padding_with_zero else 1] * embedding_padding_length + embedding_attention_mask
                        embedding_info = np.pad(embedding_info, [(embedding_padding_length, 0), (0, 0)], mode='constant', constant_values=pad_token)
                    else:
                        embedding_attention_mask = embedding_attention_mask + [0 if mask_padding_with_zero else 1] * embedding_padding_length
                        embedding_info = np.pad(embedding_info, [(0, embedding_padding_length), (0, 0)], mode='constant', constant_values=pad_token)
        elif args.embedding_type == "bos":
            embedding_attention_mask = None
        else:
            raise Exception("Not support arg: --embedding_type=%s" % args.embedding_type)
    else:
        embedding_info = None
        embedding_attention_mask = None
    features.append(
        InputFeatures(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            real_token_len=real_token_len,
            struct_input_ids=struct_input_ids,
            struct_contact_map=struct_contact_map,
            real_struct_node_size=real_struct_node_size,
            embedding_info=embedding_info,
            embedding_attention_mask=embedding_attention_mask,
            label=None
        )
    )
    batch_input = {}
    # "labels": torch.tensor([f.label for f in features], dtype=torch.long).to(args.device),
    if seq_tokenizer:
        batch_input.update(
            {
                "input_ids": torch.tensor([f.input_ids for f in features], dtype=torch.long).to(args.device),
                "attention_mask": torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(args.device),
                "token_type_ids": torch.tensor([f.token_type_ids for f in features], dtype=torch.long).to(args.device),
            }
        )
    if struct_tokenizer:
        batch_input.update(
            {
                "struct_input_ids": torch.tensor([f.struct_input_ids for f in features], dtype=torch.long).to(args.device),
                "struct_contact_map": torch.tensor([f.struct_contact_map for f in features], dtype=torch.long).to(args.device),
            }
        )
    if args.embedding_type:
        batch_input["embedding_info"] = torch.tensor(np.array([f.embedding_info for f in features], dtype=np.float32), dtype=torch.float32).to(args.device)
        if args.embedding_type != "bos":
            batch_input["embedding_attention_mask"] = torch.tensor([f.embedding_attention_mask for f in features], dtype=torch.long).to(args.device)

    return batch_info, batch_input


def predict_probs(
        args,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    prediction for one sample
    :param args:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer
    :param model:
    :param row: one sample
    :return:
    '''
    '''
    label_list = processor.get_labels(label_filepath=args.label_filepath)
    label_map = {label: i for i, label in enumerate(label_list)}
    '''
    # in order to be able to embed longer sequences
    model.to(torch.device("cpu"))
    batch_info, batch_input = transform_sample_2_feature(args, row, seq_tokenizer, subword, struct_tokenizer)
    model.to(args.device)
    if torch.cuda.is_available():
        probs = model(**batch_input)[1].detach().cpu().numpy()
    else:
        probs = model(**batch_input)[1].detach().numpy()
    return batch_info, probs


def predict_binary_class(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict positive or negative label for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer
    :param subword:
    :param struct_tokenizer
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = (probs >= args.threshold).astype(int).flatten()
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [info[0], info[1], float(probs[idx][0]), label_id_2_name[preds[idx]]]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res


def predict_multi_class(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict multi-labels for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer:
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = np.argmax(probs, axis=-1)
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [info[0], info[1], float(probs[idx][preds[idx]]), label_id_2_name[preds[idx]]]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res


def predict_multi_label(
        args,
        label_id_2_name,
        seq_tokenizer,
        subword,
        struct_tokenizer,
        model,
        row
):
    '''
    predict multi-labels for one sample
    :param args:
    :param label_id_2_name:
    :param seq_tokenizer:
    :param subword:
    :param struct_tokenizer:
    :param model:
    :param row: one sample
    :return:
    '''
    batch_info, probs = predict_probs(args, seq_tokenizer, subword, struct_tokenizer, model, row)
    # print("probs dim: ", probs.ndim)
    preds = relevant_indexes((probs >= args.threshold).astype(int))
    res = []
    for idx, info in enumerate(batch_info):
        cur_res = [
            info[0],
            info[1],
            [float(probs[idx][label_index]) for label_index in preds[idx]],
            [label_id_2_name[label_index] for label_index in preds[idx]]
        ]
        if len(info) > 2:
            cur_res += info[2:]
        res.append(cur_res)
    return res
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



