func extractFunctionMethod()

in internal/lsp/source/extract.go [214:647]


func extractFunctionMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info, isMethod bool) (*analysis.SuggestedFix, error) {
	errorPrefix := "extractFunction"
	if isMethod {
		errorPrefix = "extractMethod"
	}
	p, ok, methodOk, err := CanExtractFunction(fset, rng, src, file)
	if (!ok && !isMethod) || (!methodOk && isMethod) {
		return nil, fmt.Errorf("%s: cannot extract %s: %v", errorPrefix,
			fset.Position(rng.Start), err)
	}
	tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
	fileScope := info.Scopes[file]
	if fileScope == nil {
		return nil, fmt.Errorf("%s: file scope is empty", errorPrefix)
	}
	pkgScope := fileScope.Parent()
	if pkgScope == nil {
		return nil, fmt.Errorf("%s: package scope is empty", errorPrefix)
	}

	// A return statement is non-nested if its parent node is equal to the parent node
	// of the first node in the selection. These cases must be handled separately because
	// non-nested return statements are guaranteed to execute.
	var retStmts []*ast.ReturnStmt
	var hasNonNestedReturn bool
	startParent := findParent(outer, start)
	ast.Inspect(outer, func(n ast.Node) bool {
		if n == nil {
			return false
		}
		if n.Pos() < rng.Start || n.End() > rng.End {
			return n.Pos() <= rng.End
		}
		ret, ok := n.(*ast.ReturnStmt)
		if !ok {
			return true
		}
		if findParent(outer, n) == startParent {
			hasNonNestedReturn = true
		}
		retStmts = append(retStmts, ret)
		return false
	})
	containsReturnStatement := len(retStmts) > 0

	// Now that we have determined the correct range for the selection block,
	// we must determine the signature of the extracted function. We will then replace
	// the block with an assignment statement that calls the extracted function with
	// the appropriate parameters and return values.
	variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
	if err != nil {
		return nil, err
	}

	var (
		receiverUsed bool
		receiver     *ast.Field
		receiverName string
		receiverObj  types.Object
	)
	if isMethod {
		if outer == nil || outer.Recv == nil || len(outer.Recv.List) == 0 {
			return nil, fmt.Errorf("%s: cannot extract need method receiver", errorPrefix)
		}
		receiver = outer.Recv.List[0]
		if len(receiver.Names) == 0 || receiver.Names[0] == nil {
			return nil, fmt.Errorf("%s: cannot extract need method receiver name", errorPrefix)
		}
		recvName := receiver.Names[0]
		receiverName = recvName.Name
		receiverObj = info.ObjectOf(recvName)
	}

	var (
		params, returns         []ast.Expr     // used when calling the extracted function
		paramTypes, returnTypes []*ast.Field   // used in the signature of the extracted function
		uninitialized           []types.Object // vars we will need to initialize before the call
	)

	// Avoid duplicates while traversing vars and uninitialzed.
	seenVars := make(map[types.Object]ast.Expr)
	seenUninitialized := make(map[types.Object]struct{})

	// Some variables on the left-hand side of our assignment statement may be free. If our
	// selection begins in the same scope in which the free variable is defined, we can
	// redefine it in our assignment statement. See the following example, where 'b' and
	// 'err' (both free variables) can be redefined in the second funcCall() while maintaining
	// correctness.
	//
	//
	// Not Redefined:
	//
	// a, err := funcCall()
	// var b int
	// b, err = funcCall()
	//
	// Redefined:
	//
	// a, err := funcCall()
	// b, err := funcCall()
	//
	// We track the number of free variables that can be redefined to maintain our preference
	// of using "x, y, z := fn()" style assignment statements.
	var canRedefineCount int

	// Each identifier in the selected block must become (1) a parameter to the
	// extracted function, (2) a return value of the extracted function, or (3) a local
	// variable in the extracted function. Determine the outcome(s) for each variable
	// based on whether it is free, altered within the selected block, and used outside
	// of the selected block.
	for _, v := range variables {
		if _, ok := seenVars[v.obj]; ok {
			continue
		}
		if v.obj.Name() == "_" {
			// The blank identifier is always a local variable
			continue
		}
		typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
		if typ == nil {
			return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
		}
		seenVars[v.obj] = typ
		identifier := ast.NewIdent(v.obj.Name())
		// An identifier must meet three conditions to become a return value of the
		// extracted function. (1) its value must be defined or reassigned within
		// the selection (isAssigned), (2) it must be used at least once after the
		// selection (isUsed), and (3) its first use after the selection
		// cannot be its own reassignment or redefinition (objOverriden).
		if v.obj.Parent() == nil {
			return nil, fmt.Errorf("parent nil")
		}
		isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
		if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
			returnTypes = append(returnTypes, &ast.Field{Type: typ})
			returns = append(returns, identifier)
			if !v.free {
				uninitialized = append(uninitialized, v.obj)
			} else if v.obj.Parent().Pos() == startParent.Pos() {
				canRedefineCount++
			}
		}
		// An identifier must meet two conditions to become a parameter of the
		// extracted function. (1) it must be free (isFree), and (2) its first
		// use within the selection cannot be its own definition (isDefined).
		if v.free && !v.defined {
			// Skip the selector for a method.
			if isMethod && v.obj == receiverObj {
				receiverUsed = true
				continue
			}
			params = append(params, identifier)
			paramTypes = append(paramTypes, &ast.Field{
				Names: []*ast.Ident{identifier},
				Type:  typ,
			})
		}
	}

	// Find the function literal that encloses the selection. The enclosing function literal
	// may not be the enclosing function declaration (i.e. 'outer'). For example, in the
	// following block:
	//
	// func main() {
	//     ast.Inspect(node, func(n ast.Node) bool {
	//         v := 1 // this line extracted
	//         return true
	//     })
	// }
	//
	// 'outer' is main(). However, the extracted selection most directly belongs to
	// the anonymous function literal, the second argument of ast.Inspect(). We use the
	// enclosing function literal to determine the proper return types for return statements
	// within the selection. We still need the enclosing function declaration because this is
	// the top-level declaration. We inspect the top-level declaration to look for variables
	// as well as for code replacement.
	enclosing := outer.Type
	for _, p := range path {
		if p == enclosing {
			break
		}
		if fl, ok := p.(*ast.FuncLit); ok {
			enclosing = fl.Type
			break
		}
	}

	// We put the selection in a constructed file. We can then traverse and edit
	// the extracted selection without modifying the original AST.
	startOffset, err := Offset(tok, rng.Start)
	if err != nil {
		return nil, err
	}
	endOffset, err := Offset(tok, rng.End)
	if err != nil {
		return nil, err
	}
	selection := src[startOffset:endOffset]
	extractedBlock, err := parseBlockStmt(fset, selection)
	if err != nil {
		return nil, err
	}

	// We need to account for return statements in the selected block, as they will complicate
	// the logical flow of the extracted function. See the following example, where ** denotes
	// the range to be extracted.
	//
	// Before:
	//
	// func _() int {
	//     a := 1
	//     b := 2
	//     **if a == b {
	//         return a
	//     }**
	//     ...
	// }
	//
	// After:
	//
	// func _() int {
	//     a := 1
	//     b := 2
	//     cond0, ret0 := x0(a, b)
	//     if cond0 {
	//         return ret0
	//     }
	//     ...
	// }
	//
	// func x0(a int, b int) (bool, int) {
	//     if a == b {
	//         return true, a
	//     }
	//     return false, 0
	// }
	//
	// We handle returns by adding an additional boolean return value to the extracted function.
	// This bool reports whether the original function would have returned. Because the
	// extracted selection contains a return statement, we must also add the types in the
	// return signature of the enclosing function to the return signature of the
	// extracted function. We then add an extra if statement checking this boolean value
	// in the original function. If the condition is met, the original function should
	// return a value, mimicking the functionality of the original return statement(s)
	// in the selection.
	//
	// If there is a return that is guaranteed to execute (hasNonNestedReturns=true), then
	// we don't need to include this additional condition check and can simply return.
	//
	// Before:
	//
	// func _() int {
	//     a := 1
	//     b := 2
	//     **if a == b {
	//         return a
	//     }
	//	   return b**
	// }
	//
	// After:
	//
	// func _() int {
	//     a := 1
	//     b := 2
	//     return x0(a, b)
	// }
	//
	// func x0(a int, b int) int {
	//     if a == b {
	//         return a
	//     }
	//     return b
	// }

	var retVars []*returnVariable
	var ifReturn *ast.IfStmt
	if containsReturnStatement {
		if !hasNonNestedReturn {
			// The selected block contained return statements, so we have to modify the
			// signature of the extracted function as described above. Adjust all of
			// the return statements in the extracted function to reflect this change in
			// signature.
			if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
				pkg, extractedBlock); err != nil {
				return nil, err
			}
		}
		// Collect the additional return values and types needed to accommodate return
		// statements in the selection. Update the type signature of the extracted
		// function and construct the if statement that will be inserted in the enclosing
		// function.
		retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start, hasNonNestedReturn)
		if err != nil {
			return nil, err
		}
	}

	// Add a return statement to the end of the new function. This return statement must include
	// the values for the types of the original extracted function signature and (if a return
	// statement is present in the selection) enclosing function signature.
	// This only needs to be done if the selections does not have a non-nested return, otherwise
	// it already terminates with a return statement.
	hasReturnValues := len(returns)+len(retVars) > 0
	if hasReturnValues && !hasNonNestedReturn {
		extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
			Results: append(returns, getZeroVals(retVars)...),
		})
	}

	// Construct the appropriate call to the extracted function.
	// We must meet two conditions to use ":=" instead of '='. (1) there must be at least
	// one variable on the lhs that is uninitailized (non-free) prior to the assignment.
	// (2) all of the initialized (free) variables on the lhs must be able to be redefined.
	sym := token.ASSIGN
	canDefineCount := len(uninitialized) + canRedefineCount
	canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
	if canDefine {
		sym = token.DEFINE
	}
	var name, funName string
	if isMethod {
		name = "newMethod"
		// TODO(suzmue): generate a name that does not conflict for "newMethod".
		funName = name
	} else {
		name = "newFunction"
		funName, _ = generateAvailableIdentifier(rng.Start, file, path, info, name, 0)
	}
	extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params,
		append(returns, getNames(retVars)...), funName, sym, receiverName)

	// Build the extracted function.
	newFunc := &ast.FuncDecl{
		Name: ast.NewIdent(funName),
		Type: &ast.FuncType{
			Params:  &ast.FieldList{List: paramTypes},
			Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
		},
		Body: extractedBlock,
	}
	if isMethod {
		var names []*ast.Ident
		if receiverUsed {
			names = append(names, ast.NewIdent(receiverName))
		}
		newFunc.Recv = &ast.FieldList{
			List: []*ast.Field{{
				Names: names,
				Type:  receiver.Type,
			}},
		}
	}

	// Create variable declarations for any identifiers that need to be initialized prior to
	// calling the extracted function. We do not manually initialize variables if every return
	// value is unitialized. We can use := to initialize the variables in this situation.
	var declarations []ast.Stmt
	if canDefineCount != len(returns) {
		declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
	}

	var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer
	if err := format.Node(&declBuf, fset, declarations); err != nil {
		return nil, err
	}
	if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
		return nil, err
	}
	if ifReturn != nil {
		if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
			return nil, err
		}
	}
	if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
		return nil, err
	}
	// Find all the comments within the range and print them to be put somewhere.
	// TODO(suzmue): print these in the extracted function at the correct place.
	for _, cg := range file.Comments {
		if cg.Pos().IsValid() && cg.Pos() < rng.End && cg.Pos() >= rng.Start {
			for _, c := range cg.List {
				fmt.Fprintln(&commentBuf, c.Text)
			}
		}
	}

	// We're going to replace the whole enclosing function,
	// so preserve the text before and after the selected block.
	outerStart, err := Offset(tok, outer.Pos())
	if err != nil {
		return nil, err
	}
	outerEnd, err := Offset(tok, outer.End())
	if err != nil {
		return nil, err
	}
	before := src[outerStart:startOffset]
	after := src[endOffset:outerEnd]
	indent, err := calculateIndentation(src, tok, start)
	if err != nil {
		return nil, err
	}
	newLineIndent := "\n" + indent

	var fullReplacement strings.Builder
	fullReplacement.Write(before)
	if commentBuf.Len() > 0 {
		comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent)
		fullReplacement.WriteString(comments)
	}
	if declBuf.Len() > 0 { // add any initializations, if needed
		initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
			newLineIndent
		fullReplacement.WriteString(initializations)
	}
	fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
	if ifBuf.Len() > 0 {                      // add the if statement below the function call, if needed
		ifstatement := newLineIndent +
			strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
		fullReplacement.WriteString(ifstatement)
	}
	fullReplacement.Write(after)
	fullReplacement.WriteString("\n\n")       // add newlines after the enclosing function
	fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function

	return &analysis.SuggestedFix{
		TextEdits: []analysis.TextEdit{{
			Pos:     outer.Pos(),
			End:     outer.End(),
			NewText: []byte(fullReplacement.String()),
		}},
	}, nil
}