github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/walk/walk.go (about)

     1  // Package walk provides helper functions for traversing all elements in a
     2  // protobuf file descriptor. There are versions both for traversing "rich"
     3  // descriptors (protoreflect.Descriptor) and for traversing the underlying
     4  // "raw" descriptor protos.
     5  //
     6  // Enter And Exit
     7  //
     8  // This package includes variants of the functions that accept two callback
     9  // functions. These variants have names ending with "EnterAndExit". One function
    10  // is called as each element is visited ("enter") and the other is called after
    11  // the element and all of its descendants have been visited ("exit"). This
    12  // can be useful when you need to track state that is scoped to the visitation
    13  // of a single element.
    14  //
    15  // Source Path
    16  //
    17  // When traversing raw descriptor protos, this package include variants whose
    18  // callback accepts a protoreflect.SourcePath. These variants have names that
    19  // include "WithPath". This path can be used to locate corresponding data in the
    20  // file's source code info (if present).
    21  package walk
    22  
    23  import (
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  	"google.golang.org/protobuf/types/descriptorpb"
    27  
    28  	"github.com/jhump/protocompile/internal"
    29  )
    30  
    31  // Descriptors walks all descriptors in the given file using a depth-first
    32  // traversal, calling the given function for each descriptor in the hierarchy.
    33  // The walk ends when traversal is complete or when the function returns an
    34  // error. If the function returns an error, that is returned as the result of the
    35  // walk operation.
    36  //
    37  // Descriptors are visited using a pre-order traversal, where the function is
    38  // called for a descriptor before it is called for any of its descendants.
    39  func Descriptors(file protoreflect.FileDescriptor, fn func(protoreflect.Descriptor) error) error {
    40  	return DescriptorsEnterAndExit(file, fn, nil)
    41  }
    42  
    43  // DescriptorsEnterAndExit walks all descriptors in the given file using a
    44  // depth-first traversal, calling the given functions on entry and on exit
    45  // for each descriptor in the hierarchy. The walk ends when traversal is
    46  // complete or when a function returns an error. If a function returns an error,
    47  // that is returned as the result of the walk operation.
    48  //
    49  // The enter function is called using a pre-order traversal, where the function
    50  // is called for a descriptor before it is called for any of its descendants.
    51  // The exit function is called using a post-order traversal, where the function
    52  // is called for a descriptor only after it is called for any descendants.
    53  func DescriptorsEnterAndExit(file protoreflect.FileDescriptor, enter, exit func(protoreflect.Descriptor) error) error {
    54  	for i := 0; i < file.Messages().Len(); i++ {
    55  		msg := file.Messages().Get(i)
    56  		if err := messageDescriptor(msg, enter, exit); err != nil {
    57  			return err
    58  		}
    59  	}
    60  	for i := 0; i < file.Enums().Len(); i++ {
    61  		en := file.Enums().Get(i)
    62  		if err := enumDescriptor(en, enter, exit); err != nil {
    63  			return err
    64  		}
    65  	}
    66  	for i := 0; i < file.Extensions().Len(); i++ {
    67  		ext := file.Extensions().Get(i)
    68  		if err := enter(ext); err != nil {
    69  			return err
    70  		}
    71  		if exit != nil {
    72  			if err := exit(ext); err != nil {
    73  				return err
    74  			}
    75  		}
    76  	}
    77  	for i := 0; i < file.Services().Len(); i++ {
    78  		svc := file.Services().Get(i)
    79  		if err := enter(svc); err != nil {
    80  			return err
    81  		}
    82  		for i := 0; i < svc.Methods().Len(); i++ {
    83  			mtd := svc.Methods().Get(i)
    84  			if err := enter(mtd); err != nil {
    85  				return err
    86  			}
    87  			if exit != nil {
    88  				if err := exit(mtd); err != nil {
    89  					return err
    90  				}
    91  			}
    92  		}
    93  		if exit != nil {
    94  			if err := exit(svc); err != nil {
    95  				return err
    96  			}
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  func messageDescriptor(msg protoreflect.MessageDescriptor, enter, exit func(protoreflect.Descriptor) error) error {
   103  	if err := enter(msg); err != nil {
   104  		return err
   105  	}
   106  	for i := 0; i < msg.Fields().Len(); i++ {
   107  		fld := msg.Fields().Get(i)
   108  		if err := enter(fld); err != nil {
   109  			return err
   110  		}
   111  		if exit != nil {
   112  			if err := exit(fld); err != nil {
   113  				return err
   114  			}
   115  		}
   116  	}
   117  	for i := 0; i < msg.Oneofs().Len(); i++ {
   118  		oo := msg.Oneofs().Get(i)
   119  		if err := enter(oo); err != nil {
   120  			return err
   121  		}
   122  		if exit != nil {
   123  			if err := exit(oo); err != nil {
   124  				return err
   125  			}
   126  		}
   127  	}
   128  	for i := 0; i < msg.Messages().Len(); i++ {
   129  		nested := msg.Messages().Get(i)
   130  		if err := messageDescriptor(nested, enter, exit); err != nil {
   131  			return err
   132  		}
   133  	}
   134  	for i := 0; i < msg.Enums().Len(); i++ {
   135  		en := msg.Enums().Get(i)
   136  		if err := enumDescriptor(en, enter, exit); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	for i := 0; i < msg.Extensions().Len(); i++ {
   141  		ext := msg.Extensions().Get(i)
   142  		if err := enter(ext); err != nil {
   143  			return err
   144  		}
   145  		if exit != nil {
   146  			if err := exit(ext); err != nil {
   147  				return err
   148  			}
   149  		}
   150  	}
   151  	if exit != nil {
   152  		if err := exit(msg); err != nil {
   153  			return err
   154  		}
   155  	}
   156  	return nil
   157  }
   158  
   159  func enumDescriptor(en protoreflect.EnumDescriptor, enter, exit func(protoreflect.Descriptor) error) error {
   160  	if err := enter(en); err != nil {
   161  		return err
   162  	}
   163  	for i := 0; i < en.Values().Len(); i++ {
   164  		enVal := en.Values().Get(i)
   165  		if err := enter(enVal); err != nil {
   166  			return err
   167  		}
   168  		if exit != nil {
   169  			if err := exit(enVal); err != nil {
   170  				return err
   171  			}
   172  		}
   173  	}
   174  	if exit != nil {
   175  		if err := exit(en); err != nil {
   176  			return err
   177  		}
   178  	}
   179  	return nil
   180  }
   181  
   182  // DescriptorProtosWithPath walks all descriptor protos in the given file using
   183  // a depth-first traversal. This is the same as DescriptorProtos except that the
   184  // callback function, fn, receives a protoreflect.SourcePath, that indicates the
   185  // path for the element in the file's source code info.
   186  func DescriptorProtosWithPath(file *descriptorpb.FileDescriptorProto, fn func(protoreflect.FullName, protoreflect.SourcePath, proto.Message) error) error {
   187  	return DescriptorProtosWithPathEnterAndExit(file, fn, nil)
   188  }
   189  
   190  // DescriptorProtosWithPathEnterAndExit walks all descriptor protos in the given
   191  // file using a depth-first traversal. This is the same as
   192  // DescriptorProtosEnterAndExit except that the callback function, fn, receives
   193  // a protoreflect.SourcePath, that indicates the path for the element in the
   194  // file's source code info.
   195  func DescriptorProtosWithPathEnterAndExit(file *descriptorpb.FileDescriptorProto, enter, exit func(protoreflect.FullName, protoreflect.SourcePath, proto.Message) error) error {
   196  	w := &protoWalker{usePath: true, enter: enter, exit: exit}
   197  	return w.walkDescriptorProtos(file)
   198  }
   199  
   200  // DescriptorProtos walks all descriptor protos in the given file using a
   201  // depth-first traversal, calling the given function for each descriptor proto
   202  // in the hierarchy. The walk ends when traversal is complete or when the
   203  // function returns an error. If the function returns an error, that is
   204  // returned as the result of the walk operation.
   205  //
   206  // Descriptor protos are visited using a pre-order traversal, where the function
   207  // is called for a descriptor before it is called for any of its descendants.
   208  func DescriptorProtos(file *descriptorpb.FileDescriptorProto, fn func(protoreflect.FullName, proto.Message) error) error {
   209  	return DescriptorProtosEnterAndExit(file, fn, nil)
   210  }
   211  
   212  // DescriptorProtosEnterAndExit walks all descriptor protos in the given file
   213  // using a depth-first traversal, calling the given functions on entry and on
   214  // exit for each descriptor in the hierarchy. The walk ends when traversal is
   215  // complete or when a function returns an error. If a function returns an error,
   216  // that is returned as the result of the walk operation.
   217  //
   218  // The enter function is called using a pre-order traversal, where the function
   219  // is called for a descriptor proto before it is called for any of its
   220  // descendants. The exit function is called using a post-order traversal, where
   221  // the function is called for a descriptor proto only after it is called for any
   222  // descendants.
   223  func DescriptorProtosEnterAndExit(file *descriptorpb.FileDescriptorProto, enter, exit func(protoreflect.FullName, proto.Message) error) error {
   224  	enterWithPath := func(n protoreflect.FullName, p protoreflect.SourcePath, m proto.Message) error {
   225  		return enter(n, m)
   226  	}
   227  	var exitWithPath func(n protoreflect.FullName, p protoreflect.SourcePath, m proto.Message) error
   228  	if exit != nil {
   229  		exitWithPath = func(n protoreflect.FullName, p protoreflect.SourcePath, m proto.Message) error {
   230  			return exit(n, m)
   231  		}
   232  	}
   233  	w := &protoWalker{
   234  		enter: enterWithPath,
   235  		exit:  exitWithPath,
   236  	}
   237  	return w.walkDescriptorProtos(file)
   238  }
   239  
   240  type protoWalker struct {
   241  	usePath     bool
   242  	enter, exit func(protoreflect.FullName, protoreflect.SourcePath, proto.Message) error
   243  }
   244  
   245  func (w *protoWalker) walkDescriptorProtos(file *descriptorpb.FileDescriptorProto) error {
   246  	prefix := file.GetPackage()
   247  	if prefix != "" {
   248  		prefix = prefix + "."
   249  	}
   250  	var path protoreflect.SourcePath
   251  	for i, msg := range file.MessageType {
   252  		var p protoreflect.SourcePath
   253  		if w.usePath {
   254  			p = append(path, internal.File_messagesTag, int32(i))
   255  		}
   256  		if err := w.walkDescriptorProto(prefix, p, msg); err != nil {
   257  			return err
   258  		}
   259  	}
   260  	for i, en := range file.EnumType {
   261  		var p protoreflect.SourcePath
   262  		if w.usePath {
   263  			p = append(path, internal.File_enumsTag, int32(i))
   264  		}
   265  		if err := w.walkEnumDescriptorProto(prefix, p, en); err != nil {
   266  			return err
   267  		}
   268  	}
   269  	for i, ext := range file.Extension {
   270  		var p protoreflect.SourcePath
   271  		if w.usePath {
   272  			p = append(path, internal.File_extensionsTag, int32(i))
   273  		}
   274  		fqn := prefix + ext.GetName()
   275  		if err := w.enter(protoreflect.FullName(fqn), p, ext); err != nil {
   276  			return err
   277  		}
   278  		if w.exit != nil {
   279  			if err := w.exit(protoreflect.FullName(fqn), p, ext); err != nil {
   280  				return err
   281  			}
   282  		}
   283  	}
   284  	for i, svc := range file.Service {
   285  		var p protoreflect.SourcePath
   286  		if w.usePath {
   287  			p = append(path, internal.File_servicesTag, int32(i))
   288  		}
   289  		fqn := prefix + svc.GetName()
   290  		if err := w.enter(protoreflect.FullName(fqn), p, svc); err != nil {
   291  			return err
   292  		}
   293  		for j, mtd := range svc.Method {
   294  			var mp protoreflect.SourcePath
   295  			if w.usePath {
   296  				mp = append(p, internal.Service_methodsTag, int32(j))
   297  			}
   298  			mtdFqn := fqn + "." + mtd.GetName()
   299  			if err := w.enter(protoreflect.FullName(mtdFqn), mp, mtd); err != nil {
   300  				return err
   301  			}
   302  			if w.exit != nil {
   303  				if err := w.exit(protoreflect.FullName(mtdFqn), mp, mtd); err != nil {
   304  					return err
   305  				}
   306  			}
   307  		}
   308  		if w.exit != nil {
   309  			if err := w.exit(protoreflect.FullName(fqn), p, svc); err != nil {
   310  				return err
   311  			}
   312  		}
   313  	}
   314  	return nil
   315  }
   316  
   317  func (w *protoWalker) walkDescriptorProto(prefix string, path protoreflect.SourcePath, msg *descriptorpb.DescriptorProto) error {
   318  	fqn := prefix + msg.GetName()
   319  	if err := w.enter(protoreflect.FullName(fqn), path, msg); err != nil {
   320  		return err
   321  	}
   322  	prefix = fqn + "."
   323  	for i, fld := range msg.Field {
   324  		var p protoreflect.SourcePath
   325  		if w.usePath {
   326  			p = append(path, internal.Message_fieldsTag, int32(i))
   327  		}
   328  		fqn := prefix + fld.GetName()
   329  		if err := w.enter(protoreflect.FullName(fqn), p, fld); err != nil {
   330  			return err
   331  		}
   332  		if w.exit != nil {
   333  			if err := w.exit(protoreflect.FullName(fqn), p, fld); err != nil {
   334  				return err
   335  			}
   336  		}
   337  	}
   338  	for i, oo := range msg.OneofDecl {
   339  		var p protoreflect.SourcePath
   340  		if w.usePath {
   341  			p = append(path, internal.Message_oneOfsTag, int32(i))
   342  		}
   343  		fqn := prefix + oo.GetName()
   344  		if err := w.enter(protoreflect.FullName(fqn), p, oo); err != nil {
   345  			return err
   346  		}
   347  		if w.exit != nil {
   348  			if err := w.exit(protoreflect.FullName(fqn), p, oo); err != nil {
   349  				return err
   350  			}
   351  		}
   352  	}
   353  	for i, nested := range msg.NestedType {
   354  		var p protoreflect.SourcePath
   355  		if w.usePath {
   356  			p = append(path, internal.Message_nestedMessagesTag, int32(i))
   357  		}
   358  		if err := w.walkDescriptorProto(prefix, p, nested); err != nil {
   359  			return err
   360  		}
   361  	}
   362  	for i, en := range msg.EnumType {
   363  		var p protoreflect.SourcePath
   364  		if w.usePath {
   365  			p = append(path, internal.Message_enumsTag, int32(i))
   366  		}
   367  		if err := w.walkEnumDescriptorProto(prefix, p, en); err != nil {
   368  			return err
   369  		}
   370  	}
   371  	for i, ext := range msg.Extension {
   372  		var p protoreflect.SourcePath
   373  		if w.usePath {
   374  			p = append(path, internal.Message_extensionsTag, int32(i))
   375  		}
   376  		fqn := prefix + ext.GetName()
   377  		if err := w.enter(protoreflect.FullName(fqn), p, ext); err != nil {
   378  			return err
   379  		}
   380  		if w.exit != nil {
   381  			if err := w.exit(protoreflect.FullName(fqn), p, ext); err != nil {
   382  				return err
   383  			}
   384  		}
   385  	}
   386  	if w.exit != nil {
   387  		if err := w.exit(protoreflect.FullName(fqn), path, msg); err != nil {
   388  			return err
   389  		}
   390  	}
   391  	return nil
   392  }
   393  
   394  func (w *protoWalker) walkEnumDescriptorProto(prefix string, path protoreflect.SourcePath, en *descriptorpb.EnumDescriptorProto) error {
   395  	fqn := prefix + en.GetName()
   396  	if err := w.enter(protoreflect.FullName(fqn), path, en); err != nil {
   397  		return err
   398  	}
   399  	for i, val := range en.Value {
   400  		var p protoreflect.SourcePath
   401  		if w.usePath {
   402  			p = append(path, internal.Enum_valuesTag, int32(i))
   403  		}
   404  		fqn := prefix + val.GetName()
   405  		if err := w.enter(protoreflect.FullName(fqn), p, val); err != nil {
   406  			return err
   407  		}
   408  		if w.exit != nil {
   409  			if err := w.exit(protoreflect.FullName(fqn), p, val); err != nil {
   410  				return err
   411  			}
   412  		}
   413  	}
   414  	if w.exit != nil {
   415  		if err := w.exit(protoreflect.FullName(fqn), path, en); err != nil {
   416  			return err
   417  		}
   418  	}
   419  	return nil
   420  }