github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/protoc-gen-toolkit/gentoolkit/gentoolkit.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package gentoolkit
    18  
    19  import (
    20  	"fmt"
    21  	"go/format"
    22  	"io"
    23  	"os"
    24  	"path"
    25  	"path/filepath"
    26  	"strings"
    27  
    28  	"github.com/lastbackend/toolkit/protoc-gen-toolkit/descriptor"
    29  	toolkit_annotattions "github.com/lastbackend/toolkit/protoc-gen-toolkit/toolkit/options"
    30  	"google.golang.org/protobuf/proto"
    31  	"google.golang.org/protobuf/types/pluginpb"
    32  )
    33  
    34  const (
    35  	defaultRepoRootPath = "github.com/lastbackend/toolkit-plugins"
    36  )
    37  
    38  type Generator interface {
    39  	Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error)
    40  }
    41  
    42  type generator struct {
    43  	desc        *descriptor.Descriptor
    44  	baseImports []descriptor.GoPackage
    45  }
    46  
    47  func New(desc *descriptor.Descriptor) Generator {
    48  	return &generator{
    49  		desc: desc,
    50  	}
    51  }
    52  
    53  func (g *generator) Generate(files []*descriptor.File) ([]*descriptor.ResponseFile, error) {
    54  	contentFiles := make([]*descriptor.ResponseFile, 0)
    55  
    56  	for _, file := range files {
    57  		if len(file.Services) == 0 {
    58  			continue
    59  		}
    60  
    61  		dir := filepath.Dir(file.GeneratedFilenamePrefix)
    62  		name := filepath.Base(file.GeneratedFilenamePrefix)
    63  
    64  		// Generate service
    65  		filename := filepath.Join(dir, name+"_service.pb.toolkit.go")
    66  		genFiles, err := g.generate(filename, file, g.generateService)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  		if files != nil {
    71  			contentFiles = append(contentFiles, genFiles...)
    72  		}
    73  
    74  		if g.hasServiceMethods(file) {
    75  			// Generate rpc client
    76  			filename = filepath.Join(dir, "client", name+".pb.toolkit.rpc.go")
    77  			genFiles, err = g.generate(filename, file, g.generateClient)
    78  			if err != nil {
    79  				return nil, err
    80  			}
    81  			if files != nil {
    82  				contentFiles = append(contentFiles, genFiles...)
    83  			}
    84  
    85  			// Generate mockery
    86  			if proto.HasExtension(file.Options, toolkit_annotattions.E_TestsSpec) {
    87  				filename = filepath.Join(dir, "tests", name+".pb.toolkit.mockery.go")
    88  				genFiles, err = g.generate(filename, file, g.generateTestStubs)
    89  				if err != nil {
    90  					return nil, err
    91  				}
    92  				if files != nil {
    93  					contentFiles = append(contentFiles, genFiles...)
    94  				}
    95  			}
    96  		}
    97  	}
    98  
    99  	return contentFiles, nil
   100  }
   101  
   102  type genFunc func(file *descriptor.File) ([]byte, error)
   103  
   104  func (g *generator) generate(filename string, file *descriptor.File, fn genFunc) ([]*descriptor.ResponseFile, error) {
   105  	files := make([]*descriptor.ResponseFile, 0)
   106  	content, err := fn(file)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	if content != nil {
   111  		files = append(files, &descriptor.ResponseFile{
   112  			GoPkg: file.GoPkg,
   113  			CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{
   114  				Name:    proto.String(filename),
   115  				Content: proto.String(string(content)),
   116  			},
   117  		})
   118  	}
   119  	return files, nil
   120  }
   121  
   122  func (g *generator) generateService(file *descriptor.File) ([]byte, error) {
   123  
   124  	var pluginImportsExists = make(map[string]bool, 0)
   125  	var clientImportsExists = make(map[string]bool, 0)
   126  	var plugins = make(map[string][]*descriptor.Plugin, 0)
   127  	var definitionPlugins = make(map[string][]*descriptor.Plugin, 0)
   128  	var clients = make(map[string]*Client, 0)
   129  	var imports = g.prepareImports([]string{
   130  		"context",
   131  		"encoding/json",
   132  		"io",
   133  		"net/http",
   134  		"client github.com/lastbackend/toolkit/pkg/client",
   135  		"runtime github.com/lastbackend/toolkit/pkg/runtime",
   136  		"controller github.com/lastbackend/toolkit/pkg/runtime/controller",
   137  		"tk_http github.com/lastbackend/toolkit/pkg/server/http",
   138  		"tk_ws github.com/lastbackend/toolkit/pkg/server/http/websockets",
   139  		"toolkit github.com/lastbackend/toolkit",
   140  		"errors github.com/lastbackend/toolkit/pkg/server/http/errors",
   141  		"emptypb google.golang.org/protobuf/types/known/emptypb",
   142  		"empty github.com/golang/protobuf/ptypes/empty",
   143  	})
   144  
   145  	// checkers for conflicts and duplicates
   146  	var pkgExists = make(map[string]bool, 0)
   147  	var globalPlgExists = make(map[string]bool, 0)
   148  	var globalDuplicatePrefix = make(map[string]bool, 0)
   149  	var conflictPluginPrefix = make(map[string]string, 0)
   150  
   151  	for _, pkg := range g.baseImports {
   152  		pkgExists[pkg.Path] = true
   153  		imports = append(imports, pkg)
   154  	}
   155  
   156  	if file.Options != nil && proto.HasExtension(file.Options, toolkit_annotattions.E_Plugins) {
   157  		ePlugins := proto.GetExtension(file.Options, toolkit_annotattions.E_Plugins)
   158  		if ePlugins != nil {
   159  			plgs := ePlugins.([]*toolkit_annotattions.Plugin)
   160  			for _, props := range plgs {
   161  				if _, ok := plugins[props.Plugin]; !ok {
   162  					plugins[props.Plugin] = make([]*descriptor.Plugin, 0)
   163  				}
   164  
   165  				key := fmt.Sprintf("%s/%s", props.Plugin, props.Prefix)
   166  				if item, ok := conflictPluginPrefix[props.Prefix]; ok && item != key {
   167  					return nil, fmt.Errorf("conflict toolkit.runtime.plugins prefix with another plugin type: '%s'", props.Prefix)
   168  				}
   169  
   170  				if _, ok := globalDuplicatePrefix[props.Prefix]; ok {
   171  					return nil, fmt.Errorf("duplicate toolkit.plugins prefix: '%s'", props.Prefix)
   172  				}
   173  
   174  				if _, ok := globalPlgExists[key]; ok {
   175  					continue
   176  				}
   177  
   178  				if _, ok := pluginImportsExists[props.Plugin]; !ok {
   179  					imports = append(imports, descriptor.GoPackage{
   180  						Path: fmt.Sprintf("%s/%s", defaultRepoRootPath, strings.ToLower(props.Plugin)),
   181  						Name: path.Base(fmt.Sprintf("%s/%s", defaultRepoRootPath, strings.ToLower(props.Plugin))),
   182  					})
   183  				}
   184  
   185  				p := &descriptor.Plugin{
   186  					Plugin:   props.Plugin,
   187  					Prefix:   props.Prefix,
   188  					Pkg:      strings.ToLower(props.Plugin),
   189  					IsGlobal: true,
   190  				}
   191  
   192  				definePlugins(definitionPlugins, props.Plugin, p)
   193  
   194  				plugins[props.Plugin] = append(plugins[props.Plugin], p)
   195  
   196  				globalPlgExists[key] = true
   197  				globalDuplicatePrefix[props.Prefix] = true
   198  				conflictPluginPrefix[props.Prefix] = key
   199  			}
   200  		}
   201  	}
   202  
   203  	for _, svc := range file.Services {
   204  		var servicePlgExists = make(map[string]bool, 0)
   205  		var serviceDuplicatePrefix = make(map[string]bool, 0)
   206  		for _, m := range svc.Methods {
   207  			pkg := m.RequestType.File.GoPkg
   208  			if pkg == file.GoPkg || pkgExists[pkg.Path] {
   209  				continue
   210  			}
   211  			pkgExists[pkg.Path] = true
   212  			imports = append(imports, pkg)
   213  		}
   214  
   215  		svc.Plugins = make(map[string][]*descriptor.Plugin, 0)
   216  
   217  		if svc.Options != nil && proto.HasExtension(svc.Options, toolkit_annotattions.E_Runtime) {
   218  			eService := proto.GetExtension(svc.Options, toolkit_annotattions.E_Runtime)
   219  			if eService != nil {
   220  				ss := eService.(*toolkit_annotattions.Runtime)
   221  				if ss.Plugins != nil {
   222  					for _, props := range ss.Plugins {
   223  
   224  						key := fmt.Sprintf("%s/%s", props.Plugin, props.Prefix)
   225  						if item, ok := conflictPluginPrefix[props.Prefix]; ok && item != key {
   226  							return nil, fmt.Errorf("conflict toolkit.runtime.plugins prefix with another plugin type: '%s'", props.Prefix)
   227  						}
   228  
   229  						_, gOk := globalPlgExists[key]
   230  
   231  						if _, ok := globalDuplicatePrefix[props.Prefix]; ok && !gOk {
   232  							return nil, fmt.Errorf("duplicate toolkit.runtime.plugins prefix: '%s'", props.Prefix)
   233  						}
   234  						if _, ok := serviceDuplicatePrefix[props.Prefix]; ok {
   235  							return nil, fmt.Errorf("duplicate toolkit.runtime.plugins prefix: '%s'", props.Prefix)
   236  						}
   237  
   238  						if _, ok := servicePlgExists[key]; ok {
   239  							continue
   240  						}
   241  
   242  						if _, ok := svc.Plugins[props.Plugin]; !ok {
   243  							svc.Plugins[props.Plugin] = make([]*descriptor.Plugin, 0)
   244  						}
   245  
   246  						if _, ok := pluginImportsExists[props.Plugin]; !ok {
   247  							imports = append(imports, descriptor.GoPackage{
   248  								Path: fmt.Sprintf("%s/%s", defaultRepoRootPath, strings.ToLower(props.Plugin)),
   249  								Name: path.Base(fmt.Sprintf("%s/%s", defaultRepoRootPath, strings.ToLower(props.Plugin))),
   250  							})
   251  						}
   252  
   253  						p := &descriptor.Plugin{
   254  							Plugin: props.Plugin,
   255  							Prefix: props.Prefix,
   256  							Pkg:    strings.ToLower(props.Plugin),
   257  						}
   258  
   259  						definePlugins(definitionPlugins, props.Plugin, p)
   260  
   261  						if _, ok := globalPlgExists[key]; !ok {
   262  							svc.Plugins[props.Plugin] = append(svc.Plugins[props.Plugin], p)
   263  						}
   264  
   265  						servicePlgExists[key] = true
   266  						serviceDuplicatePrefix[props.Prefix] = true
   267  						conflictPluginPrefix[props.Prefix] = key
   268  					}
   269  				}
   270  			}
   271  		}
   272  	}
   273  
   274  	if file.Options != nil && proto.HasExtension(file.Options, toolkit_annotattions.E_Services) {
   275  		eClients := proto.GetExtension(file.Options, toolkit_annotattions.E_Services)
   276  		if eClients != nil {
   277  			clnts := eClients.([]*toolkit_annotattions.Service)
   278  			for _, value := range clnts {
   279  				if _, ok := clientImportsExists[value.Service]; !ok {
   280  					imports = append(imports, descriptor.GoPackage{
   281  						Alias: strings.ToLower(value.Service),
   282  						Path:  value.Package,
   283  					})
   284  				}
   285  				clients[value.Service] = &Client{
   286  					Service: value.Service,
   287  					Pkg:     value.Package,
   288  				}
   289  			}
   290  		}
   291  	}
   292  
   293  	to := tplServiceOptions{
   294  		File:              file,
   295  		Imports:           imports,
   296  		Clients:           clients,
   297  		Plugins:           plugins,
   298  		DefinitionPlugins: definitionPlugins,
   299  	}
   300  
   301  	content, err := applyServiceTemplate(to)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	return format.Source([]byte(content))
   307  }
   308  
   309  func (g *generator) generateClient(file *descriptor.File) ([]byte, error) {
   310  
   311  	pkgImports := []string{
   312  		"context context",
   313  		"client github.com/lastbackend/toolkit/pkg/client",
   314  		"emptypb google.golang.org/protobuf/types/known/emptypb",
   315  		"empty github.com/golang/protobuf/ptypes/empty",
   316  	}
   317  
   318  	var imports = g.prepareImports(pkgImports)
   319  
   320  	for _, svc := range file.Services {
   321  		for _, m := range svc.Methods {
   322  			if m.IsWebsocket || m.IsWebsocketProxy {
   323  				continue
   324  			}
   325  
   326  			pkg := m.RequestType.File.GoPkg
   327  			if strings.HasPrefix(m.RequestType.File.GoPkg.Path, "./") {
   328  				pkg.Path = filepath.Join(file.GoPkg.Name, m.RequestType.File.GoPkg.Path)
   329  			}
   330  
   331  			imports = append(imports, pkg)
   332  		}
   333  	}
   334  
   335  	var clients = make(map[string]*Client, 0)
   336  
   337  	if file.Options != nil && proto.HasExtension(file.Options, toolkit_annotattions.E_Services) {
   338  		eClients := proto.GetExtension(file.Options, toolkit_annotattions.E_Services)
   339  		if eClients != nil {
   340  			clnts := eClients.([]*toolkit_annotattions.Service)
   341  			for _, value := range clnts {
   342  				clients[value.Service] = &Client{
   343  					Service: value.Service,
   344  					Pkg:     value.Package,
   345  				}
   346  			}
   347  		}
   348  	}
   349  
   350  	to := tplClientOptions{
   351  		File:    file,
   352  		Imports: imports,
   353  		Clients: clients,
   354  	}
   355  
   356  	content, err := applyClientTemplate(to)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	return format.Source([]byte(content))
   362  }
   363  
   364  func (g *generator) generateTestStubs(file *descriptor.File) ([]byte, error) {
   365  	ext := proto.GetExtension(file.Options, toolkit_annotattions.E_TestsSpec)
   366  	opts, ok := ext.(*toolkit_annotattions.TestSpec)
   367  	if !ok {
   368  		return nil, nil
   369  	}
   370  
   371  	baseImports := []string{
   372  		"context context",
   373  		"client github.com/lastbackend/toolkit/pkg/client",
   374  		"emptypb google.golang.org/protobuf/types/known/emptypb",
   375  		"empty github.com/golang/protobuf/ptypes/empty",
   376  		fmt.Sprintf("servicepb %s/client", filepath.Dir(file.GeneratedFilenamePrefix)),
   377  	}
   378  
   379  	if len(opts.Mockery.Package) == 0 {
   380  		opts.Mockery.Package = "github.com/dummy/dummy"
   381  	}
   382  
   383  	var dirErr error
   384  	dir := filepath.Join(os.Getenv("GOPATH"), "src", opts.Mockery.Package)
   385  	if ok, _ := existsFileOrDir(dir); !ok {
   386  		dirErr = fmt.Errorf("directory %s does not exist", dir)
   387  	}
   388  	if ok, _ := dirIsEmpty(dir); ok {
   389  		dirErr = fmt.Errorf("directory %s is empty", dir)
   390  	}
   391  	if dirErr != nil {
   392  		content, err := applyTemplateWithMessage(tplMessageOptions{
   393  			File:    file,
   394  			Message: "Warning: You have no mock in provided directory. Please check mockery docs for mocks generation.",
   395  		})
   396  		if err != nil {
   397  			return nil, err
   398  		}
   399  
   400  		return format.Source([]byte(content))
   401  	}
   402  
   403  	var imports = g.prepareImports(baseImports)
   404  
   405  	imports = append(imports, descriptor.GoPackage{
   406  		Path:  fmt.Sprintf(opts.Mockery.Package),
   407  		Name:  path.Base(opts.Mockery.Package),
   408  		Alias: "service_mocks",
   409  	})
   410  
   411  	for _, svc := range file.Services {
   412  		for _, m := range svc.Methods {
   413  			imports = append(imports, m.RequestType.File.GoPkg)
   414  		}
   415  	}
   416  
   417  	content, err := applyTestTemplate(tplMockeryTestOptions{
   418  		File:    file,
   419  		Imports: imports,
   420  	})
   421  	if err != nil {
   422  		return nil, err
   423  	}
   424  
   425  	return format.Source([]byte(content))
   426  }
   427  
   428  func (g *generator) prepareImports(importList []string) []descriptor.GoPackage {
   429  	var imports []descriptor.GoPackage
   430  	for _, pkgPath := range importList {
   431  		var pkg descriptor.GoPackage
   432  
   433  		match := strings.Split(pkgPath, " ")
   434  		if len(match) == 2 {
   435  			pkg = descriptor.GoPackage{
   436  				Path:  match[1],
   437  				Name:  path.Base(match[1]),
   438  				Alias: match[0],
   439  			}
   440  		} else {
   441  			pkg = descriptor.GoPackage{
   442  				Path: pkgPath,
   443  				Name: path.Base(pkgPath),
   444  			}
   445  		}
   446  
   447  		if len(pkg.Alias) == 0 {
   448  			if err := g.desc.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
   449  				for i := 0; ; i++ {
   450  					alias := fmt.Sprintf("%s_%d", pkg.Name, i)
   451  					if err := g.desc.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
   452  						continue
   453  					}
   454  					pkg.Alias = alias
   455  					break
   456  				}
   457  			}
   458  		}
   459  		imports = append(imports, pkg)
   460  	}
   461  
   462  	return imports
   463  }
   464  
   465  func (g *generator) hasServiceMethods(file *descriptor.File) bool {
   466  	for _, service := range file.Services {
   467  		if len(service.Methods) > 0 {
   468  			return true
   469  		}
   470  	}
   471  	return false
   472  }
   473  
   474  func definePlugins(def map[string][]*descriptor.Plugin, name string, plugin *descriptor.Plugin) {
   475  	pItems := def[name]
   476  	exists := false
   477  	for _, p := range pItems {
   478  		if plugin.Prefix == p.Prefix && plugin.Plugin == p.Plugin {
   479  			exists = true
   480  		}
   481  	}
   482  	if !exists {
   483  		if _, ok := def[name]; !ok {
   484  			def[name] = make([]*descriptor.Plugin, 0)
   485  		}
   486  		def[name] = append(def[name], plugin)
   487  	}
   488  }
   489  
   490  func existsFileOrDir(path string) (bool, error) {
   491  	_, err := os.Stat(path)
   492  	if err == nil {
   493  		return true, nil
   494  	}
   495  	if os.IsNotExist(err) {
   496  		return false, nil
   497  	}
   498  	return false, err
   499  }
   500  
   501  func dirIsEmpty(name string) (bool, error) {
   502  	f, err := os.Open(name)
   503  	if err != nil {
   504  		return false, err
   505  	}
   506  	defer f.Close()
   507  
   508  	_, err = f.Readdirnames(1)
   509  	if err == io.EOF {
   510  		return true, nil
   511  	}
   512  	return false, err
   513  }