github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/analysis/facts/purity/purity.go (about)

     1  package purity
     2  
     3  // TODO(dh): we should split this into two facts, one tracking actual purity, and one tracking side-effects. A function
     4  // that returns a heap allocation isn't pure, but it may be free of side effects.
     5  
     6  import (
     7  	"go/types"
     8  	"reflect"
     9  
    10  	"github.com/amarpal/go-tools/go/ir"
    11  	"github.com/amarpal/go-tools/go/ir/irutil"
    12  	"github.com/amarpal/go-tools/internal/passes/buildir"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  )
    16  
    17  type IsPure struct{}
    18  
    19  func (*IsPure) AFact()           {}
    20  func (d *IsPure) String() string { return "is pure" }
    21  
    22  type Result map[*types.Func]*IsPure
    23  
    24  var Analyzer = &analysis.Analyzer{
    25  	Name:       "fact_purity",
    26  	Doc:        "Mark pure functions",
    27  	Run:        purity,
    28  	Requires:   []*analysis.Analyzer{buildir.Analyzer},
    29  	FactTypes:  []analysis.Fact{(*IsPure)(nil)},
    30  	ResultType: reflect.TypeOf(Result{}),
    31  }
    32  
    33  var pureStdlib = map[string]struct{}{
    34  	"errors.New":                      {},
    35  	"fmt.Errorf":                      {},
    36  	"fmt.Sprintf":                     {},
    37  	"fmt.Sprint":                      {},
    38  	"sort.Reverse":                    {},
    39  	"strings.Map":                     {},
    40  	"strings.Repeat":                  {},
    41  	"strings.Replace":                 {},
    42  	"strings.Title":                   {},
    43  	"strings.ToLower":                 {},
    44  	"strings.ToLowerSpecial":          {},
    45  	"strings.ToTitle":                 {},
    46  	"strings.ToTitleSpecial":          {},
    47  	"strings.ToUpper":                 {},
    48  	"strings.ToUpperSpecial":          {},
    49  	"strings.Trim":                    {},
    50  	"strings.TrimFunc":                {},
    51  	"strings.TrimLeft":                {},
    52  	"strings.TrimLeftFunc":            {},
    53  	"strings.TrimPrefix":              {},
    54  	"strings.TrimRight":               {},
    55  	"strings.TrimRightFunc":           {},
    56  	"strings.TrimSpace":               {},
    57  	"strings.TrimSuffix":              {},
    58  	"(*net/http.Request).WithContext": {},
    59  	"time.Now":                        {},
    60  	"time.Parse":                      {},
    61  	"time.ParseInLocation":            {},
    62  	"time.Unix":                       {},
    63  	"time.UnixMicro":                  {},
    64  	"time.UnixMilli":                  {},
    65  	"(time.Time).Add":                 {},
    66  	"(time.Time).AddDate":             {},
    67  	"(time.Time).After":               {},
    68  	"(time.Time).Before":              {},
    69  	"(time.Time).Clock":               {},
    70  	"(time.Time).Compare":             {},
    71  	"(time.Time).Date":                {},
    72  	"(time.Time).Day":                 {},
    73  	"(time.Time).Equal":               {},
    74  	"(time.Time).Format":              {},
    75  	"(time.Time).GoString":            {},
    76  	"(time.Time).GobEncode":           {},
    77  	"(time.Time).Hour":                {},
    78  	"(time.Time).ISOWeek":             {},
    79  	"(time.Time).In":                  {},
    80  	"(time.Time).IsDST":               {},
    81  	"(time.Time).IsZero":              {},
    82  	"(time.Time).Local":               {},
    83  	"(time.Time).Location":            {},
    84  	"(time.Time).MarshalBinary":       {},
    85  	"(time.Time).MarshalJSON":         {},
    86  	"(time.Time).MarshalText":         {},
    87  	"(time.Time).Minute":              {},
    88  	"(time.Time).Month":               {},
    89  	"(time.Time).Nanosecond":          {},
    90  	"(time.Time).Round":               {},
    91  	"(time.Time).Second":              {},
    92  	"(time.Time).String":              {},
    93  	"(time.Time).Sub":                 {},
    94  	"(time.Time).Truncate":            {},
    95  	"(time.Time).UTC":                 {},
    96  	"(time.Time).Unix":                {},
    97  	"(time.Time).UnixMicro":           {},
    98  	"(time.Time).UnixMilli":           {},
    99  	"(time.Time).UnixNano":            {},
   100  	"(time.Time).Weekday":             {},
   101  	"(time.Time).Year":                {},
   102  	"(time.Time).YearDay":             {},
   103  	"(time.Time).Zone":                {},
   104  	"(time.Time).ZoneBounds":          {},
   105  }
   106  
   107  func purity(pass *analysis.Pass) (interface{}, error) {
   108  	seen := map[*ir.Function]struct{}{}
   109  	irpkg := pass.ResultOf[buildir.Analyzer].(*buildir.IR).Pkg
   110  	var check func(fn *ir.Function) (ret bool)
   111  	check = func(fn *ir.Function) (ret bool) {
   112  		if fn.Object() == nil {
   113  			// TODO(dh): support closures
   114  			return false
   115  		}
   116  		if pass.ImportObjectFact(fn.Object(), new(IsPure)) {
   117  			return true
   118  		}
   119  		if fn.Pkg != irpkg {
   120  			// Function is in another package but wasn't marked as
   121  			// pure, ergo it isn't pure
   122  			return false
   123  		}
   124  		// Break recursion
   125  		if _, ok := seen[fn]; ok {
   126  			return false
   127  		}
   128  
   129  		seen[fn] = struct{}{}
   130  		defer func() {
   131  			if ret {
   132  				pass.ExportObjectFact(fn.Object(), &IsPure{})
   133  			}
   134  		}()
   135  
   136  		if irutil.IsStub(fn) {
   137  			return false
   138  		}
   139  
   140  		if _, ok := pureStdlib[fn.Object().(*types.Func).FullName()]; ok {
   141  			return true
   142  		}
   143  
   144  		if fn.Signature.Results().Len() == 0 {
   145  			// A function with no return values is empty or is doing some
   146  			// work we cannot see (for example because of build tags);
   147  			// don't consider it pure.
   148  			return false
   149  		}
   150  
   151  		var isBasic func(typ types.Type) bool
   152  		isBasic = func(typ types.Type) bool {
   153  			switch u := typ.Underlying().(type) {
   154  			case *types.Basic:
   155  				return true
   156  			case *types.Struct:
   157  				for i := 0; i < u.NumFields(); i++ {
   158  					if !isBasic(u.Field(i).Type()) {
   159  						return false
   160  					}
   161  				}
   162  				return true
   163  			default:
   164  				return false
   165  			}
   166  		}
   167  
   168  		for _, param := range fn.Params {
   169  			// TODO(dh): this may not be strictly correct. pure code can, to an extent, operate on non-basic types.
   170  			if !isBasic(param.Type()) {
   171  				return false
   172  			}
   173  		}
   174  
   175  		// Don't consider external functions pure.
   176  		if fn.Blocks == nil {
   177  			return false
   178  		}
   179  		checkCall := func(common *ir.CallCommon) bool {
   180  			if common.IsInvoke() {
   181  				return false
   182  			}
   183  			builtin, ok := common.Value.(*ir.Builtin)
   184  			if !ok {
   185  				if common.StaticCallee() != fn {
   186  					if common.StaticCallee() == nil {
   187  						return false
   188  					}
   189  					if !check(common.StaticCallee()) {
   190  						return false
   191  					}
   192  				}
   193  			} else {
   194  				switch builtin.Name() {
   195  				case "len", "cap":
   196  				default:
   197  					return false
   198  				}
   199  			}
   200  			return true
   201  		}
   202  
   203  		var isStackAddr func(ir.Value) bool
   204  		isStackAddr = func(v ir.Value) bool {
   205  			switch v := v.(type) {
   206  			case *ir.Alloc:
   207  				return !v.Heap
   208  			case *ir.FieldAddr:
   209  				return isStackAddr(v.X)
   210  			default:
   211  				return false
   212  			}
   213  		}
   214  		for _, b := range fn.Blocks {
   215  			for _, ins := range b.Instrs {
   216  				switch ins := ins.(type) {
   217  				case *ir.Call:
   218  					if !checkCall(ins.Common()) {
   219  						return false
   220  					}
   221  				case *ir.Defer:
   222  					if !checkCall(&ins.Call) {
   223  						return false
   224  					}
   225  				case *ir.Select:
   226  					return false
   227  				case *ir.Send:
   228  					return false
   229  				case *ir.Go:
   230  					return false
   231  				case *ir.Panic:
   232  					return false
   233  				case *ir.Store:
   234  					if !isStackAddr(ins.Addr) {
   235  						return false
   236  					}
   237  				case *ir.FieldAddr:
   238  					if !isStackAddr(ins.X) {
   239  						return false
   240  					}
   241  				case *ir.Alloc:
   242  					// TODO(dh): make use of proper escape analysis
   243  					if ins.Heap {
   244  						return false
   245  					}
   246  				case *ir.Load:
   247  					if !isStackAddr(ins.X) {
   248  						return false
   249  					}
   250  				}
   251  			}
   252  		}
   253  		return true
   254  	}
   255  	for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs {
   256  		check(fn)
   257  	}
   258  
   259  	out := Result{}
   260  	for _, fact := range pass.AllObjectFacts() {
   261  		out[fact.Object.(*types.Func)] = fact.Fact.(*IsPure)
   262  	}
   263  	return out, nil
   264  }