gorgonia.org/gorgonia@v0.9.17/cmd/genapi/main.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"os/user"
    12  	"path"
    13  	"strings"
    14  	"text/template"
    15  )
    16  
    17  const genmsg = "// Code generated by genapi, which is a API generation tool for Gorgonia. DO NOT EDIT."
    18  
    19  const (
    20  	apigenOut = "api_gen.go"
    21  	unOpOut   = "operatorPointwise_unary_gen.go"
    22  
    23  	// broadcastOpOut = "operations_broadcast.go"
    24  	unaryOps  = "operatorPointwise_unary_const.go"
    25  	binaryOps = "operatorPointwise_binary_const.go"
    26  )
    27  
    28  var (
    29  	gopath, gorgonialoc, golgiloc string
    30  )
    31  
    32  var funcmap = template.FuncMap{
    33  	"lower": strings.ToLower,
    34  }
    35  
    36  var (
    37  	unaryTemplate          *template.Template
    38  	binaryTemplate         *template.Template
    39  	broadcastTemplate      *template.Template
    40  	maybeBroadcastTemplate *template.Template
    41  )
    42  
    43  const unaryTemplateRaw = ` // {{.FnName}} performs a pointwise {{lower .FnName}}.
    44  func {{.FnName}}(a *Node) (*Node, error) { return unaryOpNode(newElemUnaryOp({{.OpType}}, a), a) }
    45  `
    46  
    47  const binaryTemplateRaw = `// {{.FnName}} performs a pointwise {{lower .FnName}} operation.
    48  {{if .AsSame -}}// retSame indicates if the data type of the return value should be the same as the input data type. It defaults to Bool otherwise.
    49  {{end -}}
    50  func {{.FnName}}(a, b *Node{{if .AsSame}}, retSame bool{{end}}) (*Node, error) { {{if not .AsSame -}}return binOpNode(newElemBinOp({{.OpType}}, a, b), a, b) {{else -}}
    51  	op := newElemBinOp({{.OpType}}, a, b)
    52  	op.retSame = retSame
    53  	return binOpNode(op, a, b)
    54  {{end -}}
    55  }
    56  `
    57  
    58  const broadcastTemplateRaw = `// Broadcast{{.FnName}} performs a {{lower .FnName}}. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
    59  func Broadcast{{.FnName}}(a, b *Node{{if .AsSame}}, retSame bool{{end}}, leftPattern, rightPattern []byte)(*Node, error) {
    60  	a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	return {{.FnName}}(a2, b2{{if .AsSame}}, retSame{{end}})
    65  }
    66  `
    67  
    68  // maybeBroadcast is the set of Broadcast functions in Golgi
    69  const maybeBroadcastTemplateRaw = `// Broadcast{{.FnName}} performs a {{lower .FnName}}. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
    70  func Broadcast{{.FnName}}(a, b *G.Node{{if .AsSame}}, retSame bool{{end}}, leftPattern, rightPattern []byte)(*G.Node, error) {
    71  	if a.Shape().Eq(b.Shape()){
    72  		return G.{{.FnName}}(a, b{{if .AsSame}}, retSame{{end}})
    73  	}
    74  	a2, b2, err := G.Broadcast(a, b, G.NewBroadcastPattern(leftPattern, rightPattern))
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	return G.{{.FnName}}(a2, b2{{if .AsSame}}, retSame{{end}})
    79  }
    80  `
    81  
    82  func init() {
    83  	gopath = os.Getenv("GOPATH")
    84  	// now that go can have a default gopath, this checks that path
    85  	if gopath == "" {
    86  		usr, err := user.Current()
    87  		if err != nil {
    88  			log.Fatal(err)
    89  		}
    90  		gopath = path.Join(usr.HomeDir, "go")
    91  		stat, err := os.Stat(gopath)
    92  		if err != nil {
    93  			log.Fatal(err)
    94  		}
    95  		if !stat.IsDir() {
    96  			log.Fatal("You need to define a $GOPATH")
    97  		}
    98  	}
    99  	gorgonialoc = path.Join(gopath, "src/gorgonia.org/gorgonia")
   100  	golgiloc = path.Join(gopath, "src/gorgonia.org/golgi")
   101  	unaryTemplate = template.Must(template.New("Unary").Funcs(funcmap).Parse(unaryTemplateRaw))
   102  	binaryTemplate = template.Must(template.New("Binary").Funcs(funcmap).Parse(binaryTemplateRaw))
   103  	broadcastTemplate = template.Must(template.New("Broadcast").Funcs(funcmap).Parse(broadcastTemplateRaw))
   104  	maybeBroadcastTemplate = template.Must(template.New("MaybeBroadcast").Funcs(funcmap).Parse(maybeBroadcastTemplateRaw))
   105  }
   106  
   107  func generateUnary(outFile io.Writer) {
   108  	// parse operator_unary_const.go
   109  	filename := path.Join(gorgonialoc, unaryOps)
   110  	fset := token.NewFileSet()
   111  	file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
   112  	if err != nil {
   113  		log.Fatal(err)
   114  	}
   115  
   116  	unaryNames := constTypes(file.Decls, "ʘUnaryOperatorType", "maxʘUnaryOperator")
   117  	for _, v := range unaryNames {
   118  		apiName := strings.Title(strings.TrimSuffix(v, "OpType"))
   119  		// legacy issue
   120  		if apiName == "Ln" {
   121  			apiName = "Log"
   122  		}
   123  		data := struct{ FnName, OpType string }{apiName, v}
   124  		unaryTemplate.Execute(outFile, data)
   125  	}
   126  
   127  }
   128  
   129  func generateBinary(outFile io.Writer) {
   130  	// parse operator_binary_const.go
   131  	filename := path.Join(gorgonialoc, binaryOps)
   132  	fset := token.NewFileSet()
   133  	file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
   134  	if err != nil {
   135  		log.Fatal(err)
   136  	}
   137  
   138  	binaryNames := constTypes(file.Decls, "ʘBinaryOperatorType", "maxʘBinaryOpType")
   139  	log.Printf("%v", binaryNames)
   140  	for _, v := range binaryNames {
   141  		apiName := strings.Title(strings.TrimSuffix(v, "OpType"))
   142  		// legacy issue
   143  		switch apiName {
   144  		case "Mul":
   145  			apiName = "HadamardProd"
   146  		case "Div":
   147  			apiName = "HadamardDiv"
   148  		}
   149  		data := struct {
   150  			FnName, OpType string
   151  			AsSame         bool
   152  		}{apiName, v, false}
   153  		switch apiName {
   154  		case "Lt", "Gt", "Lte", "Gte", "Eq", "Ne":
   155  			data.AsSame = true
   156  		}
   157  		binaryTemplate.Execute(outFile, data)
   158  	}
   159  }
   160  
   161  func generateBroadcastBinOps(tmpl *template.Template, outFile io.Writer) {
   162  	// parse operator_binary_const.go
   163  	filename := path.Join(gorgonialoc, binaryOps)
   164  	fset := token.NewFileSet()
   165  	file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
   166  	if err != nil {
   167  		log.Fatal(err)
   168  	}
   169  
   170  	binaryNames := constTypes(file.Decls, "ʘBinaryOperatorType", "maxʘBinaryOpType")
   171  	log.Printf("%v", binaryNames)
   172  	for _, v := range binaryNames {
   173  		apiName := strings.Title(strings.TrimSuffix(v, "OpType"))
   174  		// legacy issue
   175  		switch apiName {
   176  		case "Mul":
   177  			apiName = "HadamardProd"
   178  		case "Div":
   179  			apiName = "HadamardDiv"
   180  		}
   181  		data := struct {
   182  			FnName, OpType string
   183  			AsSame         bool
   184  		}{apiName, v, false}
   185  		switch apiName {
   186  		case "Lt", "Gt", "Lte", "Gte", "Eq", "Ne":
   187  			data.AsSame = true
   188  		}
   189  		tmpl.Execute(outFile, data)
   190  	}
   191  }
   192  
   193  func constTypes(decls []ast.Decl, accept, max string) (names []string) {
   194  	for i, decl := range decls {
   195  		log.Printf("DECL %d: %T", i, decl)
   196  		switch d := decl.(type) {
   197  		case *ast.GenDecl:
   198  			if d.Tok.IsKeyword() && d.Tok.String() == "const" {
   199  				log.Printf("\t%v", d.Tok.String())
   200  
   201  				// get the type
   202  				if len(d.Specs) == 0 {
   203  					continue
   204  				}
   205  
   206  				var typename string
   207  				typ := d.Specs[0].(*ast.ValueSpec).Type
   208  				if typ == nil {
   209  					continue
   210  				}
   211  				if id, ok := typ.(*ast.Ident); ok {
   212  					typename = id.Name
   213  				}
   214  				if typename != accept {
   215  					continue
   216  				}
   217  
   218  				for _, spec := range d.Specs {
   219  					name := spec.(*ast.ValueSpec).Names[0].Name
   220  					if name == max {
   221  						continue
   222  					}
   223  					names = append(names, name)
   224  				}
   225  			}
   226  		default:
   227  		}
   228  	}
   229  	return
   230  }
   231  
   232  func generateAPI() {
   233  	outFileName := path.Join(gorgonialoc, apigenOut)
   234  	outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   235  	if err != nil {
   236  		log.Fatal(err)
   237  	}
   238  	defer outFile.Close()
   239  	fmt.Fprintf(outFile, "package gorgonia\n\n%v\n\n", genmsg)
   240  	generateUnary(outFile)
   241  	generateBinary(outFile)
   242  	generateBroadcastBinOps(broadcastTemplate, outFile)
   243  }
   244  
   245  func generateInterfaces() {
   246  	outFileName := path.Join(gorgonialoc, unOpOut)
   247  	outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   248  	if err != nil {
   249  		log.Fatal(err)
   250  	}
   251  	defer outFile.Close()
   252  	fmt.Fprintf(outFile, "package gorgonia\n\n%v\n\n", genmsg)
   253  	generateUnaryInterface(outFile)
   254  }
   255  
   256  func generateGolgiAPI() {
   257  	outFileName := path.Join(golgiloc, apigenOut)
   258  	outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   259  	if err != nil {
   260  		log.Fatal(err)
   261  	}
   262  	defer outFile.Close()
   263  	fmt.Fprintf(outFile, "package golgi\n\n%v\n\n", genmsg)
   264  	generateBroadcastBinOps(maybeBroadcastTemplate, outFile)
   265  }
   266  
   267  func main() {
   268  	// generateAPI()
   269  	// generateInterfaces()
   270  	// functionSignatures()
   271  	generateGolgiAPI()
   272  }