github.com/ovsinc/prototool@v1.3.0/internal/extract/getter.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package extract 22 23 import ( 24 "fmt" 25 "strings" 26 27 "github.com/golang/protobuf/protoc-gen-go/descriptor" 28 "go.uber.org/zap" 29 ) 30 31 type getter struct { 32 logger *zap.Logger 33 } 34 35 func newGetter(options ...GetterOption) *getter { 36 getter := &getter{ 37 logger: zap.NewNop(), 38 } 39 for _, option := range options { 40 option(getter) 41 } 42 return getter 43 } 44 45 func (g *getter) GetField(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Field, error) { 46 if len(path) == 0 { 47 return nil, fmt.Errorf("empty path") 48 } 49 if path[0] == '.' { 50 path = path[1:] 51 } 52 split := strings.Split(path, ".") 53 if len(split) < 2 { 54 return nil, fmt.Errorf("no field for path %s", path) 55 } 56 message, err := g.GetMessage(fileDescriptorSets, strings.Join(split[0:len(split)-1], ".")) 57 if err != nil { 58 return nil, err 59 } 60 var foundFieldDescriptorProto *descriptor.FieldDescriptorProto 61 for _, fieldDescriptorProto := range append(message.GetField(), message.GetExtension()...) { 62 wantName := split[len(split)-1] 63 if fieldDescriptorProto.GetName() == wantName { 64 if foundFieldDescriptorProto != nil { 65 return nil, fmt.Errorf("duplicate fields for path %s", path) 66 } 67 foundFieldDescriptorProto = fieldDescriptorProto 68 } 69 } 70 if foundFieldDescriptorProto == nil { 71 return nil, fmt.Errorf("no field for path %s", path) 72 } 73 return &Field{ 74 FieldDescriptorProto: foundFieldDescriptorProto, 75 FullyQualifiedPath: "." + path, 76 DescriptorProto: message.DescriptorProto, 77 FileDescriptorProto: message.FileDescriptorProto, 78 FileDescriptorSet: message.FileDescriptorSet, 79 }, nil 80 } 81 82 func (g *getter) GetMessage(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Message, error) { 83 if len(path) == 0 { 84 return nil, fmt.Errorf("empty path") 85 } 86 if path[0] == '.' { 87 path = path[1:] 88 } 89 var descriptorProto *descriptor.DescriptorProto 90 var fileDescriptorProto *descriptor.FileDescriptorProto 91 var fileDescriptorSet *descriptor.FileDescriptorSet 92 for _, iFileDescriptorSet := range fileDescriptorSets { 93 for _, iFileDescriptorProto := range iFileDescriptorSet.File { 94 iDescriptorProto, err := findDescriptorProto(path, iFileDescriptorProto) 95 if err != nil { 96 return nil, err 97 } 98 if iDescriptorProto != nil { 99 if descriptorProto != nil { 100 return nil, fmt.Errorf("duplicate messages for path %s", path) 101 } 102 descriptorProto = iDescriptorProto 103 fileDescriptorProto = iFileDescriptorProto 104 } 105 } 106 // return first fileDescriptorSet that matches 107 // as opposed to duplicate check within fileDescriptorSet, we easily could 108 // have multiple fileDescriptorSets that match 109 if descriptorProto != nil { 110 fileDescriptorSet = iFileDescriptorSet 111 break 112 } 113 } 114 if descriptorProto == nil { 115 return nil, fmt.Errorf("no message for path %s", path) 116 } 117 return &Message{ 118 DescriptorProto: descriptorProto, 119 FullyQualifiedPath: "." + path, 120 FileDescriptorProto: fileDescriptorProto, 121 FileDescriptorSet: fileDescriptorSet, 122 }, nil 123 } 124 125 func (g *getter) GetService(fileDescriptorSets []*descriptor.FileDescriptorSet, path string) (*Service, error) { 126 if len(path) == 0 { 127 return nil, fmt.Errorf("empty path") 128 } 129 if path[0] == '.' { 130 path = path[1:] 131 } 132 var serviceDescriptorProto *descriptor.ServiceDescriptorProto 133 var fileDescriptorProto *descriptor.FileDescriptorProto 134 var fileDescriptorSet *descriptor.FileDescriptorSet 135 for _, iFileDescriptorSet := range fileDescriptorSets { 136 for _, iFileDescriptorProto := range iFileDescriptorSet.File { 137 iServiceDescriptorProto, err := findServiceDescriptorProto(path, iFileDescriptorProto) 138 if err != nil { 139 return nil, err 140 } 141 if iServiceDescriptorProto != nil { 142 if serviceDescriptorProto != nil { 143 return nil, fmt.Errorf("duplicate services for path %s", path) 144 } 145 serviceDescriptorProto = iServiceDescriptorProto 146 fileDescriptorProto = iFileDescriptorProto 147 } 148 } 149 // return first fileDescriptorSet that matches 150 // as opposed to duplicate check within fileDescriptorSet, we easily could 151 // have multiple fileDescriptorSets that match 152 if serviceDescriptorProto != nil { 153 fileDescriptorSet = iFileDescriptorSet 154 break 155 } 156 } 157 if serviceDescriptorProto == nil { 158 return nil, fmt.Errorf("no service for path %s", path) 159 } 160 return &Service{ 161 ServiceDescriptorProto: serviceDescriptorProto, 162 FullyQualifiedPath: "." + path, 163 FileDescriptorProto: fileDescriptorProto, 164 FileDescriptorSet: fileDescriptorSet, 165 }, nil 166 } 167 168 // TODO: we don't actually do full path resolution per the descriptor.proto spec 169 // https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto#L185 170 171 func findDescriptorProto(path string, fileDescriptorProto *descriptor.FileDescriptorProto) (*descriptor.DescriptorProto, error) { 172 if fileDescriptorProto.GetPackage() == "" { 173 return nil, fmt.Errorf("no package on FileDescriptorProto") 174 } 175 if !strings.HasPrefix(path, fileDescriptorProto.GetPackage()) { 176 return nil, nil 177 } 178 return findDescriptorProtoInSlice(path, fileDescriptorProto.GetPackage(), fileDescriptorProto.GetMessageType()) 179 } 180 181 func findDescriptorProtoInSlice(path string, nestedName string, descriptorProtos []*descriptor.DescriptorProto) (*descriptor.DescriptorProto, error) { 182 var foundDescriptorProto *descriptor.DescriptorProto 183 for _, descriptorProto := range descriptorProtos { 184 if descriptorProto.GetName() == "" { 185 return nil, fmt.Errorf("no name on DescriptorProto") 186 } 187 fullName := nestedName + "." + descriptorProto.GetName() 188 if path == fullName { 189 if foundDescriptorProto != nil { 190 return nil, fmt.Errorf("duplicate messages for path %s", path) 191 } 192 foundDescriptorProto = descriptorProto 193 } 194 nestedFoundDescriptorProto, err := findDescriptorProtoInSlice(path, fullName, descriptorProto.GetNestedType()) 195 if err != nil { 196 return nil, err 197 } 198 if nestedFoundDescriptorProto != nil { 199 if foundDescriptorProto != nil { 200 return nil, fmt.Errorf("duplicate messages for path %s", path) 201 } 202 foundDescriptorProto = nestedFoundDescriptorProto 203 } 204 } 205 return foundDescriptorProto, nil 206 } 207 208 func findServiceDescriptorProto(path string, fileDescriptorProto *descriptor.FileDescriptorProto) (*descriptor.ServiceDescriptorProto, error) { 209 if fileDescriptorProto.GetPackage() == "" { 210 return nil, fmt.Errorf("no package on FileDescriptorProto") 211 } 212 if !strings.HasPrefix(path, fileDescriptorProto.GetPackage()) { 213 return nil, nil 214 } 215 var foundServiceDescriptorProto *descriptor.ServiceDescriptorProto 216 for _, serviceDescriptorProto := range fileDescriptorProto.GetService() { 217 if fileDescriptorProto.GetPackage()+"."+serviceDescriptorProto.GetName() == path { 218 if foundServiceDescriptorProto != nil { 219 return nil, fmt.Errorf("duplicate services for path %s", path) 220 } 221 foundServiceDescriptorProto = serviceDescriptorProto 222 } 223 } 224 return foundServiceDescriptorProto, nil 225 }