github.com/haalcala/mattermost-server-change-repo@v0.0.0-20210713015153-16753fbeee5f/plugin/interface_generator/main.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/printer"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"log"
    15  	"os/exec"
    16  	"path/filepath"
    17  	"strings"
    18  	"text/template"
    19  
    20  	"github.com/pkg/errors"
    21  	"golang.org/x/tools/imports"
    22  )
    23  
    24  type IHookEntry struct {
    25  	FuncName string
    26  	Args     *ast.FieldList
    27  	Results  *ast.FieldList
    28  }
    29  
    30  type PluginInterfaceInfo struct {
    31  	Hooks   []IHookEntry
    32  	API     []IHookEntry
    33  	FileSet *token.FileSet
    34  }
    35  
    36  func FieldListToFuncList(fieldList *ast.FieldList, fileset *token.FileSet) string {
    37  	result := []string{}
    38  	if fieldList == nil || len(fieldList.List) == 0 {
    39  		return "()"
    40  	}
    41  	for _, field := range fieldList.List {
    42  		typeNameBuffer := &bytes.Buffer{}
    43  		err := printer.Fprint(typeNameBuffer, fileset, field.Type)
    44  		if err != nil {
    45  			panic(err)
    46  		}
    47  		typeName := typeNameBuffer.String()
    48  		names := []string{}
    49  		for _, name := range field.Names {
    50  			names = append(names, name.Name)
    51  		}
    52  		result = append(result, strings.Join(names, ", ")+" "+typeName)
    53  	}
    54  
    55  	return "(" + strings.Join(result, ", ") + ")"
    56  }
    57  
    58  func FieldListToNames(fieldList *ast.FieldList, variadicForm bool) string {
    59  	result := []string{}
    60  	if fieldList == nil || len(fieldList.List) == 0 {
    61  		return ""
    62  	}
    63  	for _, field := range fieldList.List {
    64  		for _, name := range field.Names {
    65  			paramName := name.Name
    66  			if _, ok := field.Type.(*ast.Ellipsis); ok && variadicForm {
    67  				paramName = fmt.Sprintf("%s...", paramName)
    68  			}
    69  			result = append(result, paramName)
    70  		}
    71  	}
    72  
    73  	return strings.Join(result, ", ")
    74  }
    75  
    76  func FieldListToEncodedErrors(structPrefix string, fieldList *ast.FieldList, fileset *token.FileSet) string {
    77  	result := []string{}
    78  	if fieldList == nil {
    79  		return ""
    80  	}
    81  
    82  	nextLetter := 'A'
    83  	for _, field := range fieldList.List {
    84  		typeNameBuffer := &bytes.Buffer{}
    85  		err := printer.Fprint(typeNameBuffer, fileset, field.Type)
    86  		if err != nil {
    87  			panic(err)
    88  		}
    89  
    90  		if typeNameBuffer.String() != "error" {
    91  			nextLetter++
    92  			continue
    93  		}
    94  
    95  		name := ""
    96  		if len(field.Names) == 0 {
    97  			name = string(nextLetter)
    98  			nextLetter++
    99  		} else {
   100  			for range field.Names {
   101  				name += string(nextLetter)
   102  				nextLetter++
   103  			}
   104  		}
   105  
   106  		result = append(result, structPrefix+name+" = encodableError("+structPrefix+name+")")
   107  
   108  	}
   109  
   110  	return strings.Join(result, "\n")
   111  }
   112  
   113  func FieldListDestruct(structPrefix string, fieldList *ast.FieldList, fileset *token.FileSet) string {
   114  	result := []string{}
   115  	if fieldList == nil || len(fieldList.List) == 0 {
   116  		return ""
   117  	}
   118  	nextLetter := 'A'
   119  	for _, field := range fieldList.List {
   120  		typeNameBuffer := &bytes.Buffer{}
   121  		err := printer.Fprint(typeNameBuffer, fileset, field.Type)
   122  		if err != nil {
   123  			panic(err)
   124  		}
   125  		typeName := typeNameBuffer.String()
   126  		suffix := ""
   127  		if strings.HasPrefix(typeName, "...") {
   128  			suffix = "..."
   129  		}
   130  		if len(field.Names) == 0 {
   131  			result = append(result, structPrefix+string(nextLetter)+suffix)
   132  			nextLetter++
   133  		} else {
   134  			for range field.Names {
   135  				result = append(result, structPrefix+string(nextLetter)+suffix)
   136  				nextLetter++
   137  			}
   138  		}
   139  	}
   140  
   141  	return strings.Join(result, ", ")
   142  }
   143  
   144  func FieldListToRecordSuccess(structPrefix string, fieldList *ast.FieldList) string {
   145  	if fieldList == nil || len(fieldList.List) == 0 {
   146  		return "true"
   147  	}
   148  
   149  	result := ""
   150  	nextLetter := 'A'
   151  	for _, field := range fieldList.List {
   152  		typeName := baseTypeName(field.Type)
   153  		if typeName == "error" || typeName == "AppError" {
   154  			result = structPrefix + string(nextLetter)
   155  			break
   156  		}
   157  		nextLetter++
   158  	}
   159  
   160  	if result == "" {
   161  		return "true"
   162  	}
   163  	return fmt.Sprintf("%s == nil", result)
   164  }
   165  
   166  func FieldListToStructList(fieldList *ast.FieldList, fileset *token.FileSet) string {
   167  	result := []string{}
   168  	if fieldList == nil || len(fieldList.List) == 0 {
   169  		return ""
   170  	}
   171  	nextLetter := 'A'
   172  	for _, field := range fieldList.List {
   173  		typeNameBuffer := &bytes.Buffer{}
   174  		err := printer.Fprint(typeNameBuffer, fileset, field.Type)
   175  		if err != nil {
   176  			panic(err)
   177  		}
   178  		typeName := typeNameBuffer.String()
   179  		if strings.HasPrefix(typeName, "...") {
   180  			typeName = strings.Replace(typeName, "...", "[]", 1)
   181  		}
   182  		if len(field.Names) == 0 {
   183  			result = append(result, string(nextLetter)+" "+typeName)
   184  			nextLetter++
   185  		} else {
   186  			for range field.Names {
   187  				result = append(result, string(nextLetter)+" "+typeName)
   188  				nextLetter++
   189  			}
   190  		}
   191  	}
   192  
   193  	return strings.Join(result, "\n\t")
   194  }
   195  
   196  func baseTypeName(x ast.Expr) string {
   197  	switch t := x.(type) {
   198  	case *ast.Ident:
   199  		return t.Name
   200  	case *ast.SelectorExpr:
   201  		if _, ok := t.X.(*ast.Ident); ok {
   202  			// only possible for qualified type names;
   203  			// assume type is imported
   204  			return t.Sel.Name
   205  		}
   206  	case *ast.ParenExpr:
   207  		return baseTypeName(t.X)
   208  	case *ast.StarExpr:
   209  		return baseTypeName(t.X)
   210  	}
   211  	return ""
   212  }
   213  
   214  func goList(dir string) ([]string, error) {
   215  	cmd := exec.Command("go", "list", "-f", "{{.Dir}}", dir)
   216  	bytes, err := cmd.Output()
   217  	if err != nil {
   218  		return nil, errors.Wrap(err, "Can't list packages")
   219  	}
   220  
   221  	return strings.Fields(string(bytes)), nil
   222  }
   223  
   224  func (info *PluginInterfaceInfo) addHookMethod(method *ast.Field) {
   225  	info.Hooks = append(info.Hooks, IHookEntry{
   226  		FuncName: method.Names[0].Name,
   227  		Args:     method.Type.(*ast.FuncType).Params,
   228  		Results:  method.Type.(*ast.FuncType).Results,
   229  	})
   230  }
   231  
   232  func (info *PluginInterfaceInfo) addAPIMethod(method *ast.Field) {
   233  	info.API = append(info.API, IHookEntry{
   234  		FuncName: method.Names[0].Name,
   235  		Args:     method.Type.(*ast.FuncType).Params,
   236  		Results:  method.Type.(*ast.FuncType).Results,
   237  	})
   238  }
   239  
   240  func (info *PluginInterfaceInfo) makeHookInspector() func(node ast.Node) bool {
   241  	return func(node ast.Node) bool {
   242  		if typeSpec, ok := node.(*ast.TypeSpec); ok {
   243  			if typeSpec.Name.Name == "Hooks" {
   244  				for _, method := range typeSpec.Type.(*ast.InterfaceType).Methods.List {
   245  					info.addHookMethod(method)
   246  				}
   247  				return false
   248  			} else if typeSpec.Name.Name == "API" {
   249  				for _, method := range typeSpec.Type.(*ast.InterfaceType).Methods.List {
   250  					info.addAPIMethod(method)
   251  				}
   252  				return false
   253  			}
   254  		}
   255  		return true
   256  	}
   257  }
   258  
   259  func getPluginInfo(dir string) (*PluginInterfaceInfo, error) {
   260  	pluginInfo := &PluginInterfaceInfo{
   261  		Hooks:   make([]IHookEntry, 0),
   262  		FileSet: token.NewFileSet(),
   263  	}
   264  
   265  	packages, err := parser.ParseDir(pluginInfo.FileSet, dir, nil, parser.ParseComments)
   266  	if err != nil {
   267  		log.Println("Parser error in dir "+dir+": ", err)
   268  		return nil, err
   269  	}
   270  
   271  	for _, pkg := range packages {
   272  		if pkg.Name != "plugin" {
   273  			continue
   274  		}
   275  
   276  		for _, file := range pkg.Files {
   277  			ast.Inspect(file, pluginInfo.makeHookInspector())
   278  		}
   279  	}
   280  
   281  	return pluginInfo, nil
   282  }
   283  
   284  var hooksTemplate = `// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
   285  // See LICENSE.txt for license information.
   286  
   287  // Code generated by "make pluginapi"
   288  // DO NOT EDIT
   289  
   290  package plugin
   291  
   292  {{range .HooksMethods}}
   293  
   294  func init() {
   295  	hookNameToId["{{.Name}}"] = {{.Name}}Id
   296  }
   297  
   298  type {{.Name | obscure}}Args struct {
   299  	{{structStyle .Params}}
   300  }
   301  
   302  type {{.Name | obscure}}Returns struct {
   303  	{{structStyle .Return}}
   304  }
   305  
   306  func (g *hooksRPCClient) {{.Name}}{{funcStyle .Params}} {{funcStyle .Return}} {
   307  	_args := &{{.Name | obscure}}Args{ {{valuesOnly .Params}} }
   308  	_returns := &{{.Name | obscure}}Returns{}
   309  	if g.implemented[{{.Name}}Id] {
   310  		if err := g.client.Call("Plugin.{{.Name}}", _args, _returns); err != nil {
   311  			g.log.Error("RPC call {{.Name}} to plugin failed.", mlog.Err(err))
   312  		}
   313  	}
   314  	{{ if .Return }} return {{destruct "_returns." .Return}} {{ end }}
   315  }
   316  
   317  func (s *hooksRPCServer) {{.Name}}(args *{{.Name | obscure}}Args, returns *{{.Name | obscure}}Returns) error {
   318  	if hook, ok := s.impl.(interface {
   319  		{{.Name}}{{funcStyle .Params}} {{funcStyle .Return}}
   320  	}); ok {
   321  		{{if .Return}}{{destruct "returns." .Return}} = {{end}}hook.{{.Name}}({{destruct "args." .Params}})
   322  		{{if .Return}}{{encodeErrors "returns." .Return}}{{end -}}
   323  	} else {
   324  		return encodableError(fmt.Errorf("Hook {{.Name}} called but not implemented."))
   325  	}
   326  	return nil
   327  }
   328  {{end}}
   329  
   330  {{range .APIMethods}}
   331  
   332  type {{.Name | obscure}}Args struct {
   333  	{{structStyle .Params}}
   334  }
   335  
   336  type {{.Name | obscure}}Returns struct {
   337  	{{structStyle .Return}}
   338  }
   339  
   340  func (g *apiRPCClient) {{.Name}}{{funcStyle .Params}} {{funcStyle .Return}} {
   341  	_args := &{{.Name | obscure}}Args{ {{valuesOnly .Params}} }
   342  	_returns := &{{.Name | obscure}}Returns{}
   343  	if err := g.client.Call("Plugin.{{.Name}}", _args, _returns); err != nil {
   344  		log.Printf("RPC call to {{.Name}} API failed: %s", err.Error())
   345  	}
   346  	{{ if .Return }} return {{destruct "_returns." .Return}} {{ end }}
   347  }
   348  
   349  func (s *apiRPCServer) {{.Name}}(args *{{.Name | obscure}}Args, returns *{{.Name | obscure}}Returns) error {
   350  	if hook, ok := s.impl.(interface {
   351  		{{.Name}}{{funcStyle .Params}} {{funcStyle .Return}}
   352  	}); ok {
   353  		{{if .Return}}{{destruct "returns." .Return}} = {{end}}hook.{{.Name}}({{destruct "args." .Params}})
   354  		{{if .Return}}{{encodeErrors "returns." .Return}}{{end -}}
   355  	} else {
   356  		return encodableError(fmt.Errorf("API {{.Name}} called but not implemented."))
   357  	}
   358  	return nil
   359  }
   360  {{end}}
   361  `
   362  
   363  var apiTimerLayerTemplate = `// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
   364  // See LICENSE.txt for license information.
   365  
   366  // Code generated by "make pluginapi"
   367  // DO NOT EDIT
   368  
   369  package plugin
   370  
   371  import (
   372  	"io"
   373  	"net/http"
   374  	timePkg "time"
   375  
   376  	"github.com/mattermost/mattermost-server/v5/einterfaces"
   377  	"github.com/mattermost/mattermost-server/v5/model"
   378  )
   379  
   380  type apiTimerLayer struct {
   381  	pluginID string
   382  	apiImpl  API
   383  	metrics  einterfaces.MetricsInterface
   384  }
   385  
   386  func (api *apiTimerLayer) recordTime(startTime timePkg.Time, name string, success bool) {
   387  	if api.metrics != nil {
   388  		elapsedTime := float64(timePkg.Since(startTime)) / float64(timePkg.Second)
   389  		api.metrics.ObservePluginApiDuration(api.pluginID, name, success, elapsedTime)
   390  	}
   391  }
   392  
   393  {{range .APIMethods}}
   394  
   395  func (api *apiTimerLayer) {{.Name}}{{funcStyle .Params}} {{funcStyle .Return}} {
   396  	startTime := timePkg.Now()
   397  	{{ if .Return }} {{destruct "_returns" .Return}} := {{ end }} api.apiImpl.{{.Name}}({{valuesOnly .Params}})
   398  	api.recordTime(startTime, "{{.Name}}", {{ shouldRecordSuccess "_returns" .Return }})
   399  	{{ if .Return }} return {{destruct "_returns" .Return}} {{ end -}}
   400  }
   401  
   402  {{end}}
   403  `
   404  
   405  var hooksTimerLayerTemplate = `// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
   406  // See LICENSE.txt for license information.
   407  
   408  // Code generated by "make pluginapi"
   409  // DO NOT EDIT
   410  
   411  package plugin
   412  
   413  import (
   414  	"io"
   415  	"net/http"
   416  	timePkg "time"
   417  
   418  	"github.com/mattermost/mattermost-server/v5/einterfaces"
   419  	"github.com/mattermost/mattermost-server/v5/model"
   420  )
   421  
   422  type hooksTimerLayer struct {
   423  	pluginID  string
   424  	hooksImpl Hooks
   425  	metrics   einterfaces.MetricsInterface
   426  }
   427  
   428  func (hooks *hooksTimerLayer) recordTime(startTime timePkg.Time, name string, success bool) {
   429  	if hooks.metrics != nil {
   430  		elapsedTime := float64(timePkg.Since(startTime)) / float64(timePkg.Second)
   431  		hooks.metrics.ObservePluginHookDuration(hooks.pluginID, name, success, elapsedTime)
   432  	}
   433  }
   434  
   435  {{range .HooksMethods}}
   436  
   437  func (hooks *hooksTimerLayer) {{.Name}}{{funcStyle .Params}} {{funcStyle .Return}} {
   438  	startTime := timePkg.Now()
   439  	{{ if .Return }} {{destruct "_returns" .Return}} := {{ end }} hooks.hooksImpl.{{.Name}}({{valuesOnly .Params}})
   440  	hooks.recordTime(startTime, "{{.Name}}", {{ shouldRecordSuccess "_returns" .Return }})
   441  	{{ if .Return }} return {{destruct "_returns" .Return}} {{end -}}
   442  }
   443  
   444  {{end}}
   445  `
   446  
   447  type MethodParams struct {
   448  	Name   string
   449  	Params *ast.FieldList
   450  	Return *ast.FieldList
   451  }
   452  
   453  type HooksTemplateParams struct {
   454  	HooksMethods []MethodParams
   455  	APIMethods   []MethodParams
   456  }
   457  
   458  func generateHooksGlue(info *PluginInterfaceInfo) {
   459  	templateFunctions := map[string]interface{}{
   460  		"funcStyle":   func(fields *ast.FieldList) string { return FieldListToFuncList(fields, info.FileSet) },
   461  		"structStyle": func(fields *ast.FieldList) string { return FieldListToStructList(fields, info.FileSet) },
   462  		"valuesOnly":  func(fields *ast.FieldList) string { return FieldListToNames(fields, false) },
   463  		"encodeErrors": func(structPrefix string, fields *ast.FieldList) string {
   464  			return FieldListToEncodedErrors(structPrefix, fields, info.FileSet)
   465  		},
   466  		"destruct": func(structPrefix string, fields *ast.FieldList) string {
   467  			return FieldListDestruct(structPrefix, fields, info.FileSet)
   468  		},
   469  		"shouldRecordSuccess": func(structPrefix string, fields *ast.FieldList) string {
   470  			return FieldListToRecordSuccess(structPrefix, fields)
   471  		},
   472  		"obscure": func(name string) string {
   473  			return "Z_" + name
   474  		},
   475  	}
   476  
   477  	hooksTemplate, err := template.New("hooks").Funcs(templateFunctions).Parse(hooksTemplate)
   478  	if err != nil {
   479  		panic(err)
   480  	}
   481  
   482  	templateParams := HooksTemplateParams{}
   483  	for _, hook := range info.Hooks {
   484  		templateParams.HooksMethods = append(templateParams.HooksMethods, MethodParams{
   485  			Name:   hook.FuncName,
   486  			Params: hook.Args,
   487  			Return: hook.Results,
   488  		})
   489  	}
   490  	for _, api := range info.API {
   491  		templateParams.APIMethods = append(templateParams.APIMethods, MethodParams{
   492  			Name:   api.FuncName,
   493  			Params: api.Args,
   494  			Return: api.Results,
   495  		})
   496  	}
   497  	templateResult := &bytes.Buffer{}
   498  	hooksTemplate.Execute(templateResult, &templateParams)
   499  
   500  	formatted, err := imports.Process("", templateResult.Bytes(), nil)
   501  	if err != nil {
   502  		panic(err)
   503  	}
   504  
   505  	if err := ioutil.WriteFile(filepath.Join(getPluginPackageDir(), "client_rpc_generated.go"), formatted, 0664); err != nil {
   506  		panic(err)
   507  	}
   508  }
   509  
   510  func generatePluginTimerLayer(info *PluginInterfaceInfo) {
   511  	templateFunctions := map[string]interface{}{
   512  		"funcStyle":   func(fields *ast.FieldList) string { return FieldListToFuncList(fields, info.FileSet) },
   513  		"structStyle": func(fields *ast.FieldList) string { return FieldListToStructList(fields, info.FileSet) },
   514  		"valuesOnly":  func(fields *ast.FieldList) string { return FieldListToNames(fields, true) },
   515  		"destruct": func(structPrefix string, fields *ast.FieldList) string {
   516  			return FieldListDestruct(structPrefix, fields, info.FileSet)
   517  		},
   518  		"shouldRecordSuccess": func(structPrefix string, fields *ast.FieldList) string {
   519  			return FieldListToRecordSuccess(structPrefix, fields)
   520  		},
   521  	}
   522  
   523  	// Prepare template params
   524  	templateParams := HooksTemplateParams{}
   525  	for _, hook := range info.Hooks {
   526  		templateParams.HooksMethods = append(templateParams.HooksMethods, MethodParams{
   527  			Name:   hook.FuncName,
   528  			Params: hook.Args,
   529  			Return: hook.Results,
   530  		})
   531  	}
   532  	for _, api := range info.API {
   533  		templateParams.APIMethods = append(templateParams.APIMethods, MethodParams{
   534  			Name:   api.FuncName,
   535  			Params: api.Args,
   536  			Return: api.Results,
   537  		})
   538  	}
   539  
   540  	pluginTemplates := map[string]string{
   541  		"api_timer_layer_generated.go":   apiTimerLayerTemplate,
   542  		"hooks_timer_layer_generated.go": hooksTimerLayerTemplate,
   543  	}
   544  
   545  	for fileName, presetTemplate := range pluginTemplates {
   546  		parsedTemplate, err := template.New("hooks").Funcs(templateFunctions).Parse(presetTemplate)
   547  		if err != nil {
   548  			panic(err)
   549  		}
   550  
   551  		templateResult := &bytes.Buffer{}
   552  		parsedTemplate.Execute(templateResult, &templateParams)
   553  
   554  		formatted, err := imports.Process("", templateResult.Bytes(), nil)
   555  		if err != nil {
   556  			panic(err)
   557  		}
   558  
   559  		if err := ioutil.WriteFile(filepath.Join(getPluginPackageDir(), fileName), formatted, 0664); err != nil {
   560  			panic(err)
   561  		}
   562  	}
   563  }
   564  
   565  func getPluginPackageDir() string {
   566  	dirs, err := goList("github.com/mattermost/mattermost-server/v5/plugin")
   567  	if err != nil {
   568  		panic(err)
   569  	} else if len(dirs) != 1 {
   570  		panic("More than one package dir, or no dirs!")
   571  	}
   572  
   573  	return dirs[0]
   574  }
   575  
   576  func removeExcluded(info *PluginInterfaceInfo) *PluginInterfaceInfo {
   577  	toBeExcluded := func(item string) bool {
   578  		excluded := []string{
   579  			"FileWillBeUploaded",
   580  			"Implemented",
   581  			"LoadPluginConfiguration",
   582  			"InstallPlugin",
   583  			"LogDebug",
   584  			"LogError",
   585  			"LogInfo",
   586  			"LogWarn",
   587  			"MessageWillBePosted",
   588  			"MessageWillBeUpdated",
   589  			"OnActivate",
   590  			"PluginHTTP",
   591  			"ServeHTTP",
   592  		}
   593  		for _, exclusion := range excluded {
   594  			if exclusion == item {
   595  				return true
   596  			}
   597  		}
   598  		return false
   599  	}
   600  	hooksResult := make([]IHookEntry, 0, len(info.Hooks))
   601  	for _, hook := range info.Hooks {
   602  		if !toBeExcluded(hook.FuncName) {
   603  			hooksResult = append(hooksResult, hook)
   604  		}
   605  	}
   606  	info.Hooks = hooksResult
   607  
   608  	apiResult := make([]IHookEntry, 0, len(info.API))
   609  	for _, api := range info.API {
   610  		if !toBeExcluded(api.FuncName) {
   611  			apiResult = append(apiResult, api)
   612  		}
   613  	}
   614  	info.API = apiResult
   615  
   616  	return info
   617  }
   618  
   619  func main() {
   620  	pluginPackageDir := getPluginPackageDir()
   621  
   622  	log.Println("Generating plugin hooks glue")
   623  	forRPC, err := getPluginInfo(pluginPackageDir)
   624  	if err != nil {
   625  		fmt.Println("Unable to get plugin info: " + err.Error())
   626  	}
   627  	generateHooksGlue(removeExcluded(forRPC))
   628  
   629  	// Generate plugin timer layers
   630  	log.Println("Generating plugin timer glue")
   631  	forPlugins, err := getPluginInfo(pluginPackageDir)
   632  	if err != nil {
   633  		fmt.Println("Unable to get plugin info: " + err.Error())
   634  	}
   635  	generatePluginTimerLayer(forPlugins)
   636  }