github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/cgo/lapacke/generate_lapacke.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build ignore
     6  
     7  // generate_lapacke creates a lapacke.go file from the provided C header file
     8  // with optionally added documentation from the documentation package.
     9  package main
    10  
    11  import (
    12  	"bytes"
    13  	"fmt"
    14  	"go/format"
    15  	"io/ioutil"
    16  	"log"
    17  	"os"
    18  	"strings"
    19  	"text/template"
    20  
    21  	"github.com/cznic/cc"
    22  
    23  	"github.com/gonum/internal/binding"
    24  )
    25  
    26  const (
    27  	header = "lapacke.h"
    28  	target = "lapacke.go"
    29  
    30  	prefix = "LAPACKE_"
    31  	suffix = "_work"
    32  )
    33  
    34  const (
    35  	elideRepeat = true
    36  	noteOrigin  = false
    37  )
    38  
    39  var skip = map[string]bool{
    40  	// Deprecated.
    41  	"LAPACKE_cggsvp_work": true,
    42  	"LAPACKE_dggsvp_work": true,
    43  	"LAPACKE_sggsvp_work": true,
    44  	"LAPACKE_zggsvp_work": true,
    45  	"LAPACKE_cggsvd_work": true,
    46  	"LAPACKE_dggsvd_work": true,
    47  	"LAPACKE_sggsvd_work": true,
    48  	"LAPACKE_zggsvd_work": true,
    49  	"LAPACKE_cgeqpf_work": true,
    50  	"LAPACKE_dgeqpf_work": true,
    51  	"LAPACKE_sgeqpf_work": true,
    52  	"LAPACKE_zgeqpf_work": true,
    53  }
    54  
    55  // needsInt is a list of routines that need to return the integer info value and
    56  // and cannot convert to a success boolean.
    57  var needsInt = map[string]bool{
    58  	"hseqr": true,
    59  	"geev":  true,
    60  	"geevx": true,
    61  }
    62  
    63  // allUplo is a list of routines that allow 'A' for their uplo argument.
    64  // The list keys are truncated by one character to cover all four numeric types.
    65  var allUplo = map[string]bool{
    66  	"lacpy": true,
    67  	"laset": true,
    68  }
    69  
    70  var cToGoType = map[string]string{
    71  	"char":           "byte",
    72  	"int_must":       "int",
    73  	"int_must32":     "int32",
    74  	"int":            "bool",
    75  	"float":          "float32",
    76  	"double":         "float64",
    77  	"float complex":  "complex64",
    78  	"double complex": "complex128",
    79  }
    80  
    81  var cToGoTypeConv = map[string]string{
    82  	"int_must":       "int",
    83  	"int":            "isZero",
    84  	"float":          "float32",
    85  	"double":         "float64",
    86  	"float complex":  "complex64",
    87  	"double complex": "complex128",
    88  }
    89  
    90  var cgoEnums = map[string]*template.Template{}
    91  
    92  var byteTypes = map[string]string{
    93  	"compq": "lapack.Comp",
    94  	"compz": "lapack.Comp",
    95  
    96  	"d": "blas.Diag",
    97  
    98  	"job":    "lapack.Job",
    99  	"joba":   "lapack.Job",
   100  	"jobr":   "lapack.Job",
   101  	"jobp":   "lapack.Job",
   102  	"jobq":   "lapack.Job",
   103  	"jobt":   "lapack.Job",
   104  	"jobu":   "lapack.Job",
   105  	"jobu1":  "lapack.Job",
   106  	"jobu2":  "lapack.Job",
   107  	"jobv":   "lapack.Job",
   108  	"jobv1t": "lapack.Job",
   109  	"jobv2t": "lapack.Job",
   110  	"jobvl":  "lapack.Job",
   111  	"jobvr":  "lapack.Job",
   112  	"jobvt":  "lapack.Job",
   113  	"jobz":   "lapack.Job",
   114  
   115  	"side": "blas.Side",
   116  
   117  	"trans":  "blas.Transpose",
   118  	"trana":  "blas.Transpose",
   119  	"tranb":  "blas.Transpose",
   120  	"transr": "blas.Transpose",
   121  
   122  	"ul": "blas.Uplo",
   123  
   124  	"balanc": "byte",
   125  	"cmach":  "byte",
   126  	"direct": "byte",
   127  	"dist":   "byte",
   128  	"equed":  "byte",
   129  	"eigsrc": "byte",
   130  	"fact":   "byte",
   131  	"howmny": "byte",
   132  	"id":     "byte",
   133  	"initv":  "byte",
   134  	"norm":   "byte",
   135  	"order":  "byte",
   136  	"pack":   "byte",
   137  	"sense":  "byte",
   138  	"signs":  "byte",
   139  	"storev": "byte",
   140  	"sym":    "byte",
   141  	"typ":    "byte",
   142  	"rng":    "byte",
   143  	"vect":   "byte",
   144  	"way":    "byte",
   145  }
   146  
   147  func typeForByte(n string) string {
   148  	t, ok := byteTypes[n]
   149  	if !ok {
   150  		return fmt.Sprintf("<unknown %q>", n)
   151  	}
   152  	return t
   153  }
   154  
   155  var intTypes = map[string]string{
   156  	"forwrd": "int32",
   157  
   158  	"ijob": "lapack.Job",
   159  
   160  	"wantq": "int32",
   161  	"wantz": "int32",
   162  }
   163  
   164  func typeForInt(n string) string {
   165  	t, ok := intTypes[n]
   166  	if !ok {
   167  		return "int"
   168  	}
   169  	return t
   170  }
   171  
   172  // TODO(kortschak): convForInt* are for #define types,
   173  // so they could go away. Kept here now for diff reduction.
   174  
   175  func convForInt(n string) string {
   176  	switch n {
   177  	case "rowMajor":
   178  		return "C.int"
   179  	case "forwrd", "wantq", "wantz":
   180  		return "C.lapack_logical"
   181  	default:
   182  		return "C.lapack_int"
   183  	}
   184  }
   185  
   186  func convForIntSlice(n string) string {
   187  	switch n {
   188  	case "bwork", "tryrac":
   189  		return "*C.lapack_logical"
   190  	default:
   191  		return "*C.lapack_int"
   192  	}
   193  }
   194  
   195  var goTypes = map[binding.TypeKey]*template.Template{
   196  	{Kind: cc.Char}:                           template.Must(template.New("byte").Funcs(map[string]interface{}{"typefor": typeForByte}).Parse("{{typefor .}}")),
   197  	{Kind: cc.Int}:                            template.Must(template.New("int").Funcs(map[string]interface{}{"typefor": typeForInt}).Parse("{{typefor .}}")),
   198  	{Kind: cc.Char, IsPointer: true}:          template.Must(template.New("[]byte").Parse("[]byte")),
   199  	{Kind: cc.Int, IsPointer: true}:           template.Must(template.New("[]int32").Parse("[]int32")),
   200  	{Kind: cc.FloatComplex, IsPointer: true}:  template.Must(template.New("[]complex64").Parse("[]complex64")),
   201  	{Kind: cc.DoubleComplex, IsPointer: true}: template.Must(template.New("[]complex128").Parse("[]complex128")),
   202  }
   203  
   204  var cgoTypes = map[binding.TypeKey]*template.Template{
   205  	{Kind: cc.Char}:                           template.Must(template.New("char").Parse("(C.char)({{.}})")),
   206  	{Kind: cc.Int}:                            template.Must(template.New("int").Funcs(map[string]interface{}{"conv": convForInt}).Parse(`({{conv .}})({{.}})`)),
   207  	{Kind: cc.Float}:                          template.Must(template.New("float").Parse("(C.float)({{.}})")),
   208  	{Kind: cc.Double}:                         template.Must(template.New("double").Parse("(C.double)({{.}})")),
   209  	{Kind: cc.FloatComplex}:                   template.Must(template.New("lapack_complex_float").Parse("(C.lapack_complex_float)({{.}})")),
   210  	{Kind: cc.DoubleComplex}:                  template.Must(template.New("lapack_complex_double").Parse("(C.lapack_complex_double)({{.}})")),
   211  	{Kind: cc.Char, IsPointer: true}:          template.Must(template.New("char*").Parse("(*C.char)(unsafe.Pointer(_{{.}}))")),
   212  	{Kind: cc.Int, IsPointer: true}:           template.Must(template.New("int*").Funcs(map[string]interface{}{"conv": convForIntSlice}).Parse("({{conv .}})(_{{.}})")),
   213  	{Kind: cc.Float, IsPointer: true}:         template.Must(template.New("float").Parse("(*C.float)(_{{.}})")),
   214  	{Kind: cc.Double, IsPointer: true}:        template.Must(template.New("double").Parse("(*C.double)(_{{.}})")),
   215  	{Kind: cc.FloatComplex, IsPointer: true}:  template.Must(template.New("lapack_complex_float*").Parse("(*C.lapack_complex_float)(_{{.}})")),
   216  	{Kind: cc.DoubleComplex, IsPointer: true}: template.Must(template.New("lapack_complex_double*").Parse("(*C.lapack_complex_double)(_{{.}})")),
   217  }
   218  
   219  var names = map[string]string{
   220  	"matrix_layout": "rowMajor",
   221  	"uplo":          "ul",
   222  	"range":         "rng",
   223  	"diag":          "d",
   224  	"select":        "sel",
   225  	"type":          "typ",
   226  }
   227  
   228  func shorten(n string) string {
   229  	s, ok := names[n]
   230  	if ok {
   231  		return s
   232  	}
   233  	return n
   234  }
   235  
   236  func join(a []string) string {
   237  	return strings.Join(a, " ")
   238  }
   239  
   240  func main() {
   241  	decls, err := binding.Declarations(header)
   242  	if err != nil {
   243  		log.Fatal(err)
   244  	}
   245  
   246  	var buf bytes.Buffer
   247  
   248  	h, err := template.New("handwritten").
   249  		Funcs(map[string]interface{}{"join": join}).
   250  		Parse(handwritten)
   251  	if err != nil {
   252  		log.Fatal(err)
   253  	}
   254  	err = h.Execute(&buf, struct {
   255  		Header string
   256  		Lib    []string
   257  	}{
   258  		Header: header,
   259  		Lib:    os.Args[1:],
   260  	})
   261  	if err != nil {
   262  		log.Fatal(err)
   263  	}
   264  
   265  	for _, d := range decls {
   266  		if !strings.HasPrefix(d.Name, prefix) || !strings.HasSuffix(d.Name, suffix) || skip[d.Name] {
   267  			continue
   268  		}
   269  		lapackeName := strings.TrimSuffix(strings.TrimPrefix(d.Name, prefix), suffix)
   270  		switch {
   271  		case strings.HasSuffix(lapackeName, "fsx"):
   272  			continue
   273  		case strings.HasSuffix(lapackeName, "vxx"):
   274  			continue
   275  		case strings.HasSuffix(lapackeName, "rook"):
   276  			continue
   277  		}
   278  		if hasFuncParameter(d) {
   279  			continue
   280  		}
   281  
   282  		goSignature(&buf, d)
   283  		if noteOrigin {
   284  			fmt.Fprintf(&buf, "\t// %s %s %s ...\n\n", d.Position(), d.Return, d.Name)
   285  		}
   286  		parameterChecks(&buf, d, parameterCheckRules)
   287  		buf.WriteByte('\t')
   288  		cgoCall(&buf, d)
   289  		buf.WriteString("}\n")
   290  	}
   291  
   292  	b, err := format.Source(buf.Bytes())
   293  	if err != nil {
   294  		log.Fatal(err)
   295  	}
   296  	err = ioutil.WriteFile(target, b, 0664)
   297  	if err != nil {
   298  		log.Fatal(err)
   299  	}
   300  }
   301  
   302  // This removes select and selctg parameterised functions.
   303  func hasFuncParameter(d binding.Declaration) bool {
   304  	for _, p := range d.Parameters() {
   305  		if p.Kind() != cc.Ptr {
   306  			continue
   307  		}
   308  		if p.Elem().Kind() == cc.Function {
   309  			return true
   310  		}
   311  	}
   312  	return false
   313  }
   314  
   315  func goSignature(buf *bytes.Buffer, d binding.Declaration) {
   316  	lapackeName := strings.TrimSuffix(strings.TrimPrefix(d.Name, prefix), suffix)
   317  	goName := binding.UpperCaseFirst(lapackeName)
   318  
   319  	parameters := d.Parameters()
   320  
   321  	fmt.Fprintf(buf, "\n// See http://www.netlib.org/cgi-bin/netlibfiles.txt?format=txt&filename=/lapack/lapack_routine/%s.f.\n", lapackeName)
   322  	fmt.Fprintf(buf, "func %s(", goName)
   323  	c := 0
   324  	for i, p := range parameters {
   325  		if p.Name() == "matrix_layout" {
   326  			continue
   327  		}
   328  		if c != 0 {
   329  			buf.WriteString(", ")
   330  		}
   331  		c++
   332  
   333  		n := shorten(binding.LowerCaseFirst(p.Name()))
   334  		var this, next string
   335  
   336  		if p.Kind() == cc.Enum {
   337  			this = binding.GoTypeForEnum(p.Type(), n)
   338  		} else {
   339  			this = binding.GoTypeFor(p.Type(), n, goTypes)
   340  		}
   341  
   342  		if elideRepeat && i < len(parameters)-1 && p.Type().Kind() == parameters[i+1].Type().Kind() {
   343  			p := parameters[i+1]
   344  			n := shorten(binding.LowerCaseFirst(p.Name()))
   345  			if p.Kind() == cc.Enum {
   346  				next = binding.GoTypeForEnum(p.Type(), n)
   347  			} else {
   348  				next = binding.GoTypeFor(p.Type(), n, goTypes)
   349  			}
   350  		}
   351  		if next == this {
   352  			buf.WriteString(n)
   353  		} else {
   354  			fmt.Fprintf(buf, "%s %s", n, this)
   355  		}
   356  	}
   357  	if d.Return.Kind() != cc.Void {
   358  		var must string
   359  		if needsInt[lapackeName[1:]] {
   360  			must = "_must"
   361  		}
   362  		fmt.Fprintf(buf, ") %s {\n", cToGoType[d.Return.String()+must])
   363  	} else {
   364  		buf.WriteString(") {\n")
   365  	}
   366  }
   367  
   368  func parameterChecks(buf *bytes.Buffer, d binding.Declaration, rules []func(*bytes.Buffer, binding.Declaration, binding.Parameter) bool) {
   369  	done := make(map[int]bool)
   370  	for _, p := range d.Parameters() {
   371  		for i, r := range rules {
   372  			if done[i] {
   373  				continue
   374  			}
   375  			done[i] = r(buf, d, p)
   376  		}
   377  	}
   378  }
   379  
   380  func cgoCall(buf *bytes.Buffer, d binding.Declaration) {
   381  	if d.Return.Kind() != cc.Void {
   382  		lapackeName := strings.TrimSuffix(strings.TrimPrefix(d.Name, prefix), suffix)
   383  		var must string
   384  		if needsInt[lapackeName[1:]] {
   385  			must = "_must"
   386  		}
   387  		fmt.Fprintf(buf, "return %s(", cToGoTypeConv[d.Return.String()+must])
   388  	}
   389  	fmt.Fprintf(buf, "C.%s(", d.Name)
   390  	for i, p := range d.Parameters() {
   391  		if i != 0 {
   392  			buf.WriteString(", ")
   393  		}
   394  		if p.Type().Kind() == cc.Enum {
   395  			buf.WriteString(binding.CgoConversionForEnum(shorten(binding.LowerCaseFirst(p.Name())), p.Type()))
   396  		} else {
   397  			buf.WriteString(binding.CgoConversionFor(shorten(binding.LowerCaseFirst(p.Name())), p.Type(), cgoTypes))
   398  		}
   399  	}
   400  	if d.Return.Kind() != cc.Void {
   401  		buf.WriteString(")")
   402  	}
   403  	buf.WriteString(")\n")
   404  }
   405  
   406  var parameterCheckRules = []func(*bytes.Buffer, binding.Declaration, binding.Parameter) bool{
   407  	uplo,
   408  	diag,
   409  	side,
   410  	trans,
   411  	address,
   412  }
   413  
   414  func uplo(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
   415  	if p.Name() != "uplo" {
   416  		return false
   417  	}
   418  	lapackeName := strings.TrimSuffix(strings.TrimPrefix(d.Name, prefix), suffix)
   419  	if allUplo[lapackeName[1:]] {
   420  		fmt.Fprint(buf, `	switch ul {
   421  	case blas.Upper:
   422  		ul = 'U'
   423  	case blas.Lower:
   424  		ul = 'L'
   425  	default:
   426  		ul = 'A'
   427  	}
   428  `)
   429  	} else {
   430  		fmt.Fprint(buf, `	switch ul {
   431  	case blas.Upper:
   432  		ul = 'U'
   433  	case blas.Lower:
   434  		ul = 'L'
   435  	default:
   436  		panic("lapack: illegal triangle")
   437  	}
   438  `)
   439  	}
   440  	return true
   441  }
   442  
   443  func diag(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
   444  	if p.Name() != "diag" {
   445  		return false
   446  	}
   447  	fmt.Fprint(buf, `	switch d {
   448  	case blas.Unit:
   449  		d = 'U'
   450  	case blas.NonUnit:
   451  		d = 'N'
   452  	default:
   453  		panic("lapack: illegal diagonal")
   454  	}
   455  `)
   456  	return true
   457  }
   458  
   459  func side(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
   460  	if p.Name() != "side" {
   461  		return false
   462  	}
   463  	fmt.Fprint(buf, `	switch side {
   464  	case blas.Left:
   465  		side = 'L'
   466  	case blas.Right:
   467  		side = 'R'
   468  	default:
   469  		panic("lapack: bad side")
   470  	}
   471  `)
   472  	return true
   473  }
   474  
   475  func trans(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
   476  	n := shorten(binding.LowerCaseFirst(p.Name()))
   477  	if !strings.HasPrefix(n, "tran") {
   478  		return false
   479  	}
   480  	fmt.Fprintf(buf, `	switch %[1]s {
   481  	case blas.NoTrans:
   482  		%[1]s = 'N'
   483  	case blas.Trans:
   484  		%[1]s = 'T'
   485  	case blas.ConjTrans:
   486  		%[1]s = 'C'
   487  	default:
   488  		panic("lapack: bad trans")
   489  	}
   490  `, n)
   491  	return false
   492  }
   493  
   494  var addrTypes = map[string]string{
   495  	"char":           "byte",
   496  	"int":            "int32",
   497  	"float":          "float32",
   498  	"double":         "float64",
   499  	"float complex":  "complex64",
   500  	"double complex": "complex128",
   501  }
   502  
   503  func address(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
   504  	n := shorten(binding.LowerCaseFirst(p.Name()))
   505  	if p.Type().Kind() == cc.Ptr {
   506  		t := strings.TrimPrefix(p.Type().Element().String(), "const ")
   507  		fmt.Fprintf(buf, `	var _%[1]s *%[2]s
   508  	if len(%[1]s) > 0 {
   509  		_%[1]s = &%[1]s[0]
   510  	}
   511  `, n, addrTypes[t])
   512  	}
   513  	return false
   514  }
   515  
   516  const handwritten = `// Code generated by "go generate github.com/gonum/lapack/cgo/lapacke" from {{.Header}}; DO NOT EDIT.
   517  
   518  // Copyright ©2014 The gonum Authors. All rights reserved.
   519  // Use of this source code is governed by a BSD-style
   520  // license that can be found in the LICENSE file.
   521  
   522  // This repository is no longer maintained.
   523  // Development has moved to https://github.com/gonum/netlib.
   524  //
   525  // Package lapacke provides bindings to the LAPACKE C Interface to LAPACK.
   526  //
   527  // Links are provided to the NETLIB fortran implementation/dependencies for each function.
   528  package lapacke
   529  
   530  /*
   531  #cgo CFLAGS: -g -O2{{if .Lib}}
   532  #cgo LDFLAGS: {{join .Lib}}{{end}}
   533  #include "{{.Header}}"
   534  */
   535  import "C"
   536  
   537  import (
   538  	"unsafe"
   539  
   540  	"github.com/gonum/blas"
   541  	"github.com/gonum/lapack"
   542  )
   543  
   544  // Type order is used to specify the matrix storage format. We still interact with
   545  // an API that allows client calls to specify order, so this is here to document that fact.
   546  type order int
   547  
   548  const (
   549  	rowMajor order = 101 + iota
   550  	colMajor
   551  )
   552  
   553  func isZero(ret C.int) bool { return ret == 0 }
   554  `