github.com/apache/beam/sdks/v2@v2.48.2/go/cmd/specialize/main.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one or more
     2  // contributor license agreements.  See the NOTICE file distributed with
     3  // this work for additional information regarding copyright ownership.
     4  // The ASF licenses this file to You under the Apache License, Version 2.0
     5  // (the "License"); you may not use this file except in compliance with
     6  // the License.  You may obtain a copy of the License at
     7  //
     8  //    http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  // specialize is a low-level tool to generate type-specialized code. It is a
    17  // convenience wrapper over text/template suitable for go generate. Unlike
    18  // many other template tools, it does not parse Go code and allows use of
    19  // text/template control within the template itself.
    20  package main
    21  
    22  import (
    23  	"bytes"
    24  	"flag"
    25  	"fmt"
    26  	"log"
    27  	"math"
    28  	"os"
    29  	"path/filepath"
    30  	"strings"
    31  	"text/template"
    32  
    33  	"golang.org/x/text/cases"
    34  	"golang.org/x/text/language"
    35  )
    36  
    37  var (
    38  	noheader = flag.Bool("noheader", false, "Omit auto-generated header")
    39  	pack     = flag.String("package", "", "Package name (optional)")
    40  	imports  = flag.String("imports", "", "Comma-separated list of extra imports (optional)")
    41  
    42  	x = flag.String("x", "", "Comma-separated list of X types (optional)")
    43  	y = flag.String("y", "", "Comma-separated list of Y types (optional)")
    44  	z = flag.String("z", "", "Comma-separated list of Z types (optional)")
    45  
    46  	input  = flag.String("input", "", "Template file.")
    47  	output = flag.String("output", "", "Filename for generated code. If not provided, a file next to the input is generated.")
    48  )
    49  
    50  // Top is the top-level struct to be passed to the template.
    51  type Top struct {
    52  	// Name is the base form of the filename: "foo/bar.tmpl" -> "bar".
    53  	Name string
    54  	// Package is the package name.
    55  	Package string
    56  	// Imports is a list of custom imports, if provided.
    57  	Imports []string
    58  	// X is the list of X type values.
    59  	X []*X
    60  }
    61  
    62  // X is the concrete type to be iterated over in the user template.
    63  type X struct {
    64  	// Name is the name of X for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
    65  	Name string
    66  	// Type is the textual type of X: "int", "float32", "foo.Baz".
    67  	Type string
    68  	// Y is the list of Y type values for this X.
    69  	Y []*Y
    70  }
    71  
    72  // Y is the concrete type to be iterated over in the user template for each X.
    73  // Each combination of X and Y will be present.
    74  type Y struct {
    75  	// Name is the name of Y for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
    76  	Name string
    77  	// Type is the textual type of Y: "int", "float32", "foo.Baz".
    78  	Type string
    79  	// Z is the list of Z type values for this Y.
    80  	Z []*Z
    81  }
    82  
    83  // Z is the concrete type to be iterated over in the user template for each Y.
    84  // Each combination of X, Y and Z will be present.
    85  type Z struct {
    86  	// Name is the name of Z for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
    87  	Name string
    88  	// Type is the textual type of Z: "int", "float32", "foo.Baz".
    89  	Type string
    90  }
    91  
    92  var (
    93  	integers   = []string{"int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64"}
    94  	floats     = []string{"float32", "float64"}
    95  	primitives = append(append([]string{"bool", "string"}, integers...), floats...)
    96  
    97  	macros = map[string][]string{
    98  		"integers":   integers,
    99  		"floats":     floats,
   100  		"primitives": primitives,
   101  		"data":       append([]string{"[]byte"}, primitives...),
   102  		"universals": {"typex.T", "typex.U", "typex.V", "typex.W", "typex.X", "typex.Y", "typex.Z"},
   103  	}
   104  
   105  	packageMacros = map[string][]string{
   106  		"typex": {"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"},
   107  	}
   108  )
   109  
   110  func usage() {
   111  	fmt.Fprintf(os.Stderr, "Usage: %v [options] --input=<filename.tmpl --x=<types>\n", filepath.Base(os.Args[0]))
   112  	flag.PrintDefaults()
   113  }
   114  
   115  func main() {
   116  	flag.Usage = usage
   117  	flag.Parse()
   118  
   119  	log.SetFlags(0)
   120  	log.SetPrefix("specialize: ")
   121  
   122  	if *input == "" {
   123  		flag.Usage()
   124  		log.Fatalf("no template file")
   125  	}
   126  
   127  	name := filepath.Base(*input)
   128  	if index := strings.Index(name, "."); index > 0 {
   129  		name = name[:index]
   130  	}
   131  	if *output == "" {
   132  		*output = filepath.Join(filepath.Dir(*input), name+".go")
   133  	}
   134  
   135  	top := Top{Name: name, Package: *pack, Imports: expand(packageMacros, *imports)}
   136  	var ys []*Y
   137  	if *y != "" {
   138  		var zs []*Z
   139  		if *z != "" {
   140  			for _, zt := range expand(macros, *z) {
   141  				zs = append(zs, &Z{Name: makeName(zt), Type: zt})
   142  			}
   143  		}
   144  		for _, yt := range expand(macros, *y) {
   145  			ys = append(ys, &Y{Name: makeName(yt), Type: yt, Z: zs})
   146  		}
   147  	}
   148  	for _, xt := range expand(macros, *x) {
   149  		top.X = append(top.X, &X{Name: makeName(xt), Type: xt, Y: ys})
   150  	}
   151  
   152  	tmpl, err := template.New(*input).Funcs(funcMap).ParseFiles(*input)
   153  	if err != nil {
   154  		log.Fatalf("template parse failed: %v", err)
   155  	}
   156  	var buf bytes.Buffer
   157  	if !*noheader {
   158  		buf.WriteString("// File generated by specialize. Do not edit.\n\n")
   159  	}
   160  	if err := tmpl.Funcs(funcMap).Execute(&buf, top); err != nil {
   161  		log.Fatalf("specialization failed: %v", err)
   162  	}
   163  	if err := os.WriteFile(*output, buf.Bytes(), 0644); err != nil {
   164  		log.Fatalf("write failed: %v", err)
   165  	}
   166  }
   167  
   168  // expand parses, cleans up and expands macros for a comma-separated list.
   169  func expand(subst map[string][]string, list string) []string {
   170  	var ret []string
   171  	for _, xt := range strings.Split(list, ",") {
   172  		xt = strings.TrimSpace(xt)
   173  		if xt == "" {
   174  			continue
   175  		}
   176  		if exp, ok := subst[strings.ToLower(xt)]; ok {
   177  			for _, t := range exp {
   178  				ret = append(ret, t)
   179  			}
   180  			continue
   181  		}
   182  		ret = append(ret, xt)
   183  	}
   184  	return ret
   185  }
   186  
   187  // makeName creates a capitalized identifier from a type.
   188  func makeName(t string) string {
   189  	if strings.HasPrefix(t, "[]") {
   190  		return makeName(t[2:] + "Slice")
   191  	}
   192  
   193  	t = strings.Replace(t, ".", "_", -1)
   194  	t = strings.Replace(t, "[", "_", -1)
   195  	t = strings.Replace(t, "]", "_", -1)
   196  	return cases.Title(language.Und, cases.NoLower).String(t)
   197  }
   198  
   199  // Useful template functions
   200  
   201  var funcMap template.FuncMap = map[string]any{
   202  	"join":                                   strings.Join,
   203  	"upto":                                   upto,
   204  	"mkargs":                                 mkargs,
   205  	"mktuple":                                mktuple,
   206  	"mktuplef":                               mktuplef,
   207  	"add":                                    add,
   208  	"mult":                                   mult,
   209  	"dict":                                   dict,
   210  	"list":                                   list,
   211  	"genericTypingRepresentation":            genericTypingRepresentation,
   212  	"possibleBundleLifecycleParameterCombos": possibleBundleLifecycleParameterCombos,
   213  }
   214  
   215  // mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type".
   216  // If n is 0, it returns the empty string.
   217  func mkargs(n int, format, typ string) string {
   218  	if n == 0 {
   219  		return ""
   220  	}
   221  	return fmt.Sprintf("%v %v", mktuplef(n, format), typ)
   222  }
   223  
   224  // mktuple(n, v) returns "v, v, ..., v".
   225  func mktuple(n int, v string) string {
   226  	var ret []string
   227  	for i := 0; i < n; i++ {
   228  		ret = append(ret, v)
   229  	}
   230  	return strings.Join(ret, ", ")
   231  }
   232  
   233  // mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>"
   234  func mktuplef(n int, format string) string {
   235  	var ret []string
   236  	for i := 0; i < n; i++ {
   237  		ret = append(ret, fmt.Sprintf(format, i))
   238  	}
   239  	return strings.Join(ret, ", ")
   240  }
   241  
   242  // upto(n) returns []int{0, 1, .., n-1}.
   243  func upto(i int) []int {
   244  	var ret []int
   245  	for k := 0; k < i; k++ {
   246  		ret = append(ret, k)
   247  	}
   248  	return ret
   249  }
   250  
   251  func add(i int, j int) int {
   252  	return i + j
   253  }
   254  
   255  func mult(i int, j int) int {
   256  	return i * j
   257  }
   258  
   259  func dict(values ...any) map[string]any {
   260  	dict := make(map[string]any, len(values)/2)
   261  	if len(values)%2 != 0 {
   262  		panic("Invalid dictionary call")
   263  	}
   264  	for i := 0; i < len(values); i += 2 {
   265  		dict[values[i].(string)] = values[i+1]
   266  	}
   267  
   268  	return dict
   269  }
   270  
   271  func list(values ...string) []string {
   272  	return values
   273  }
   274  
   275  func genericTypingRepresentation(in int, out int, includeType bool) string {
   276  	seenElements := false
   277  	typing := ""
   278  	if in > 0 {
   279  		typing += fmt.Sprintf("[I%v", 0)
   280  		for i := 1; i < in; i++ {
   281  			typing += fmt.Sprintf(", I%v", i)
   282  		}
   283  		seenElements = true
   284  	}
   285  	if out > 0 {
   286  		i := 0
   287  		if !seenElements {
   288  			typing += fmt.Sprintf("[R%v", 0)
   289  			i++
   290  		}
   291  		for i < out {
   292  			typing += fmt.Sprintf(", R%v", i)
   293  			i++
   294  		}
   295  		seenElements = true
   296  	}
   297  
   298  	if seenElements {
   299  		if includeType {
   300  			typing += " any"
   301  		}
   302  		typing += "]"
   303  	}
   304  
   305  	return typing
   306  }
   307  
   308  func possibleBundleLifecycleParameterCombos(numInInterface any, processElementInInterface any) [][]string {
   309  	numIn := numInInterface.(int)
   310  	processElementIn := processElementInInterface.(int)
   311  	orderedKnownParameterOptions := []string{"context.Context", "typex.PaneInfo", "[]typex.Window", "typex.EventTime", "typex.BundleFinalization"}
   312  	// Because of how Bundle lifecycle functions are invoked, all known parameters must precede unknown options and be in order.
   313  	// Once we hit an unknown options, all remaining unknown options must be included since all iters/emitters must be included
   314  	// Therefore, we can generate a powerset of the known options and fill out any remaining parameters with an ordered set of remaining unknown options
   315  	pSetSize := int(math.Pow(2, float64(len(orderedKnownParameterOptions))))
   316  	combos := make([][]string, 0, pSetSize)
   317  
   318  	for index := 0; index < pSetSize; index++ {
   319  		var subSet []string
   320  
   321  		for j, elem := range orderedKnownParameterOptions {
   322  			// And with the bit representation to get this iteration of the powerset.
   323  			if index&(1<<uint(j)) > 0 {
   324  				subSet = append(subSet, elem)
   325  			}
   326  		}
   327  		// Fill out any remaining parameter slots with consecutive parameters from ProcessElement if there are enough options
   328  		if len(subSet) <= numIn && numIn-len(subSet) <= processElementIn {
   329  			for len(subSet) < numIn {
   330  				nextElement := processElementIn - (numIn - len(subSet))
   331  				subSet = append(subSet, fmt.Sprintf("I%v", nextElement))
   332  			}
   333  			combos = append(combos, subSet)
   334  		}
   335  	}
   336  
   337  	return combos
   338  }