def __setitem__()

in python/pyspark/pandas/indexing.py [0:0]


    def __setitem__(self, key: Any, value: Any) -> None:
        from pyspark.pandas.frame import DataFrame
        from pyspark.pandas.series import Series, first_series

        if self._is_series:
            if (
                isinstance(key, Series)
                and (isinstance(self, iLocIndexer) or not same_anchor(key, self._psdf_or_psser))
            ) or (
                isinstance(value, Series)
                and (isinstance(self, iLocIndexer) or not same_anchor(value, self._psdf_or_psser))
            ):
                if self._psdf_or_psser.name is None:
                    psdf = self._psdf_or_psser.to_frame()
                    column_label = psdf._internal.column_labels[0]
                else:
                    psdf = self._psdf_or_psser._psdf.copy()
                    column_label = self._psdf_or_psser._column_label
                temp_natural_order = verify_temp_column_name(psdf, "__temp_natural_order__")
                temp_key_col = verify_temp_column_name(psdf, "__temp_key_col__")
                temp_value_col = verify_temp_column_name(psdf, "__temp_value_col__")

                psdf[temp_natural_order] = F.monotonically_increasing_id()
                if isinstance(key, Series):
                    psdf[temp_key_col] = key
                if isinstance(value, Series):
                    psdf[temp_value_col] = value
                psdf = psdf.sort_values(temp_natural_order).drop(columns=temp_natural_order)

                psser = psdf._psser_for(column_label)
                if isinstance(key, Series):
                    key = F.col(
                        "`{}`".format(psdf[temp_key_col]._internal.data_spark_column_names[0])
                    )
                if isinstance(value, Series):
                    value = F.col(
                        "`{}`".format(psdf[temp_value_col]._internal.data_spark_column_names[0])
                    )

                type(self)(psser)[key] = value

                if self._psdf_or_psser.name is None:
                    psser = psser.rename()

                self._psdf_or_psser._psdf._update_internal_frame(
                    psser._psdf[
                        self._psdf_or_psser._psdf._internal.column_labels
                    ]._internal.resolved_copy,
                    check_same_anchor=False,
                )
                return

            if isinstance(value, DataFrame):
                raise ValueError("Incompatible indexer with DataFrame")

            cond, limit, remaining_index = self._select_rows(key)
            if cond is None:
                cond = F.lit(True)
            if limit is not None:
                cond = cond & (
                    self._internal.spark_frame[cast(iLocIndexer, self)._sequence_col] < F.lit(limit)
                )

            if isinstance(value, (Series, PySparkColumn)):
                if remaining_index is not None and remaining_index == 0:
                    raise ValueError(
                        "No axis named {} for object type {}".format(key, type(value).__name__)
                    )
                if isinstance(value, Series):
                    value = value.spark.column
            else:
                value = F.lit(value)
            scol = (
                F.when(cond, value)
                .otherwise(self._internal.spark_column_for(self._psdf_or_psser._column_label))
                .alias(name_like_string(self._psdf_or_psser.name or SPARK_DEFAULT_SERIES_NAME))
            )

            internal = self._internal.with_new_spark_column(
                self._psdf_or_psser._column_label, scol  # TODO: dtype?
            )
            self._psdf_or_psser._psdf._update_internal_frame(internal, check_same_anchor=False)
        else:
            assert self._is_df

            if isinstance(key, tuple):
                if len(key) != 2:
                    raise SparkPandasIndexingError("Only accepts pairs of candidates")
                rows_sel, cols_sel = key
            else:
                rows_sel = key
                cols_sel = None

            if isinstance(value, DataFrame):
                if len(value.columns) == 1:
                    value = first_series(value)
                else:
                    raise ValueError("Only a dataframe with one column can be assigned")

            if (
                isinstance(rows_sel, Series)
                and (
                    isinstance(self, iLocIndexer) or not same_anchor(rows_sel, self._psdf_or_psser)
                )
            ) or (
                isinstance(value, Series)
                and (isinstance(self, iLocIndexer) or not same_anchor(value, self._psdf_or_psser))
            ):
                psdf = cast(DataFrame, self._psdf_or_psser.copy())
                temp_natural_order = verify_temp_column_name(psdf, "__temp_natural_order__")
                temp_key_col = verify_temp_column_name(psdf, "__temp_key_col__")
                temp_value_col = verify_temp_column_name(psdf, "__temp_value_col__")

                psdf[temp_natural_order] = F.monotonically_increasing_id()
                if isinstance(rows_sel, Series):
                    psdf[temp_key_col] = rows_sel
                if isinstance(value, Series):
                    psdf[temp_value_col] = value
                psdf = psdf.sort_values(temp_natural_order).drop(columns=temp_natural_order)

                if isinstance(rows_sel, Series):
                    rows_sel = F.col(
                        "`{}`".format(psdf[temp_key_col]._internal.data_spark_column_names[0])
                    )
                if isinstance(value, Series):
                    value = F.col(
                        "`{}`".format(psdf[temp_value_col]._internal.data_spark_column_names[0])
                    )

                type(self)(psdf)[rows_sel, cols_sel] = value

                self._psdf_or_psser._update_internal_frame(
                    psdf[list(self._psdf_or_psser.columns)]._internal.resolved_copy,
                    check_same_anchor=False,
                )
                return

            cond, limit, remaining_index = self._select_rows(rows_sel)
            missing_keys: List[Name] = []
            _, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys)

            if cond is None:
                cond = F.lit(True)
            if limit is not None:
                cond = cond & (
                    self._internal.spark_frame[cast(iLocIndexer, self)._sequence_col] < F.lit(limit)
                )

            if isinstance(value, (Series, PySparkColumn)):
                if remaining_index is not None and remaining_index == 0:
                    raise ValueError("Incompatible indexer with Series")
                if len(data_spark_columns) > 1:
                    raise ValueError("shape mismatch")
                if isinstance(value, Series):
                    value = value.spark.column
            else:
                value = F.lit(value)

            new_data_spark_columns = []
            new_fields = []
            for new_scol, spark_column_name, new_field in zip(
                self._internal.data_spark_columns,
                self._internal.data_spark_column_names,
                self._internal.data_fields,
            ):
                for scol in data_spark_columns:
                    if spark_column_equals(new_scol, scol):
                        new_scol = F.when(cond, value).otherwise(scol).alias(spark_column_name)
                        new_field = InternalField.from_struct_field(
                            self._internal.spark_frame.select(new_scol).schema[0],
                            use_extension_dtypes=new_field.is_extension_dtype,
                        )
                        break
                new_data_spark_columns.append(new_scol)
                new_fields.append(new_field)

            column_labels = self._internal.column_labels.copy()
            for missing in missing_keys:
                if is_name_like_tuple(missing):
                    label = cast(Label, missing)
                else:
                    label = cast(Label, (missing,))
                if len(label) < self._internal.column_labels_level:
                    label = tuple(
                        list(label) + ([""] * (self._internal.column_labels_level - len(label)))
                    )
                elif len(label) > self._internal.column_labels_level:
                    raise KeyError(
                        "Key length ({}) exceeds index depth ({})".format(
                            len(label), self._internal.column_labels_level
                        )
                    )
                column_labels.append(label)
                new_data_spark_columns.append(F.when(cond, value).alias(name_like_string(label)))
                new_fields.append(None)

            internal = self._internal.with_new_columns(
                new_data_spark_columns, column_labels=column_labels, data_fields=new_fields
            )
            self._psdf_or_psser._update_internal_frame(internal, check_same_anchor=False)