public static LabeledRowMatrix createWordDocMatrix()

in core/src/main/java/org/apache/sdap/mudrod/utils/MatrixUtil.java [131:221]


  public static LabeledRowMatrix createWordDocMatrix(JavaPairRDD<String, List<String>> uniqueDocRDD) {
    // Index documents with unique IDs
    JavaPairRDD<List<String>, Long> corpus = uniqueDocRDD.values().zipWithIndex();
    // cal word-doc numbers
    JavaPairRDD<Tuple2<String, Long>, Double> worddocNumRDD = corpus.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>, Long>, Tuple2<String, Long>, Double>() {
      /**
       *
       */
      private static final long serialVersionUID = 1L;

      @Override
      public Iterator<Tuple2<Tuple2<String, Long>, Double>> call(Tuple2<List<String>, Long> docwords) throws Exception {
        List<Tuple2<Tuple2<String, Long>, Double>> pairs = new ArrayList<>();
        List<String> words = docwords._1;
        for (String word : words) {
          Tuple2<String, Long> worddoc = new Tuple2<>(word, docwords._2);
          pairs.add(new Tuple2<Tuple2<String, Long>, Double>(worddoc, 1.0));
        }
        return pairs.iterator();
      }
    }).reduceByKey(new Function2<Double, Double, Double>() {
      /**
       *
       */
      private static final long serialVersionUID = 1L;

      @Override
      public Double call(Double first, Double second) throws Exception {
        return first + second;
      }
    });
    // cal word doc-numbers
    JavaPairRDD<String, Tuple2<List<Long>, List<Double>>> wordDocnumRDD = worddocNumRDD
        .mapToPair(new PairFunction<Tuple2<Tuple2<String, Long>, Double>, String, Tuple2<List<Long>, List<Double>>>() {
          /**
           *
           */
          private static final long serialVersionUID = 1L;

          @Override
          public Tuple2<String, Tuple2<List<Long>, List<Double>>> call(Tuple2<Tuple2<String, Long>, Double> worddocNum) throws Exception {
            List<Long> docs = new ArrayList<>();
            docs.add(worddocNum._1._2);
            List<Double> nums = new ArrayList<>();
            nums.add(worddocNum._2);
            Tuple2<List<Long>, List<Double>> docmums = new Tuple2<>(docs, nums);
            return new Tuple2<>(worddocNum._1._1, docmums);
          }
        });
    // trans to vector
    final int corporsize = (int) uniqueDocRDD.keys().count();
    JavaPairRDD<String, Vector> wordVectorRDD = wordDocnumRDD.reduceByKey(new Function2<Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>>() {
      /**
       *
       */
      private static final long serialVersionUID = 1L;

      @Override
      public Tuple2<List<Long>, List<Double>> call(Tuple2<List<Long>, List<Double>> arg0, Tuple2<List<Long>, List<Double>> arg1) throws Exception {
        arg0._1.addAll(arg1._1);
        arg0._2.addAll(arg1._2);
        return new Tuple2<>(arg0._1, arg0._2);
      }
    }).mapToPair(new PairFunction<Tuple2<String, Tuple2<List<Long>, List<Double>>>, String, Vector>() {
      /**
       *
       */
      private static final long serialVersionUID = 1L;

      @Override
      public Tuple2<String, Vector> call(Tuple2<String, Tuple2<List<Long>, List<Double>>> arg0) throws Exception {
        int docsize = arg0._2._1.size();
        int[] intArray = new int[docsize];
        double[] doubleArray = new double[docsize];
        for (int i = 0; i < docsize; i++) {
          intArray[i] = arg0._2._1.get(i).intValue();
          doubleArray[i] = arg0._2._2.get(i).intValue();
        }
        Vector sv = Vectors.sparse(corporsize, intArray, doubleArray);
        return new Tuple2<>(arg0._1, sv);
      }
    });

    RowMatrix wordDocMatrix = new RowMatrix(wordVectorRDD.values().rdd());

    LabeledRowMatrix labeledRowMatrix = new LabeledRowMatrix();
    labeledRowMatrix.rowMatrix = wordDocMatrix;
    labeledRowMatrix.rowkeys = wordVectorRDD.keys().collect();
    labeledRowMatrix.colkeys = uniqueDocRDD.keys().collect();
    return labeledRowMatrix;
  }