github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/vector/generate/naive.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/format"
     7  	"strings"
     8  )
     9  
    10  type Naive struct {
    11  	Config
    12  
    13  	Files []*NaiveFile
    14  }
    15  
    16  func NewNaive(config Config) *Naive {
    17  	return &Naive{Config: config}
    18  }
    19  
    20  type NaiveFile struct {
    21  	Config
    22  
    23  	Path    string
    24  	Content bytes.Buffer
    25  }
    26  
    27  func (ctx *Naive) In(path string) File {
    28  	for _, file := range ctx.Files {
    29  		if file.Path == path {
    30  			return file
    31  		}
    32  	}
    33  
    34  	file := &NaiveFile{Path: path, Config: ctx.Config}
    35  	ctx.Files = append(ctx.Files, file)
    36  
    37  	file.emitHeader()
    38  
    39  	return file
    40  }
    41  
    42  func (ctx *NaiveFile) Formatted() ([]byte, error) {
    43  	formatted, err := format.Source(ctx.Content.Bytes())
    44  	if err != nil {
    45  		return ctx.Content.Bytes(), err
    46  	}
    47  	return formatted, nil
    48  }
    49  
    50  func (ctx *NaiveFile) emitHeader() {
    51  	ctx.Printf("package %s\n", ctx.Config.Package)
    52  }
    53  
    54  func (ctx *NaiveFile) Func(signature string, template Template) {
    55  	switch t := template.(type) {
    56  	case Iterate:
    57  		ctx.Iterate(signature, t)
    58  	default:
    59  		panic(fmt.Sprintf("unhandled %T", t))
    60  	}
    61  }
    62  
    63  func (ctx *NaiveFile) Iterate(signature string, body Iterate) {
    64  	pf := ctx.Printf
    65  
    66  	pf("\n")
    67  	pf("func %s {\n", ctx.specializeSignature(signature))
    68  	defer pf("}\n")
    69  
    70  	// determine the primary iterator
    71  	itCandidates := []It{}
    72  	for _, it := range body.Ranges {
    73  		if it.Count.Expr != "" {
    74  			itCandidates = append(itCandidates, it)
    75  		}
    76  	}
    77  
    78  	// maybe generate a primary iterator based on the slices we have
    79  	firstRangeIsPrimary := false
    80  	if len(itCandidates) == 0 {
    81  		firstRangeIsPrimary = true
    82  		first := body.Ranges[0]
    83  
    84  		ensure(first.Inc.Const == 1 && first.Inc.Expr == "")
    85  		ensure(first.Start.Const == 0 && first.Start.Expr == "")
    86  
    87  		// generate a range based on the first item in the slice
    88  		pf("n := len(%s)\n", first.Name)
    89  		r := Range("i", 0, "n")
    90  		body.Ranges = append(body.Ranges, r)
    91  		itCandidates = []It{r}
    92  	}
    93  	ensure(len(itCandidates) == 1)
    94  
    95  	// generate boundary checks for the iteration
    96  	prim := itCandidates[0]
    97  
    98  	for i, it := range body.Ranges {
    99  		if i == 0 && firstRangeIsPrimary {
   100  			// skip the one we used to calculate `n`
   101  			continue
   102  		}
   103  		if !it.Count.Derived {
   104  			continue
   105  		}
   106  
   107  		size := Var("len(" + it.Name + ")")
   108  		// TODO: handle multiplication overflow
   109  		pf("if %s < int(%s) { panic(\"%s is too small\") }\n", sub(size, it.Start), mul(prim.Count, it.Inc), it.Name)
   110  	}
   111  
   112  	// generate iterators if necessary
   113  	if ctx.Pointer {
   114  		for i, it := range body.Ranges {
   115  			if i == 0 && firstRangeIsPrimary || it == prim {
   116  				continue
   117  			}
   118  
   119  			pf("p%s := unsafe.Pointer(&%s[%s])\n", it.Name, it.Name, it.Start)
   120  		}
   121  	} else if ctx.Counter {
   122  		for i, it := range body.Ranges {
   123  			if i == 0 && firstRangeIsPrimary || it == prim {
   124  				continue
   125  			}
   126  
   127  			pf("i%s := %s\n", it.Name, it.Start)
   128  		}
   129  	}
   130  
   131  	if ctx.Unroll <= 1 {
   132  		// TODO: simplify increment
   133  		pf("for %s := %s ; %s < %s; %s {\n", prim.Name, prim.Start, prim.Name, prim.Count, increment(prim.Name, prim.Inc))
   134  		pf("	%s\n", ctx.specializeBody(body.For, prim, 0, body.Ranges))
   135  		pf("	%s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, 1))
   136  		pf("}\n")
   137  	} else {
   138  		pf("%s := %s\n", prim.Name, prim.Start)
   139  		pf("%s_unroll := %s - %s %% %v\n", prim.Count, prim.Count, prim.Count, ctx.Unroll)
   140  
   141  		pf("for ; %s < %s_unroll; %s {\n", prim.Name, prim.Count, increment(prim.Name, mul(Const(ctx.Unroll), prim.Inc)))
   142  		for i := 0; i < ctx.Unroll; i++ {
   143  			pf("	%s\n", ctx.specializeBody(body.For, prim, i, body.Ranges))
   144  		}
   145  		pf("	%s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, ctx.Unroll))
   146  		pf("}\n")
   147  
   148  		pf("for ; %s < %s; %s {\n", prim.Name, prim.Count, increment(prim.Name, prim.Inc))
   149  		pf("	%s\n", ctx.specializeBody(body.For, prim, 0, body.Ranges))
   150  		pf("	%s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, 1))
   151  		pf("}\n")
   152  	}
   153  }
   154  
   155  func (ctx *NaiveFile) specializeSignature(signature string) string {
   156  	return strings.ReplaceAll(signature, "$Type", ctx.Config.Type.Name)
   157  }
   158  
   159  func (ctx *NaiveFile) specializeBody(body string, prim It, primOffset int, ranges []It) string {
   160  	if ctx.Pointer {
   161  		return ctx.specializePointerAccess(body, prim, primOffset, ranges)
   162  	} else if ctx.Counter {
   163  		return ctx.specializeCounterAccess(body, prim, primOffset, ranges)
   164  	} else {
   165  		return ctx.specializeDirectAccess(body, prim, primOffset, ranges)
   166  	}
   167  }
   168  
   169  func (ctx *NaiveFile) advanceIterators(firstRangeIsPrimary bool, prim It, ranges []It, count int) (code string) {
   170  	if ctx.Pointer {
   171  		for _, it := range ranges {
   172  			if it == prim {
   173  				continue
   174  			}
   175  			code += fmt.Sprintf("p%s = unsafe.Add(%s,%s)\n",
   176  				it.Name,
   177  				it.Name,
   178  				mul(mul(it.Inc, Const(count)), Const(ctx.Type.Size)),
   179  			)
   180  		}
   181  	} else if ctx.Counter {
   182  		for i, it := range ranges {
   183  			if i == 0 && firstRangeIsPrimary || it == prim {
   184  				continue
   185  			}
   186  
   187  			code += increment("i"+it.Name, mul(it.Inc, Const(count))) + "\n"
   188  		}
   189  	}
   190  	return strings.TrimSpace(code)
   191  }
   192  
   193  func (ctx *NaiveFile) specializeDirectAccess(body string, prim It, primOffset int, ranges []It) string {
   194  	return rxVariable.ReplaceAllStringFunc(body, func(ref string) string {
   195  		ref = ref[1:]
   196  
   197  		for _, it := range ranges {
   198  			if it.Name == ref {
   199  				return fmt.Sprintf("%s[%s]", ref,
   200  					add(it.Start,
   201  						mul(
   202  							add(Var(prim.Name), Const(primOffset)),
   203  							it.Inc,
   204  						),
   205  					),
   206  				)
   207  			}
   208  		}
   209  
   210  		panic("did not find " + ref)
   211  	})
   212  }
   213  
   214  func (ctx *NaiveFile) specializeCounterAccess(body string, prim It, primOffset int, ranges []It) string {
   215  	return rxVariable.ReplaceAllStringFunc(body, func(ref string) string {
   216  		ref = ref[1:]
   217  
   218  		for _, it := range ranges {
   219  			if it.Name == ref {
   220  				at := it.Name + "[i" + it.Name
   221  				if primOffset > 0 {
   222  					at += "+" + mul(Const(primOffset), it.Inc).String()
   223  				}
   224  				at += "]"
   225  				return at
   226  			}
   227  		}
   228  
   229  		panic("did not find " + ref)
   230  	})
   231  }
   232  
   233  func (ctx *NaiveFile) specializePointerAccess(body string, prim It, primOffset int, ranges []It) string {
   234  	return rxVariable.ReplaceAllStringFunc(body, func(ref string) string {
   235  		ref = ref[1:]
   236  
   237  		for _, it := range ranges {
   238  			if it.Name == ref {
   239  
   240  				if primOffset > 0 {
   241  					return fmt.Sprintf("*(*%s)(unsafe.Add(p%s, %s))",
   242  						ctx.Type.Name, ref,
   243  						mul(mul(Const(primOffset), it.Inc), Const(ctx.Type.Size)))
   244  				} else {
   245  					return fmt.Sprintf("*(*%s)(p%s)", ctx.Type.Name, ref)
   246  				}
   247  			}
   248  		}
   249  
   250  		panic("did not find " + ref)
   251  	})
   252  }
   253  
   254  func (ctx *NaiveFile) Printf(format string, args ...any) {
   255  	fmt.Fprintf(&ctx.Content, format, args...)
   256  }
   257  
   258  func ensure(v bool) {
   259  	if !v {
   260  		panic("unexpected")
   261  	}
   262  }