github.com/secoba/wails/v2@v2.6.4/internal/binding/binding.go (about)

     1  package binding
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  	"reflect"
    10  	"runtime"
    11  	"sort"
    12  	"strings"
    13  
    14  	"github.com/secoba/wails/v2/internal/typescriptify"
    15  
    16  	"github.com/leaanthony/slicer"
    17  	"github.com/secoba/wails/v2/internal/logger"
    18  )
    19  
    20  type Bindings struct {
    21  	db         *DB
    22  	logger     logger.CustomLogger
    23  	exemptions slicer.StringSlicer
    24  
    25  	structsToGenerateTS map[string]map[string]interface{}
    26  	tsPrefix            string
    27  	tsSuffix            string
    28  	obfuscate           bool
    29  }
    30  
    31  // NewBindings returns a new Bindings object
    32  func NewBindings(logger *logger.Logger, structPointersToBind []interface{}, exemptions []interface{}, obfuscate bool) *Bindings {
    33  	result := &Bindings{
    34  		db:                  newDB(),
    35  		logger:              logger.CustomLogger("Bindings"),
    36  		structsToGenerateTS: make(map[string]map[string]interface{}),
    37  		obfuscate:           obfuscate,
    38  	}
    39  
    40  	for _, exemption := range exemptions {
    41  		if exemption == nil {
    42  			continue
    43  		}
    44  		name := runtime.FuncForPC(reflect.ValueOf(exemption).Pointer()).Name()
    45  		// Yuk yuk yuk! Is there a better way?
    46  		name = strings.TrimSuffix(name, "-fm")
    47  		result.exemptions.Add(name)
    48  	}
    49  
    50  	// Add the structs to bind
    51  	for _, ptr := range structPointersToBind {
    52  		err := result.Add(ptr)
    53  		if err != nil {
    54  			logger.Fatal("Error during binding: " + err.Error())
    55  		}
    56  	}
    57  
    58  	return result
    59  }
    60  
    61  // Add the given struct methods to the Bindings
    62  func (b *Bindings) Add(structPtr interface{}) error {
    63  	methods, err := b.getMethods(structPtr)
    64  	if err != nil {
    65  		return fmt.Errorf("cannot bind value to app: %s", err.Error())
    66  	}
    67  
    68  	for _, method := range methods {
    69  		splitName := strings.Split(method.Name, ".")
    70  		packageName := splitName[0]
    71  		structName := splitName[1]
    72  		methodName := splitName[2]
    73  
    74  		// Add it as a regular method
    75  		b.db.AddMethod(packageName, structName, methodName, method)
    76  	}
    77  	return nil
    78  }
    79  
    80  func (b *Bindings) DB() *DB {
    81  	return b.db
    82  }
    83  
    84  func (b *Bindings) ToJSON() (string, error) {
    85  	return b.db.ToJSON()
    86  }
    87  
    88  func (b *Bindings) GenerateModels() ([]byte, error) {
    89  	models := map[string]string{}
    90  	var seen slicer.StringSlicer
    91  	allStructNames := b.getAllStructNames()
    92  	allStructNames.Sort()
    93  	for packageName, structsToGenerate := range b.structsToGenerateTS {
    94  		thisPackageCode := ""
    95  		w := typescriptify.New()
    96  		w.WithPrefix(b.tsPrefix)
    97  		w.WithSuffix(b.tsSuffix)
    98  		w.Namespace = packageName
    99  		w.WithBackupDir("")
   100  		w.KnownStructs = allStructNames
   101  		// sort the structs
   102  		var structNames []string
   103  		for structName := range structsToGenerate {
   104  			structNames = append(structNames, structName)
   105  		}
   106  		sort.Strings(structNames)
   107  		for _, structName := range structNames {
   108  			fqstructname := packageName + "." + structName
   109  			if seen.Contains(fqstructname) {
   110  				continue
   111  			}
   112  			structInterface := structsToGenerate[structName]
   113  			w.Add(structInterface)
   114  		}
   115  		str, err := w.Convert(nil)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		thisPackageCode += str
   120  		seen.AddSlice(w.GetGeneratedStructs())
   121  		models[packageName] = thisPackageCode
   122  	}
   123  
   124  	// Sort the package names first to make the output deterministic
   125  	sortedPackageNames := make([]string, 0)
   126  	for packageName := range models {
   127  		sortedPackageNames = append(sortedPackageNames, packageName)
   128  	}
   129  	sort.Strings(sortedPackageNames)
   130  
   131  	var modelsData bytes.Buffer
   132  	for _, packageName := range sortedPackageNames {
   133  		modelData := models[packageName]
   134  		if strings.TrimSpace(modelData) == "" {
   135  			continue
   136  		}
   137  		modelsData.WriteString("export namespace " + packageName + " {\n")
   138  		sc := bufio.NewScanner(strings.NewReader(modelData))
   139  		for sc.Scan() {
   140  			modelsData.WriteString("\t" + sc.Text() + "\n")
   141  		}
   142  		modelsData.WriteString("\n}\n\n")
   143  	}
   144  	return modelsData.Bytes(), nil
   145  }
   146  
   147  func (b *Bindings) WriteModels(modelsDir string) error {
   148  	modelsData, err := b.GenerateModels()
   149  	if err != nil {
   150  		return err
   151  	}
   152  	// Don't write if we don't have anything
   153  	if len(modelsData) == 0 {
   154  		return nil
   155  	}
   156  
   157  	filename := filepath.Join(modelsDir, "models.ts")
   158  	err = os.WriteFile(filename, modelsData, 0o755)
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	return nil
   164  }
   165  
   166  func (b *Bindings) AddStructToGenerateTS(packageName string, structName string, s interface{}) {
   167  	if b.structsToGenerateTS[packageName] == nil {
   168  		b.structsToGenerateTS[packageName] = make(map[string]interface{})
   169  	}
   170  	if b.structsToGenerateTS[packageName][structName] != nil {
   171  		return
   172  	}
   173  	b.structsToGenerateTS[packageName][structName] = s
   174  
   175  	// Iterate this struct and add any struct field references
   176  	structType := reflect.TypeOf(s)
   177  	if hasElements(structType) {
   178  		structType = structType.Elem()
   179  	}
   180  
   181  	for i := 0; i < structType.NumField(); i++ {
   182  		field := structType.Field(i)
   183  		if field.Anonymous {
   184  			continue
   185  		}
   186  		kind := field.Type.Kind()
   187  		if kind == reflect.Struct {
   188  			if !field.IsExported() {
   189  				continue
   190  			}
   191  			fqname := field.Type.String()
   192  			sNameSplit := strings.Split(fqname, ".")
   193  			if len(sNameSplit) < 2 {
   194  				continue
   195  			}
   196  			sName := sNameSplit[1]
   197  			pName := getPackageName(fqname)
   198  			a := reflect.New(field.Type)
   199  			if b.hasExportedJSONFields(field.Type) {
   200  				s := reflect.Indirect(a).Interface()
   201  				b.AddStructToGenerateTS(pName, sName, s)
   202  			}
   203  		} else if hasElements(field.Type) && field.Type.Elem().Kind() == reflect.Struct {
   204  			if !field.IsExported() {
   205  				continue
   206  			}
   207  			fqname := field.Type.Elem().String()
   208  			sNameSplit := strings.Split(fqname, ".")
   209  			if len(sNameSplit) < 2 {
   210  				continue
   211  			}
   212  			sName := sNameSplit[1]
   213  			pName := getPackageName(fqname)
   214  			typ := field.Type.Elem()
   215  			a := reflect.New(typ)
   216  			if b.hasExportedJSONFields(typ) {
   217  				s := reflect.Indirect(a).Interface()
   218  				b.AddStructToGenerateTS(pName, sName, s)
   219  			}
   220  		}
   221  	}
   222  }
   223  
   224  func (b *Bindings) SetTsPrefix(prefix string) *Bindings {
   225  	b.tsPrefix = prefix
   226  	return b
   227  }
   228  
   229  func (b *Bindings) SetTsSuffix(postfix string) *Bindings {
   230  	b.tsSuffix = postfix
   231  	return b
   232  }
   233  
   234  func (b *Bindings) getAllStructNames() *slicer.StringSlicer {
   235  	var result slicer.StringSlicer
   236  	for packageName, structsToGenerate := range b.structsToGenerateTS {
   237  		for structName := range structsToGenerate {
   238  			result.Add(packageName + "." + structName)
   239  		}
   240  	}
   241  	return &result
   242  }
   243  
   244  func (b *Bindings) hasExportedJSONFields(typeOf reflect.Type) bool {
   245  	for i := 0; i < typeOf.NumField(); i++ {
   246  		jsonFieldName := ""
   247  		f := typeOf.Field(i)
   248  		jsonTag := f.Tag.Get("json")
   249  		if len(jsonTag) == 0 {
   250  			continue
   251  		}
   252  		jsonTagParts := strings.Split(jsonTag, ",")
   253  		if len(jsonTagParts) > 0 {
   254  			jsonFieldName = jsonTagParts[0]
   255  		}
   256  		for _, t := range jsonTagParts {
   257  			if t == "-" {
   258  				continue
   259  			}
   260  		}
   261  		if jsonFieldName != "" {
   262  			return true
   263  		}
   264  	}
   265  	return false
   266  }