in cli/azd/extensions/microsoft.azd.ai.builder/internal/pkg/qna/decision_tree.go [76:188]
func (t *DecisionTree) askQuestion(ctx context.Context, question Question, value any) error {
if question.State == nil {
question.State = map[string]any{}
}
if question.BeforeAsk != nil {
if err := question.BeforeAsk(ctx, &question, value); err != nil {
return fmt.Errorf("before ask function failed: %w", err)
}
}
if question.Heading != "" {
fmt.Println()
fmt.Println(output.WithHintFormat(question.Heading))
}
if question.Message != "" {
fmt.Println(question.Message)
fmt.Println()
}
var response any
var err error
if question.Prompt != nil {
response, err = question.Prompt.Ask(ctx, question)
if err != nil {
return fmt.Errorf("failed to ask question: %w", err)
}
}
if question.AfterAsk != nil {
if err := question.AfterAsk(ctx, &question, response); err != nil {
return fmt.Errorf("after ask function failed: %w", err)
}
}
t.applyBinding(question, response)
// Handle the case where the branch is based on the user's response
if len(question.Branches) > 0 {
selectionValues := []any{}
switch result := response.(type) {
case string:
selectionValues = append(selectionValues, result)
case bool:
selectionValues = append(selectionValues, result)
case []string:
for _, selectedValue := range result {
selectionValues = append(selectionValues, selectedValue)
}
default:
return errors.New("unsupported value type")
}
// We need to process all the question branches from the selected values
// Iterate through the selected values and find the corresponding branches
for _, selectedValue := range selectionValues {
steps, has := question.Branches[selectedValue]
if !has {
log.Printf("branch not found for selected value: %s\n", selectedValue)
continue
}
// Iterate through the steps in the branch
for _, questionReference := range steps {
nextQuestion, has := t.questions[questionReference.Key]
if !has {
return fmt.Errorf("question not found for branch: %s\n", selectedValue)
}
nextQuestion.State = question.State
if err = t.askQuestion(ctx, nextQuestion, selectedValue); err != nil {
return fmt.Errorf("failed to ask question: %w", err)
}
if err := mergo.Merge(&question.State, nextQuestion.State, mergo.WithOverride); err != nil {
return fmt.Errorf("failed to merge question states: %w", err)
}
}
}
}
// After processing branches, we need to check if there is a next question
if len(question.Next) == 0 {
return nil
}
for _, nextQuestionRef := range question.Next {
nextQuestion, has := t.questions[nextQuestionRef.Key]
if !has {
return fmt.Errorf("next question not found: %s", nextQuestionRef.Key)
}
nextQuestion.State = question.State
if nextQuestionRef.State != nil {
if err := mergo.Merge(&nextQuestion.State, nextQuestionRef.State, mergo.WithOverride); err != nil {
return fmt.Errorf("failed to merge question states: %w", err)
}
}
if err = t.askQuestion(ctx, nextQuestion, response); err != nil {
return fmt.Errorf("failed to ask next question: %w", err)
}
if err := mergo.Merge(&question.State, nextQuestion.State, mergo.WithOverride); err != nil {
return fmt.Errorf("failed to merge question states: %w", err)
}
}
return nil
}