gorgonia.org/gorgonia@v0.9.17/cmd/genapi/main.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/parser" 7 "go/token" 8 "io" 9 "log" 10 "os" 11 "os/user" 12 "path" 13 "strings" 14 "text/template" 15 ) 16 17 const genmsg = "// Code generated by genapi, which is a API generation tool for Gorgonia. DO NOT EDIT." 18 19 const ( 20 apigenOut = "api_gen.go" 21 unOpOut = "operatorPointwise_unary_gen.go" 22 23 // broadcastOpOut = "operations_broadcast.go" 24 unaryOps = "operatorPointwise_unary_const.go" 25 binaryOps = "operatorPointwise_binary_const.go" 26 ) 27 28 var ( 29 gopath, gorgonialoc, golgiloc string 30 ) 31 32 var funcmap = template.FuncMap{ 33 "lower": strings.ToLower, 34 } 35 36 var ( 37 unaryTemplate *template.Template 38 binaryTemplate *template.Template 39 broadcastTemplate *template.Template 40 maybeBroadcastTemplate *template.Template 41 ) 42 43 const unaryTemplateRaw = ` // {{.FnName}} performs a pointwise {{lower .FnName}}. 44 func {{.FnName}}(a *Node) (*Node, error) { return unaryOpNode(newElemUnaryOp({{.OpType}}, a), a) } 45 ` 46 47 const binaryTemplateRaw = `// {{.FnName}} performs a pointwise {{lower .FnName}} operation. 48 {{if .AsSame -}}// retSame indicates if the data type of the return value should be the same as the input data type. It defaults to Bool otherwise. 49 {{end -}} 50 func {{.FnName}}(a, b *Node{{if .AsSame}}, retSame bool{{end}}) (*Node, error) { {{if not .AsSame -}}return binOpNode(newElemBinOp({{.OpType}}, a, b), a, b) {{else -}} 51 op := newElemBinOp({{.OpType}}, a, b) 52 op.retSame = retSame 53 return binOpNode(op, a, b) 54 {{end -}} 55 } 56 ` 57 58 const broadcastTemplateRaw = `// Broadcast{{.FnName}} performs a {{lower .FnName}}. The operation is precomposed with a broadcast such that the shapes matches before operations commence. 59 func Broadcast{{.FnName}}(a, b *Node{{if .AsSame}}, retSame bool{{end}}, leftPattern, rightPattern []byte)(*Node, error) { 60 a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern)) 61 if err != nil { 62 return nil, err 63 } 64 return {{.FnName}}(a2, b2{{if .AsSame}}, retSame{{end}}) 65 } 66 ` 67 68 // maybeBroadcast is the set of Broadcast functions in Golgi 69 const maybeBroadcastTemplateRaw = `// Broadcast{{.FnName}} performs a {{lower .FnName}}. The operation is precomposed with a broadcast such that the shapes matches before operations commence. 70 func Broadcast{{.FnName}}(a, b *G.Node{{if .AsSame}}, retSame bool{{end}}, leftPattern, rightPattern []byte)(*G.Node, error) { 71 if a.Shape().Eq(b.Shape()){ 72 return G.{{.FnName}}(a, b{{if .AsSame}}, retSame{{end}}) 73 } 74 a2, b2, err := G.Broadcast(a, b, G.NewBroadcastPattern(leftPattern, rightPattern)) 75 if err != nil { 76 return nil, err 77 } 78 return G.{{.FnName}}(a2, b2{{if .AsSame}}, retSame{{end}}) 79 } 80 ` 81 82 func init() { 83 gopath = os.Getenv("GOPATH") 84 // now that go can have a default gopath, this checks that path 85 if gopath == "" { 86 usr, err := user.Current() 87 if err != nil { 88 log.Fatal(err) 89 } 90 gopath = path.Join(usr.HomeDir, "go") 91 stat, err := os.Stat(gopath) 92 if err != nil { 93 log.Fatal(err) 94 } 95 if !stat.IsDir() { 96 log.Fatal("You need to define a $GOPATH") 97 } 98 } 99 gorgonialoc = path.Join(gopath, "src/gorgonia.org/gorgonia") 100 golgiloc = path.Join(gopath, "src/gorgonia.org/golgi") 101 unaryTemplate = template.Must(template.New("Unary").Funcs(funcmap).Parse(unaryTemplateRaw)) 102 binaryTemplate = template.Must(template.New("Binary").Funcs(funcmap).Parse(binaryTemplateRaw)) 103 broadcastTemplate = template.Must(template.New("Broadcast").Funcs(funcmap).Parse(broadcastTemplateRaw)) 104 maybeBroadcastTemplate = template.Must(template.New("MaybeBroadcast").Funcs(funcmap).Parse(maybeBroadcastTemplateRaw)) 105 } 106 107 func generateUnary(outFile io.Writer) { 108 // parse operator_unary_const.go 109 filename := path.Join(gorgonialoc, unaryOps) 110 fset := token.NewFileSet() 111 file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors) 112 if err != nil { 113 log.Fatal(err) 114 } 115 116 unaryNames := constTypes(file.Decls, "ʘUnaryOperatorType", "maxʘUnaryOperator") 117 for _, v := range unaryNames { 118 apiName := strings.Title(strings.TrimSuffix(v, "OpType")) 119 // legacy issue 120 if apiName == "Ln" { 121 apiName = "Log" 122 } 123 data := struct{ FnName, OpType string }{apiName, v} 124 unaryTemplate.Execute(outFile, data) 125 } 126 127 } 128 129 func generateBinary(outFile io.Writer) { 130 // parse operator_binary_const.go 131 filename := path.Join(gorgonialoc, binaryOps) 132 fset := token.NewFileSet() 133 file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors) 134 if err != nil { 135 log.Fatal(err) 136 } 137 138 binaryNames := constTypes(file.Decls, "ʘBinaryOperatorType", "maxʘBinaryOpType") 139 log.Printf("%v", binaryNames) 140 for _, v := range binaryNames { 141 apiName := strings.Title(strings.TrimSuffix(v, "OpType")) 142 // legacy issue 143 switch apiName { 144 case "Mul": 145 apiName = "HadamardProd" 146 case "Div": 147 apiName = "HadamardDiv" 148 } 149 data := struct { 150 FnName, OpType string 151 AsSame bool 152 }{apiName, v, false} 153 switch apiName { 154 case "Lt", "Gt", "Lte", "Gte", "Eq", "Ne": 155 data.AsSame = true 156 } 157 binaryTemplate.Execute(outFile, data) 158 } 159 } 160 161 func generateBroadcastBinOps(tmpl *template.Template, outFile io.Writer) { 162 // parse operator_binary_const.go 163 filename := path.Join(gorgonialoc, binaryOps) 164 fset := token.NewFileSet() 165 file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors) 166 if err != nil { 167 log.Fatal(err) 168 } 169 170 binaryNames := constTypes(file.Decls, "ʘBinaryOperatorType", "maxʘBinaryOpType") 171 log.Printf("%v", binaryNames) 172 for _, v := range binaryNames { 173 apiName := strings.Title(strings.TrimSuffix(v, "OpType")) 174 // legacy issue 175 switch apiName { 176 case "Mul": 177 apiName = "HadamardProd" 178 case "Div": 179 apiName = "HadamardDiv" 180 } 181 data := struct { 182 FnName, OpType string 183 AsSame bool 184 }{apiName, v, false} 185 switch apiName { 186 case "Lt", "Gt", "Lte", "Gte", "Eq", "Ne": 187 data.AsSame = true 188 } 189 tmpl.Execute(outFile, data) 190 } 191 } 192 193 func constTypes(decls []ast.Decl, accept, max string) (names []string) { 194 for i, decl := range decls { 195 log.Printf("DECL %d: %T", i, decl) 196 switch d := decl.(type) { 197 case *ast.GenDecl: 198 if d.Tok.IsKeyword() && d.Tok.String() == "const" { 199 log.Printf("\t%v", d.Tok.String()) 200 201 // get the type 202 if len(d.Specs) == 0 { 203 continue 204 } 205 206 var typename string 207 typ := d.Specs[0].(*ast.ValueSpec).Type 208 if typ == nil { 209 continue 210 } 211 if id, ok := typ.(*ast.Ident); ok { 212 typename = id.Name 213 } 214 if typename != accept { 215 continue 216 } 217 218 for _, spec := range d.Specs { 219 name := spec.(*ast.ValueSpec).Names[0].Name 220 if name == max { 221 continue 222 } 223 names = append(names, name) 224 } 225 } 226 default: 227 } 228 } 229 return 230 } 231 232 func generateAPI() { 233 outFileName := path.Join(gorgonialoc, apigenOut) 234 outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 235 if err != nil { 236 log.Fatal(err) 237 } 238 defer outFile.Close() 239 fmt.Fprintf(outFile, "package gorgonia\n\n%v\n\n", genmsg) 240 generateUnary(outFile) 241 generateBinary(outFile) 242 generateBroadcastBinOps(broadcastTemplate, outFile) 243 } 244 245 func generateInterfaces() { 246 outFileName := path.Join(gorgonialoc, unOpOut) 247 outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 248 if err != nil { 249 log.Fatal(err) 250 } 251 defer outFile.Close() 252 fmt.Fprintf(outFile, "package gorgonia\n\n%v\n\n", genmsg) 253 generateUnaryInterface(outFile) 254 } 255 256 func generateGolgiAPI() { 257 outFileName := path.Join(golgiloc, apigenOut) 258 outFile, err := os.OpenFile(outFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 259 if err != nil { 260 log.Fatal(err) 261 } 262 defer outFile.Close() 263 fmt.Fprintf(outFile, "package golgi\n\n%v\n\n", genmsg) 264 generateBroadcastBinOps(maybeBroadcastTemplate, outFile) 265 } 266 267 func main() { 268 // generateAPI() 269 // generateInterfaces() 270 // functionSignatures() 271 generateGolgiAPI() 272 }