vitess.io/vitess@v0.16.2/go/mysql/collations/tools/makecolldata/codegen/codegen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package codegen
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"log"
    23  	"os"
    24  	"os/exec"
    25  	"path"
    26  	"reflect"
    27  	"sort"
    28  )
    29  
    30  type Generator struct {
    31  	bytes.Buffer
    32  	local    Package
    33  	imported map[Package]bool
    34  }
    35  
    36  func NewGenerator(pkg Package) *Generator {
    37  	return &Generator{
    38  		local:    pkg,
    39  		imported: make(map[Package]bool),
    40  	}
    41  }
    42  
    43  func Merge(gens ...*Generator) *Generator {
    44  	result := NewGenerator(gens[0].local)
    45  
    46  	for i, gen := range gens {
    47  		if gen.local != result.local {
    48  			result.Fail("cannot merge generators with different package names")
    49  		}
    50  
    51  		for pkg, imported := range gen.imported {
    52  			if !result.imported[pkg] {
    53  				result.imported[pkg] = imported
    54  			}
    55  		}
    56  
    57  		if i > 0 {
    58  			result.WriteString("\n\n")
    59  		}
    60  		gen.WriteTo(result)
    61  	}
    62  
    63  	return result
    64  }
    65  
    66  func (g *Generator) WriteToFile(out string) {
    67  	var file, fmtfile bytes.Buffer
    68  	file.Grow(g.Buffer.Len() + 1024)
    69  
    70  	fmt.Fprintf(&file, "// Code generated by %s DO NOT EDIT\n\n", path.Base(os.Args[0]))
    71  	fmt.Fprintf(&file, "package %s\n\n", g.local.Name())
    72  	fmt.Fprintf(&file, "import (\n")
    73  
    74  	var sortedPackages []Package
    75  	for pkg := range g.imported {
    76  		sortedPackages = append(sortedPackages, pkg)
    77  	}
    78  	sort.Slice(sortedPackages, func(i, j int) bool {
    79  		return sortedPackages[i] < sortedPackages[j]
    80  	})
    81  	for _, pkg := range sortedPackages {
    82  		var name = "_"
    83  		if g.imported[pkg] {
    84  			name = pkg.Name()
    85  		}
    86  		fmt.Fprintf(&file, "%s %q\n", name, pkg)
    87  	}
    88  
    89  	fmt.Fprintf(&file, ")\n\n")
    90  	g.Buffer.WriteTo(&file)
    91  
    92  	var stderr bytes.Buffer
    93  	gofmt := exec.Command("gofmt", "-s")
    94  	gofmt.Stdin = &file
    95  	gofmt.Stdout = &fmtfile
    96  	gofmt.Stderr = &stderr
    97  	if err := gofmt.Run(); err != nil {
    98  		g.Fail(fmt.Sprintf("failed to format generated code: %v\n%s", err, stderr.Bytes()))
    99  	}
   100  
   101  	if err := os.WriteFile(out, fmtfile.Bytes(), 0644); err != nil {
   102  		g.Fail(fmt.Sprintf("failed to generate %q: %v", out, err))
   103  	}
   104  
   105  	log.Printf("written %q (%.02fkb)", out, float64(fmtfile.Len())/1024.0)
   106  }
   107  
   108  func (g *Generator) Fail(err string) {
   109  	log.Printf("codegen: error: %v", err)
   110  	os.Exit(1)
   111  }
   112  
   113  func (g *Generator) printArray(iface any) {
   114  	switch ary := iface.(type) {
   115  	case Array8:
   116  		g.WriteString("[...]uint8{")
   117  		for i, v := range ary {
   118  			if i%16 == 0 && len(ary)%16 == 0 {
   119  				g.WriteByte('\n')
   120  			}
   121  			fmt.Fprintf(g, "0x%02x, ", v)
   122  		}
   123  	case Array16:
   124  		g.WriteString("[...]uint16{")
   125  		for i, v := range ary {
   126  			if i%8 == 0 && len(ary)%8 == 0 {
   127  				g.WriteByte('\n')
   128  			}
   129  			fmt.Fprintf(g, "0x%04x, ", v)
   130  		}
   131  	case Array32:
   132  		g.WriteString("[...]uint32{")
   133  		for i, v := range ary {
   134  			if i%8 == 0 && len(ary)%8 == 0 {
   135  				g.WriteByte('\n')
   136  			}
   137  			fmt.Fprintf(g, "0x%08x, ", v)
   138  		}
   139  	}
   140  	g.WriteString("\n}")
   141  }
   142  
   143  func (g *Generator) gatherPackages(tt reflect.Type) {
   144  	if pkg := tt.PkgPath(); pkg != "" {
   145  		g.imported[Package(pkg)] = true
   146  	}
   147  	switch tt.Kind() {
   148  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
   149  		g.gatherPackages(tt.Elem())
   150  	case reflect.Struct:
   151  		for i := 0; i < tt.NumField(); i++ {
   152  			g.gatherPackages(tt.Field(i).Type)
   153  		}
   154  	}
   155  }
   156  
   157  func (g *Generator) UsePackage(pkg Package) {
   158  	g.imported[pkg] = false
   159  }
   160  
   161  func (g *Generator) printAtom(v any) {
   162  	switch v := v.(type) {
   163  	case string:
   164  		g.WriteString(v)
   165  	case *string:
   166  		g.WriteString(*v)
   167  	case bool:
   168  		fmt.Fprint(g, v)
   169  	case *bool:
   170  		fmt.Fprint(g, *v)
   171  	case int:
   172  		fmt.Fprint(g, v)
   173  	case *int32:
   174  		fmt.Fprint(g, *v)
   175  	case *int64:
   176  		fmt.Fprint(g, *v)
   177  	case float64:
   178  		fmt.Fprint(g, v)
   179  	case *float64:
   180  		fmt.Fprint(g, *v)
   181  	case Package:
   182  		g.imported[v] = true
   183  		g.WriteString(v.Name())
   184  	case Quote:
   185  		fmt.Fprintf(g, "%q", v)
   186  	case Array8, Array16, Array32:
   187  		g.printArray(v)
   188  	default:
   189  		g.gatherPackages(reflect.TypeOf(v))
   190  		fmt.Fprintf(g, "%#v", v)
   191  	}
   192  }
   193  
   194  type Array8 []byte
   195  type Array16 []uint16
   196  type Array32 []uint32
   197  type Quote string
   198  type Package string
   199  
   200  func (pkg Package) Name() string {
   201  	return path.Base(string(pkg))
   202  }
   203  
   204  func (g *Generator) P(str ...any) {
   205  	for _, v := range str {
   206  		g.printAtom(v)
   207  	}
   208  	g.WriteByte('\n')
   209  }