go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/google/descutil/printer/printer.go (about)

     1  // Copyright 2021 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package printer
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"sort"
    21  	"strings"
    22  	"unicode"
    23  
    24  	txtpb "github.com/protocolbuffers/txtpbfmt/parser"
    25  	"google.golang.org/protobuf/encoding/prototext"
    26  	"google.golang.org/protobuf/reflect/protoreflect"
    27  	"google.golang.org/protobuf/types/descriptorpb"
    28  
    29  	"go.chromium.org/luci/common/data/text/indented"
    30  	"go.chromium.org/luci/common/proto/google/descutil"
    31  )
    32  
    33  // Printer prints a proto3 definition from a description.
    34  type Printer struct {
    35  	file           *descriptorpb.FileDescriptorProto
    36  	sourceCodeInfo map[any]*descriptorpb.SourceCodeInfo_Location
    37  
    38  	Out indented.Writer
    39  
    40  	// Err is not nil if writing to Out failed.
    41  	Err error
    42  }
    43  
    44  // NewPrinter creates a new Printer which will output protobuf definition text
    45  // (i.e. ".proto" file) to the given writer.
    46  func NewPrinter(out io.Writer) *Printer {
    47  	return &Printer{Out: indented.Writer{Writer: out}}
    48  }
    49  
    50  // SetFile specifies the file containing the descriptors being printed.
    51  // Used to relativize names and print comments.
    52  func (p *Printer) SetFile(f *descriptorpb.FileDescriptorProto) error {
    53  	p.file = f
    54  	var err error
    55  	p.sourceCodeInfo, err = descutil.IndexSourceCodeInfo(f)
    56  	return err
    57  }
    58  
    59  // Printf prints to p.Out unless there was an error.
    60  func (p *Printer) Printf(format string, a ...any) {
    61  	if p.Err == nil {
    62  		_, p.Err = fmt.Fprintf(&p.Out, format, a...)
    63  	}
    64  }
    65  
    66  // Package prints package declaration.
    67  func (p *Printer) Package(name string) {
    68  	p.Printf("package %s;\n", name)
    69  }
    70  
    71  // open prints a string, followed by " {\n" and increases indentation level.
    72  // Returns a function that decreases indentation level and closes the brace
    73  // followed by a newline.
    74  // Usage: defer open("package x")()
    75  func (p *Printer) open(format string, a ...any) func() {
    76  	p.Printf(format, a...)
    77  	p.Printf(" {\n")
    78  	p.Out.Level++
    79  	return func() {
    80  		p.Out.Level--
    81  		p.Printf("}\n")
    82  	}
    83  }
    84  
    85  // MaybeLeadingComments prints leading comments of the descriptorpb proto
    86  // if found.
    87  func (p *Printer) MaybeLeadingComments(ptr any) {
    88  	comments := p.sourceCodeInfo[ptr].GetLeadingComments()
    89  	// print comments, but insert "//" before each newline.
    90  	for len(comments) > 0 {
    91  		var toPrint string
    92  		if lineEnd := strings.Index(comments, "\n"); lineEnd >= 0 {
    93  			toPrint = comments[:lineEnd+1] // includes newline
    94  			comments = comments[lineEnd+1:]
    95  		} else {
    96  			// actually this does not happen, because comments always end with
    97  			// newline, but just in case.
    98  			toPrint = comments + "\n"
    99  			comments = ""
   100  		}
   101  		p.Printf("//%s", toPrint)
   102  	}
   103  }
   104  
   105  // AppendLeadingComments allows adding additional leading comments to any printable
   106  // descriptorpb object associated with this printer.
   107  //
   108  // Each line will be prepended with " " and appended with "\n".
   109  //
   110  // e.g.
   111  //
   112  //	p := NewPrinter(os.Stdout)
   113  //	p.AppendLeadingComments(protodesc.ToDescriptorProto(myMsg.ProtoReflect()), []string{
   114  //	  "This is a line.",
   115  //	  "This is the next line.",
   116  //	})
   117  func (p *Printer) AppendLeadingComments(ptr any, lines []string) {
   118  	loc, ok := p.sourceCodeInfo[ptr]
   119  	if !ok {
   120  		loc = &descriptorpb.SourceCodeInfo_Location{}
   121  		p.sourceCodeInfo[ptr] = loc
   122  	}
   123  	bld := strings.Builder{}
   124  	for _, line := range lines {
   125  		bld.WriteRune(' ')
   126  		bld.WriteString(line)
   127  		bld.WriteRune('\n')
   128  	}
   129  	comments := loc.GetLeadingComments() + bld.String()
   130  	loc.LeadingComments = &comments
   131  }
   132  
   133  // shorten removes leading "." and trims package name if it matches p.file.
   134  func (p *Printer) shorten(name string) string {
   135  	name = strings.TrimPrefix(name, ".")
   136  	if p.file.GetPackage() != "" {
   137  		name = strings.TrimPrefix(name, p.file.GetPackage()+".")
   138  	}
   139  	return name
   140  }
   141  
   142  // Service prints a service definition.
   143  // If methodIndex != -1, only one method is printed.
   144  // If serviceIndex != -1, leading comments are printed if found.
   145  func (p *Printer) Service(service *descriptorpb.ServiceDescriptorProto, methodIndex int) {
   146  	p.MaybeLeadingComments(service)
   147  	defer p.open("service %s", service.GetName())()
   148  
   149  	if methodIndex < 0 {
   150  		for i := range service.Method {
   151  			p.Method(service.Method[i])
   152  		}
   153  	} else {
   154  		p.Method(service.Method[methodIndex])
   155  		if len(service.Method) > 1 {
   156  			p.Printf("// other methods were omitted.\n")
   157  		}
   158  	}
   159  }
   160  
   161  // Method prints a service method definition.
   162  func (p *Printer) Method(method *descriptorpb.MethodDescriptorProto) {
   163  	p.MaybeLeadingComments(method)
   164  	p.Printf(
   165  		"rpc %s(%s) returns (%s) {};\n",
   166  		method.GetName(),
   167  		p.shorten(method.GetInputType()),
   168  		p.shorten(method.GetOutputType()),
   169  	)
   170  }
   171  
   172  var fieldTypeName = map[descriptorpb.FieldDescriptorProto_Type]string{
   173  	descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:   "double",
   174  	descriptorpb.FieldDescriptorProto_TYPE_FLOAT:    "float",
   175  	descriptorpb.FieldDescriptorProto_TYPE_INT64:    "int64",
   176  	descriptorpb.FieldDescriptorProto_TYPE_UINT64:   "uint64",
   177  	descriptorpb.FieldDescriptorProto_TYPE_INT32:    "int32",
   178  	descriptorpb.FieldDescriptorProto_TYPE_FIXED64:  "fixed64",
   179  	descriptorpb.FieldDescriptorProto_TYPE_FIXED32:  "fixed32",
   180  	descriptorpb.FieldDescriptorProto_TYPE_BOOL:     "bool",
   181  	descriptorpb.FieldDescriptorProto_TYPE_STRING:   "string",
   182  	descriptorpb.FieldDescriptorProto_TYPE_BYTES:    "bytes",
   183  	descriptorpb.FieldDescriptorProto_TYPE_UINT32:   "uint32",
   184  	descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: "sfixed32",
   185  	descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: "sfixed64",
   186  	descriptorpb.FieldDescriptorProto_TYPE_SINT32:   "sint32",
   187  	descriptorpb.FieldDescriptorProto_TYPE_SINT64:   "sint64",
   188  }
   189  
   190  // Field prints a field definition.
   191  func (p *Printer) Field(field *descriptorpb.FieldDescriptorProto) {
   192  	p.MaybeLeadingComments(field)
   193  	if descutil.Repeated(field) {
   194  		p.Printf("repeated ")
   195  	}
   196  
   197  	typeName := fieldTypeName[field.GetType()]
   198  	if typeName == "" {
   199  		typeName = p.shorten(field.GetTypeName())
   200  	}
   201  	if typeName == "" {
   202  		typeName = "<unsupported type>"
   203  	}
   204  	p.Printf("%s %s = %d", typeName, field.GetName(), field.GetNumber())
   205  
   206  	p.fieldOptions(field)
   207  
   208  	p.Printf(";\n")
   209  }
   210  
   211  // converts snake_case to camelCase.
   212  func camel(snakeCase string) string {
   213  	prev := 'x'
   214  	return strings.Map(
   215  		func(r rune) rune {
   216  			if prev == '_' {
   217  				prev = r
   218  				return unicode.ToTitle(r)
   219  			}
   220  			prev = r
   221  			if r == '_' {
   222  				return -1
   223  			}
   224  			return r
   225  		}, snakeCase)
   226  }
   227  
   228  func (p *Printer) optionValue(ed protoreflect.EnumDescriptor, v protoreflect.Value) {
   229  	switch x := v.Interface().(type) {
   230  	case bool:
   231  		p.Printf(" %t", x)
   232  	case int32, int64, uint32, uint64:
   233  		p.Printf(" %d", v.Interface())
   234  	case float32, float64:
   235  		p.Printf(" %f", v.Interface())
   236  	case string, []byte:
   237  		p.Printf(" %q", v.Interface())
   238  	case protoreflect.EnumNumber:
   239  		p.Printf(" %s", ed.Values().ByNumber(x).Name())
   240  	case protoreflect.Message:
   241  		p.Printf(" {\n")
   242  		p.Out.Level++
   243  		defer func() {
   244  			p.Out.Level--
   245  			p.Printf("}")
   246  		}()
   247  
   248  		data, err := prototext.MarshalOptions{Indent: "\t"}.Marshal(x.Interface())
   249  		if err != nil {
   250  			panic(err)
   251  		}
   252  		// ensure textproto output is stable.
   253  		data, err = txtpb.Format(data)
   254  		if err != nil {
   255  			panic(err)
   256  		}
   257  
   258  		p.Printf("%s", data)
   259  	default:
   260  		panic("unknown protoreflect.Value type")
   261  	}
   262  
   263  	return
   264  }
   265  
   266  type optField struct {
   267  	fieldNum       protoreflect.FieldNumber
   268  	renderedName   string
   269  	enumDescriptor protoreflect.EnumDescriptor
   270  	val            protoreflect.Value
   271  }
   272  
   273  type optFields []optField
   274  
   275  func (p *Printer) collectOptions(m protoreflect.ProtoMessage, extra optFields) optFields {
   276  	toPrint := append(optFields{}, extra...)
   277  
   278  	m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   279  		renderedName := fd.TextName()
   280  		if fd.IsExtension() {
   281  			renderedName = fmt.Sprintf("(%s)", p.shorten(string(fd.FullName())))
   282  		}
   283  		toPrint = append(toPrint, optField{fd.Number(), renderedName, fd.Enum(), v})
   284  		return true
   285  	})
   286  	sort.Slice(toPrint, func(i, j int) bool {
   287  		return toPrint[i].fieldNum < toPrint[j].fieldNum
   288  	})
   289  	return toPrint
   290  }
   291  
   292  func (of optFields) write(p *Printer, prefix string, afterEach func(last bool)) {
   293  	for i, f := range of {
   294  		p.Printf("%s%s =", prefix, f.renderedName)
   295  		p.optionValue(f.enumDescriptor, f.val)
   296  		if afterEach != nil {
   297  			afterEach(i == len(of)-1)
   298  		}
   299  	}
   300  }
   301  
   302  func (p *Printer) fieldOptions(field *descriptorpb.FieldDescriptorProto) {
   303  	var extra optFields
   304  	if field.GetJsonName() != camel(field.GetName()) {
   305  		extra = optFields{{
   306  			0, "json_name", nil, protoreflect.ValueOfString(field.GetJsonName()),
   307  		}}
   308  	}
   309  	toPrint := p.collectOptions(field.Options, extra)
   310  
   311  	if len(toPrint) == 0 {
   312  		return
   313  	}
   314  	if len(toPrint) == 1 {
   315  		p.Printf(" [")
   316  		toPrint.write(p, "", nil)
   317  		p.Printf("]")
   318  		return
   319  	}
   320  
   321  	p.Printf(" [\n")
   322  	p.Out.Level++
   323  	defer func() {
   324  		p.Out.Level--
   325  		p.Printf("]")
   326  	}()
   327  
   328  	nl := func(hasNext bool) {
   329  		if hasNext {
   330  			p.Printf(",\n")
   331  		} else {
   332  			p.Printf("\n")
   333  		}
   334  	}
   335  
   336  	toPrint.write(p, "", func(last bool) {
   337  		nl(!last)
   338  	})
   339  }
   340  
   341  // Message prints a message definition.
   342  func (p *Printer) Message(msg *descriptorpb.DescriptorProto) {
   343  	p.MaybeLeadingComments(msg)
   344  	defer p.open("message %s", msg.GetName())()
   345  
   346  	p.collectOptions(msg.Options, nil).write(p, "option ", func(last bool) {
   347  		p.Printf(";\n")
   348  	})
   349  
   350  	for _, name := range msg.ReservedName {
   351  		p.Printf("reserved %q;\n", name)
   352  	}
   353  
   354  	for _, rng := range msg.ReservedRange {
   355  		if rng.GetStart() == rng.GetEnd()-1 {
   356  			p.Printf("reserved %d;\n", rng.GetStart())
   357  		} else {
   358  			p.Printf("reserved %d to %d;\n", rng.GetStart(), rng.GetEnd())
   359  		}
   360  	}
   361  
   362  	for i := range msg.GetOneofDecl() {
   363  		p.OneOf(msg, i)
   364  	}
   365  
   366  	for i, f := range msg.Field {
   367  		if f.OneofIndex == nil {
   368  			p.Field(msg.Field[i])
   369  		}
   370  	}
   371  }
   372  
   373  // OneOf prints a oneof definition.
   374  func (p *Printer) OneOf(msg *descriptorpb.DescriptorProto, oneOfIndex int) {
   375  	of := msg.GetOneofDecl()[oneOfIndex]
   376  	p.MaybeLeadingComments(of)
   377  	defer p.open("oneof %s", of.GetName())()
   378  
   379  	for i, f := range msg.Field {
   380  		if f.OneofIndex != nil && int(f.GetOneofIndex()) == oneOfIndex {
   381  			p.Field(msg.Field[i])
   382  		}
   383  	}
   384  }
   385  
   386  // Enum prints an enum definition.
   387  func (p *Printer) Enum(enum *descriptorpb.EnumDescriptorProto) {
   388  	p.MaybeLeadingComments(enum)
   389  	defer p.open("enum %s", enum.GetName())()
   390  
   391  	for _, v := range enum.Value {
   392  		p.EnumValue(v)
   393  	}
   394  }
   395  
   396  // EnumValue prints an enum value definition.
   397  func (p *Printer) EnumValue(v *descriptorpb.EnumValueDescriptorProto) {
   398  	p.MaybeLeadingComments(v)
   399  	p.Printf("%s = %d;\n", v.GetName(), v.GetNumber())
   400  }