func()

in spark/sql/dataframe.go [1350:1394]


func (df *dataFrameImpl) RandomSplit(ctx context.Context, weights []float64) ([]DataFrame, error) {
	// Check that we don't have negative weights:
	total := 0.0
	for _, w := range weights {
		if w < 0.0 {
			return nil, sparkerrors.WithType(fmt.Errorf("weights must not be negative"), sparkerrors.InvalidArgumentError)
		}
		total += w
	}
	seed := rand.Int64()
	normalizedWeights := make([]float64, len(weights))
	for i, w := range weights {
		normalizedWeights[i] = w / total
	}

	// Calculate the cumulative sum of the weights:
	cumulativeWeights := make([]float64, len(weights)+1)
	cumulativeWeights[0] = 0.0
	for i := 0; i < len(normalizedWeights); i++ {
		cumulativeWeights[i+1] = cumulativeWeights[i] + normalizedWeights[i]
	}

	// Iterate over cumulative weights as the boundaries of the interval and create the dataframes:
	dataFrames := make([]DataFrame, len(weights))
	withReplacement := false
	for i := 1; i < len(cumulativeWeights); i++ {
		sampleRelation := &proto.Relation{
			Common: &proto.RelationCommon{
				PlanId: newPlanId(),
			},
			RelType: &proto.Relation_Sample{
				Sample: &proto.Sample{
					Input:              df.relation,
					LowerBound:         cumulativeWeights[i-1],
					UpperBound:         cumulativeWeights[i],
					WithReplacement:    &withReplacement,
					Seed:               &seed,
					DeterministicOrder: true,
				},
			},
		}
		dataFrames[i-1] = NewDataFrame(df.session, sampleRelation)
	}
	return dataFrames, nil
}