in superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py [0:0]
def postprocess_datasets(session: Session) -> None: # noqa: C901
"""
Postprocess datasets after insertion to
- Quote table names for physical datasets (if needed)
- Link referenced tables to virtual datasets
"""
total = session.query(SqlaTable).count()
if not total:
return
offset = 0
limit = 10000
joined_tables = sa.join(
NewDataset,
SqlaTable,
NewDataset.uuid == SqlaTable.uuid,
).join(
Database,
Database.id == SqlaTable.database_id,
isouter=True,
)
assert session.query(func.count()).select_from(joined_tables).scalar() == total
print(f">> Run postprocessing on {total} datasets")
update_count = 0
def print_update_count():
if SHOW_PROGRESS:
print(
f" Will update {update_count} datasets" + " " * 20,
end="\r",
)
while offset < total:
print(
f" Process dataset {offset + 1}~{min(total, offset + limit)}..."
+ " " * 30
)
for (
database_id,
dataset_id,
expression,
extra,
is_physical,
schema,
sqlalchemy_uri,
) in session.execute(
select(
[
NewDataset.database_id,
NewDataset.id.label("dataset_id"),
NewDataset.expression,
SqlaTable.extra,
NewDataset.is_physical,
SqlaTable.schema,
Database.sqlalchemy_uri,
]
)
.select_from(joined_tables)
.offset(offset)
.limit(limit)
):
drivername = (sqlalchemy_uri or "").split("://")[0]
updates = {}
updated = False
if is_physical and drivername and expression:
quoted_expression = get_identifier_quoter(drivername)(expression)
if quoted_expression != expression:
updates["expression"] = quoted_expression
# add schema name to `dataset.extra_json` so we don't have to join
# tables in order to use datasets
if schema:
try:
extra_json = json.loads(extra) if extra else {}
except json.JSONDecodeError:
extra_json = {}
extra_json["schema"] = schema
updates["extra_json"] = json.dumps(extra_json)
if updates:
session.execute(
sa.update(NewDataset)
.where(NewDataset.id == dataset_id)
.values(**updates)
)
updated = True
if not is_physical and drivername and expression:
table_refrences = extract_table_references(
expression, get_dialect_name(drivername), show_warning=False
)
found_tables = find_tables(
session, database_id, schema, table_refrences
)
if found_tables:
op.bulk_insert(
dataset_table_association_table,
[
{"dataset_id": dataset_id, "table_id": table.id}
for table in found_tables
],
)
updated = True
if updated:
update_count += 1
print_update_count()
session.flush()
offset += limit
if SHOW_PROGRESS:
print("")