def postprocess_columns()

in superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py [0:0]


def postprocess_columns(session: Session) -> None:  # noqa: C901
    """
    At this step, we will
      - Add engine specific quotes to `expression` of physical columns
      - Tuck some extra metadata to `extra_json`
    """
    total = session.query(NewColumn).count()
    if not total:
        return

    def get_joined_tables(offset, limit):
        # Import aliased from sqlalchemy
        from sqlalchemy.orm import aliased

        # Create alias of NewColumn
        new_column_alias = aliased(NewColumn)
        # Get subquery and give it the alias "sl_colums_2"
        subquery = (
            session.query(new_column_alias)
            .offset(offset)
            .limit(limit)
            .subquery("sl_columns_2")
        )

        return (
            sa.join(
                subquery,
                NewColumn,
                # Use column id from subquery
                subquery.c.id == NewColumn.id,
            )
            .join(
                dataset_column_association_table,
                # Use column id from subquery
                dataset_column_association_table.c.column_id == subquery.c.id,
            )
            .join(
                NewDataset,
                NewDataset.id == dataset_column_association_table.c.dataset_id,
            )
            .join(
                dataset_table_association_table,
                # Join tables with physical datasets
                and_(
                    NewDataset.is_physical,
                    dataset_table_association_table.c.dataset_id == NewDataset.id,
                ),
                isouter=True,
            )
            .join(Database, Database.id == NewDataset.database_id)
            .join(
                TableColumn,
                # Use column uuid from subquery
                TableColumn.uuid == subquery.c.uuid,
                isouter=True,
            )
            .join(
                SqlMetric,
                # Use column uuid from subquery
                SqlMetric.uuid == subquery.c.uuid,
                isouter=True,
            )
        )

    offset = 0
    limit = 100000

    print(f">> Run postprocessing on {total:,} columns")

    update_count = 0

    def print_update_count():
        if SHOW_PROGRESS:
            print(
                f"   Will update {update_count} columns" + " " * 20,
                end="\r",
            )

    while offset < total:
        query = (
            select(
                # sorted alphabetically
                [
                    NewColumn.id.label("column_id"),
                    TableColumn.column_name,
                    NewColumn.changed_by_fk,
                    NewColumn.changed_on,
                    NewColumn.created_on,
                    NewColumn.description,
                    SqlMetric.d3format,
                    NewDataset.external_url,
                    NewColumn.extra_json,
                    NewColumn.is_dimensional,
                    NewColumn.is_filterable,
                    NewDataset.is_managed_externally,
                    NewColumn.is_physical,
                    SqlMetric.metric_type,
                    TableColumn.python_date_format,
                    Database.sqlalchemy_uri,
                    dataset_table_association_table.c.table_id,
                    func.coalesce(
                        TableColumn.verbose_name, SqlMetric.verbose_name
                    ).label("verbose_name"),
                    NewColumn.warning_text,
                ]
            )
            .select_from(get_joined_tables(offset, limit))
            .where(
                # pre-filter to columns with potential updates
                or_(
                    NewColumn.is_physical,
                    TableColumn.verbose_name.isnot(None),
                    TableColumn.verbose_name.isnot(None),
                    SqlMetric.verbose_name.isnot(None),
                    SqlMetric.d3format.isnot(None),
                    SqlMetric.metric_type.isnot(None),
                )
            )
        )

        start = offset + 1
        end = min(total, offset + limit)
        count = session.query(func.count()).select_from(query).scalar()
        print(f"   [Column {start:,} to {end:,}] {count:,} may be updated")

        physical_columns = []

        for (
            # sorted alphabetically
            column_id,
            column_name,
            changed_by_fk,
            changed_on,
            created_on,
            description,
            d3format,
            external_url,
            extra_json,
            is_dimensional,
            is_filterable,
            is_managed_externally,
            is_physical,
            metric_type,
            python_date_format,
            sqlalchemy_uri,
            table_id,
            verbose_name,
            warning_text,
        ) in session.execute(query):
            try:
                extra = json.loads(extra_json) if extra_json else {}
            except json.JSONDecodeError:
                extra = {}
            updated_extra = {**extra}
            updates = {}

            if is_managed_externally:
                updates["is_managed_externally"] = True
            if external_url:
                updates["external_url"] = external_url

            # update extra json
            for key, val in (
                {
                    "verbose_name": verbose_name,
                    "python_date_format": python_date_format,
                    "d3format": d3format,
                    "metric_type": metric_type,
                }
            ).items():
                # save the original val, including if it's `false`
                if val is not None:
                    updated_extra[key] = val

            if updated_extra != extra:
                updates["extra_json"] = json.dumps(updated_extra)

            # update expression for physical table columns
            if is_physical:
                if column_name and sqlalchemy_uri:
                    drivername = sqlalchemy_uri.split("://")[0]
                    if is_physical and drivername:
                        quoted_expression = get_identifier_quoter(drivername)(
                            column_name
                        )
                        if quoted_expression != column_name:
                            updates["expression"] = quoted_expression
                # duplicate physical columns for tables
                physical_columns.append(
                    dict(  # noqa: C408
                        created_on=created_on,
                        changed_on=changed_on,
                        changed_by_fk=changed_by_fk,
                        description=description,
                        expression=updates.get("expression", column_name),
                        external_url=external_url,
                        extra_json=updates.get("extra_json", extra_json),
                        is_aggregation=False,
                        is_dimensional=is_dimensional,
                        is_filterable=is_filterable,
                        is_managed_externally=is_managed_externally,
                        is_physical=True,
                        name=column_name,
                        table_id=table_id,
                        warning_text=warning_text,
                    )
                )

            if updates:
                session.execute(
                    sa.update(NewColumn)
                    .where(NewColumn.id == column_id)
                    .values(**updates)
                )
                update_count += 1
                print_update_count()

        if physical_columns:
            op.bulk_insert(NewColumn.__table__, physical_columns)

        session.flush()
        offset += limit

    if SHOW_PROGRESS:
        print("")

    print("   Assign table column relations...")
    insert_from_select(
        table_column_association_table,
        select([NewColumn.table_id, NewColumn.id.label("column_id")])
        .select_from(NewColumn)
        .where(and_(NewColumn.is_physical, NewColumn.table_id.isnot(None))),
    )