github.com/nikron/prototool@v1.3.0/internal/extract/getter.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package extract
    22  
    23  import (
    24  	"fmt"
    25  	"strings"
    26  
    27  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    28  	"go.uber.org/zap"
    29  )
    30  
    31  type getter struct {
    32  	logger *zap.Logger
    33  }
    34  
    35  func newGetter(options ...GetterOption) *getter {
    36  	getter := &getter{
    37  		logger: zap.NewNop(),
    38  	}
    39  	for _, option := range options {
    40  		option(getter)
    41  	}
    42  	return getter
    43  }
    44  
    45  func (g *getter) GetField(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Field, error) {
    46  	if len(path) == 0 {
    47  		return nil, fmt.Errorf("empty path")
    48  	}
    49  	if path[0] == '.' {
    50  		path = path[1:]
    51  	}
    52  	split := strings.Split(path, ".")
    53  	if len(split) < 2 {
    54  		return nil, fmt.Errorf("no field for path %s", path)
    55  	}
    56  	message, err := g.GetMessage(fileDescriptorSets, strings.Join(split[0:len(split)-1], "."))
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	var foundFieldDescriptorProto *descriptor.FieldDescriptorProto
    61  	for _, fieldDescriptorProto := range append(message.GetField(), message.GetExtension()...) {
    62  		wantName := split[len(split)-1]
    63  		if fieldDescriptorProto.GetName() == wantName {
    64  			if foundFieldDescriptorProto != nil {
    65  				return nil, fmt.Errorf("duplicate fields for path %s", path)
    66  			}
    67  			foundFieldDescriptorProto = fieldDescriptorProto
    68  		}
    69  	}
    70  	if foundFieldDescriptorProto == nil {
    71  		return nil, fmt.Errorf("no field for path %s", path)
    72  	}
    73  	return &Field{
    74  		FieldDescriptorProto: foundFieldDescriptorProto,
    75  		FullyQualifiedPath:   "." + path,
    76  		DescriptorProto:      message.DescriptorProto,
    77  		FileDescriptorProto:  message.FileDescriptorProto,
    78  		FileDescriptorSet:    message.FileDescriptorSet,
    79  	}, nil
    80  }
    81  
    82  func (g *getter) GetMessage(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Message, error) {
    83  	if len(path) == 0 {
    84  		return nil, fmt.Errorf("empty path")
    85  	}
    86  	if path[0] == '.' {
    87  		path = path[1:]
    88  	}
    89  	var descriptorProto *descriptor.DescriptorProto
    90  	var fileDescriptorProto *descriptor.FileDescriptorProto
    91  	var fileDescriptorSet *descriptor.FileDescriptorSet
    92  	for _, iFileDescriptorSet := range fileDescriptorSets {
    93  		for _, iFileDescriptorProto := range iFileDescriptorSet.File {
    94  			iDescriptorProto, err := findDescriptorProto(path, iFileDescriptorProto)
    95  			if err != nil {
    96  				return nil, err
    97  			}
    98  			if iDescriptorProto != nil {
    99  				if descriptorProto != nil {
   100  					return nil, fmt.Errorf("duplicate messages for path %s", path)
   101  				}
   102  				descriptorProto = iDescriptorProto
   103  				fileDescriptorProto = iFileDescriptorProto
   104  			}
   105  		}
   106  		// return first fileDescriptorSet that matches
   107  		// as opposed to duplicate check within fileDescriptorSet, we easily could
   108  		// have multiple fileDescriptorSets that match
   109  		if descriptorProto != nil {
   110  			fileDescriptorSet = iFileDescriptorSet
   111  			break
   112  		}
   113  	}
   114  	if descriptorProto == nil {
   115  		return nil, fmt.Errorf("no message for path %s", path)
   116  	}
   117  	return &Message{
   118  		DescriptorProto:     descriptorProto,
   119  		FullyQualifiedPath:  "." + path,
   120  		FileDescriptorProto: fileDescriptorProto,
   121  		FileDescriptorSet:   fileDescriptorSet,
   122  	}, nil
   123  }
   124  
   125  func (g *getter) GetService(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Service, error) {
   126  	if len(path) == 0 {
   127  		return nil, fmt.Errorf("empty path")
   128  	}
   129  	if path[0] == '.' {
   130  		path = path[1:]
   131  	}
   132  	var serviceDescriptorProto *descriptor.ServiceDescriptorProto
   133  	var fileDescriptorProto *descriptor.FileDescriptorProto
   134  	var fileDescriptorSet *descriptor.FileDescriptorSet
   135  	for _, iFileDescriptorSet := range fileDescriptorSets {
   136  		for _, iFileDescriptorProto := range iFileDescriptorSet.File {
   137  			iServiceDescriptorProto, err := findServiceDescriptorProto(path, iFileDescriptorProto)
   138  			if err != nil {
   139  				return nil, err
   140  			}
   141  			if iServiceDescriptorProto != nil {
   142  				if serviceDescriptorProto != nil {
   143  					return nil, fmt.Errorf("duplicate services for path %s", path)
   144  				}
   145  				serviceDescriptorProto = iServiceDescriptorProto
   146  				fileDescriptorProto = iFileDescriptorProto
   147  			}
   148  		}
   149  		// return first fileDescriptorSet that matches
   150  		// as opposed to duplicate check within fileDescriptorSet, we easily could
   151  		// have multiple fileDescriptorSets that match
   152  		if serviceDescriptorProto != nil {
   153  			fileDescriptorSet = iFileDescriptorSet
   154  			break
   155  		}
   156  	}
   157  	if serviceDescriptorProto == nil {
   158  		return nil, fmt.Errorf("no service for path %s", path)
   159  	}
   160  	return &Service{
   161  		ServiceDescriptorProto: serviceDescriptorProto,
   162  		FullyQualifiedPath:     "." + path,
   163  		FileDescriptorProto:    fileDescriptorProto,
   164  		FileDescriptorSet:      fileDescriptorSet,
   165  	}, nil
   166  }
   167  
   168  // TODO: we don't actually do full path resolution per the descriptor.proto spec
   169  // https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto#L185
   170  
   171  func findDescriptorProto(path string, fileDescriptorProto *descriptor.FileDescriptorProto) (*descriptor.DescriptorProto, error) {
   172  	if fileDescriptorProto.GetPackage() == "" {
   173  		return nil, fmt.Errorf("no package on FileDescriptorProto")
   174  	}
   175  	if !strings.HasPrefix(path, fileDescriptorProto.GetPackage()) {
   176  		return nil, nil
   177  	}
   178  	return findDescriptorProtoInSlice(path, fileDescriptorProto.GetPackage(), fileDescriptorProto.GetMessageType())
   179  }
   180  
   181  func findDescriptorProtoInSlice(path string, nestedName string, descriptorProtos []*descriptor.DescriptorProto) (*descriptor.DescriptorProto, error) {
   182  	var foundDescriptorProto *descriptor.DescriptorProto
   183  	for _, descriptorProto := range descriptorProtos {
   184  		if descriptorProto.GetName() == "" {
   185  			return nil, fmt.Errorf("no name on DescriptorProto")
   186  		}
   187  		fullName := nestedName + "." + descriptorProto.GetName()
   188  		if path == fullName {
   189  			if foundDescriptorProto != nil {
   190  				return nil, fmt.Errorf("duplicate messages for path %s", path)
   191  			}
   192  			foundDescriptorProto = descriptorProto
   193  		}
   194  		nestedFoundDescriptorProto, err := findDescriptorProtoInSlice(path, fullName, descriptorProto.GetNestedType())
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  		if nestedFoundDescriptorProto != nil {
   199  			if foundDescriptorProto != nil {
   200  				return nil, fmt.Errorf("duplicate messages for path %s", path)
   201  			}
   202  			foundDescriptorProto = nestedFoundDescriptorProto
   203  		}
   204  	}
   205  	return foundDescriptorProto, nil
   206  }
   207  
   208  func findServiceDescriptorProto(path string, fileDescriptorProto *descriptor.FileDescriptorProto) (*descriptor.ServiceDescriptorProto, error) {
   209  	if fileDescriptorProto.GetPackage() == "" {
   210  		return nil, fmt.Errorf("no package on FileDescriptorProto")
   211  	}
   212  	if !strings.HasPrefix(path, fileDescriptorProto.GetPackage()) {
   213  		return nil, nil
   214  	}
   215  	var foundServiceDescriptorProto *descriptor.ServiceDescriptorProto
   216  	for _, serviceDescriptorProto := range fileDescriptorProto.GetService() {
   217  		if fileDescriptorProto.GetPackage()+"."+serviceDescriptorProto.GetName() == path {
   218  			if foundServiceDescriptorProto != nil {
   219  				return nil, fmt.Errorf("duplicate services for path %s", path)
   220  			}
   221  			foundServiceDescriptorProto = serviceDescriptorProto
   222  		}
   223  	}
   224  	return foundServiceDescriptorProto, nil
   225  }