def _upsert()

in genai-on-vertex-ai/gemini/evals_playbook/utils/evals_playbook.py [0:0]


    def _upsert(self, table_class, rows, debug=False):
        """Inserts or updates rows in the specified BigQuery table.

        Args:
            table_name: The name of the table.
            rows: A list of dictionaries where each dictionary represents a row.
            update_keys: A list of keys to use for updating existing rows.
        """

        table_name, update_keys = get_table_name_keys(table_class)

        if isinstance(rows, dict):
            rows = [rows]

        # Validate that update keys are present in all rows
        all_keys = set().union(*(d.keys() for d in rows))
        for row in rows:
            for key in update_keys:
                if key not in row:
                    raise ValueError(f"Update key '{key}' not found in row: {row}")

        # Get BigQuery table schema
        table_id = f"{cfg.PROJECT_ID}.{cfg.BQ_DATASET_ID}.{table_name}"
        client = bigquery.Client(project=cfg.PROJECT_ID)
        table = client.get_table(table_id)
        schema = {schema.name:schema.field_type for schema in table.schema}

        # Construct the MERGE query dynamically
        merge_query = f"""
            MERGE INTO `{table_id}` AS target
            USING (
                SELECT * FROM UNNEST(@rows)
            ) AS source
            ON {" AND ".join(f"target.{key} = source.{key}" for key in update_keys)}
        """

        if update_keys:
            merge_query += f"""     WHEN MATCHED THEN
                UPDATE SET {", ".join(f"target.{key} = source.{key}" for key in all_keys if key not in update_keys + ['create_datetime'])}
        """

        merge_query += f"""     WHEN NOT MATCHED THEN
                INSERT({", ".join([key for key in all_keys])})
                VALUES({", ".join(f"source.{key}" for key in all_keys)})
        """

        # Convert rows to BigQuery format 
        rows_for_query = []
        for row in rows:
            row_for_query = []
            for key, val in row.items():
                field_type = schema.get(key)
                if field_type == "BOOLEAN":
                    field_type = "BOOL"
                if (val is not None):
                    if isinstance(val, datetime.datetime):
                        val = val.isoformat()
                    if isinstance(val, list):
                        row_for_query.append(bigquery.ArrayQueryParameter(key, field_type, val))
                    else:
                        row_for_query.append(bigquery.ScalarQueryParameter(key, field_type, val))
            rows_for_query.append(bigquery.StructQueryParameter("x", *row_for_query))  

        job_config = bigquery.QueryJobConfig(
            query_parameters=[bigquery.ArrayQueryParameter("rows", "STRUCT", rows_for_query)]
        )

        # -- DEBUGGING --
        print("MERGE Query:")
        print(merge_query)
        print("\nRows:")
        print(rows_for_query)
        # -- END DEBUGGING --

        query_job = client.query(merge_query, job_config=job_config)
        query_job.result()  # Wait for the MERGE to complete