github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/protoc-gen-toolkit/descriptor/services.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package descriptor
    18  
    19  import (
    20  	"fmt"
    21  	"strings"
    22  
    23  	toolkit_annotattions "github.com/lastbackend/toolkit/protoc-gen-toolkit/toolkit/options"
    24  	options "google.golang.org/genproto/googleapis/api/annotations"
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/types/descriptorpb"
    27  )
    28  
    29  func (d *Descriptor) loadServices(file *File) error {
    30  	var services []*Service
    31  	for _, service := range file.GetService() {
    32  		svc := &Service{
    33  			File:                   file,
    34  			ServiceDescriptorProto: service,
    35  			HTTPMiddlewares:        make([]string, 0),
    36  		}
    37  		for _, md := range service.GetMethod() {
    38  			method, err := d.newMethod(svc, md)
    39  			if err != nil {
    40  				return err
    41  			}
    42  			svc.Methods = append(svc.Methods, method)
    43  		}
    44  
    45  		if service.Options != nil && proto.HasExtension(service.Options, toolkit_annotattions.E_Server) {
    46  			eServer := proto.GetExtension(svc.Options, toolkit_annotattions.E_Server)
    47  			if eServer != nil {
    48  				ss := eServer.(*toolkit_annotattions.Server)
    49  				svc.HTTPMiddlewares = ss.Middlewares
    50  			}
    51  		}
    52  		if service.Options != nil && proto.HasExtension(service.Options, toolkit_annotattions.E_Runtime) {
    53  			eService := proto.GetExtension(svc.Options, toolkit_annotattions.E_Runtime)
    54  			if eService != nil {
    55  				ss := eService.(*toolkit_annotattions.Runtime)
    56  				if ss.Servers != nil {
    57  					svc.UseHTTPServer = checkSetServerOption(ss.Servers, toolkit_annotattions.Runtime_HTTP)
    58  					svc.UseWebsocketProxyServer = checkSetServerOption(ss.Servers, toolkit_annotattions.Runtime_WEBSOCKET_PROXY)
    59  					svc.UseWebsocketServer = checkSetServerOption(ss.Servers, toolkit_annotattions.Runtime_WEBSOCKET)
    60  					svc.UseGRPCServer = checkSetServerOption(ss.Servers, toolkit_annotattions.Runtime_GRPC)
    61  				}
    62  			}
    63  		}
    64  
    65  		services = append(services, svc)
    66  	}
    67  
    68  	file.Services = services
    69  
    70  	return nil
    71  }
    72  
    73  func (d *Descriptor) newMethod(svc *Service, md *descriptorpb.MethodDescriptorProto) (*Method, error) {
    74  	requestType, err := d.findMessage(svc.File.GetPackage(), md.GetInputType())
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	responseType, err := d.findMessage(svc.File.GetPackage(), md.GetOutputType())
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	method := &Method{
    84  		Service:               svc,
    85  		MethodDescriptorProto: md,
    86  		RequestType:           requestType,
    87  		ResponseType:          responseType,
    88  	}
    89  
    90  	if method.Options != nil && proto.HasExtension(method.Options, options.E_Http) {
    91  		err = setBindingsToMethod(method)
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  	}
    96  
    97  	return method, nil
    98  }
    99  
   100  func (d *Descriptor) findMessage(location, name string) (*Message, error) {
   101  	if strings.HasPrefix(name, ".") {
   102  		method, ok := d.messageMap[name]
   103  		if !ok {
   104  			return nil, fmt.Errorf("no message found: %s", name)
   105  		}
   106  		return method, nil
   107  	}
   108  
   109  	if !strings.HasPrefix(location, ".") {
   110  		location = fmt.Sprintf(".%s", location)
   111  	}
   112  
   113  	parts := strings.Split(location, ".")
   114  	for len(parts) > 0 {
   115  		messageName := strings.Join(append(parts, name), ".")
   116  		if method, ok := d.messageMap[messageName]; ok {
   117  			return method, nil
   118  		}
   119  		parts = parts[:len(parts)-1]
   120  	}
   121  
   122  	return nil, fmt.Errorf("no message found: %s", name)
   123  }
   124  
   125  func setBindingsToMethod(method *Method) error {
   126  	routeOpts, err := getRouteOptions(method)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	switch true {
   131  	case routeOpts != nil && routeOpts.GetWebsocket():
   132  		opts, err := getHTTPOptions(method)
   133  		if err != nil {
   134  			return err
   135  		}
   136  
   137  		method.IsWebsocket = true
   138  
   139  		binding := &Binding{
   140  			Method:       method,
   141  			Index:        len(method.Bindings),
   142  			HttpMethod:   "http.MethodGet",
   143  			RpcMethod:    method.GetName(),
   144  			HttpPath:     opts.GetGet(),
   145  			HttpParams:   getVariablesFromPath(opts.GetGet()),
   146  			RequestType:  method.RequestType,
   147  			ResponseType: method.ResponseType,
   148  			Websocket:    true,
   149  		}
   150  		method.Bindings = append(method.Bindings, binding)
   151  
   152  	case routeOpts != nil && routeOpts.GetWebsocketProxy() != nil:
   153  		rOpts := routeOpts.GetWebsocketProxy()
   154  
   155  		method.IsWebsocketProxy = true
   156  
   157  		binding := &Binding{
   158  			Method:         method,
   159  			Index:          len(method.Bindings),
   160  			HttpMethod:     "http.MethodGet",
   161  			RpcMethod:      method.GetName(),
   162  			Service:        rOpts.GetService(),
   163  			RpcPath:        rOpts.GetMethod(),
   164  			RequestType:    method.RequestType,
   165  			ResponseType:   method.ResponseType,
   166  			WebsocketProxy: true,
   167  		}
   168  
   169  		method.Bindings = append(method.Bindings, binding)
   170  
   171  	case proto.HasExtension(method.Options, options.E_Http):
   172  		opts, err := getHTTPOptions(method)
   173  		if err != nil {
   174  			return err
   175  		}
   176  		if opts != nil {
   177  			binding, err := newHttpBinding(method, opts, routeOpts, false)
   178  			if err != nil {
   179  				return err
   180  			}
   181  			method.Bindings = append(method.Bindings, binding)
   182  			for _, additional := range opts.GetAdditionalBindings() {
   183  				if len(additional.AdditionalBindings) > 0 {
   184  					continue
   185  				}
   186  				b, err := newHttpBinding(method, additional, routeOpts, true)
   187  				if err != nil {
   188  					continue
   189  				}
   190  				method.Bindings = append(method.Bindings, b)
   191  			}
   192  		}
   193  	}
   194  
   195  	return nil
   196  }
   197  
   198  func getHTTPOptions(method *Method) (*options.HttpRule, error) {
   199  	if method.Options == nil {
   200  		return nil, nil
   201  	}
   202  	if !proto.HasExtension(method.Options, options.E_Http) {
   203  		return nil, nil
   204  	}
   205  	ext := proto.GetExtension(method.Options, options.E_Http)
   206  	opts, ok := ext.(*options.HttpRule)
   207  	if !ok {
   208  		return nil, fmt.Errorf("extension is not an HttpRule")
   209  	}
   210  	return opts, nil
   211  }
   212  
   213  func getRouteOptions(m *Method) (*toolkit_annotattions.Route, error) {
   214  	if m.Options == nil {
   215  		return nil, nil
   216  	}
   217  	if !proto.HasExtension(m.Options, toolkit_annotattions.E_Route) {
   218  		return nil, nil
   219  	}
   220  	ext := proto.GetExtension(m.Options, toolkit_annotattions.E_Route)
   221  	opts, ok := ext.(*toolkit_annotattions.Route)
   222  	if !ok {
   223  		return nil, fmt.Errorf("extension is not an Route")
   224  	}
   225  	return opts, nil
   226  }
   227  
   228  func newHttpBinding(method *Method, opts *options.HttpRule, rOpts *toolkit_annotattions.Route, additionalBinding bool) (*Binding, error) {
   229  	var (
   230  		httpMethod string
   231  		httpPath   string
   232  	)
   233  	switch {
   234  	case opts.GetGet() != "":
   235  		httpMethod = "http.MethodGet"
   236  		httpPath = opts.GetGet()
   237  	case opts.GetPut() != "":
   238  		httpMethod = "http.MethodPut"
   239  		httpPath = opts.GetPut()
   240  		if opts.Body == "" {
   241  			opts.Body = "*"
   242  		}
   243  	case opts.GetPost() != "":
   244  		httpMethod = "http.MethodPost"
   245  		httpPath = opts.GetPost()
   246  		if opts.Body == "" {
   247  			opts.Body = "*"
   248  		}
   249  	case opts.GetDelete() != "":
   250  		httpMethod = "http.MethodDelete"
   251  		httpPath = opts.GetDelete()
   252  	case opts.GetPatch() != "":
   253  		httpMethod = "http.MethodPatch"
   254  		httpPath = opts.GetPatch()
   255  		if opts.Body == "" {
   256  			opts.Body = "*"
   257  		}
   258  	default:
   259  		return nil, fmt.Errorf("not fount method")
   260  	}
   261  
   262  	b := &Binding{
   263  		Method:                   method,
   264  		Index:                    len(method.Bindings),
   265  		RpcMethod:                method.GetName(),
   266  		HttpMethod:               httpMethod,
   267  		HttpPath:                 httpPath,
   268  		HttpParams:               getVariablesFromPath(httpPath),
   269  		RequestType:              method.RequestType,
   270  		ResponseType:             method.ResponseType,
   271  		Stream:                   method.GetClientStreaming(),
   272  		Middlewares:              rOpts.GetMiddlewares(),
   273  		ExcludeGlobalMiddlewares: rOpts.GetExcludeGlobalMiddlewares(),
   274  		RawBody:                  opts.Body,
   275  		AdditionalBinding:        additionalBinding,
   276  	}
   277  	if proxyOpts := rOpts.GetHttpProxy(); proxyOpts != nil {
   278  		b.Service = proxyOpts.GetService()
   279  		b.RpcPath = proxyOpts.GetMethod()
   280  	}
   281  	return b, nil
   282  }
   283  
   284  func getVariablesFromPath(path string) (variables []string) {
   285  	if path == "" {
   286  		return make([]string, 0)
   287  	}
   288  
   289  	for path != "" {
   290  		firstIndex := -1
   291  		lastIndex := -1
   292  
   293  		firstIndex = strings.IndexAny(path, "{")
   294  		lastIndex = strings.IndexAny(path, "}")
   295  
   296  		if firstIndex > -1 && lastIndex > -1 {
   297  			field := path[firstIndex+1 : lastIndex]
   298  			if len(strings.TrimSpace(field)) > 0 {
   299  				variables = append(variables, field)
   300  			}
   301  			path = path[lastIndex+1:]
   302  		}
   303  
   304  		if firstIndex == -1 || lastIndex == -1 {
   305  			path = path[1:]
   306  		}
   307  	}
   308  
   309  	return variables
   310  }
   311  
   312  func checkSetServerOption(servers []toolkit_annotattions.Runtime_Server, server toolkit_annotattions.Runtime_Server) bool {
   313  	for _, val := range servers {
   314  		if val == server {
   315  			return true
   316  		}
   317  	}
   318  	return false
   319  }