github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/roachpb/gen_batch.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // This file generates batch_generated.go. It can be run via:
    12  //    go run -tags gen-batch gen_batch.go
    13  
    14  // +build gen-batch
    15  
    16  package main
    17  
    18  import (
    19  	"fmt"
    20  	"io"
    21  	"os"
    22  	"reflect"
    23  	"strings"
    24  
    25  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    26  )
    27  
    28  type variantInfo struct {
    29  	// variantType is the name of the variant type that implements
    30  	// the union interface (isRequestUnion_Value,isResponseUnion_Value).
    31  	variantType string
    32  	// variantName is the unique suffix of variantType. It is also
    33  	// the name of the single field in this type.
    34  	variantName string
    35  	// msgType is the name of the variant's corresponding Request/Response
    36  	// type.
    37  	msgType string
    38  }
    39  
    40  var errVariants []variantInfo
    41  var reqVariants []variantInfo
    42  var resVariants []variantInfo
    43  var reqResVariantMapping map[variantInfo]variantInfo
    44  
    45  func initVariant(varInstance interface{}) variantInfo {
    46  	t := reflect.TypeOf(varInstance)
    47  	f := t.Elem().Field(0) // variants always have 1 field
    48  	return variantInfo{
    49  		variantType: t.Elem().Name(),
    50  		variantName: f.Name,
    51  		msgType:     f.Type.Elem().Name(),
    52  	}
    53  }
    54  
    55  func initVariants() {
    56  	_, _, _, errVars := (&roachpb.ErrorDetail{}).XXX_OneofFuncs()
    57  	for _, v := range errVars {
    58  		errInfo := initVariant(v)
    59  		errVariants = append(errVariants, errInfo)
    60  	}
    61  
    62  	_, _, _, resVars := (&roachpb.ResponseUnion{}).XXX_OneofFuncs()
    63  	resVarInfos := make(map[string]variantInfo, len(resVars))
    64  	for _, v := range resVars {
    65  		resInfo := initVariant(v)
    66  		resVariants = append(resVariants, resInfo)
    67  		resVarInfos[resInfo.variantName] = resInfo
    68  	}
    69  
    70  	_, _, _, reqVars := (&roachpb.RequestUnion{}).XXX_OneofFuncs()
    71  	reqResVariantMapping = make(map[variantInfo]variantInfo, len(reqVars))
    72  	for _, v := range reqVars {
    73  		reqInfo := initVariant(v)
    74  		reqVariants = append(reqVariants, reqInfo)
    75  
    76  		// The ResponseUnion variants match those in RequestUnion, with the
    77  		// following exceptions:
    78  		resName := reqInfo.variantName
    79  		switch resName {
    80  		case "TransferLease":
    81  			resName = "RequestLease"
    82  		}
    83  		resInfo, ok := resVarInfos[resName]
    84  		if !ok {
    85  			panic(fmt.Sprintf("unknown response variant %q", resName))
    86  		}
    87  		reqResVariantMapping[reqInfo] = resInfo
    88  	}
    89  }
    90  
    91  func genGetInner(w io.Writer, unionName, variantName string, variants []variantInfo) {
    92  	fmt.Fprintf(w, `
    93  // GetInner returns the %[2]s contained in the union.
    94  func (ru %[1]s) GetInner() %[2]s {
    95  	switch t := ru.GetValue().(type) {
    96  `, unionName, variantName)
    97  
    98  	for _, v := range variants {
    99  		fmt.Fprintf(w, `	case *%s:
   100  		return t.%s
   101  `, v.variantType, v.variantName)
   102  	}
   103  
   104  	fmt.Fprint(w, `	default:
   105  		return nil
   106  	}
   107  }
   108  `)
   109  }
   110  
   111  func genSetInner(w io.Writer, unionName, variantName string, variants []variantInfo) {
   112  	fmt.Fprintf(w, `
   113  // SetInner sets the %[2]s in the union.
   114  func (ru *%[1]s) SetInner(r %[2]s) bool {
   115  	var union is%[1]s_Value
   116  	switch t := r.(type) {
   117  `, unionName, variantName)
   118  
   119  	for _, v := range variants {
   120  		fmt.Fprintf(w, `	case *%s:
   121  		union = &%s{t}
   122  `, v.msgType, v.variantType)
   123  	}
   124  
   125  	fmt.Fprint(w, `	default:
   126  		return false
   127  	}
   128  	ru.Value = union
   129  	return true
   130  }
   131  `)
   132  }
   133  
   134  func main() {
   135  	initVariants()
   136  
   137  	f, err := os.Create("batch_generated.go")
   138  	if err != nil {
   139  		fmt.Fprintln(os.Stderr, "Error opening file: ", err)
   140  		os.Exit(1)
   141  	}
   142  
   143  	// First comment for github/Go; second for reviewable.
   144  	// https://github.com/golang/go/issues/13560#issuecomment-277804473
   145  	// https://github.com/Reviewable/Reviewable/wiki/FAQ#how-do-i-tell-reviewable-that-a-file-is-generated-and-should-not-be-reviewed
   146  	fmt.Fprint(f, `// Code generated by gen_batch.go; DO NOT EDIT.
   147  // GENERATED FILE DO NOT EDIT
   148  
   149  package roachpb
   150  
   151  import (
   152  	"fmt"
   153  	"strconv"
   154  	"strings"
   155  )
   156  `)
   157  
   158  	// Generate GetInner methods.
   159  	genGetInner(f, "ErrorDetail", "error", errVariants)
   160  	genGetInner(f, "RequestUnion", "Request", reqVariants)
   161  	genGetInner(f, "ResponseUnion", "Response", resVariants)
   162  
   163  	// Generate SetInner methods.
   164  	genSetInner(f, "ErrorDetail", "error", errVariants)
   165  	genSetInner(f, "RequestUnion", "Request", reqVariants)
   166  	genSetInner(f, "ResponseUnion", "Response", resVariants)
   167  
   168  	fmt.Fprintf(f, `
   169  type reqCounts [%d]int32
   170  `, len(reqVariants))
   171  
   172  	// Generate getReqCounts function.
   173  	fmt.Fprint(f, `
   174  // getReqCounts returns the number of times each
   175  // request type appears in the batch.
   176  func (ba *BatchRequest) getReqCounts() reqCounts {
   177  	var counts reqCounts
   178  	for _, ru := range ba.Requests {
   179  		switch ru.GetValue().(type) {
   180  `)
   181  
   182  	for i, v := range reqVariants {
   183  		fmt.Fprintf(f, `		case *%s:
   184  			counts[%d]++
   185  `, v.variantType, i)
   186  	}
   187  
   188  	fmt.Fprint(f, `		default:
   189  			panic(fmt.Sprintf("unsupported request: %+v", ru))
   190  		}
   191  	}
   192  	return counts
   193  }
   194  `)
   195  
   196  	// A few shorthands to help make the names more terse.
   197  	shorthands := map[string]string{
   198  		"Delete":      "Del",
   199  		"Range":       "Rng",
   200  		"Transaction": "Txn",
   201  		"Reverse":     "Rev",
   202  		"Admin":       "Adm",
   203  		"Increment":   "Inc",
   204  		"Conditional": "C",
   205  		"Check":       "Chk",
   206  		"Truncate":    "Trunc",
   207  	}
   208  
   209  	// Generate Summary function.
   210  	fmt.Fprintf(f, `
   211  var requestNames = []string{`)
   212  	for _, v := range reqVariants {
   213  		name := v.variantName
   214  		for str, short := range shorthands {
   215  			name = strings.Replace(name, str, short, -1)
   216  		}
   217  		fmt.Fprintf(f, `
   218  	"%s",`, name)
   219  	}
   220  	fmt.Fprint(f, `
   221  }
   222  `)
   223  
   224  	// We don't use Fprint to avoid go vet warnings about
   225  	// formatting directives in string.
   226  	fmt.Fprint(f, `
   227  // Summary prints a short summary of the requests in a batch.
   228  func (ba *BatchRequest) Summary() string {
   229  	var b strings.Builder
   230  	ba.WriteSummary(&b)
   231  	return b.String()
   232  }
   233  
   234  // WriteSummary writes a short summary of the requests in a batch
   235  // to the provided builder.
   236  func (ba *BatchRequest) WriteSummary(b *strings.Builder) {
   237  	if len(ba.Requests) == 0 {
   238  		b.WriteString("empty batch")
   239  		return
   240  	}
   241  	counts := ba.getReqCounts()
   242  	var tmp [10]byte
   243  	var comma bool
   244  	for i, v := range counts {
   245  		if v != 0 {
   246  			if comma {
   247  				b.WriteString(", ")
   248  			}
   249  			comma = true
   250  
   251  			b.Write(strconv.AppendInt(tmp[:0], int64(v), 10))
   252  			b.WriteString(" ")
   253  			b.WriteString(requestNames[i])
   254  		}
   255  	}
   256  }
   257  `)
   258  
   259  	// Generate CreateReply function.
   260  	fmt.Fprint(f, `
   261  // The following types are used to group the allocations of Responses
   262  // and their corresponding isResponseUnion_Value union wrappers together.
   263  `)
   264  	allocTypes := make(map[string]string)
   265  	for _, resV := range resVariants {
   266  		allocName := strings.ToLower(resV.msgType[:1]) + resV.msgType[1:] + "Alloc"
   267  		fmt.Fprintf(f, `type %s struct {
   268  	union %s
   269  	resp  %s
   270  }
   271  `, allocName, resV.variantType, resV.msgType)
   272  		allocTypes[resV.variantName] = allocName
   273  	}
   274  
   275  	fmt.Fprint(f, `
   276  // CreateReply creates replies for each of the contained requests, wrapped in a
   277  // BatchResponse. The response objects are batch allocated to minimize
   278  // allocation overhead.
   279  func (ba *BatchRequest) CreateReply() *BatchResponse {
   280  	br := &BatchResponse{}
   281  	br.Responses = make([]ResponseUnion, len(ba.Requests))
   282  
   283  	counts := ba.getReqCounts()
   284  
   285  `)
   286  
   287  	for i, v := range reqVariants {
   288  		resV, ok := reqResVariantMapping[v]
   289  		if !ok {
   290  			panic(fmt.Sprintf("unknown response variant for %v", v))
   291  		}
   292  		fmt.Fprintf(f, "	var buf%d []%s\n", i, allocTypes[resV.variantName])
   293  	}
   294  
   295  	fmt.Fprint(f, `
   296  	for i, r := range ba.Requests {
   297  		switch r.GetValue().(type) {
   298  `)
   299  
   300  	for i, v := range reqVariants {
   301  		resV, ok := reqResVariantMapping[v]
   302  		if !ok {
   303  			panic(fmt.Sprintf("unknown response variant for %v", v))
   304  		}
   305  
   306  		fmt.Fprintf(f, `		case *%[2]s:
   307  			if buf%[1]d == nil {
   308  				buf%[1]d = make([]%[3]s, counts[%[1]d])
   309  			}
   310  			buf%[1]d[0].union.%[4]s = &buf%[1]d[0].resp
   311  			br.Responses[i].Value = &buf%[1]d[0].union
   312  			buf%[1]d = buf%[1]d[1:]
   313  `, i, v.variantType, allocTypes[resV.variantName], resV.variantName)
   314  	}
   315  
   316  	fmt.Fprintf(f, "%s", `		default:
   317  			panic(fmt.Sprintf("unsupported request: %+v", r))
   318  		}
   319  	}
   320  	return br
   321  }
   322  `)
   323  
   324  	if err := f.Close(); err != nil {
   325  		fmt.Fprintln(os.Stderr, "Error closing file: ", err)
   326  		os.Exit(1)
   327  	}
   328  }