in core/src/main/java/org/apache/sdap/mudrod/utils/MatrixUtil.java [131:221]
public static LabeledRowMatrix createWordDocMatrix(JavaPairRDD<String, List<String>> uniqueDocRDD) {
// Index documents with unique IDs
JavaPairRDD<List<String>, Long> corpus = uniqueDocRDD.values().zipWithIndex();
// cal word-doc numbers
JavaPairRDD<Tuple2<String, Long>, Double> worddocNumRDD = corpus.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>, Long>, Tuple2<String, Long>, Double>() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Iterator<Tuple2<Tuple2<String, Long>, Double>> call(Tuple2<List<String>, Long> docwords) throws Exception {
List<Tuple2<Tuple2<String, Long>, Double>> pairs = new ArrayList<>();
List<String> words = docwords._1;
for (String word : words) {
Tuple2<String, Long> worddoc = new Tuple2<>(word, docwords._2);
pairs.add(new Tuple2<Tuple2<String, Long>, Double>(worddoc, 1.0));
}
return pairs.iterator();
}
}).reduceByKey(new Function2<Double, Double, Double>() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Double call(Double first, Double second) throws Exception {
return first + second;
}
});
// cal word doc-numbers
JavaPairRDD<String, Tuple2<List<Long>, List<Double>>> wordDocnumRDD = worddocNumRDD
.mapToPair(new PairFunction<Tuple2<Tuple2<String, Long>, Double>, String, Tuple2<List<Long>, List<Double>>>() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Tuple2<List<Long>, List<Double>>> call(Tuple2<Tuple2<String, Long>, Double> worddocNum) throws Exception {
List<Long> docs = new ArrayList<>();
docs.add(worddocNum._1._2);
List<Double> nums = new ArrayList<>();
nums.add(worddocNum._2);
Tuple2<List<Long>, List<Double>> docmums = new Tuple2<>(docs, nums);
return new Tuple2<>(worddocNum._1._1, docmums);
}
});
// trans to vector
final int corporsize = (int) uniqueDocRDD.keys().count();
JavaPairRDD<String, Vector> wordVectorRDD = wordDocnumRDD.reduceByKey(new Function2<Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>>() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Tuple2<List<Long>, List<Double>> call(Tuple2<List<Long>, List<Double>> arg0, Tuple2<List<Long>, List<Double>> arg1) throws Exception {
arg0._1.addAll(arg1._1);
arg0._2.addAll(arg1._2);
return new Tuple2<>(arg0._1, arg0._2);
}
}).mapToPair(new PairFunction<Tuple2<String, Tuple2<List<Long>, List<Double>>>, String, Vector>() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Vector> call(Tuple2<String, Tuple2<List<Long>, List<Double>>> arg0) throws Exception {
int docsize = arg0._2._1.size();
int[] intArray = new int[docsize];
double[] doubleArray = new double[docsize];
for (int i = 0; i < docsize; i++) {
intArray[i] = arg0._2._1.get(i).intValue();
doubleArray[i] = arg0._2._2.get(i).intValue();
}
Vector sv = Vectors.sparse(corporsize, intArray, doubleArray);
return new Tuple2<>(arg0._1, sv);
}
});
RowMatrix wordDocMatrix = new RowMatrix(wordVectorRDD.values().rdd());
LabeledRowMatrix labeledRowMatrix = new LabeledRowMatrix();
labeledRowMatrix.rowMatrix = wordDocMatrix;
labeledRowMatrix.rowkeys = wordVectorRDD.keys().collect();
labeledRowMatrix.colkeys = uniqueDocRDD.keys().collect();
return labeledRowMatrix;
}