in src/ab/plugins/data/engine.py [0:0]
def read_data(self, ds: DataSource):
# local mode, force sampling
if app.config.SPARK['spark.master'].startswith('local'):
sample_rate, sample_count, sample = ds.sample()
spark_dataframe = self.convert_to_spark_data_type(sample, ds.get_table_info())
return sample_rate, sample_count, spark_dataframe
# else load all data
if ds.type_ in ('mysql', 'ads'):
return 100, None, spark.get_or_create().read.format('jdbc') \
.option('url', ds.db.jdbc_url) \
.option('driver', 'com.mysql.jdbc.Driver') \
.option('dbtable', ds.table_name) \
.option('user', ds.db.username) \
.option('password', ds.db.password) \
.option('useUnicode', True) \
.option('characterEncoding', 'UTF-8') \
.load()
elif ds.type_ == 'hive':
sql = 'SELECT * FROM {db}.{table_name}'.format(db=ds.db.db, table_name=ds.table_name)
if ds.partitions:
condition = ODPS.join_partitions(ds.partitions)
sql += ' where {condition}'.format(condition=condition)
logger.info('sample sql:', sql)
return 100, None, spark.get_or_create().sql(sql)
elif ds.type_ == 'odps':
if not ds.db.tunnel_endpoint:
raise AlgorithmException('please set odps tunnel endpoint')
# cluster mode, read all data
data = spark.get_or_create().read \
.format("org.apache.spark.aliyun.odps.datasource") \
.option("odpsUrl", ds.db.endpoint) \
.option("tunnelUrl", ds.db.tunnel_endpoint) \
.option("table", ds.table_name) \
.option("project", ds.db.project) \
.option("accessKeyId", ds.db.access_id) \
.option("accessKeySecret", ds.db.access_key)
num_partitions = spark.get_or_create().sparkContext.getConf().get('odps.numPartitions')
if num_partitions:
data = data.option("numPartitions", int(num_partitions))
if ds.partitions:
assert len(ds.partitions) == 1, 'spark can only read one odps partition'
data = data.option('partitionSpec', ds.partitions[0])
return 100, None, data.load()
raise AlgorithmException('unrecognized data source type for spark: {ds.type_}'.format(ds=ds))