in spark/sql/dataframe.go [1350:1394]
func (df *dataFrameImpl) RandomSplit(ctx context.Context, weights []float64) ([]DataFrame, error) {
// Check that we don't have negative weights:
total := 0.0
for _, w := range weights {
if w < 0.0 {
return nil, sparkerrors.WithType(fmt.Errorf("weights must not be negative"), sparkerrors.InvalidArgumentError)
}
total += w
}
seed := rand.Int64()
normalizedWeights := make([]float64, len(weights))
for i, w := range weights {
normalizedWeights[i] = w / total
}
// Calculate the cumulative sum of the weights:
cumulativeWeights := make([]float64, len(weights)+1)
cumulativeWeights[0] = 0.0
for i := 0; i < len(normalizedWeights); i++ {
cumulativeWeights[i+1] = cumulativeWeights[i] + normalizedWeights[i]
}
// Iterate over cumulative weights as the boundaries of the interval and create the dataframes:
dataFrames := make([]DataFrame, len(weights))
withReplacement := false
for i := 1; i < len(cumulativeWeights); i++ {
sampleRelation := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
},
RelType: &proto.Relation_Sample{
Sample: &proto.Sample{
Input: df.relation,
LowerBound: cumulativeWeights[i-1],
UpperBound: cumulativeWeights[i],
WithReplacement: &withReplacement,
Seed: &seed,
DeterministicOrder: true,
},
},
}
dataFrames[i-1] = NewDataFrame(df.session, sampleRelation)
}
return dataFrames, nil
}