func()

in cli/azd/extensions/microsoft.azd.ai.builder/internal/pkg/qna/decision_tree.go [76:188]


func (t *DecisionTree) askQuestion(ctx context.Context, question Question, value any) error {
	if question.State == nil {
		question.State = map[string]any{}
	}

	if question.BeforeAsk != nil {
		if err := question.BeforeAsk(ctx, &question, value); err != nil {
			return fmt.Errorf("before ask function failed: %w", err)
		}
	}

	if question.Heading != "" {
		fmt.Println()
		fmt.Println(output.WithHintFormat(question.Heading))
	}

	if question.Message != "" {
		fmt.Println(question.Message)
		fmt.Println()
	}

	var response any
	var err error

	if question.Prompt != nil {
		response, err = question.Prompt.Ask(ctx, question)
		if err != nil {
			return fmt.Errorf("failed to ask question: %w", err)
		}
	}

	if question.AfterAsk != nil {
		if err := question.AfterAsk(ctx, &question, response); err != nil {
			return fmt.Errorf("after ask function failed: %w", err)
		}
	}

	t.applyBinding(question, response)

	// Handle the case where the branch is based on the user's response
	if len(question.Branches) > 0 {
		selectionValues := []any{}

		switch result := response.(type) {
		case string:
			selectionValues = append(selectionValues, result)
		case bool:
			selectionValues = append(selectionValues, result)
		case []string:
			for _, selectedValue := range result {
				selectionValues = append(selectionValues, selectedValue)
			}
		default:
			return errors.New("unsupported value type")
		}

		// We need to process all the question branches from the selected values
		// Iterate through the selected values and find the corresponding branches
		for _, selectedValue := range selectionValues {
			steps, has := question.Branches[selectedValue]
			if !has {
				log.Printf("branch not found for selected value: %s\n", selectedValue)
				continue
			}

			// Iterate through the steps in the branch
			for _, questionReference := range steps {
				nextQuestion, has := t.questions[questionReference.Key]
				if !has {
					return fmt.Errorf("question not found for branch: %s\n", selectedValue)
				}

				nextQuestion.State = question.State
				if err = t.askQuestion(ctx, nextQuestion, selectedValue); err != nil {
					return fmt.Errorf("failed to ask question: %w", err)
				}

				if err := mergo.Merge(&question.State, nextQuestion.State, mergo.WithOverride); err != nil {
					return fmt.Errorf("failed to merge question states: %w", err)
				}
			}
		}
	}

	// After processing branches, we need to check if there is a next question
	if len(question.Next) == 0 {
		return nil
	}

	for _, nextQuestionRef := range question.Next {
		nextQuestion, has := t.questions[nextQuestionRef.Key]
		if !has {
			return fmt.Errorf("next question not found: %s", nextQuestionRef.Key)
		}

		nextQuestion.State = question.State
		if nextQuestionRef.State != nil {
			if err := mergo.Merge(&nextQuestion.State, nextQuestionRef.State, mergo.WithOverride); err != nil {
				return fmt.Errorf("failed to merge question states: %w", err)
			}
		}

		if err = t.askQuestion(ctx, nextQuestion, response); err != nil {
			return fmt.Errorf("failed to ask next question: %w", err)
		}

		if err := mergo.Merge(&question.State, nextQuestion.State, mergo.WithOverride); err != nil {
			return fmt.Errorf("failed to merge question states: %w", err)
		}
	}

	return nil
}