github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/go/ir/const.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ir
     6  
     7  // This file defines the Const SSA value type.
     8  
     9  import (
    10  	"fmt"
    11  	"go/ast"
    12  	"go/constant"
    13  	"go/types"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/amarpal/go-tools/go/types/typeutil"
    18  	"golang.org/x/exp/typeparams"
    19  )
    20  
    21  // NewConst returns a new constant of the specified value and type.
    22  // val must be valid according to the specification of Const.Value.
    23  func NewConst(val constant.Value, typ types.Type, source ast.Node) *Const {
    24  	c := &Const{
    25  		register: register{
    26  			typ: typ,
    27  		},
    28  		Value: val,
    29  	}
    30  	c.setSource(source)
    31  	return c
    32  }
    33  
    34  // intConst returns an 'int' constant that evaluates to i.
    35  // (i is an int64 in case the host is narrower than the target.)
    36  func intConst(i int64, source ast.Node) *Const {
    37  	return NewConst(constant.MakeInt64(i), tInt, source)
    38  }
    39  
    40  // nilConst returns a nil constant of the specified type, which may
    41  // be any reference type, including interfaces.
    42  func nilConst(typ types.Type, source ast.Node) *Const {
    43  	return NewConst(nil, typ, source)
    44  }
    45  
    46  // stringConst returns a 'string' constant that evaluates to s.
    47  func stringConst(s string, source ast.Node) *Const {
    48  	return NewConst(constant.MakeString(s), tString, source)
    49  }
    50  
    51  // zeroConst returns a new "zero" constant of the specified type.
    52  func zeroConst(t types.Type, source ast.Node) Constant {
    53  	if _, ok := t.Underlying().(*types.Interface); ok && !typeparams.IsTypeParam(t) {
    54  		// Handle non-generic interface early to simplify following code.
    55  		return nilConst(t, source)
    56  	}
    57  
    58  	tset := typeutil.NewTypeSet(t)
    59  
    60  	switch typ := tset.CoreType().(type) {
    61  	case *types.Struct:
    62  		values := make([]Value, typ.NumFields())
    63  		for i := 0; i < typ.NumFields(); i++ {
    64  			values[i] = zeroConst(typ.Field(i).Type(), source)
    65  		}
    66  		ac := &AggregateConst{
    67  			register: register{typ: t},
    68  			Values:   values,
    69  		}
    70  		ac.setSource(source)
    71  		return ac
    72  	case *types.Tuple:
    73  		values := make([]Value, typ.Len())
    74  		for i := 0; i < typ.Len(); i++ {
    75  			values[i] = zeroConst(typ.At(i).Type(), source)
    76  		}
    77  		ac := &AggregateConst{
    78  			register: register{typ: t},
    79  			Values:   values,
    80  		}
    81  		ac.setSource(source)
    82  		return ac
    83  	}
    84  
    85  	isNillable := func(term *types.Term) bool {
    86  		switch typ := term.Type().Underlying().(type) {
    87  		case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature, *typeutil.Iterator:
    88  			return true
    89  		case *types.Basic:
    90  			switch typ.Kind() {
    91  			case types.UnsafePointer, types.UntypedNil:
    92  				return true
    93  			default:
    94  				return false
    95  			}
    96  		default:
    97  			return false
    98  		}
    99  	}
   100  
   101  	isInfo := func(info types.BasicInfo) func(*types.Term) bool {
   102  		return func(term *types.Term) bool {
   103  			basic, ok := term.Type().Underlying().(*types.Basic)
   104  			if !ok {
   105  				return false
   106  			}
   107  			return (basic.Info() & info) != 0
   108  		}
   109  	}
   110  
   111  	isArray := func(term *types.Term) bool {
   112  		_, ok := term.Type().Underlying().(*types.Array)
   113  		return ok
   114  	}
   115  
   116  	switch {
   117  	case tset.All(isInfo(types.IsNumeric)):
   118  		return NewConst(constant.MakeInt64(0), t, source)
   119  	case tset.All(isInfo(types.IsString)):
   120  		return NewConst(constant.MakeString(""), t, source)
   121  	case tset.All(isInfo(types.IsBoolean)):
   122  		return NewConst(constant.MakeBool(false), t, source)
   123  	case tset.All(isNillable):
   124  		return nilConst(t, source)
   125  	case tset.All(isArray):
   126  		var k ArrayConst
   127  		k.setType(t)
   128  		k.setSource(source)
   129  		return &k
   130  	default:
   131  		var k GenericConst
   132  		k.setType(t)
   133  		k.setSource(source)
   134  		return &k
   135  	}
   136  }
   137  
   138  func (c *Const) RelString(from *types.Package) string {
   139  	var p string
   140  	if c.Value == nil {
   141  		p = "nil"
   142  	} else if c.Value.Kind() == constant.String {
   143  		v := constant.StringVal(c.Value)
   144  		const max = 20
   145  		// TODO(adonovan): don't cut a rune in half.
   146  		if len(v) > max {
   147  			v = v[:max-3] + "..." // abbreviate
   148  		}
   149  		p = strconv.Quote(v)
   150  	} else {
   151  		p = c.Value.String()
   152  	}
   153  	return fmt.Sprintf("Const <%s> {%s}", relType(c.Type(), from), p)
   154  }
   155  
   156  func (c *Const) String() string {
   157  	if c.block == nil {
   158  		// Constants don't have a block till late in the compilation process. But we want to print consts during
   159  		// debugging.
   160  		return c.RelString(nil)
   161  	}
   162  	return c.RelString(c.Parent().pkg())
   163  }
   164  
   165  func (v *ArrayConst) RelString(pkg *types.Package) string {
   166  	return fmt.Sprintf("ArrayConst <%s>", relType(v.Type(), pkg))
   167  }
   168  
   169  func (v *ArrayConst) String() string {
   170  	return v.RelString(v.Parent().pkg())
   171  }
   172  
   173  func (v *AggregateConst) RelString(pkg *types.Package) string {
   174  	values := make([]string, len(v.Values))
   175  	for i, v := range v.Values {
   176  		if v != nil {
   177  			values[i] = v.Name()
   178  		} else {
   179  			values[i] = "nil"
   180  		}
   181  	}
   182  	return fmt.Sprintf("AggregateConst <%s> (%s)", relType(v.Type(), pkg), strings.Join(values, ", "))
   183  }
   184  
   185  func (v *AggregateConst) String() string {
   186  	if v.block == nil {
   187  		return v.RelString(nil)
   188  	}
   189  	return v.RelString(v.Parent().pkg())
   190  }
   191  
   192  func (v *GenericConst) RelString(pkg *types.Package) string {
   193  	return fmt.Sprintf("GenericConst <%s>", relType(v.Type(), pkg))
   194  }
   195  
   196  func (v *GenericConst) String() string {
   197  	return v.RelString(v.Parent().pkg())
   198  }
   199  
   200  // IsNil returns true if this constant represents a typed or untyped nil value.
   201  func (c *Const) IsNil() bool {
   202  	return c.Value == nil
   203  }
   204  
   205  // Int64 returns the numeric value of this constant truncated to fit
   206  // a signed 64-bit integer.
   207  func (c *Const) Int64() int64 {
   208  	switch x := constant.ToInt(c.Value); x.Kind() {
   209  	case constant.Int:
   210  		if i, ok := constant.Int64Val(x); ok {
   211  			return i
   212  		}
   213  		return 0
   214  	case constant.Float:
   215  		f, _ := constant.Float64Val(x)
   216  		return int64(f)
   217  	}
   218  	panic(fmt.Sprintf("unexpected constant value: %T", c.Value))
   219  }
   220  
   221  // Uint64 returns the numeric value of this constant truncated to fit
   222  // an unsigned 64-bit integer.
   223  func (c *Const) Uint64() uint64 {
   224  	switch x := constant.ToInt(c.Value); x.Kind() {
   225  	case constant.Int:
   226  		if u, ok := constant.Uint64Val(x); ok {
   227  			return u
   228  		}
   229  		return 0
   230  	case constant.Float:
   231  		f, _ := constant.Float64Val(x)
   232  		return uint64(f)
   233  	}
   234  	panic(fmt.Sprintf("unexpected constant value: %T", c.Value))
   235  }
   236  
   237  // Float64 returns the numeric value of this constant truncated to fit
   238  // a float64.
   239  func (c *Const) Float64() float64 {
   240  	f, _ := constant.Float64Val(c.Value)
   241  	return f
   242  }
   243  
   244  // Complex128 returns the complex value of this constant truncated to
   245  // fit a complex128.
   246  func (c *Const) Complex128() complex128 {
   247  	re, _ := constant.Float64Val(constant.Real(c.Value))
   248  	im, _ := constant.Float64Val(constant.Imag(c.Value))
   249  	return complex(re, im)
   250  }
   251  
   252  func (c *Const) equal(o Constant) bool {
   253  	// TODO(dh): don't use == for types, this will miss identical pointer types, among others
   254  	oc, ok := o.(*Const)
   255  	if !ok {
   256  		return false
   257  	}
   258  	return c.typ == oc.typ && c.Value == oc.Value && c.source == oc.source
   259  }
   260  
   261  func (c *AggregateConst) equal(o Constant) bool {
   262  	oc, ok := o.(*AggregateConst)
   263  	if !ok {
   264  		return false
   265  	}
   266  	// TODO(dh): don't use == for types, this will miss identical pointer types, among others
   267  	if c.typ != oc.typ {
   268  		return false
   269  	}
   270  	if c.source != oc.source {
   271  		return false
   272  	}
   273  	for i, v := range c.Values {
   274  		if !v.(Constant).equal(oc.Values[i].(Constant)) {
   275  			return false
   276  		}
   277  	}
   278  	return true
   279  }
   280  
   281  func (c *ArrayConst) equal(o Constant) bool {
   282  	oc, ok := o.(*ArrayConst)
   283  	if !ok {
   284  		return false
   285  	}
   286  	// TODO(dh): don't use == for types, this will miss identical pointer types, among others
   287  	return c.typ == oc.typ && c.source == oc.source
   288  }
   289  
   290  func (c *GenericConst) equal(o Constant) bool {
   291  	oc, ok := o.(*GenericConst)
   292  	if !ok {
   293  		return false
   294  	}
   295  	// TODO(dh): don't use == for types, this will miss identical pointer types, among others
   296  	return c.typ == oc.typ && c.source == oc.source
   297  }