func AsDoFn()

in sdks/go/pkg/beam/core/graph/fn.go [490:649]


func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
	addContext := func(err error, fn *Fn) error {
		return errors.WithContextf(err, "graph.AsDoFn: for Fn named %v", fn.Name())
	}

	if fn.methods == nil {
		fn.methods = make(map[string]*funcx.Fn)
	}
	if fn.Fn != nil {
		fn.methods[processElementName] = fn.Fn
	}

	if _, ok := fn.methods[processElementName]; !ok {
		err := errors.Errorf("failed to find %v method", processElementName)
		if fn.Recv != nil {
			v := reflect.ValueOf(fn.Recv)
			if v.Kind() != reflect.Ptr {
				err = errors.Wrap(err, "structural DoFn passed by value, ensure that the ProcessElement method has a value receiver or pass the DoFn by pointer")
			}
		}
		return nil, addContext(err, fn)
	}

	// Make sure that all state entries have keys. If they don't set them to the struct field name.
	if fn.Recv != nil {
		v := reflect.Indirect(reflect.ValueOf(fn.Recv))
		for i := 0; i < v.NumField(); i++ {
			f := v.Field(i)
			if f.CanInterface() {
				if ps, ok := f.Interface().(state.PipelineState); ok {
					if ps.StateKey() == "" {
						f.FieldByName("Key").SetString(v.Type().Field(i).Name)
					}
				}
			}
		}
	}

	// Validate ProcessElement has correct number of main inputs (as indicated by
	// numMainIn), and that main inputs are before side inputs.
	processFn := fn.methods[processElementName]
	if err := validateMainInputs(fn, processFn, processElementName, numMainIn); err != nil {
		return nil, addContext(err, fn)
	}

	// If numMainIn is unknown, we can try inferring it from the number of inputs in ProcessElement.
	pos, num, _ := processFn.Inputs()
	if numMainIn == MainUnknown && num == 1 {
		numMainIn = MainSingle
	}

	// If the ProcessElement function includes side inputs or emit functions those must also be
	// present in the signatures of startBundle and finishBundle.
	processFnInputs := processFn.Param[pos : pos+num]
	if startFn, ok := fn.methods[startBundleName]; ok {
		if err := validateSideInputs(processFnInputs, startFn, startBundleName, numMainIn); err != nil {
			return nil, addContext(err, fn)
		}
	}
	if finishFn, ok := fn.methods[finishBundleName]; ok {
		if err := validateSideInputs(processFnInputs, finishFn, finishBundleName, numMainIn); err != nil {
			return nil, addContext(err, fn)
		}
	}

	pos, num, ok := processFn.Emits()
	var processFnEmits []funcx.FnParam
	if ok {
		processFnEmits = processFn.Param[pos : pos+num]
	} else {
		processFnEmits = processFn.Param[0:0]
	}
	if startFn, ok := fn.methods[startBundleName]; ok {
		if err := validateEmits(processFnEmits, startFn, startBundleName); err != nil {
			return nil, addContext(err, fn)
		}
	}
	if finishFn, ok := fn.methods[finishBundleName]; ok {
		if err := validateEmits(processFnEmits, finishFn, finishBundleName); err != nil {
			return nil, addContext(err, fn)
		}
	}

	// Check that Setup and Teardown have no parameters other than Context.
	for _, name := range []string{setupName, teardownName} {
		if method, ok := fn.methods[name]; ok {
			params := method.Param
			if len(params) > 1 || (len(params) == 1 && params[0].Kind != funcx.FnContext) {
				err := errors.Errorf(
					"method %v has invalid parameters, "+
						"only allowed an optional context.Context", name)
				err = errors.SetTopLevelMsgf(err,
					"Method %v of DoFns should have no parameters other than "+
						"an optional context.Context, but invalid parameters are "+
						"present in DoFn %v.",
					name, fn.Name())
				return nil, addContext(err, fn)
			}
		}
	}

	// Check that none of the methods (except ProcessElement) have any return
	// values other than error.
	for _, name := range []string{setupName, startBundleName, finishBundleName, teardownName} {
		if method, ok := fn.methods[name]; ok {
			returns := method.Ret
			if len(returns) > 1 || (len(returns) == 1 && returns[0].Kind != funcx.RetError) {
				err := errors.Errorf(
					"method %v has invalid return values, "+
						"only allowed an optional error", name)
				err = errors.SetTopLevelMsgf(err,
					"Method %v of DoFns should have no return values other "+
						"than an optional error, but invalid return values are present "+
						"in DoFn %v.",
					name, fn.Name())
				return nil, addContext(err, fn)
			}
		}
	}

	// Check whether to perform SDF validation.
	isSdf, err := validateIsSdf(fn)
	if err != nil {
		return nil, addContext(err, fn)
	}

	// Perform validation on the SDF method signatures to ensure they're valid.
	if isSdf {
		err := validateSdfSignatures(fn, numMainIn)
		if err != nil {
			return nil, addContext(err, fn)
		}
	}

	isWatermarkEstimating, err := validateIsWatermarkEstimating(fn, isSdf)
	if err != nil {
		return nil, addContext(err, fn)
	}

	if isWatermarkEstimating {
		err := validateWatermarkSig(fn, int(numMainIn))
		if err != nil {
			return nil, addContext(err, fn)
		}
	}

	doFn := (*DoFn)(fn)

	err = validateState(doFn, numMainIn)
	if err != nil {
		return nil, addContext(err, fn)
	}

	err = validateTimer(doFn, numMainIn)
	if err != nil {
		return nil, addContext(err, fn)
	}

	return doFn, nil
}