def cluster_docs()

in src/graph_utils.py [0:0]


def cluster_docs(docs, 
                graph_method, 
                similarity_metric="tfidf", 
                threshold=None,
                neighbour_weights=None,
                per_doc_sent_limit=100,
                time_analysis=False
    ):
    """ Module to cluster the documents and assign
    rank to each sentence based on the importance of
    that sentence in the cluster. This importance is
    calculated based on the defined `graph_method'. 
    """
    program_start_time = time.time()
    if similarity_metric=="tfidf":
        tfidfs, word_index_mapping, total_num_words = compute_base_features(docs)
    elif "rouge" in similarity_metric:
        rouge = Rouge()
    
    # transform the docs into cluster
    start_time = time.time()
    cluster = []
    for i, doc in enumerate(docs):
        for j,sent in enumerate(sent_tokenize(doc)[:per_doc_sent_limit]):
            info = {}
            if similarity_metric == "tfidf":
                vector = np.zeros(total_num_words)
                for word in sent.split():
                    try:
                        vector[word_index_mapping[word]]=tfidfs[i][word]
                    except:
                        print("Error in the key mapping; ignoring for now")
            else:
                vector = None

            info["id"] = "d{}_s{}".format(i,j)
            info["sentence"] = sent
            info["vector"] = vector
            cluster.append(info)
    
    if time_analysis:
        print(f"execution time for transform module is {time.time()-start_time}")

    start_time = time.time()
    
    # Create an Adjacency Marix
    sim_mat = np.zeros((len(cluster),len(cluster)))
                
    for i in range(len(cluster)):
        for j in range(len(cluster)):
            if neighbour_weights is None:
                tmp_counter = j
            else:
                tmp_counter = j-1
            if i<tmp_counter:
                if similarity_metric == "tfidf":
                    if norm(cluster[i]["vector"]) < 1e-20 or norm(cluster[j]["vector"])< 1e-20: # special case were vector is super spare! sent has may be one word
                        sim_mat[i,j] = 0
                    else:
                        sim_mat[i,j] = -distance.cosine(cluster[i]["vector"], cluster[j]["vector"])+1 # use 1-cosine distane to calculate similarity
                else:
                    try:
                        if len(cluster[i]["sentence"].split())==1 or \
                            len(cluster[j]["sentence"].split())==1:
                            sim_mat[i,j] = 0
                        else:
                            sim_mat[i,j] = rouge.get_scores(cluster[i]["sentence"], cluster[j]["sentence"])[0][similarity_metric]["f"]
                    except:
                        print(f"Rouge exception for pair sample: {cluster[i]['sentence']} ||| {cluster[j]['sentence']}")
                        sim_mat[i][j] = 0

            elif i>tmp_counter:
                sim_mat[i,j] = sim_mat[j,i]
            else:
                if i==j:
                    sim_mat[i,j] = 0 # when i==j
                else: ## adjacent nodes
                    sim_mat[i,j] = neighbour_weights
    if time_analysis:
        print(f"execution time for adjacency matrix creation is {time.time()-start_time}")

    start_time = time.time()
    # Graph score calculation
    if threshold is not None:
        sim_mat[sim_mat<threshold] = 0
    #sim_mat = np.exp(sim_mat)
    nx_graph = nx.from_numpy_array(sim_mat)
    final_graph_method = graph_method
    if graph_method=="pagerank":
        try:
            scores = nx.pagerank(nx_graph)
        except:
            print("Power Iteration failed..switching to generic..")
            scores = np.sum(sim_mat, axis=1)
            final_graph_method = "generic"
    elif graph_method=="generic":
        scores = np.sum(sim_mat, axis=1)
    else:
        raise Exception("Unknown graph_method: {}".format(graph_method))

    if time_analysis:
        print(f"execution time for graph calculation is {time.time()-start_time}")

    # Data format
    start_time = time.time()
    cluster_dict = {}
    cluster_dict["cluster"] = {}
    for i, sent_info in enumerate(cluster):
        sent_info["score"] = scores[i]
        sent_info["vector"] = None
        cluster_dict["cluster"][sent_info["id"]] = sent_info

    cluster_dict["sim_mat"] = sim_mat.tolist()
    cluster_dict["graph_method"] = final_graph_method

    if time_analysis:
        print(f"execution time for data formating module is {time.time()-start_time}")

    if time_analysis:
        print(f"execution time for whole module is {time.time()-program_start_time}")

    return cluster_dict