in internal/lsp/source/extract.go [214:647]
func extractFunctionMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info, isMethod bool) (*analysis.SuggestedFix, error) {
errorPrefix := "extractFunction"
if isMethod {
errorPrefix = "extractMethod"
}
p, ok, methodOk, err := CanExtractFunction(fset, rng, src, file)
if (!ok && !isMethod) || (!methodOk && isMethod) {
return nil, fmt.Errorf("%s: cannot extract %s: %v", errorPrefix,
fset.Position(rng.Start), err)
}
tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
fileScope := info.Scopes[file]
if fileScope == nil {
return nil, fmt.Errorf("%s: file scope is empty", errorPrefix)
}
pkgScope := fileScope.Parent()
if pkgScope == nil {
return nil, fmt.Errorf("%s: package scope is empty", errorPrefix)
}
// A return statement is non-nested if its parent node is equal to the parent node
// of the first node in the selection. These cases must be handled separately because
// non-nested return statements are guaranteed to execute.
var retStmts []*ast.ReturnStmt
var hasNonNestedReturn bool
startParent := findParent(outer, start)
ast.Inspect(outer, func(n ast.Node) bool {
if n == nil {
return false
}
if n.Pos() < rng.Start || n.End() > rng.End {
return n.Pos() <= rng.End
}
ret, ok := n.(*ast.ReturnStmt)
if !ok {
return true
}
if findParent(outer, n) == startParent {
hasNonNestedReturn = true
}
retStmts = append(retStmts, ret)
return false
})
containsReturnStatement := len(retStmts) > 0
// Now that we have determined the correct range for the selection block,
// we must determine the signature of the extracted function. We will then replace
// the block with an assignment statement that calls the extracted function with
// the appropriate parameters and return values.
variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
if err != nil {
return nil, err
}
var (
receiverUsed bool
receiver *ast.Field
receiverName string
receiverObj types.Object
)
if isMethod {
if outer == nil || outer.Recv == nil || len(outer.Recv.List) == 0 {
return nil, fmt.Errorf("%s: cannot extract need method receiver", errorPrefix)
}
receiver = outer.Recv.List[0]
if len(receiver.Names) == 0 || receiver.Names[0] == nil {
return nil, fmt.Errorf("%s: cannot extract need method receiver name", errorPrefix)
}
recvName := receiver.Names[0]
receiverName = recvName.Name
receiverObj = info.ObjectOf(recvName)
}
var (
params, returns []ast.Expr // used when calling the extracted function
paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function
uninitialized []types.Object // vars we will need to initialize before the call
)
// Avoid duplicates while traversing vars and uninitialzed.
seenVars := make(map[types.Object]ast.Expr)
seenUninitialized := make(map[types.Object]struct{})
// Some variables on the left-hand side of our assignment statement may be free. If our
// selection begins in the same scope in which the free variable is defined, we can
// redefine it in our assignment statement. See the following example, where 'b' and
// 'err' (both free variables) can be redefined in the second funcCall() while maintaining
// correctness.
//
//
// Not Redefined:
//
// a, err := funcCall()
// var b int
// b, err = funcCall()
//
// Redefined:
//
// a, err := funcCall()
// b, err := funcCall()
//
// We track the number of free variables that can be redefined to maintain our preference
// of using "x, y, z := fn()" style assignment statements.
var canRedefineCount int
// Each identifier in the selected block must become (1) a parameter to the
// extracted function, (2) a return value of the extracted function, or (3) a local
// variable in the extracted function. Determine the outcome(s) for each variable
// based on whether it is free, altered within the selected block, and used outside
// of the selected block.
for _, v := range variables {
if _, ok := seenVars[v.obj]; ok {
continue
}
if v.obj.Name() == "_" {
// The blank identifier is always a local variable
continue
}
typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
if typ == nil {
return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
}
seenVars[v.obj] = typ
identifier := ast.NewIdent(v.obj.Name())
// An identifier must meet three conditions to become a return value of the
// extracted function. (1) its value must be defined or reassigned within
// the selection (isAssigned), (2) it must be used at least once after the
// selection (isUsed), and (3) its first use after the selection
// cannot be its own reassignment or redefinition (objOverriden).
if v.obj.Parent() == nil {
return nil, fmt.Errorf("parent nil")
}
isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
returnTypes = append(returnTypes, &ast.Field{Type: typ})
returns = append(returns, identifier)
if !v.free {
uninitialized = append(uninitialized, v.obj)
} else if v.obj.Parent().Pos() == startParent.Pos() {
canRedefineCount++
}
}
// An identifier must meet two conditions to become a parameter of the
// extracted function. (1) it must be free (isFree), and (2) its first
// use within the selection cannot be its own definition (isDefined).
if v.free && !v.defined {
// Skip the selector for a method.
if isMethod && v.obj == receiverObj {
receiverUsed = true
continue
}
params = append(params, identifier)
paramTypes = append(paramTypes, &ast.Field{
Names: []*ast.Ident{identifier},
Type: typ,
})
}
}
// Find the function literal that encloses the selection. The enclosing function literal
// may not be the enclosing function declaration (i.e. 'outer'). For example, in the
// following block:
//
// func main() {
// ast.Inspect(node, func(n ast.Node) bool {
// v := 1 // this line extracted
// return true
// })
// }
//
// 'outer' is main(). However, the extracted selection most directly belongs to
// the anonymous function literal, the second argument of ast.Inspect(). We use the
// enclosing function literal to determine the proper return types for return statements
// within the selection. We still need the enclosing function declaration because this is
// the top-level declaration. We inspect the top-level declaration to look for variables
// as well as for code replacement.
enclosing := outer.Type
for _, p := range path {
if p == enclosing {
break
}
if fl, ok := p.(*ast.FuncLit); ok {
enclosing = fl.Type
break
}
}
// We put the selection in a constructed file. We can then traverse and edit
// the extracted selection without modifying the original AST.
startOffset, err := Offset(tok, rng.Start)
if err != nil {
return nil, err
}
endOffset, err := Offset(tok, rng.End)
if err != nil {
return nil, err
}
selection := src[startOffset:endOffset]
extractedBlock, err := parseBlockStmt(fset, selection)
if err != nil {
return nil, err
}
// We need to account for return statements in the selected block, as they will complicate
// the logical flow of the extracted function. See the following example, where ** denotes
// the range to be extracted.
//
// Before:
//
// func _() int {
// a := 1
// b := 2
// **if a == b {
// return a
// }**
// ...
// }
//
// After:
//
// func _() int {
// a := 1
// b := 2
// cond0, ret0 := x0(a, b)
// if cond0 {
// return ret0
// }
// ...
// }
//
// func x0(a int, b int) (bool, int) {
// if a == b {
// return true, a
// }
// return false, 0
// }
//
// We handle returns by adding an additional boolean return value to the extracted function.
// This bool reports whether the original function would have returned. Because the
// extracted selection contains a return statement, we must also add the types in the
// return signature of the enclosing function to the return signature of the
// extracted function. We then add an extra if statement checking this boolean value
// in the original function. If the condition is met, the original function should
// return a value, mimicking the functionality of the original return statement(s)
// in the selection.
//
// If there is a return that is guaranteed to execute (hasNonNestedReturns=true), then
// we don't need to include this additional condition check and can simply return.
//
// Before:
//
// func _() int {
// a := 1
// b := 2
// **if a == b {
// return a
// }
// return b**
// }
//
// After:
//
// func _() int {
// a := 1
// b := 2
// return x0(a, b)
// }
//
// func x0(a int, b int) int {
// if a == b {
// return a
// }
// return b
// }
var retVars []*returnVariable
var ifReturn *ast.IfStmt
if containsReturnStatement {
if !hasNonNestedReturn {
// The selected block contained return statements, so we have to modify the
// signature of the extracted function as described above. Adjust all of
// the return statements in the extracted function to reflect this change in
// signature.
if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
pkg, extractedBlock); err != nil {
return nil, err
}
}
// Collect the additional return values and types needed to accommodate return
// statements in the selection. Update the type signature of the extracted
// function and construct the if statement that will be inserted in the enclosing
// function.
retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start, hasNonNestedReturn)
if err != nil {
return nil, err
}
}
// Add a return statement to the end of the new function. This return statement must include
// the values for the types of the original extracted function signature and (if a return
// statement is present in the selection) enclosing function signature.
// This only needs to be done if the selections does not have a non-nested return, otherwise
// it already terminates with a return statement.
hasReturnValues := len(returns)+len(retVars) > 0
if hasReturnValues && !hasNonNestedReturn {
extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
Results: append(returns, getZeroVals(retVars)...),
})
}
// Construct the appropriate call to the extracted function.
// We must meet two conditions to use ":=" instead of '='. (1) there must be at least
// one variable on the lhs that is uninitailized (non-free) prior to the assignment.
// (2) all of the initialized (free) variables on the lhs must be able to be redefined.
sym := token.ASSIGN
canDefineCount := len(uninitialized) + canRedefineCount
canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
if canDefine {
sym = token.DEFINE
}
var name, funName string
if isMethod {
name = "newMethod"
// TODO(suzmue): generate a name that does not conflict for "newMethod".
funName = name
} else {
name = "newFunction"
funName, _ = generateAvailableIdentifier(rng.Start, file, path, info, name, 0)
}
extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params,
append(returns, getNames(retVars)...), funName, sym, receiverName)
// Build the extracted function.
newFunc := &ast.FuncDecl{
Name: ast.NewIdent(funName),
Type: &ast.FuncType{
Params: &ast.FieldList{List: paramTypes},
Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
},
Body: extractedBlock,
}
if isMethod {
var names []*ast.Ident
if receiverUsed {
names = append(names, ast.NewIdent(receiverName))
}
newFunc.Recv = &ast.FieldList{
List: []*ast.Field{{
Names: names,
Type: receiver.Type,
}},
}
}
// Create variable declarations for any identifiers that need to be initialized prior to
// calling the extracted function. We do not manually initialize variables if every return
// value is unitialized. We can use := to initialize the variables in this situation.
var declarations []ast.Stmt
if canDefineCount != len(returns) {
declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
}
var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer
if err := format.Node(&declBuf, fset, declarations); err != nil {
return nil, err
}
if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
return nil, err
}
if ifReturn != nil {
if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
return nil, err
}
}
if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
return nil, err
}
// Find all the comments within the range and print them to be put somewhere.
// TODO(suzmue): print these in the extracted function at the correct place.
for _, cg := range file.Comments {
if cg.Pos().IsValid() && cg.Pos() < rng.End && cg.Pos() >= rng.Start {
for _, c := range cg.List {
fmt.Fprintln(&commentBuf, c.Text)
}
}
}
// We're going to replace the whole enclosing function,
// so preserve the text before and after the selected block.
outerStart, err := Offset(tok, outer.Pos())
if err != nil {
return nil, err
}
outerEnd, err := Offset(tok, outer.End())
if err != nil {
return nil, err
}
before := src[outerStart:startOffset]
after := src[endOffset:outerEnd]
indent, err := calculateIndentation(src, tok, start)
if err != nil {
return nil, err
}
newLineIndent := "\n" + indent
var fullReplacement strings.Builder
fullReplacement.Write(before)
if commentBuf.Len() > 0 {
comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent)
fullReplacement.WriteString(comments)
}
if declBuf.Len() > 0 { // add any initializations, if needed
initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
newLineIndent
fullReplacement.WriteString(initializations)
}
fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
if ifBuf.Len() > 0 { // add the if statement below the function call, if needed
ifstatement := newLineIndent +
strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
fullReplacement.WriteString(ifstatement)
}
fullReplacement.Write(after)
fullReplacement.WriteString("\n\n") // add newlines after the enclosing function
fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
return &analysis.SuggestedFix{
TextEdits: []analysis.TextEdit{{
Pos: outer.Pos(),
End: outer.End(),
NewText: []byte(fullReplacement.String()),
}},
}, nil
}