go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/projects/nodes/cmd/nodegen/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"bytes"
    12  	"flag"
    13  	"fmt"
    14  	"os"
    15  	"slices"
    16  	"strings"
    17  	"text/template"
    18  	"unicode"
    19  
    20  	"golang.org/x/tools/imports"
    21  
    22  	"go.charczuk.com/projects/nodes/pkg/incrutil"
    23  	"go.charczuk.com/sdk/iter"
    24  )
    25  
    26  func init() {
    27  	flag.Parse()
    28  	inputTypeToOutputTypes = generateOutputTypesForInputType()
    29  }
    30  
    31  func main() {
    32  	var inputTypesFunc func() []string
    33  	var outputTypesFunc any
    34  	var templateBody string
    35  	switch *flagNodeMode {
    36  	case "both": // emit instances for both input and output types
    37  		templateBody = templateBoth
    38  		outputTypesFunc = getOutputTypesForInput
    39  		inputTypesFunc = getInputTypes
    40  	case "input": // emit instances just for input types
    41  		templateBody = templateInput
    42  		outputTypesFunc = func() []string { return nil }
    43  		inputTypesFunc = getInputTypes
    44  	case "output": // emit instances just for output types
    45  		templateBody = templateOutput
    46  		outputTypesFunc = getOutputTypes
    47  		inputTypesFunc = getInputTypes
    48  	case "sort": // a special case for the sort node type
    49  		templateBody = templateSort
    50  		outputTypesFunc = func() []string { return nil }
    51  		inputTypesFunc = getSortInputTypes
    52  	case "none":
    53  		templateBody = templateNone
    54  		outputTypesFunc = func() []string { return nil }
    55  		inputTypesFunc = func() []string { return nil }
    56  	default:
    57  		fmt.Fprintf(os.Stderr, "invalid node mode %v\n", *flagNodeMode)
    58  		os.Exit(1)
    59  	}
    60  
    61  	// if *flagDebug {
    62  	// 	fmt.Printf("using input-types: %s\n", strings.Join(inputTypesFunc(), ","))
    63  	// 	if typed, ok := outputTypesFunc.(func() []string); ok {
    64  	// 		fmt.Printf("using output-types: %s\n", strings.Join(typed(), ","))
    65  	// 	} else if typed, ok := outputTypesFunc.(func(string) []string); ok {
    66  	// 		fullTypes := make(map[string][]string)
    67  	// 		for _, inputType := range inputTypesFunc() {
    68  	// 			fullTypes[inputType] = typed(inputType)
    69  	// 		}
    70  	// 		fmt.Printf("using output-types: %#v\n", fullTypes)
    71  	// 	}
    72  	// 	// fmt.Printf("using template:\n%s\n", templateBody)
    73  	// }
    74  
    75  	t, err := template.New("").Funcs(template.FuncMap{
    76  		"toExported":  toExported,
    77  		"toLower":     toLower,
    78  		"toScalar":    toScalar,
    79  		"toArray":     toArray,
    80  		"apiType":     incrutil.APIValueTypeForGoType,
    81  		"incrFn":      formatIncrFn,
    82  		"incrManyFn":  formatIncrManyFn,
    83  		"incrType":    formatIncrType,
    84  		"isArray":     isArray,
    85  		"outputTypes": outputTypesFunc,
    86  	}).Parse(templateBody)
    87  	if err != nil {
    88  		fmt.Fprintf(os.Stderr, "template parse error: %+v\n", err)
    89  		os.Exit(1)
    90  	}
    91  	buf := new(bytes.Buffer)
    92  	err = t.Execute(buf, map[string]any{
    93  		"inputTypes": inputTypesFunc(),
    94  		"nodeType":   *flagNodeType,
    95  	})
    96  	if err != nil {
    97  		fmt.Fprintf(os.Stderr, "template execution error: %+v\n", err)
    98  		os.Exit(1)
    99  	}
   100  	formatted, err := imports.Process(flag.Args()[0], buf.Bytes(), nil)
   101  	if err != nil {
   102  		// fmt.Fprintln(os.Stdout, buf.String())
   103  		fmt.Fprintf(os.Stderr, "format source error: %+v\n", err)
   104  		os.Exit(1)
   105  	}
   106  	os.WriteFile(flag.Args()[0], formatted, 0644)
   107  }
   108  
   109  var flagDebug = flag.Bool("debug", false, "If debug output should be shown.")
   110  
   111  var flagNodeType = flag.String("node-type", "merge", "The node type (will be a component of the exported function name)")
   112  var flagNodeMode = flag.String("node-mode", "both", "The node mode (e.g. input|output|both)")
   113  
   114  var flagIncrType = flag.String("incr-type", "", "The incremental node type")
   115  var flagIncrFn = flag.String("incr-fn", "", "The incremental function value (can be a template)")
   116  var flagIncrManyFn = flag.String("incr-many-fn", "", "The incremental function name for array types")
   117  
   118  var flagInputTypes = flag.String("i-types", "", "specific input types to include")
   119  var flagInputTypesScalar = flag.Bool("i-scalar", false, "if we should include scalar input types")
   120  var flagInputTypesArray = flag.Bool("i-array", false, "if we should include array input types")
   121  var flagInputTypesTable = flag.Bool("i-table", false, "if we should include the special table input type")
   122  var flagInputTypesSVG = flag.Bool("i-svg", false, "if we should include the special svg input type")
   123  var flagInputTypesAny = flag.Bool("i-any", false, "if we should include the special any input type")
   124  var flagInputTypesAnyArray = flag.Bool("i-any-array", false, "if we should include the special []any input type")
   125  
   126  var flagInputTypesFilterMath = flag.Bool("i-filter-math", false, "if input types should be math only")
   127  
   128  var flagOutputTypes = flag.String("o-types", "", "specific output types to include")
   129  var flagOutputTypesScalar = flag.Bool("o-scalar", false, "if we should include scalar output types")
   130  var flagOutputTypesArray = flag.Bool("o-array", false, "if we should include array output types")
   131  var flagOutputTypesTable = flag.Bool("o-table", false, "if we should include the special table output type")
   132  var flagOutputTypesSVG = flag.Bool("o-svg", false, "if we should include the special svg output type")
   133  var flagOutputTypesAny = flag.Bool("o-any", false, "if we should include the special any output type")
   134  var flagOutputTypesAnyArray = flag.Bool("o-any-array", false, "if we should include the special []any output type")
   135  
   136  var flagOutputTypesFilterMath = flag.Bool("o-filter-math", false, "if output types should be math only")
   137  
   138  var flagOutputTypesMatchInput = flag.Bool("o-match-input", false, "if the output type should match the include type")
   139  var flagOutputTypesMatchInputAsScalar = flag.Bool("o-match-input-scalar", false, "if the output type should match the include type as a scalar type")
   140  var flagOutputTypesMatchInputAsArray = flag.Bool("o-match-input-array", false, "if the output type should match the include type as an array type")
   141  
   142  var inputTypeToOutputTypes map[string][]string
   143  
   144  func generateOutputTypesForInputType() map[string][]string {
   145  	var inputToOutputTypes = make(map[string][]string)
   146  
   147  	if *flagInputTypes != "" {
   148  		for _, t := range iter.Apply(strings.Split(trimQuotes(*flagInputTypes), ","), trimQuotes) {
   149  			inputToOutputTypes[t] = []string{}
   150  		}
   151  	} else {
   152  		if *flagInputTypesScalar {
   153  			for _, t := range incrutil.ValueScalarTypes {
   154  				if !*flagInputTypesFilterMath || (*flagInputTypesFilterMath && incrutil.ValueTypeIsMath(t)) {
   155  					inputToOutputTypes[t] = []string{}
   156  				}
   157  			}
   158  		}
   159  		if *flagInputTypesArray {
   160  			for _, t := range incrutil.ValueArrayTypes {
   161  				if !*flagInputTypesFilterMath || (*flagInputTypesFilterMath && incrutil.ValueTypeIsMath(t)) {
   162  					inputToOutputTypes[t] = []string{}
   163  				}
   164  			}
   165  		}
   166  		if *flagInputTypesTable {
   167  			inputToOutputTypes[incrutil.ValueTypeTable] = []string{}
   168  		}
   169  		if *flagInputTypesSVG {
   170  			inputToOutputTypes[incrutil.ValueTypeSVG] = []string{}
   171  		}
   172  		if *flagInputTypesAny {
   173  			inputToOutputTypes[incrutil.ValueTypeAny] = []string{}
   174  		}
   175  		if *flagInputTypesAnyArray {
   176  			inputToOutputTypes[incrutil.ValueTypeAnyArray] = []string{}
   177  		}
   178  	}
   179  
   180  	for t := range inputToOutputTypes {
   181  		if *flagOutputTypes != "" {
   182  			inputToOutputTypes[t] = iter.Apply(strings.Split(trimQuotes(*flagOutputTypes), ","), trimQuotes)
   183  		} else {
   184  			if *flagOutputTypesMatchInput {
   185  				if !*flagOutputTypesFilterMath || (*flagOutputTypesFilterMath && incrutil.ValueTypeIsMath(t)) {
   186  					inputToOutputTypes[t] = []string{t}
   187  				}
   188  			} else if *flagOutputTypesMatchInputAsArray {
   189  				if !*flagOutputTypesFilterMath || (*flagOutputTypesFilterMath && incrutil.ValueTypeIsMath(t)) {
   190  					if isArray(t) {
   191  						inputToOutputTypes[t] = []string{t}
   192  					} else {
   193  						inputToOutputTypes[t] = []string{"[]" + t}
   194  					}
   195  				}
   196  			} else if *flagOutputTypesMatchInputAsScalar {
   197  				if !*flagOutputTypesFilterMath || (*flagOutputTypesFilterMath && incrutil.ValueTypeIsMath(t)) {
   198  					if isArray(t) {
   199  						inputToOutputTypes[t] = []string{toScalar(t)}
   200  					} else {
   201  						inputToOutputTypes[t] = []string{t}
   202  					}
   203  				}
   204  			} else {
   205  				if *flagOutputTypesScalar {
   206  					if *flagOutputTypesFilterMath {
   207  						inputToOutputTypes[t] = iter.Filter(incrutil.ValueScalarTypes, incrutil.ValueTypeIsMath)
   208  					} else {
   209  						inputToOutputTypes[t] = incrutil.ValueScalarTypes
   210  					}
   211  				}
   212  				if *flagOutputTypesArray {
   213  					if *flagOutputTypesFilterMath {
   214  						inputToOutputTypes[t] = append(inputToOutputTypes[t], iter.Filter(incrutil.ValueArrayTypes, incrutil.ValueTypeIsMath)...)
   215  					} else {
   216  						inputToOutputTypes[t] = append(inputToOutputTypes[t], incrutil.ValueArrayTypes...)
   217  					}
   218  				}
   219  				if *flagOutputTypesTable {
   220  					inputToOutputTypes[t] = append(inputToOutputTypes[t], incrutil.ValueTypeTable)
   221  				}
   222  				if *flagOutputTypesSVG {
   223  					inputToOutputTypes[t] = append(inputToOutputTypes[t], incrutil.ValueTypeSVG)
   224  				}
   225  				if *flagOutputTypesAny {
   226  					inputToOutputTypes[t] = append(inputToOutputTypes[t], incrutil.ValueTypeAny)
   227  				}
   228  				if *flagOutputTypesAnyArray {
   229  					inputToOutputTypes[t] = append(inputToOutputTypes[t], incrutil.ValueTypeAnyArray)
   230  				}
   231  			}
   232  		}
   233  	}
   234  	return inputToOutputTypes
   235  }
   236  
   237  func getOutputTypesForInput(inputType string) []string {
   238  	outputTypes, _ := inputTypeToOutputTypes[inputType]
   239  	slices.Sort(outputTypes)
   240  	return outputTypes
   241  }
   242  
   243  func getInputTypes() []string {
   244  	var inputTypes []string
   245  	for k := range inputTypeToOutputTypes {
   246  		inputTypes = append(inputTypes, k)
   247  	}
   248  	slices.Sort(inputTypes)
   249  	return inputTypes
   250  }
   251  
   252  func getOutputTypes() []string {
   253  	var outputTypes []string
   254  	if *flagOutputTypes != "" {
   255  		outputTypes = iter.Apply(strings.Split(trimQuotes(*flagOutputTypes), ","), trimQuotes)
   256  	} else {
   257  		if *flagOutputTypesScalar {
   258  			outputTypes = append(outputTypes, incrutil.ValueScalarTypes...)
   259  		}
   260  		if *flagOutputTypesArray {
   261  			outputTypes = append(outputTypes, incrutil.ValueArrayTypes...)
   262  		}
   263  		if *flagOutputTypesTable {
   264  			outputTypes = append(outputTypes, incrutil.ValueTypeTable)
   265  		}
   266  		if *flagOutputTypesSVG {
   267  			outputTypes = append(outputTypes, incrutil.ValueTypeSVG)
   268  		}
   269  		if *flagOutputTypesAny {
   270  			outputTypes = append(outputTypes, incrutil.ValueTypeAny)
   271  		}
   272  		if *flagOutputTypesFilterMath {
   273  			outputTypes = iter.Filter(outputTypes, incrutil.ValueTypeIsMath)
   274  		}
   275  	}
   276  	slices.Sort(outputTypes)
   277  	return outputTypes
   278  }
   279  
   280  func renderFunctionSnippet(tmpl, inputType, outputType string) (string, error) {
   281  	t, err := template.New("").Funcs(template.FuncMap{
   282  		"toScalar":   toScalar,
   283  		"toArray":    toArray,
   284  		"toLower":    toLower,
   285  		"toExported": toExported,
   286  	}).Parse(trimQuotes(tmpl))
   287  	if err != nil {
   288  		return "", err
   289  	}
   290  	buf := new(bytes.Buffer)
   291  	err = t.Execute(buf, map[string]any{
   292  		"inputType":  inputType,
   293  		"outputType": outputType,
   294  	})
   295  	if err != nil {
   296  		return "", err
   297  	}
   298  	return buf.String(), nil
   299  }
   300  
   301  func formatIncrFn(inputType, outputType string) (string, error) {
   302  	return renderFunctionSnippet(*flagIncrFn, inputType, outputType)
   303  }
   304  
   305  func formatIncrManyFn(inputType, outputType string) (string, error) {
   306  	var fn string
   307  	if flagIncrManyFn != nil && *flagIncrManyFn != "" {
   308  		fn = *flagIncrManyFn
   309  	} else {
   310  		fn = *flagIncrFn
   311  	}
   312  	return renderFunctionSnippet(fn, inputType, outputType)
   313  }
   314  
   315  func formatIncrType(inputType, outputType string) (string, error) {
   316  	return renderFunctionSnippet(*flagIncrType, inputType, outputType)
   317  }
   318  
   319  func isArray(typeName string) bool {
   320  	return strings.HasPrefix(strings.TrimSpace(typeName), "[]")
   321  }
   322  
   323  func getSortInputTypes() []string {
   324  	return iter.Filter(incrutil.ValueTypes, func(v string) bool { return v != "[]bool" && v != "bool" })
   325  }
   326  
   327  const templateBoth = `
   328  // File generated by nodegen. DO NOT EDIT.
   329  package gen
   330  
   331  import (
   332  	"fmt"
   333  	"time"
   334  
   335  	"github.com/wcharczuk/go-incr"
   336  	"go.charczuk.com/projects/nodes/pkg/incrutil"
   337  	"go.charczuk.com/projects/nodes/pkg/types"
   338  	"go.charczuk.com/projects/nodes/pkg/funcs"
   339  )
   340  
   341  {{ $vm := . }}
   342  func {{ .nodeType | toExported }}ForNode(graph *incr.Graph, n *types.Node) (output incr.INode, err error) {
   343  	switch n.Metadata.InputType {
   344  	{{- range $inputIndex, $inputType := $vm.inputTypes }}
   345  	case "{{ $inputType | apiType }}":
   346  		switch n.Metadata.OutputType {
   347  		{{- range $outputIndex, $outputType := outputTypes $inputType }}
   348  		case "{{ $outputType | apiType }}":
   349  		{{- if isArray $inputType }}
   350  			output = {{ incrType $inputType $outputType }}(graph, {{ incrManyFn $inputType $outputType }})
   351  			return
   352  		{{- else }}
   353  			output = {{ incrType $inputType $outputType }}(graph, {{ incrFn $inputType $outputType }})
   354  			return
   355  		{{- end }}
   356  		{{- end }}
   357  		default:
   358  			err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} output type %v for input type %v", n.Metadata.OutputType, n.Metadata.InputType)
   359  			return
   360  		}
   361  	{{ end }}
   362  	default:
   363  		err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} input type %v", n.Metadata.InputType)
   364  		return
   365  	}
   366  }
   367  `
   368  
   369  const templateSort = `
   370  // File generated by nodegen. DO NOT EDIT.
   371  package gen
   372  
   373  import (
   374  	"fmt"
   375  	"time"
   376  
   377  	"github.com/wcharczuk/go-incr"
   378  	"go.charczuk.com/sdk/iter"
   379  	"go.charczuk.com/projects/nodes/pkg/incrutil"
   380  	"go.charczuk.com/projects/nodes/pkg/types"
   381  	"go.charczuk.com/projects/nodes/pkg/funcs"
   382  )
   383  
   384  {{ $vm := . }}
   385  func {{ .nodeType | toExported }}ForNode(graph *incr.Graph, n *types.Node) (output incr.INode, err error) {
   386  	switch n.Metadata.InputType {
   387  	{{- range $inputIndex, $inputType := $vm.inputTypes }}
   388  	case "{{ $inputType | apiType }}":
   389  		switch n.Metadata.OutputType {
   390  		case "{{ $inputType | apiType }}":
   391  		{{- if $inputType | eq "[]time.Time" }}
   392  			output = incrutil.MapN[{{ $inputType }}](graph, funcs.SortComparerMany[{{ $inputType | toScalar }}](iter.SorterComparerFunc[{{ $inputType | toScalar }}](funcs.TimestampAsc)))
   393  		{{- else if $inputType | eq "time.Time" }}
   394  			output = incrutil.MapN[{{ $inputType }}](graph, funcs.SortComparer[{{ $inputType | toScalar }}](iter.SorterComparerFunc[{{ $inputType | toScalar }}](funcs.TimestampAsc)))
   395  		{{- else }}
   396  		{{- if isArray $inputType }}
   397  			output = incrutil.MapN[{{ $inputType }}](graph, funcs.SortMany[{{ $inputType | toScalar }}])
   398  		{{- else }}
   399  			output = incrutil.MapN[{{ $inputType }}](graph, funcs.Sort[{{ $inputType | toScalar }}])
   400  		{{- end }}
   401  		{{- end }}
   402  			return
   403  		default:
   404  			err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} output type %v for input type %v", n.Metadata.OutputType, n.Metadata.InputType)
   405  			return
   406  		}
   407  	{{ end }}
   408  	default:
   409  		err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} input type %v", n.Metadata.InputType)
   410  		return
   411  	}
   412  }
   413  `
   414  
   415  const templateInput = `
   416  // File generated by nodegen. DO NOT EDIT.
   417  package gen
   418  
   419  import (
   420  	"fmt"
   421  	"time"
   422  
   423  	"github.com/wcharczuk/go-incr"
   424  	"go.charczuk.com/projects/nodes/pkg/incrutil"
   425  	"go.charczuk.com/projects/nodes/pkg/types"
   426  	"go.charczuk.com/projects/nodes/pkg/funcs"
   427  )
   428  
   429  {{ $vm := . }}
   430  func {{ .nodeType | toExported }}ForNode(graph *incr.Graph, n *types.Node) (output incr.INode, err error) {
   431  	switch n.Metadata.InputType {
   432  	{{- range $inputIndex, $inputType := .inputTypes }}
   433  	case "{{ $inputType | apiType }}":
   434  		output = {{ incrType $inputType "" }}(graph, {{ incrFn $inputType "" }})
   435  		return
   436  	{{- end }}
   437  	default:
   438  		err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} input type %v", n.Metadata.InputType)
   439  		return
   440  	}
   441  }`
   442  const templateOutput = `
   443  // File generated by nodegen. DO NOT EDIT.
   444  package gen
   445  
   446  import (
   447  	"fmt"
   448  	"time"
   449  
   450  	"github.com/wcharczuk/go-incr"
   451  	"go.charczuk.com/projects/nodes/pkg/incrutil"
   452  	"go.charczuk.com/projects/nodes/pkg/types"
   453  	"go.charczuk.com/projects/nodes/pkg/funcs"
   454  )
   455  
   456  {{ $vm := . }}
   457  func {{ .nodeType | toExported }}ForNode(graph *incr.Graph, n *types.Node) (output incr.INode, err error) {
   458  	switch n.Metadata.OutputType {
   459  	{{ range $outputIndex, $outputType := outputTypes }}
   460  	case "{{ $outputType | apiType }}":
   461  	{{- if isArray $outputType }}
   462  		output = {{ incrType "" $outputType }}(graph, {{ incrManyFn "" $outputType }})
   463  		return
   464  	{{- else }}
   465  		output = {{ incrType "" $outputType }}(graph, {{ incrFn "" $outputType }})
   466  		return
   467  	{{- end }}
   468  	{{- end }}
   469  	default:
   470  		err = fmt.Errorf("invalid {{ $vm.nodeType | toLower }} output type %v", n.Metadata.OutputType)
   471  		return
   472  	}
   473  }`
   474  
   475  const templateNone = `
   476  // File generated by nodegen. DO NOT EDIT.
   477  package gen
   478  
   479  import (
   480  	"fmt"
   481  	"time"
   482  
   483  	"github.com/wcharczuk/go-incr"
   484  	"go.charczuk.com/projects/nodes/pkg/incrutil"
   485  	"go.charczuk.com/projects/nodes/pkg/types"
   486  	"go.charczuk.com/projects/nodes/pkg/funcs"
   487  )
   488  
   489  {{ $vm := . }}
   490  func {{ .nodeType | toExported }}ForNode(graph *incr.Graph, n *types.Node) (output incr.INode, err error) {
   491  	output = {{ incrType "" "" }}(graph, {{ incrFn "" "" }})
   492  	return
   493  }`
   494  
   495  func toExported(input string) string {
   496  	if input == "" {
   497  		return ""
   498  	}
   499  	inputRunes := []rune(input)
   500  	inputRunes[0] = unicode.ToUpper(inputRunes[0])
   501  	return string(inputRunes)
   502  }
   503  
   504  func toLower(input string) string {
   505  	return strings.ToLower(input)
   506  }
   507  
   508  func toScalar(typeName string) string {
   509  	return strings.TrimPrefix(strings.TrimSpace(typeName), "[]")
   510  }
   511  
   512  func toArray(typeName string) string {
   513  	if strings.HasPrefix(typeName, "[]") {
   514  		return typeName
   515  	}
   516  	return "[]" + typeName
   517  }
   518  
   519  func trimQuotes(str string) string {
   520  	str = strings.TrimSpace(str)
   521  	str = strings.TrimPrefix(str, `'`)
   522  	str = strings.TrimPrefix(str, `"`)
   523  	str = strings.TrimSuffix(str, `'`)
   524  	str = strings.TrimSuffix(str, `"`)
   525  	str = strings.TrimSpace(str)
   526  	return str
   527  }