def run()

in 11_realtime/create_traindata.py [0:0]


def run(project, bucket, region, input):
    if input == 'local':
        logging.info('Running locally on small extract')
        argv = [
            '--runner=DirectRunner'
        ]
        flights_output = '/tmp/'
    else:
        logging.info('Running in the cloud on full dataset input={}'.format(input))
        argv = [
            '--project={0}'.format(project),
            '--job_name=ch11traindata',
            # '--save_main_session', # not needed as we are running as a package now
            '--staging_location=gs://{0}/flights/staging/'.format(bucket),
            '--temp_location=gs://{0}/flights/temp/'.format(bucket),
            '--setup_file=./setup.py',
            '--autoscaling_algorithm=THROUGHPUT_BASED',
            '--max_num_workers=20',
            # '--max_num_workers=4', '--worker_machine_type=m1-ultramem-40', '--disk_size_gb=500',  # for full 2015-2019 dataset
            '--region={}'.format(region),
            '--runner=DataflowRunner'
        ]
        flights_output = 'gs://{}/ch11/data/'.format(bucket)

    with beam.Pipeline(argv=argv) as pipeline:

        # read the event stream
        if input == 'local':
            input_file = './alldata_sample.json'
            logging.info("Reading from {} ... Writing to {}".format(input_file, flights_output))
            events = (
                    pipeline
                    | 'read_input' >> beam.io.ReadFromText(input_file)
                    | 'parse_input' >> beam.Map(lambda line: json.loads(line))
            )
        elif input == 'bigquery':
            input_table = 'dsongcp.flights_tzcorr'
            logging.info("Reading from {} ... Writing to {}".format(input_table, flights_output))
            events = (
                    pipeline
                    | 'read_input' >> beam.io.ReadFromBigQuery(table=input_table)
            )
        else:
            logging.error("Unknown input type {}".format(input))
            return

        # events -> features.  See ./flights_transforms.py for the code shared between training & prediction
        features = ftxf.transform_events_to_features(events)

        # shuffle globally so that we are not at mercy of TensorFlow's shuffle buffer
        features = (
            features
            | 'into_global' >> beam.WindowInto(beam.window.GlobalWindows())
            | 'shuffle' >> beam.util.Reshuffle()
        )

        # write out
        for split in ['ALL', 'TRAIN', 'VALIDATE', 'TEST']:
            feats = features
            if split != 'ALL':
                feats = feats | 'only_{}'.format(split) >> beam.Filter(lambda f: f['data_split'] == split)
            (
                feats
                | '{}_to_string'.format(split) >> beam.FlatMap(dict_to_csv)
                | '{}_to_gcs'.format(split) >> beam.io.textio.WriteToText(os.path.join(flights_output, split.lower()),
                                                                          file_name_suffix='.csv', header=CSV_HEADER,
                                                                          # workaround b/207384805
                                                                          num_shards=1)
            )