public static PipelineResult run()

in v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java [465:781]


  public static PipelineResult run(Options options) {

    Pipeline pipeline = Pipeline.create(options);
    pipeline
        .getOptions()
        .as(DataflowPipelineWorkerPoolOptions.class)
        .setAutoscalingAlgorithm(AutoscalingAlgorithmType.THROUGHPUT_BASED);

    // calculate the max connections per worker
    int maxNumWorkers =
        pipeline.getOptions().as(DataflowPipelineWorkerPoolOptions.class).getMaxNumWorkers() > 0
            ? pipeline.getOptions().as(DataflowPipelineWorkerPoolOptions.class).getMaxNumWorkers()
            : 1;
    int connectionPoolSizePerWorker = (int) (options.getMaxShardConnections() / maxNumWorkers);
    if (connectionPoolSizePerWorker < 1) {
      // This can happen when the number of workers is more than max.
      // This can cause overload on the source database. Error out and let the user know.
      LOG.error(
          "Max workers {} is more than max shard connections {}, this can lead to more database"
              + " connections than desired",
          maxNumWorkers,
          options.getMaxShardConnections());
      throw new IllegalArgumentException(
          "Max Dataflow workers "
              + maxNumWorkers
              + " is more than max per shard connections: "
              + options.getMaxShardConnections()
              + " this can lead to more"
              + " database connections than desired. Either reduce the max allowed workers or"
              + " incease the max shard connections");
    }

    // Read the session file for Mysql Only
    Schema schema =
        MYSQL_SOURCE_TYPE.equals(options.getSourceType())
            ? SessionFileReader.read(
                options.getSessionFilePath()) // Read from session file for MYSQL source type
            : new Schema();

    // Prepare Spanner config
    SpannerConfig spannerConfig =
        SpannerConfig.create()
            .withProjectId(ValueProvider.StaticValueProvider.of(options.getSpannerProjectId()))
            .withInstanceId(ValueProvider.StaticValueProvider.of(options.getInstanceId()))
            .withDatabaseId(ValueProvider.StaticValueProvider.of(options.getDatabaseId()));

    // Create shadow tables
    // Note that there is a limit on the number of tables that can be created per DB: 5000.
    // If we create shadow tables per shard, there will be an explosion of tables.
    // Anyway the shadow table has Spanner PK so no need to again separate by the shard
    // Lookup by the Spanner PK should be sufficient.

    // Prepare Spanner config
    SpannerConfig spannerMetadataConfig =
        SpannerConfig.create()
            .withProjectId(ValueProvider.StaticValueProvider.of(options.getSpannerProjectId()))
            .withInstanceId(ValueProvider.StaticValueProvider.of(options.getMetadataInstance()))
            .withDatabaseId(ValueProvider.StaticValueProvider.of(options.getMetadataDatabase()));

    ShadowTableCreator shadowTableCreator =
        new ShadowTableCreator(
            spannerConfig,
            spannerMetadataConfig,
            SpannerAccessor.getOrCreate(spannerMetadataConfig)
                .getDatabaseAdminClient()
                .getDatabase(
                    spannerMetadataConfig.getInstanceId().get(),
                    spannerMetadataConfig.getDatabaseId().get())
                .getDialect(),
            options.getShadowTablePrefix());

    DataflowPipelineDebugOptions debugOptions = options.as(DataflowPipelineDebugOptions.class);

    shadowTableCreator.createShadowTablesInSpanner();
    Ddl ddl = SpannerSchema.getInformationSchemaAsDdl(spannerConfig);
    List<Shard> shards;
    String shardingMode;
    if (MYSQL_SOURCE_TYPE.equals(options.getSourceType())) {
      ShardFileReader shardFileReader = new ShardFileReader(new SecretManagerAccessorImpl());
      shards = shardFileReader.getOrderedShardDetails(options.getSourceShardsFilePath());
      shardingMode = Constants.SHARDING_MODE_MULTI_SHARD;

    } else {
      CassandraConfigFileReader cassandraConfigFileReader = new CassandraConfigFileReader();
      shards = cassandraConfigFileReader.getCassandraShard(options.getSourceShardsFilePath());
      LOG.info("Cassandra config is: {}", shards.get(0));
      shardingMode = Constants.SHARDING_MODE_SINGLE_SHARD;
    }
    if (shards.size() == 1 && !options.getIsShardedMigration()) {
      shardingMode = Constants.SHARDING_MODE_SINGLE_SHARD;
      Shard shard = shards.get(0);
      if (shard.getLogicalShardId() == null) {
        shard.setLogicalShardId(Constants.DEFAULT_SHARD_ID);
        LOG.info(
            "Logical shard id was not found, hence setting it to : " + Constants.DEFAULT_SHARD_ID);
      }
    }

    if (options.getSourceType().equals(CASSANDRA_SOURCE_TYPE)) {
      Map<String, SpannerTable> spannerTableMap =
          SpannerSchema.convertDDLTableToSpannerTable(ddl.allTables());
      Map<String, NameAndCols> spannerTableNameColsMap =
          SpannerSchema.convertDDLTableToSpannerNameAndColsTable(ddl.allTables());
      try {
        CassandraSourceMetadata cassandraSourceMetadata =
            new CassandraSourceMetadata.Builder()
                .setResultSet(
                    CassandraSourceSchemaReader.getInformationSchemaAsResultSet(
                        (CassandraShard) shards.get(0)))
                .build();
        schema =
            new Schema(
                spannerTableMap,
                null,
                cassandraSourceMetadata.getSourceTableMap(),
                spannerTableNameColsMap,
                cassandraSourceMetadata.getNameAndColsMap());
      } catch (Exception e) {
        throw new IllegalArgumentException(e);
      }
    }

    boolean isRegularMode = "regular".equals(options.getRunMode());
    PCollectionTuple reconsumedElements = null;
    DeadLetterQueueManager dlqManager = buildDlqManager(options);

    int reshuffleBucketSize =
        maxNumWorkers
            * (debugOptions.getNumberOfWorkerHarnessThreads() > 0
                ? debugOptions.getNumberOfWorkerHarnessThreads()
                : Constants.DEFAULT_WORKER_HARNESS_THREAD_COUNT);

    if (isRegularMode) {
      reconsumedElements =
          dlqManager.getReconsumerDataTransformForFiles(
              pipeline.apply(
                  "Read retry from PubSub",
                  new PubSubNotifiedDlqIO(
                      options.getDlqGcsPubSubSubscription(),
                      // file paths to ignore when re-consuming for retry
                      new ArrayList<String>(
                          Arrays.asList(
                              "/severe/",
                              "/tmp_retry",
                              "/tmp_severe/",
                              ".temp",
                              "/tmp_skip/",
                              "/" + options.getSkipDirectoryName())))));
    } else {
      reconsumedElements =
          dlqManager.getReconsumerDataTransform(
              pipeline.apply(dlqManager.dlqReconsumer(options.getDlqRetryMinutes())));
    }

    PCollection<FailsafeElement<String, String>> dlqJsonStrRecords =
        reconsumedElements
            .get(DeadLetterQueueManager.RETRYABLE_ERRORS)
            .setCoder(FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));

    PCollection<TrimmedShardedDataChangeRecord> dlqRecords =
        dlqJsonStrRecords
            .apply(
                "Convert DLQ records to TrimmedShardedDataChangeRecord",
                ParDo.of(new ConvertDlqRecordToTrimmedShardedDataChangeRecordFn()))
            .setCoder(SerializableCoder.of(TrimmedShardedDataChangeRecord.class));
    PCollection<TrimmedShardedDataChangeRecord> mergedRecords = null;
    if (isRegularMode) {
      PCollection<TrimmedShardedDataChangeRecord> changeRecordsFromDB =
          pipeline
              .apply(
                  getReadChangeStreamDoFn(
                      options,
                      spannerConfig)) // This emits PCollection<DataChangeRecord> which is Spanner
              // change
              // stream data
              .apply("Reshuffle", Reshuffle.viaRandomKey())
              .apply("Filteration", ParDo.of(new FilterRecordsFn(options.getFiltrationMode())))
              .apply("Preprocess", ParDo.of(new PreprocessRecordsFn()))
              .setCoder(SerializableCoder.of(TrimmedShardedDataChangeRecord.class));
      mergedRecords =
          PCollectionList.of(changeRecordsFromDB)
              .and(dlqRecords)
              .apply("Flatten", Flatten.pCollections())
              .setCoder(SerializableCoder.of(TrimmedShardedDataChangeRecord.class));
    } else {
      mergedRecords = dlqRecords;
    }
    CustomTransformation customTransformation =
        CustomTransformation.builder(
                options.getTransformationJarPath(), options.getTransformationClassName())
            .setCustomParameters(options.getTransformationCustomParameters())
            .build();
    SourceWriterTransform.Result sourceWriterOutput =
        mergedRecords
            .apply(
                "AssignShardId", // This emits PCollection<KV<Long,
                // TrimmedShardedDataChangeRecord>> which is Spanner change stream data with key as
                // PK
                // mod
                // number of parallelism
                ParDo.of(
                    new AssignShardIdFn(
                        spannerConfig,
                        schema,
                        ddl,
                        shardingMode,
                        shards.get(0).getLogicalShardId(),
                        options.getSkipDirectoryName(),
                        options.getShardingCustomJarPath(),
                        options.getShardingCustomClassName(),
                        options.getShardingCustomParameters(),
                        options.getMaxShardConnections()
                            * shards.size()))) // currently assuming that all shards accept the same
            .setCoder(
                KvCoder.of(
                    VarLongCoder.of(), SerializableCoder.of(TrimmedShardedDataChangeRecord.class)))
            .apply("Reshuffle2", Reshuffle.of())
            .apply(
                "Write to source",
                new SourceWriterTransform(
                    shards,
                    schema,
                    spannerMetadataConfig,
                    options.getSourceDbTimezoneOffset(),
                    ddl,
                    options.getShadowTablePrefix(),
                    options.getSkipDirectoryName(),
                    connectionPoolSizePerWorker,
                    options.getSourceType(),
                    customTransformation));

    PCollection<FailsafeElement<String, String>> dlqPermErrorRecords =
        reconsumedElements
            .get(DeadLetterQueueManager.PERMANENT_ERRORS)
            .setCoder(FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));

    PCollection<FailsafeElement<String, String>> permErrorsFromSourceWriter =
        sourceWriterOutput
            .permanentErrors()
            .setCoder(StringUtf8Coder.of())
            .apply(
                "Reshuffle3", Reshuffle.<String>viaRandomKey().withNumBuckets(reshuffleBucketSize))
            .apply(
                "Convert permanent errors from source writer to DLQ format",
                ParDo.of(new ConvertChangeStreamErrorRecordToFailsafeElementFn()))
            .setCoder(FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));

    PCollection<FailsafeElement<String, String>> permanentErrors =
        PCollectionList.of(dlqPermErrorRecords)
            .and(permErrorsFromSourceWriter)
            .apply(Flatten.pCollections())
            .apply("Reshuffle", Reshuffle.viaRandomKey());

    permanentErrors
        .apply("Update DLQ metrics", ParDo.of(new UpdateDlqMetricsFn(isRegularMode)))
        .apply(
            "DLQ: Write Severe errors to GCS",
            MapElements.via(new StringDeadLetterQueueSanitizer()))
        .setCoder(StringUtf8Coder.of())
        .apply(
            "Write To DLQ for severe errors",
            DLQWriteTransform.WriteDLQ.newBuilder()
                .withDlqDirectory(dlqManager.getSevereDlqDirectoryWithDateTime())
                .withTmpDirectory((options).getDeadLetterQueueDirectory() + "/tmp_severe/")
                .setIncludePaneInfo(true)
                .build());

    PCollection<FailsafeElement<String, String>> retryErrors =
        sourceWriterOutput
            .retryableErrors()
            .setCoder(StringUtf8Coder.of())
            .apply(
                "Reshuffle4", Reshuffle.<String>viaRandomKey().withNumBuckets(reshuffleBucketSize))
            .apply(
                "Convert retryable errors from source writer to DLQ format",
                ParDo.of(new ConvertChangeStreamErrorRecordToFailsafeElementFn()))
            .setCoder(FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));

    retryErrors
        .apply(
            "DLQ: Write retryable Failures to GCS",
            MapElements.via(new StringDeadLetterQueueSanitizer()))
        .setCoder(StringUtf8Coder.of())
        .apply(
            "Write To DLQ for retryable errors",
            DLQWriteTransform.WriteDLQ.newBuilder()
                .withDlqDirectory(dlqManager.getRetryDlqDirectoryWithDateTime())
                .withTmpDirectory(options.getDeadLetterQueueDirectory() + "/tmp_retry/")
                .setIncludePaneInfo(true)
                .build());

    PCollection<FailsafeElement<String, String>> skippedRecords =
        sourceWriterOutput
            .skippedSourceWrites()
            .setCoder(StringUtf8Coder.of())
            .apply(
                "Reshuffle5", Reshuffle.<String>viaRandomKey().withNumBuckets(reshuffleBucketSize))
            .apply(
                "Convert skipped records from source writer to DLQ format",
                ParDo.of(new ConvertChangeStreamErrorRecordToFailsafeElementFn()))
            .setCoder(FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));

    skippedRecords
        .apply(
            "Write skipped records to GCS", MapElements.via(new StringDeadLetterQueueSanitizer()))
        .setCoder(StringUtf8Coder.of())
        .apply(
            "Writing skipped records to GCS",
            DLQWriteTransform.WriteDLQ.newBuilder()
                .withDlqDirectory(
                    options.getDeadLetterQueueDirectory() + "/" + options.getSkipDirectoryName())
                .withTmpDirectory(options.getDeadLetterQueueDirectory() + "/tmp_skip/")
                .setIncludePaneInfo(true)
                .build());

    return pipeline.run();
  }