protected static void configureComparisonPipeline()

in spanner-data-validator-java/src/main/java/com/google/migration/JDBCToSpannerDVTWithHash.java [238:449]


  protected static void configureComparisonPipeline(Pipeline p,
      PipelineTracker pipelineTracker,
      DVTOptionsCore options,
      TableSpec tableSpec,
      BigQueryIO.Write<ComparerResult> comparerResultWrite,
      BigQueryIO.Write<HashResult> jdbcConflictingRecordsWriter,
      BigQueryIO.Write<HashResult> spannerConflictingRecordsWriter,
      CustomTransformation customTransformation,
      Schema schema) {

    Integer partitionCount = options.getPartitionCount();
    if(tableSpec.getPartitionCount() > 0) {
      partitionCount = tableSpec.getPartitionCount();
    }

    Integer partitionFilterRatio = options.getPartitionFilterRatio();
    if(tableSpec.getPartitionFilterRatio() > 0) {
      partitionFilterRatio = tableSpec.getPartitionFilterRatio();
    }

    List<PartitionRange> bRanges = getPartitionRanges(tableSpec,
        partitionCount,
        partitionFilterRatio);

    Helpers.printPartitionRanges(bRanges, tableSpec.getTableName());

    String tableName = tableSpec.getTableName();
    String shardConfigurationFileUrl = options.getSourceConfigURL();

    String createRangesForTableStep = String.format("CreateRangesForTable-%s", tableName);
    PCollection<PartitionRange> pRanges = p.apply(createRangesForTableStep, Create.of(bRanges));

    // get ranges of keys
    String partitionRangesViewStep = String.format("PartitionRangesForTable-%s", tableName);
    final PCollectionView<List<PartitionRange>> partitionRangesView =
        pRanges.apply(partitionRangesViewStep, View.asList());

    PCollection<HashResult> spannerRecords =
        getSpannerRecords(tableName,
            pipelineTracker,
            tableSpec.getDestQuery(),
            tableSpec.getRangeFieldIndex(),
            tableSpec.getRangeFieldType(),
            options,
            pRanges,
            tableSpec.getTimestampThresholdColIndex());

    pipelineTracker.addToSpannerReadList(spannerRecords);

    // Map Range [start, end) + hash => HashResult (spanner)
    String mapWithRangesForSpannerStep =
        String.format("MapWithRangesSpannerRecordsForTable-%s", tableName);
    PCollection<KV<String, HashResult>> mappedWithHashSpannerRecords =
        spannerRecords.apply(mapWithRangesForSpannerStep, ParDo.of(new MapWithRangeFn(partitionRangesView,
                MapWithRangeType.RANGE_PLUS_HASH,
                tableSpec.getRangeFieldType()))
            .withSideInputs(partitionRangesView));

    pRanges = (PCollection<PartitionRange>) pipelineTracker.applyJDBCWait(pRanges);

    PCollection<HashResult> jdbcRecords;

    if(Helpers.isNullOrEmpty(shardConfigurationFileUrl)) {
      jdbcRecords =
          getJDBCRecords(tableName,
              pipelineTracker,
              tableSpec.getSourceQuery(),
              tableSpec.getRangeFieldIndex(),
              tableSpec.getRangeFieldType(),
              options,
              pRanges,
              tableSpec.getTimestampThresholdColIndex(),
              customTransformation,
              schema);
    } else {
      jdbcRecords =
          getJDBCRecordsWithSharding(tableName,
              pipelineTracker,
              tableSpec.getSourceQuery(),
              tableSpec.getRangeFieldIndex(),
              tableSpec.getRangeFieldType(),
              options,
              pRanges,
              tableSpec.getTimestampThresholdColIndex(),
              customTransformation,
              schema);
    }

    pipelineTracker.addToJDBCReadList(jdbcRecords);

    // Map Range [start, end) + hash => HashResult (JDBC)
    String mapWithRangesForJDBCStep =
        String.format("MapWithRangesJDBCRecordsForTable-%s", tableName);
    PCollection<KV<String, HashResult>> mappedWithHashJdbcRecords =
        jdbcRecords
            .apply(mapWithRangesForJDBCStep, ParDo.of(new MapWithRangeFn(partitionRangesView,
                MapWithRangeType.RANGE_PLUS_HASH,
                tableSpec.getRangeFieldType()))
            .withSideInputs(partitionRangesView));

    // Group by range [start, end) + hash => {JDBC HashResult if it exists, Spanner HashResult if it exists}
    String groupByKeyStep = String.format("GroupByKeyForTable-%s", tableName);
    PCollection<KV<String, CoGbkResult>> results =
        KeyedPCollectionTuple.of(jdbcTag, mappedWithHashJdbcRecords)
            .and(spannerTag, mappedWithHashSpannerRecords)
            .apply(groupByKeyStep, CoGroupByKey.create());

    // Now tag the results by range
    PCollectionTuple countMatches = results.apply(
        String.format("CountMatchesForTable-%s", tableName),
        ParDo.of(new CountMatchesDoFn(tableSpec.getTimestampThresholdValue(), tableSpec.getTimestampThresholdDeltaInMins()))
            .withOutputTags(matchedRecordsTag,
                TupleTagList.of(unmatchedSpannerRecordsTag)
                    .and(unmatchedJDBCRecordsTag)
                    .and(sourceRecordsTag)
                    .and(targetRecordsTag)
                    .and(unmatchedSpannerRecordValuesTag)
                    .and(unmatchedJDBCRecordValuesTag)));

    // Count the tagged results by range
    PCollection<KV<String, Long>> matchedRecordCount =
        countMatches
            .get(matchedRecordsTag)
            .apply(String.format("MatchedCountForTable-%s", tableName), Count.perKey());

    PCollection<KV<String, Long>> unmatchedJDBCRecordCount =
        countMatches
            .get(unmatchedJDBCRecordsTag)
            .apply(String.format("UnmatchedCountForTable-%s", tableName), Count.perKey());

    PCollection<KV<String, Long>> unmatchedSpannerRecordCount =
        countMatches
            .get(unmatchedSpannerRecordsTag)
            .apply(String.format("UnmatchedSpannerCountForTable-%s", tableName), Count.perKey());

    PCollection<KV<String, Long>> sourceRecordCount =
        countMatches
            .get(sourceRecordsTag)
            .apply(String.format("UnmatchedJDBCCountForTable-%s", tableName), Count.perKey());

    PCollection<KV<String, Long>> targetRecordCount =
        countMatches
            .get(targetRecordsTag)
            .apply(String.format("TargetCountForTable-%s", tableName), Count.perKey());

    if(spannerConflictingRecordsWriter != null) {
      PCollection<HashResult> unmatchedSpannerValues =
          countMatches.get(unmatchedSpannerRecordValuesTag);

      unmatchedSpannerValues.apply(String.format("SpannerConflictingRecordsWriter-%s", tableName),
          spannerConflictingRecordsWriter);

      LOG.info("****** Writing spanner conflicting records");
    } else {
      LOG.info("****** Not writing spanner conflicting records");
    }

    if(jdbcConflictingRecordsWriter != null) {
      PCollection<HashResult> unmatchedJDBCValues =
          countMatches.get(unmatchedJDBCRecordValuesTag);

      unmatchedJDBCValues.apply(String.format("JDBCConflictingRecordsWriter-%s", tableName),
          jdbcConflictingRecordsWriter);

      LOG.info("****** Writing JDBC conflicting records");
    } else {
      LOG.info("****** Not writing JDBC conflicting records");
    }

    // group above counts by key
    PCollection<KV<String, CoGbkResult>> comparerResults =
        KeyedPCollectionTuple.of(matchedRecordCountTag, matchedRecordCount)
            .and(unmatchedSpannerRecordCountTag, unmatchedSpannerRecordCount)
            .and(unmatchedJDBCRecordCountTag, unmatchedJDBCRecordCount)
            .and(sourceRecordCountTag, sourceRecordCount)
            .and(targetRecordCountTag, targetRecordCount)
            .apply(String.format("GroupCountsByKeyForTable-%s", tableName), CoGroupByKey.create());

    String runName = options.getRunName();

    // assign grouped counts to object that can then be written to BQ
    PCollection<ComparerResult> reportOutput =
        comparerResults.apply(String.format("ReportOutputForTable-%s", tableName),
            ParDo.of(
            new DoFn<KV<String, CoGbkResult>, ComparerResult>() {
              @ProcessElement
              public void processElement(ProcessContext c) {
                ComparerResult comparerResult =
                    new ComparerResult(runName, c.element().getKey());

                comparerResult.matchCount =
                    getCountForTag(c.element().getValue(), matchedRecordCountTag);

                comparerResult.sourceConflictCount =
                    getCountForTag(c.element().getValue(), unmatchedJDBCRecordCountTag);

                comparerResult.targetConflictCount =
                    getCountForTag(c.element().getValue(), unmatchedSpannerRecordCountTag);

                comparerResult.sourceCount =
                    getCountForTag(c.element().getValue(), sourceRecordCountTag);

                comparerResult.targetCount =
                    getCountForTag(c.element().getValue(), targetRecordCountTag);

                c.output(comparerResult);
              }
            }));

    reportOutput.apply(String.format("BQWriteForTable-%s", tableName),
        comparerResultWrite);
  }