github.com/hamba/slices@v0.2.1-0.20220316050741-75c057d92699/internal/gen/gen.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"go/format"
     6  	"log"
     7  	"os"
     8  	"text/template"
     9  )
    10  
    11  var tmpl = `package slices
    12  
    13  // Code generated by 'gen.go'. DO NOT EDIT.
    14  
    15  import "unsafe"
    16  
    17  // Contains
    18  {{- range $i, $type := .contains}}
    19  func {{ $type }}Contains(sptr, vptr unsafe.Pointer) bool {
    20  	v := *(*{{ $type }})(vptr)
    21  	for _, vv := range *(*[]{{ $type }})(sptr) {
    22  		if vv == v {
    23  			return true
    24  		}
    25  	}
    26  	return false
    27  }
    28  {{end -}}
    29  
    30  // LesserOf
    31  {{- range $i, $type := .lesserOf}}
    32  func {{ $type }}Lesser(ptr unsafe.Pointer) func(i, j int) bool {
    33  	return func(i, j int) bool {
    34  		v := *(*[]{{ $type }})(ptr)
    35  		return v[i] < v[j]
    36  	}
    37  }
    38  {{end -}}
    39  
    40  // GreaterOf
    41  {{- range $i, $type := .greaterOf}}
    42  func {{ $type }}Greater(ptr unsafe.Pointer) func(i, j int) bool {
    43  	return func(i, j int) bool {
    44  		v := *(*[]{{ $type }})(ptr)
    45  		return v[i] > v[j]
    46  	}
    47  }
    48  {{end -}}
    49  
    50  // Intersect
    51  {{- range $i, $type := .intersect}}
    52  func {{ $type }}Intersect(sptr, optr unsafe.Pointer) interface{} {
    53  	slice := make([]{{ $type }}, len(*(*[]{{ $type }})(sptr)))
    54  	copy(slice, *(*[]{{ $type }})(sptr))
    55  	for i := 0; i < len(slice); i++ {
    56  		found := false
    57  		for _, v := range *(*[]{{ $type }})(optr) {
    58  			if v == slice[i] {
    59  				found = true
    60  				break
    61  			}
    62  		}
    63  		if !found {
    64  			slice = append(slice[:i], slice[i+1:]...)
    65  			i--
    66  		}
    67  	}
    68  	return slice
    69  }
    70  {{end -}}
    71  
    72  // Except
    73  {{- range $i, $type := .except}}
    74  func {{ $type }}Except(sptr, optr unsafe.Pointer) interface{} {
    75  	s := make([]{{ $type }}, len(*(*[]{{ $type }})(sptr)))
    76  	copy(s, *(*[]{{ $type }})(sptr))
    77  	for i := 0; i < len(s); i++ {
    78  		for _, v := range  *(*[]{{ $type }})(optr) {
    79  			if v == s[i] {
    80  				s = append(s[:i], s[i+1:]...)
    81  				i--
    82  				break
    83  			}
    84  		}
    85  	}
    86  	return s
    87  }
    88  {{end -}}
    89  
    90  `
    91  
    92  func main() {
    93  	parse, err := template.New("gen").Parse(tmpl)
    94  	if err != nil {
    95  		log.Fatal(err)
    96  	}
    97  
    98  	ops := map[string][]string{
    99  		"contains":  {"bool", "string", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64"},
   100  		"lesserOf":  {"string", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64"},
   101  		"greaterOf": {"string", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64"},
   102  		"intersect": {"string", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64"},
   103  		"except":    {"string", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64"},
   104  	}
   105  
   106  	b := &bytes.Buffer{}
   107  	if err = parse.Execute(b, ops); err != nil {
   108  		log.Fatal(err)
   109  	}
   110  
   111  	source, err := format.Source(b.Bytes())
   112  	if err != nil {
   113  		log.Fatal(err)
   114  	}
   115  
   116  	if err = os.WriteFile("ops.gen.go", source, 0600); err != nil {
   117  		log.Fatal(err)
   118  	}
   119  }