github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/vector/generate/naive.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/format" 7 "strings" 8 ) 9 10 type Naive struct { 11 Config 12 13 Files []*NaiveFile 14 } 15 16 func NewNaive(config Config) *Naive { 17 return &Naive{Config: config} 18 } 19 20 type NaiveFile struct { 21 Config 22 23 Path string 24 Content bytes.Buffer 25 } 26 27 func (ctx *Naive) In(path string) File { 28 for _, file := range ctx.Files { 29 if file.Path == path { 30 return file 31 } 32 } 33 34 file := &NaiveFile{Path: path, Config: ctx.Config} 35 ctx.Files = append(ctx.Files, file) 36 37 file.emitHeader() 38 39 return file 40 } 41 42 func (ctx *NaiveFile) Formatted() ([]byte, error) { 43 formatted, err := format.Source(ctx.Content.Bytes()) 44 if err != nil { 45 return ctx.Content.Bytes(), err 46 } 47 return formatted, nil 48 } 49 50 func (ctx *NaiveFile) emitHeader() { 51 ctx.Printf("package %s\n", ctx.Config.Package) 52 } 53 54 func (ctx *NaiveFile) Func(signature string, template Template) { 55 switch t := template.(type) { 56 case Iterate: 57 ctx.Iterate(signature, t) 58 default: 59 panic(fmt.Sprintf("unhandled %T", t)) 60 } 61 } 62 63 func (ctx *NaiveFile) Iterate(signature string, body Iterate) { 64 pf := ctx.Printf 65 66 pf("\n") 67 pf("func %s {\n", ctx.specializeSignature(signature)) 68 defer pf("}\n") 69 70 // determine the primary iterator 71 itCandidates := []It{} 72 for _, it := range body.Ranges { 73 if it.Count.Expr != "" { 74 itCandidates = append(itCandidates, it) 75 } 76 } 77 78 // maybe generate a primary iterator based on the slices we have 79 firstRangeIsPrimary := false 80 if len(itCandidates) == 0 { 81 firstRangeIsPrimary = true 82 first := body.Ranges[0] 83 84 ensure(first.Inc.Const == 1 && first.Inc.Expr == "") 85 ensure(first.Start.Const == 0 && first.Start.Expr == "") 86 87 // generate a range based on the first item in the slice 88 pf("n := len(%s)\n", first.Name) 89 r := Range("i", 0, "n") 90 body.Ranges = append(body.Ranges, r) 91 itCandidates = []It{r} 92 } 93 ensure(len(itCandidates) == 1) 94 95 // generate boundary checks for the iteration 96 prim := itCandidates[0] 97 98 for i, it := range body.Ranges { 99 if i == 0 && firstRangeIsPrimary { 100 // skip the one we used to calculate `n` 101 continue 102 } 103 if !it.Count.Derived { 104 continue 105 } 106 107 size := Var("len(" + it.Name + ")") 108 // TODO: handle multiplication overflow 109 pf("if %s < int(%s) { panic(\"%s is too small\") }\n", sub(size, it.Start), mul(prim.Count, it.Inc), it.Name) 110 } 111 112 // generate iterators if necessary 113 if ctx.Pointer { 114 for i, it := range body.Ranges { 115 if i == 0 && firstRangeIsPrimary || it == prim { 116 continue 117 } 118 119 pf("p%s := unsafe.Pointer(&%s[%s])\n", it.Name, it.Name, it.Start) 120 } 121 } else if ctx.Counter { 122 for i, it := range body.Ranges { 123 if i == 0 && firstRangeIsPrimary || it == prim { 124 continue 125 } 126 127 pf("i%s := %s\n", it.Name, it.Start) 128 } 129 } 130 131 if ctx.Unroll <= 1 { 132 // TODO: simplify increment 133 pf("for %s := %s ; %s < %s; %s {\n", prim.Name, prim.Start, prim.Name, prim.Count, increment(prim.Name, prim.Inc)) 134 pf(" %s\n", ctx.specializeBody(body.For, prim, 0, body.Ranges)) 135 pf(" %s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, 1)) 136 pf("}\n") 137 } else { 138 pf("%s := %s\n", prim.Name, prim.Start) 139 pf("%s_unroll := %s - %s %% %v\n", prim.Count, prim.Count, prim.Count, ctx.Unroll) 140 141 pf("for ; %s < %s_unroll; %s {\n", prim.Name, prim.Count, increment(prim.Name, mul(Const(ctx.Unroll), prim.Inc))) 142 for i := 0; i < ctx.Unroll; i++ { 143 pf(" %s\n", ctx.specializeBody(body.For, prim, i, body.Ranges)) 144 } 145 pf(" %s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, ctx.Unroll)) 146 pf("}\n") 147 148 pf("for ; %s < %s; %s {\n", prim.Name, prim.Count, increment(prim.Name, prim.Inc)) 149 pf(" %s\n", ctx.specializeBody(body.For, prim, 0, body.Ranges)) 150 pf(" %s\n", ctx.advanceIterators(firstRangeIsPrimary, prim, body.Ranges, 1)) 151 pf("}\n") 152 } 153 } 154 155 func (ctx *NaiveFile) specializeSignature(signature string) string { 156 return strings.ReplaceAll(signature, "$Type", ctx.Config.Type.Name) 157 } 158 159 func (ctx *NaiveFile) specializeBody(body string, prim It, primOffset int, ranges []It) string { 160 if ctx.Pointer { 161 return ctx.specializePointerAccess(body, prim, primOffset, ranges) 162 } else if ctx.Counter { 163 return ctx.specializeCounterAccess(body, prim, primOffset, ranges) 164 } else { 165 return ctx.specializeDirectAccess(body, prim, primOffset, ranges) 166 } 167 } 168 169 func (ctx *NaiveFile) advanceIterators(firstRangeIsPrimary bool, prim It, ranges []It, count int) (code string) { 170 if ctx.Pointer { 171 for _, it := range ranges { 172 if it == prim { 173 continue 174 } 175 code += fmt.Sprintf("p%s = unsafe.Add(%s,%s)\n", 176 it.Name, 177 it.Name, 178 mul(mul(it.Inc, Const(count)), Const(ctx.Type.Size)), 179 ) 180 } 181 } else if ctx.Counter { 182 for i, it := range ranges { 183 if i == 0 && firstRangeIsPrimary || it == prim { 184 continue 185 } 186 187 code += increment("i"+it.Name, mul(it.Inc, Const(count))) + "\n" 188 } 189 } 190 return strings.TrimSpace(code) 191 } 192 193 func (ctx *NaiveFile) specializeDirectAccess(body string, prim It, primOffset int, ranges []It) string { 194 return rxVariable.ReplaceAllStringFunc(body, func(ref string) string { 195 ref = ref[1:] 196 197 for _, it := range ranges { 198 if it.Name == ref { 199 return fmt.Sprintf("%s[%s]", ref, 200 add(it.Start, 201 mul( 202 add(Var(prim.Name), Const(primOffset)), 203 it.Inc, 204 ), 205 ), 206 ) 207 } 208 } 209 210 panic("did not find " + ref) 211 }) 212 } 213 214 func (ctx *NaiveFile) specializeCounterAccess(body string, prim It, primOffset int, ranges []It) string { 215 return rxVariable.ReplaceAllStringFunc(body, func(ref string) string { 216 ref = ref[1:] 217 218 for _, it := range ranges { 219 if it.Name == ref { 220 at := it.Name + "[i" + it.Name 221 if primOffset > 0 { 222 at += "+" + mul(Const(primOffset), it.Inc).String() 223 } 224 at += "]" 225 return at 226 } 227 } 228 229 panic("did not find " + ref) 230 }) 231 } 232 233 func (ctx *NaiveFile) specializePointerAccess(body string, prim It, primOffset int, ranges []It) string { 234 return rxVariable.ReplaceAllStringFunc(body, func(ref string) string { 235 ref = ref[1:] 236 237 for _, it := range ranges { 238 if it.Name == ref { 239 240 if primOffset > 0 { 241 return fmt.Sprintf("*(*%s)(unsafe.Add(p%s, %s))", 242 ctx.Type.Name, ref, 243 mul(mul(Const(primOffset), it.Inc), Const(ctx.Type.Size))) 244 } else { 245 return fmt.Sprintf("*(*%s)(p%s)", ctx.Type.Name, ref) 246 } 247 } 248 } 249 250 panic("did not find " + ref) 251 }) 252 } 253 254 func (ctx *NaiveFile) Printf(format string, args ...any) { 255 fmt.Fprintf(&ctx.Content, format, args...) 256 } 257 258 func ensure(v bool) { 259 if !v { 260 panic("unexpected") 261 } 262 }