github.com/apache/beam/sdks/v2@v2.48.2/go/cmd/specialize/main.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one or more 2 // contributor license agreements. See the NOTICE file distributed with 3 // this work for additional information regarding copyright ownership. 4 // The ASF licenses this file to You under the Apache License, Version 2.0 5 // (the "License"); you may not use this file except in compliance with 6 // the License. You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // specialize is a low-level tool to generate type-specialized code. It is a 17 // convenience wrapper over text/template suitable for go generate. Unlike 18 // many other template tools, it does not parse Go code and allows use of 19 // text/template control within the template itself. 20 package main 21 22 import ( 23 "bytes" 24 "flag" 25 "fmt" 26 "log" 27 "math" 28 "os" 29 "path/filepath" 30 "strings" 31 "text/template" 32 33 "golang.org/x/text/cases" 34 "golang.org/x/text/language" 35 ) 36 37 var ( 38 noheader = flag.Bool("noheader", false, "Omit auto-generated header") 39 pack = flag.String("package", "", "Package name (optional)") 40 imports = flag.String("imports", "", "Comma-separated list of extra imports (optional)") 41 42 x = flag.String("x", "", "Comma-separated list of X types (optional)") 43 y = flag.String("y", "", "Comma-separated list of Y types (optional)") 44 z = flag.String("z", "", "Comma-separated list of Z types (optional)") 45 46 input = flag.String("input", "", "Template file.") 47 output = flag.String("output", "", "Filename for generated code. If not provided, a file next to the input is generated.") 48 ) 49 50 // Top is the top-level struct to be passed to the template. 51 type Top struct { 52 // Name is the base form of the filename: "foo/bar.tmpl" -> "bar". 53 Name string 54 // Package is the package name. 55 Package string 56 // Imports is a list of custom imports, if provided. 57 Imports []string 58 // X is the list of X type values. 59 X []*X 60 } 61 62 // X is the concrete type to be iterated over in the user template. 63 type X struct { 64 // Name is the name of X for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". 65 Name string 66 // Type is the textual type of X: "int", "float32", "foo.Baz". 67 Type string 68 // Y is the list of Y type values for this X. 69 Y []*Y 70 } 71 72 // Y is the concrete type to be iterated over in the user template for each X. 73 // Each combination of X and Y will be present. 74 type Y struct { 75 // Name is the name of Y for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". 76 Name string 77 // Type is the textual type of Y: "int", "float32", "foo.Baz". 78 Type string 79 // Z is the list of Z type values for this Y. 80 Z []*Z 81 } 82 83 // Z is the concrete type to be iterated over in the user template for each Y. 84 // Each combination of X, Y and Z will be present. 85 type Z struct { 86 // Name is the name of Z for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". 87 Name string 88 // Type is the textual type of Z: "int", "float32", "foo.Baz". 89 Type string 90 } 91 92 var ( 93 integers = []string{"int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64"} 94 floats = []string{"float32", "float64"} 95 primitives = append(append([]string{"bool", "string"}, integers...), floats...) 96 97 macros = map[string][]string{ 98 "integers": integers, 99 "floats": floats, 100 "primitives": primitives, 101 "data": append([]string{"[]byte"}, primitives...), 102 "universals": {"typex.T", "typex.U", "typex.V", "typex.W", "typex.X", "typex.Y", "typex.Z"}, 103 } 104 105 packageMacros = map[string][]string{ 106 "typex": {"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"}, 107 } 108 ) 109 110 func usage() { 111 fmt.Fprintf(os.Stderr, "Usage: %v [options] --input=<filename.tmpl --x=<types>\n", filepath.Base(os.Args[0])) 112 flag.PrintDefaults() 113 } 114 115 func main() { 116 flag.Usage = usage 117 flag.Parse() 118 119 log.SetFlags(0) 120 log.SetPrefix("specialize: ") 121 122 if *input == "" { 123 flag.Usage() 124 log.Fatalf("no template file") 125 } 126 127 name := filepath.Base(*input) 128 if index := strings.Index(name, "."); index > 0 { 129 name = name[:index] 130 } 131 if *output == "" { 132 *output = filepath.Join(filepath.Dir(*input), name+".go") 133 } 134 135 top := Top{Name: name, Package: *pack, Imports: expand(packageMacros, *imports)} 136 var ys []*Y 137 if *y != "" { 138 var zs []*Z 139 if *z != "" { 140 for _, zt := range expand(macros, *z) { 141 zs = append(zs, &Z{Name: makeName(zt), Type: zt}) 142 } 143 } 144 for _, yt := range expand(macros, *y) { 145 ys = append(ys, &Y{Name: makeName(yt), Type: yt, Z: zs}) 146 } 147 } 148 for _, xt := range expand(macros, *x) { 149 top.X = append(top.X, &X{Name: makeName(xt), Type: xt, Y: ys}) 150 } 151 152 tmpl, err := template.New(*input).Funcs(funcMap).ParseFiles(*input) 153 if err != nil { 154 log.Fatalf("template parse failed: %v", err) 155 } 156 var buf bytes.Buffer 157 if !*noheader { 158 buf.WriteString("// File generated by specialize. Do not edit.\n\n") 159 } 160 if err := tmpl.Funcs(funcMap).Execute(&buf, top); err != nil { 161 log.Fatalf("specialization failed: %v", err) 162 } 163 if err := os.WriteFile(*output, buf.Bytes(), 0644); err != nil { 164 log.Fatalf("write failed: %v", err) 165 } 166 } 167 168 // expand parses, cleans up and expands macros for a comma-separated list. 169 func expand(subst map[string][]string, list string) []string { 170 var ret []string 171 for _, xt := range strings.Split(list, ",") { 172 xt = strings.TrimSpace(xt) 173 if xt == "" { 174 continue 175 } 176 if exp, ok := subst[strings.ToLower(xt)]; ok { 177 for _, t := range exp { 178 ret = append(ret, t) 179 } 180 continue 181 } 182 ret = append(ret, xt) 183 } 184 return ret 185 } 186 187 // makeName creates a capitalized identifier from a type. 188 func makeName(t string) string { 189 if strings.HasPrefix(t, "[]") { 190 return makeName(t[2:] + "Slice") 191 } 192 193 t = strings.Replace(t, ".", "_", -1) 194 t = strings.Replace(t, "[", "_", -1) 195 t = strings.Replace(t, "]", "_", -1) 196 return cases.Title(language.Und, cases.NoLower).String(t) 197 } 198 199 // Useful template functions 200 201 var funcMap template.FuncMap = map[string]any{ 202 "join": strings.Join, 203 "upto": upto, 204 "mkargs": mkargs, 205 "mktuple": mktuple, 206 "mktuplef": mktuplef, 207 "add": add, 208 "mult": mult, 209 "dict": dict, 210 "list": list, 211 "genericTypingRepresentation": genericTypingRepresentation, 212 "possibleBundleLifecycleParameterCombos": possibleBundleLifecycleParameterCombos, 213 } 214 215 // mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type". 216 // If n is 0, it returns the empty string. 217 func mkargs(n int, format, typ string) string { 218 if n == 0 { 219 return "" 220 } 221 return fmt.Sprintf("%v %v", mktuplef(n, format), typ) 222 } 223 224 // mktuple(n, v) returns "v, v, ..., v". 225 func mktuple(n int, v string) string { 226 var ret []string 227 for i := 0; i < n; i++ { 228 ret = append(ret, v) 229 } 230 return strings.Join(ret, ", ") 231 } 232 233 // mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>" 234 func mktuplef(n int, format string) string { 235 var ret []string 236 for i := 0; i < n; i++ { 237 ret = append(ret, fmt.Sprintf(format, i)) 238 } 239 return strings.Join(ret, ", ") 240 } 241 242 // upto(n) returns []int{0, 1, .., n-1}. 243 func upto(i int) []int { 244 var ret []int 245 for k := 0; k < i; k++ { 246 ret = append(ret, k) 247 } 248 return ret 249 } 250 251 func add(i int, j int) int { 252 return i + j 253 } 254 255 func mult(i int, j int) int { 256 return i * j 257 } 258 259 func dict(values ...any) map[string]any { 260 dict := make(map[string]any, len(values)/2) 261 if len(values)%2 != 0 { 262 panic("Invalid dictionary call") 263 } 264 for i := 0; i < len(values); i += 2 { 265 dict[values[i].(string)] = values[i+1] 266 } 267 268 return dict 269 } 270 271 func list(values ...string) []string { 272 return values 273 } 274 275 func genericTypingRepresentation(in int, out int, includeType bool) string { 276 seenElements := false 277 typing := "" 278 if in > 0 { 279 typing += fmt.Sprintf("[I%v", 0) 280 for i := 1; i < in; i++ { 281 typing += fmt.Sprintf(", I%v", i) 282 } 283 seenElements = true 284 } 285 if out > 0 { 286 i := 0 287 if !seenElements { 288 typing += fmt.Sprintf("[R%v", 0) 289 i++ 290 } 291 for i < out { 292 typing += fmt.Sprintf(", R%v", i) 293 i++ 294 } 295 seenElements = true 296 } 297 298 if seenElements { 299 if includeType { 300 typing += " any" 301 } 302 typing += "]" 303 } 304 305 return typing 306 } 307 308 func possibleBundleLifecycleParameterCombos(numInInterface any, processElementInInterface any) [][]string { 309 numIn := numInInterface.(int) 310 processElementIn := processElementInInterface.(int) 311 orderedKnownParameterOptions := []string{"context.Context", "typex.PaneInfo", "[]typex.Window", "typex.EventTime", "typex.BundleFinalization"} 312 // Because of how Bundle lifecycle functions are invoked, all known parameters must precede unknown options and be in order. 313 // Once we hit an unknown options, all remaining unknown options must be included since all iters/emitters must be included 314 // Therefore, we can generate a powerset of the known options and fill out any remaining parameters with an ordered set of remaining unknown options 315 pSetSize := int(math.Pow(2, float64(len(orderedKnownParameterOptions)))) 316 combos := make([][]string, 0, pSetSize) 317 318 for index := 0; index < pSetSize; index++ { 319 var subSet []string 320 321 for j, elem := range orderedKnownParameterOptions { 322 // And with the bit representation to get this iteration of the powerset. 323 if index&(1<<uint(j)) > 0 { 324 subSet = append(subSet, elem) 325 } 326 } 327 // Fill out any remaining parameter slots with consecutive parameters from ProcessElement if there are enough options 328 if len(subSet) <= numIn && numIn-len(subSet) <= processElementIn { 329 for len(subSet) < numIn { 330 nextElement := processElementIn - (numIn - len(subSet)) 331 subSet = append(subSet, fmt.Sprintf("I%v", nextElement)) 332 } 333 combos = append(combos, subSet) 334 } 335 } 336 337 return combos 338 }