github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/internal/gosmith/type.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  )
     8  
     9  type TypeClass int
    10  
    11  const (
    12  	ClassBoolean TypeClass = iota
    13  	ClassNumeric
    14  	ClassComplex
    15  	ClassString
    16  	ClassArray
    17  	ClassSlice
    18  	ClassStruct
    19  	ClassPointer
    20  	ClassFunction
    21  	ClassInterface
    22  	ClassMap
    23  	ClassChan
    24  
    25  	TraitAny
    26  	TraitOrdered
    27  	TraitComparable
    28  	TraitIndexable
    29  	TraitReceivable
    30  	TraitSendable
    31  	TraitHashable
    32  	TraitPrintable
    33  	TraitLenCapable
    34  	TraitGlobal
    35  )
    36  
    37  type Type struct {
    38  	id             string
    39  	class          TypeClass
    40  	namedUserType  bool
    41  	ktyp           *Type   // map key, chan elem, array elem, slice elem, pointee type
    42  	vtyp           *Type   // map val
    43  	utyp           *Type   // underlying type
    44  	styp           []*Type // function arguments
    45  	rtyp           []*Type // function return values
    46  	elems          []*Var  // struct fields and interface methods
    47  	literal        func() string
    48  	complexLiteral func() string
    49  
    50  	// TODO: cache types
    51  	// pointerTo *Type
    52  }
    53  
    54  func (smith *Smith) initTypes() {
    55  	smith.predefinedTypes = []*Type{
    56  		{id: "string", class: ClassString, literal: func() string { return "\"foo\"" }},
    57  		{id: "bool", class: ClassBoolean, literal: func() string { return "false" }},
    58  		{id: "int", class: ClassNumeric, literal: func() string { return "1" }},
    59  		{id: "byte", class: ClassNumeric, literal: func() string { return "byte(0)" }},
    60  		{id: "interface{}", class: ClassInterface, literal: func() string { return "interface{}(nil)" }},
    61  		{id: "rune", class: ClassNumeric, literal: func() string { return "rune(0)" }},
    62  		{id: "float32", class: ClassNumeric, literal: func() string { return "float32(1.0)" }},
    63  		{id: "float64", class: ClassNumeric, literal: func() string { return "1.0" }},
    64  		{id: "complex64", class: ClassComplex, literal: func() string { return "complex64(1i)" }},
    65  		{id: "complex128", class: ClassComplex, literal: func() string { return "1i" }},
    66  
    67  		{id: "uint", class: ClassNumeric, literal: func() string { return "uint(1)" }},
    68  		{id: "uintptr", class: ClassNumeric, literal: func() string { return "uintptr(0)" }},
    69  		{id: "int16", class: ClassNumeric, literal: func() string { return "int16(1)" }},
    70  		{id: "error", class: ClassInterface, literal: func() string { return "error(nil)" }},
    71  	}
    72  	for _, t := range smith.predefinedTypes {
    73  		t.utyp = t
    74  	}
    75  
    76  	smith.stringType = smith.predefinedTypes[0]
    77  	smith.boolType = smith.predefinedTypes[1]
    78  	smith.intType = smith.predefinedTypes[2]
    79  	smith.byteType = smith.predefinedTypes[3]
    80  	smith.efaceType = smith.predefinedTypes[4]
    81  	smith.runeType = smith.predefinedTypes[5]
    82  	smith.float32Type = smith.predefinedTypes[6]
    83  	smith.float64Type = smith.predefinedTypes[7]
    84  	smith.complex64Type = smith.predefinedTypes[8]
    85  	smith.complex128Type = smith.predefinedTypes[9]
    86  
    87  	smith.stringType.complexLiteral = func() string {
    88  		if smith.rndBool() {
    89  			return `"ab\x0acd"`
    90  		}
    91  		return "`abc\\x0acd`"
    92  	}
    93  }
    94  
    95  func fmtTypeList(list []*Type, parens bool) string {
    96  	var buf bytes.Buffer
    97  	if parens || len(list) > 1 {
    98  		buf.Write([]byte{'('})
    99  	}
   100  	for i, t := range list {
   101  		if i != 0 {
   102  			buf.Write([]byte{','})
   103  		}
   104  		fmt.Fprintf(&buf, "%v", t.id)
   105  	}
   106  	if parens || len(list) > 1 {
   107  		buf.Write([]byte{')'})
   108  	}
   109  	return buf.String()
   110  }
   111  
   112  func (smith *Smith) atype(trait TypeClass) *Type {
   113  	smith.typeDepth++
   114  	defer func() {
   115  		smith.typeDepth--
   116  	}()
   117  	for {
   118  		if smith.typeDepth >= 3 || smith.rndBool() {
   119  			var cand []*Type
   120  			for _, t := range smith.types() {
   121  				if smith.satisfiesTrait(t, trait) {
   122  					cand = append(cand, t)
   123  				}
   124  			}
   125  			if len(cand) > 0 {
   126  				return cand[smith.rnd(len(cand))]
   127  			}
   128  		}
   129  		t := smith.typeLit()
   130  		if t != nil && smith.satisfiesTrait(t, trait) {
   131  			return t
   132  		}
   133  	}
   134  }
   135  
   136  func (smith *Smith) typeLit() *Type {
   137  	switch smith.choice("array", "chan", "struct", "pointer", "interface", "slice", "function", "map") {
   138  	case "array":
   139  		return smith.arrayOf(smith.atype(TraitAny))
   140  	case "chan":
   141  		return smith.chanOf(smith.atype(TraitAny))
   142  	case "struct":
   143  		var elems []*Var
   144  		var buf bytes.Buffer
   145  		fmt.Fprintf(&buf, "struct { ")
   146  		for smith.rndBool() {
   147  			e := &Var{id: smith.newId("Field"), typ: smith.atype(TraitAny)}
   148  			elems = append(elems, e)
   149  			fmt.Fprintf(&buf, "%v %v\n", e.id, e.typ.id)
   150  		}
   151  		fmt.Fprintf(&buf, "}")
   152  		id := buf.String()
   153  		return &Type{
   154  			id:    id,
   155  			class: ClassStruct,
   156  			elems: elems,
   157  			literal: func() string {
   158  				return F("(%v{})", id)
   159  			},
   160  			complexLiteral: func() string {
   161  				if smith.rndBool() {
   162  					// unnamed
   163  					var buf bytes.Buffer
   164  					fmt.Fprintf(&buf, "(%v{", id)
   165  					for i := 0; i < len(elems); i++ {
   166  						fmt.Fprintf(&buf, "%v, ", smith.rvalue(elems[i].typ))
   167  					}
   168  					fmt.Fprintf(&buf, "})")
   169  					return buf.String()
   170  				} else {
   171  					// named
   172  					var buf bytes.Buffer
   173  					fmt.Fprintf(&buf, "(%v{", id)
   174  					for i := 0; i < len(elems); i++ {
   175  						if smith.rndBool() {
   176  							fmt.Fprintf(&buf, "%v: %v, ", elems[i].id, smith.rvalue(elems[i].typ))
   177  						}
   178  					}
   179  					fmt.Fprintf(&buf, "})")
   180  					return buf.String()
   181  				}
   182  			},
   183  		}
   184  	case "pointer":
   185  		return pointerTo(smith.atype(TraitAny))
   186  	case "interface":
   187  		var buf bytes.Buffer
   188  		fmt.Fprintf(&buf, "interface { ")
   189  		for smith.rndBool() {
   190  			fmt.Fprintf(&buf, " %v %v %v\n", smith.newId("Method"),
   191  				fmtTypeList(smith.atypeList(TraitAny), true),
   192  				fmtTypeList(smith.atypeList(TraitAny), false))
   193  		}
   194  		fmt.Fprintf(&buf, "}")
   195  		return &Type{
   196  			id:    buf.String(),
   197  			class: ClassInterface,
   198  			literal: func() string {
   199  				return F("%v(nil)", buf.String())
   200  			},
   201  		}
   202  	case "slice":
   203  		return smith.sliceOf(smith.atype(TraitAny))
   204  	case "function":
   205  		return smith.funcOf(smith.atypeList(TraitAny), smith.atypeList(TraitAny))
   206  	case "map":
   207  		ktyp := smith.atype(TraitHashable)
   208  		vtyp := smith.atype(TraitAny)
   209  		return &Type{
   210  			id:    F("map[%v]%v", ktyp.id, vtyp.id),
   211  			class: ClassMap,
   212  			ktyp:  ktyp,
   213  			vtyp:  vtyp,
   214  			literal: func() string {
   215  				if smith.rndBool() {
   216  					cap := ""
   217  					if smith.rndBool() {
   218  						cap = "," + smith.rvalue(smith.intType)
   219  					}
   220  					return F("make(map[%v]%v %v)", ktyp.id, vtyp.id, cap)
   221  				} else {
   222  					return F("map[%v]%v{}", ktyp.id, vtyp.id)
   223  				}
   224  			},
   225  		}
   226  	default:
   227  		panic("bad")
   228  	}
   229  }
   230  
   231  func (smith *Smith) satisfiesTrait(t *Type, trait TypeClass) bool {
   232  	if trait < TraitAny {
   233  		return t.class == trait
   234  	}
   235  
   236  	switch trait {
   237  	case TraitAny:
   238  		return true
   239  	case TraitOrdered:
   240  		return t.class == ClassNumeric || t.class == ClassString
   241  	case TraitComparable:
   242  		return t.class == ClassBoolean || t.class == ClassNumeric || t.class == ClassString ||
   243  			t.class == ClassPointer || t.class == ClassChan || t.class == ClassInterface
   244  	case TraitIndexable:
   245  		return t.class == ClassArray || t.class == ClassSlice || t.class == ClassString ||
   246  			t.class == ClassMap
   247  	case TraitReceivable:
   248  		return t.class == ClassChan
   249  	case TraitSendable:
   250  		return t.class == ClassChan
   251  	case TraitHashable:
   252  		if t.class == ClassFunction || t.class == ClassMap || t.class == ClassSlice {
   253  			return false
   254  		}
   255  		if t.class == ClassArray && !smith.satisfiesTrait(t.ktyp, TraitHashable) {
   256  			return false
   257  		}
   258  		if t.class == ClassStruct {
   259  			for _, e := range t.elems {
   260  				if !smith.satisfiesTrait(e.typ, TraitHashable) {
   261  					return false
   262  				}
   263  			}
   264  		}
   265  		return true
   266  	case TraitPrintable:
   267  		return t.class == ClassBoolean || t.class == ClassNumeric || t.class == ClassString ||
   268  			t.class == ClassPointer || t.class == ClassInterface
   269  	case TraitLenCapable:
   270  		return t.class == ClassString || t.class == ClassSlice || t.class == ClassArray ||
   271  			t.class == ClassMap || t.class == ClassChan
   272  	case TraitGlobal:
   273  		for _, t1 := range smith.predefinedTypes {
   274  			if t == t1 {
   275  				return true
   276  			}
   277  		}
   278  		return false
   279  	default:
   280  		panic("bad")
   281  	}
   282  }
   283  
   284  func (smith *Smith) atypeList(trait TypeClass) []*Type {
   285  	n := smith.rnd(4) + 1
   286  	list := make([]*Type, n)
   287  	for i := 0; i < n; i++ {
   288  		list[i] = smith.atype(trait)
   289  	}
   290  	return list
   291  }
   292  
   293  func typeList(t *Type, n int) []*Type {
   294  	list := make([]*Type, n)
   295  	for i := 0; i < n; i++ {
   296  		list[i] = t
   297  	}
   298  	return list
   299  }
   300  
   301  func pointerTo(elem *Type) *Type {
   302  	return &Type{
   303  		id:    F("*%v", elem.id),
   304  		class: ClassPointer,
   305  		ktyp:  elem,
   306  		literal: func() string {
   307  			return F("(*%v)(nil)", elem.id)
   308  		}}
   309  }
   310  
   311  func (smith *Smith) chanOf(elem *Type) *Type {
   312  	return &Type{
   313  		id:    F("chan %v", elem.id),
   314  		class: ClassChan,
   315  		ktyp:  elem,
   316  		literal: func() string {
   317  			cap := ""
   318  			if smith.rndBool() {
   319  				cap = "," + smith.rvalue(smith.intType)
   320  			}
   321  			return F("make(chan %v %v)", elem.id, cap)
   322  		},
   323  	}
   324  }
   325  
   326  func (smith *Smith) sliceOf(elem *Type) *Type {
   327  	return &Type{
   328  		id:    F("[]%v", elem.id),
   329  		class: ClassSlice,
   330  		ktyp:  elem,
   331  		literal: func() string {
   332  			return F("[]%v{}", elem.id)
   333  		},
   334  		complexLiteral: func() string {
   335  			switch smith.choice("normal", "keyed") {
   336  			case "normal":
   337  				return F("[]%v{%v}", elem.id, smith.fmtRvalueList(typeList(elem, smith.rnd(3))))
   338  			case "keyed":
   339  				n := smith.rnd(3)
   340  				var indexes []int
   341  			loop:
   342  				for len(indexes) < n {
   343  					i := smith.rnd(10)
   344  					for _, i1 := range indexes {
   345  						if i1 == i {
   346  							continue loop
   347  						}
   348  					}
   349  					indexes = append(indexes, i)
   350  				}
   351  				var buf bytes.Buffer
   352  				fmt.Fprintf(&buf, "[]%v{", elem.id)
   353  				for i, idx := range indexes {
   354  					if i != 0 {
   355  						fmt.Fprintf(&buf, ",")
   356  					}
   357  					fmt.Fprintf(&buf, "%v: %v", idx, smith.rvalue(elem))
   358  				}
   359  				fmt.Fprintf(&buf, "}")
   360  				return buf.String()
   361  			default:
   362  				panic("bad")
   363  			}
   364  		},
   365  	}
   366  }
   367  
   368  func (smith *Smith) arrayOf(elem *Type) *Type {
   369  	size := smith.rnd(3)
   370  	return &Type{
   371  		id:    F("[%v]%v", size, elem.id),
   372  		class: ClassArray,
   373  		ktyp:  elem,
   374  		literal: func() string {
   375  			return F("[%v]%v{}", size, elem.id)
   376  		},
   377  		complexLiteral: func() string {
   378  			switch smith.choice("normal", "keyed") {
   379  			case "normal":
   380  				return F("[%v]%v{%v}", smith.choice(F("%v", size), "..."), elem.id, smith.fmtRvalueList(typeList(elem, size)))
   381  			case "keyed":
   382  				var buf bytes.Buffer
   383  				fmt.Fprintf(&buf, "[%v]%v{", size, elem.id)
   384  				for i := 0; i < size; i++ {
   385  					if i != 0 {
   386  						fmt.Fprintf(&buf, ",")
   387  					}
   388  					fmt.Fprintf(&buf, "%v: %v", i, smith.rvalue(elem))
   389  				}
   390  				fmt.Fprintf(&buf, "}")
   391  				return buf.String()
   392  			default:
   393  				panic("bad")
   394  			}
   395  		},
   396  	}
   397  }
   398  
   399  func (smith *Smith) funcOf(alist, rlist []*Type) *Type {
   400  	t := &Type{
   401  		id:    F("func%v %v", fmtTypeList(alist, true), fmtTypeList(rlist, false)),
   402  		class: ClassFunction,
   403  		styp:  alist,
   404  		rtyp:  rlist,
   405  	}
   406  	t.literal = func() string {
   407  		return F("((func%v %v)(nil))", fmtTypeList(alist, true), fmtTypeList(rlist, false))
   408  	}
   409  	t.complexLiteral = func() string {
   410  		return smith.genFuncLit(t)
   411  	}
   412  	return t
   413  }
   414  
   415  func (smith *Smith) genFuncLit(ft *Type) string {
   416  	//return F("((func%v %v)(nil))", fmtTypeList(ft.styp, true), fmtTypeList(ft.rtyp, false))
   417  
   418  	if smith.curBlockPos == -1 {
   419  		smith.line("")
   420  	}
   421  
   422  	f := &Func{args: ft.styp, rets: ft.rtyp}
   423  	curFunc0 := smith.curFunc
   424  	smith.curFunc = f
   425  	curBlock0 := smith.curBlock
   426  	curBlockPos0 := smith.curBlockPos
   427  	curBlockLen0 := len(smith.curBlock.sub)
   428  	exprDepth0 := smith.exprDepth
   429  	exprCount0 := smith.exprCount
   430  	smith.exprDepth = 0
   431  	smith.exprCount = 0
   432  	defer func() {
   433  		smith.curBlock = curBlock0
   434  		smith.curFunc = curFunc0
   435  		smith.exprDepth = exprDepth0
   436  		smith.exprCount = exprCount0
   437  		smith.curBlockPos = curBlockPos0 + (len(smith.curBlock.sub) - curBlockLen0)
   438  	}()
   439  
   440  	fb := &Block{parent: smith.curBlock, subBlock: smith.curBlock.sub[smith.curBlockPos]}
   441  	smith.curBlock = fb
   442  	smith.enterBlock(true)
   443  	smith.enterBlock(true)
   444  	argIds := make([]string, len(f.args))
   445  	argStr := ""
   446  	for i, a := range f.args {
   447  		argIds[i] = smith.newId("Param")
   448  		if i != 0 {
   449  			argStr += ", "
   450  		}
   451  		argStr += argIds[i] + " " + a.id
   452  	}
   453  	smith.line("func(%v)%v {", argStr, fmtTypeList(f.rets, false))
   454  	for i, a := range f.args {
   455  		smith.defineVar(argIds[i], a)
   456  	}
   457  	smith.curBlock.funcBoundary = true
   458  	smith.genBlock()
   459  	smith.leaveBlock()
   460  	smith.stmtReturn()
   461  	smith.line("}")
   462  	smith.leaveBlock()
   463  
   464  	//b := curBlock.sub[curBlockPos]
   465  	//copy(curBlock.sub[curBlockPos:], curBlock.sub[curBlockPos+1:])
   466  	//curBlock.sub = curBlock.sub[:len(curBlock.sub)-1]
   467  
   468  	var buf bytes.Buffer
   469  	w := bufio.NewWriter(&buf)
   470  	serializeBlock(w, fb, 0)
   471  	w.Flush()
   472  	s := buf.String()
   473  	//fmt.Printf("GEN FUNC:\n%v\n", s)
   474  	return s[:len(s)-1]
   475  }
   476  
   477  func dependsOn(t, t0 *Type) bool {
   478  	if t == nil {
   479  		return false
   480  	}
   481  	if t.class == ClassInterface {
   482  		// We don't know how to walk all types referenced by an interface yet.
   483  		return true
   484  	}
   485  	if t0 == nil && t.namedUserType {
   486  		return true
   487  	}
   488  	if t == t0 {
   489  		return true
   490  	}
   491  	if dependsOn(t.ktyp, t0) {
   492  		return true
   493  	}
   494  	if dependsOn(t.vtyp, t0) {
   495  		return true
   496  	}
   497  	if dependsOn(t.ktyp, t0) {
   498  		return true
   499  	}
   500  	for _, t1 := range t.styp {
   501  		if dependsOn(t1, t0) {
   502  			return true
   503  		}
   504  	}
   505  	for _, t1 := range t.rtyp {
   506  		if dependsOn(t1, t0) {
   507  			return true
   508  		}
   509  	}
   510  	for _, e := range t.elems {
   511  		if dependsOn(e.typ, t0) {
   512  			return true
   513  		}
   514  	}
   515  	return false
   516  }