github.com/v2fly/tools@v0.100.0/internal/lsp/analysis/simplifycompositelit/simplifycompositelit.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package simplifycompositelit defines an Analyzer that simplifies composite literals.
     6  // https://github.com/golang/go/blob/master/src/cmd/gofmt/simplify.go
     7  // https://golang.org/cmd/gofmt/#hdr-The_simplify_command
     8  package simplifycompositelit
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"go/ast"
    14  	"go/printer"
    15  	"go/token"
    16  	"reflect"
    17  
    18  	"github.com/v2fly/tools/go/analysis"
    19  	"github.com/v2fly/tools/go/analysis/passes/inspect"
    20  	"github.com/v2fly/tools/go/ast/inspector"
    21  )
    22  
    23  const Doc = `check for composite literal simplifications
    24  
    25  An array, slice, or map composite literal of the form:
    26  	[]T{T{}, T{}}
    27  will be simplified to:
    28  	[]T{{}, {}}
    29  
    30  This is one of the simplifications that "gofmt -s" applies.`
    31  
    32  var Analyzer = &analysis.Analyzer{
    33  	Name:     "simplifycompositelit",
    34  	Doc:      Doc,
    35  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    36  	Run:      run,
    37  }
    38  
    39  func run(pass *analysis.Pass) (interface{}, error) {
    40  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    41  	nodeFilter := []ast.Node{(*ast.CompositeLit)(nil)}
    42  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    43  		expr := n.(*ast.CompositeLit)
    44  
    45  		outer := expr
    46  		var keyType, eltType ast.Expr
    47  		switch typ := outer.Type.(type) {
    48  		case *ast.ArrayType:
    49  			eltType = typ.Elt
    50  		case *ast.MapType:
    51  			keyType = typ.Key
    52  			eltType = typ.Value
    53  		}
    54  
    55  		if eltType == nil {
    56  			return
    57  		}
    58  		var ktyp reflect.Value
    59  		if keyType != nil {
    60  			ktyp = reflect.ValueOf(keyType)
    61  		}
    62  		typ := reflect.ValueOf(eltType)
    63  		for _, x := range outer.Elts {
    64  			// look at value of indexed/named elements
    65  			if t, ok := x.(*ast.KeyValueExpr); ok {
    66  				if keyType != nil {
    67  					simplifyLiteral(pass, ktyp, keyType, t.Key)
    68  				}
    69  				x = t.Value
    70  			}
    71  			simplifyLiteral(pass, typ, eltType, x)
    72  		}
    73  	})
    74  	return nil, nil
    75  }
    76  
    77  func simplifyLiteral(pass *analysis.Pass, typ reflect.Value, astType, x ast.Expr) {
    78  	// if the element is a composite literal and its literal type
    79  	// matches the outer literal's element type exactly, the inner
    80  	// literal type may be omitted
    81  	if inner, ok := x.(*ast.CompositeLit); ok && match(typ, reflect.ValueOf(inner.Type)) {
    82  		var b bytes.Buffer
    83  		printer.Fprint(&b, pass.Fset, inner.Type)
    84  		createDiagnostic(pass, inner.Type.Pos(), inner.Type.End(), b.String())
    85  	}
    86  	// if the outer literal's element type is a pointer type *T
    87  	// and the element is & of a composite literal of type T,
    88  	// the inner &T may be omitted.
    89  	if ptr, ok := astType.(*ast.StarExpr); ok {
    90  		if addr, ok := x.(*ast.UnaryExpr); ok && addr.Op == token.AND {
    91  			if inner, ok := addr.X.(*ast.CompositeLit); ok {
    92  				if match(reflect.ValueOf(ptr.X), reflect.ValueOf(inner.Type)) {
    93  					var b bytes.Buffer
    94  					printer.Fprint(&b, pass.Fset, inner.Type)
    95  					// Account for the & by subtracting 1 from typ.Pos().
    96  					createDiagnostic(pass, inner.Type.Pos()-1, inner.Type.End(), "&"+b.String())
    97  				}
    98  			}
    99  		}
   100  	}
   101  }
   102  
   103  func createDiagnostic(pass *analysis.Pass, start, end token.Pos, typ string) {
   104  	pass.Report(analysis.Diagnostic{
   105  		Pos:     start,
   106  		End:     end,
   107  		Message: "redundant type from array, slice, or map composite literal",
   108  		SuggestedFixes: []analysis.SuggestedFix{{
   109  			Message: fmt.Sprintf("Remove '%s'", typ),
   110  			TextEdits: []analysis.TextEdit{{
   111  				Pos:     start,
   112  				End:     end,
   113  				NewText: []byte{},
   114  			}},
   115  		}},
   116  	})
   117  }
   118  
   119  // match reports whether pattern matches val,
   120  // recording wildcard submatches in m.
   121  // If m == nil, match checks whether pattern == val.
   122  // from https://github.com/golang/go/blob/26154f31ad6c801d8bad5ef58df1e9263c6beec7/src/cmd/gofmt/rewrite.go#L160
   123  func match(pattern, val reflect.Value) bool {
   124  	// Otherwise, pattern and val must match recursively.
   125  	if !pattern.IsValid() || !val.IsValid() {
   126  		return !pattern.IsValid() && !val.IsValid()
   127  	}
   128  	if pattern.Type() != val.Type() {
   129  		return false
   130  	}
   131  
   132  	// Special cases.
   133  	switch pattern.Type() {
   134  	case identType:
   135  		// For identifiers, only the names need to match
   136  		// (and none of the other *ast.Object information).
   137  		// This is a common case, handle it all here instead
   138  		// of recursing down any further via reflection.
   139  		p := pattern.Interface().(*ast.Ident)
   140  		v := val.Interface().(*ast.Ident)
   141  		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
   142  	case objectPtrType, positionType:
   143  		// object pointers and token positions always match
   144  		return true
   145  	case callExprType:
   146  		// For calls, the Ellipsis fields (token.Position) must
   147  		// match since that is how f(x) and f(x...) are different.
   148  		// Check them here but fall through for the remaining fields.
   149  		p := pattern.Interface().(*ast.CallExpr)
   150  		v := val.Interface().(*ast.CallExpr)
   151  		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
   152  			return false
   153  		}
   154  	}
   155  
   156  	p := reflect.Indirect(pattern)
   157  	v := reflect.Indirect(val)
   158  	if !p.IsValid() || !v.IsValid() {
   159  		return !p.IsValid() && !v.IsValid()
   160  	}
   161  
   162  	switch p.Kind() {
   163  	case reflect.Slice:
   164  		if p.Len() != v.Len() {
   165  			return false
   166  		}
   167  		for i := 0; i < p.Len(); i++ {
   168  			if !match(p.Index(i), v.Index(i)) {
   169  				return false
   170  			}
   171  		}
   172  		return true
   173  
   174  	case reflect.Struct:
   175  		for i := 0; i < p.NumField(); i++ {
   176  			if !match(p.Field(i), v.Field(i)) {
   177  				return false
   178  			}
   179  		}
   180  		return true
   181  
   182  	case reflect.Interface:
   183  		return match(p.Elem(), v.Elem())
   184  	}
   185  
   186  	// Handle token integers, etc.
   187  	return p.Interface() == v.Interface()
   188  }
   189  
   190  // Values/types for special cases.
   191  var (
   192  	identType     = reflect.TypeOf((*ast.Ident)(nil))
   193  	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
   194  	positionType  = reflect.TypeOf(token.NoPos)
   195  	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
   196  )