github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/importers/imports.go (about)

     1  // Package importers helps with dynamic imports for templating
     2  package importers
     3  
     4  import (
     5  	"bytes"
     6  	"fmt"
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/spf13/cast"
    11  
    12  	"github.com/friendsofgo/errors"
    13  	"github.com/volatiletech/strmangle"
    14  )
    15  
    16  // Collection of imports for various templating purposes
    17  // Drivers add to any and all of these, and is completely responsible
    18  // for populating BasedOnType.
    19  type Collection struct {
    20  	All  Set `toml:"all" json:"all,omitempty"`
    21  	Test Set `toml:"test" json:"test,omitempty"`
    22  
    23  	Singleton     Map `toml:"singleton" json:"singleton,omitempty"`
    24  	TestSingleton Map `toml:"test_singleton" json:"test_singleton,omitempty"`
    25  
    26  	BasedOnType Map `toml:"based_on_type" json:"based_on_type,omitempty"`
    27  }
    28  
    29  // Set defines the optional standard imports and
    30  // thirdParty imports (from github for example)
    31  type Set struct {
    32  	Standard   List `toml:"standard"`
    33  	ThirdParty List `toml:"third_party"`
    34  }
    35  
    36  // Format the set into Go syntax (compatible with go imports)
    37  func (s Set) Format() []byte {
    38  	stdlen, thirdlen := len(s.Standard), len(s.ThirdParty)
    39  	if stdlen+thirdlen < 1 {
    40  		return []byte{}
    41  	}
    42  
    43  	if stdlen+thirdlen == 1 {
    44  		var imp string
    45  		if stdlen == 1 {
    46  			imp = s.Standard[0]
    47  		} else {
    48  			imp = s.ThirdParty[0]
    49  		}
    50  		return []byte(fmt.Sprintf("import %s", imp))
    51  	}
    52  
    53  	buf := &bytes.Buffer{}
    54  	buf.WriteString("import (")
    55  	for _, std := range s.Standard {
    56  		fmt.Fprintf(buf, "\n\t%s", std)
    57  	}
    58  	if stdlen != 0 && thirdlen != 0 {
    59  		buf.WriteString("\n")
    60  	}
    61  	for _, third := range s.ThirdParty {
    62  		fmt.Fprintf(buf, "\n\t%s", third)
    63  	}
    64  	buf.WriteString("\n)\n")
    65  
    66  	return buf.Bytes()
    67  }
    68  
    69  // SetFromInterface creates a set from a theoretical map[string]interface{}.
    70  // This is to load from a loosely defined configuration file.
    71  func SetFromInterface(intf interface{}) (Set, error) {
    72  	s := Set{}
    73  
    74  	setIntf, ok := intf.(map[string]interface{})
    75  	if !ok {
    76  		return s, errors.New("import set should be map[string]interface{}")
    77  	}
    78  
    79  	standardIntf, ok := setIntf["standard"]
    80  	if ok {
    81  		standardsIntf, ok := standardIntf.([]interface{})
    82  		if !ok {
    83  			return s, errors.New("import set standards must be an slice")
    84  		}
    85  
    86  		s.Standard = List{}
    87  		for i, intf := range standardsIntf {
    88  			str, ok := intf.(string)
    89  			if !ok {
    90  				return s, errors.Errorf("import set standard slice element %d (%+v) must be string", i, s)
    91  			}
    92  			s.Standard = append(s.Standard, str)
    93  		}
    94  	}
    95  
    96  	thirdPartyIntf, ok := setIntf["third_party"]
    97  	if ok {
    98  		thirdPartysIntf, ok := thirdPartyIntf.([]interface{})
    99  		if !ok {
   100  			return s, errors.New("import set third_party must be an slice")
   101  		}
   102  
   103  		s.ThirdParty = List{}
   104  		for i, intf := range thirdPartysIntf {
   105  			str, ok := intf.(string)
   106  			if !ok {
   107  				return s, errors.Errorf("import set third party slice element %d (%+v) must be string", i, intf)
   108  			}
   109  			s.ThirdParty = append(s.ThirdParty, str)
   110  		}
   111  	}
   112  
   113  	return s, nil
   114  }
   115  
   116  // Map of file/type -> imports
   117  // Map's consumers do not understand windows paths. Always specify paths
   118  // using forward slash (/).
   119  type Map map[string]Set
   120  
   121  // MapFromInterface creates a Map from a theoretical map[string]interface{}
   122  // or []map[string]interface{}
   123  // This is to load from a loosely defined configuration file.
   124  func MapFromInterface(intf interface{}) (Map, error) {
   125  	m := Map{}
   126  
   127  	iter := func(i interface{}, fn func(string, interface{}) error) error {
   128  		switch toIter := intf.(type) {
   129  		case []interface{}:
   130  			for _, intf := range toIter {
   131  				obj := cast.ToStringMap(intf)
   132  				name := obj["name"].(string)
   133  				if err := fn(name, intf); err != nil {
   134  					return err
   135  				}
   136  			}
   137  		case map[string]interface{}:
   138  			for k, v := range toIter {
   139  				if err := fn(k, v); err != nil {
   140  					return err
   141  				}
   142  			}
   143  		default:
   144  			panic("import map should be map[string]interface or []map[string]interface{}")
   145  		}
   146  
   147  		return nil
   148  	}
   149  
   150  	err := iter(intf, func(name string, value interface{}) error {
   151  		s, err := SetFromInterface(value)
   152  		if err != nil {
   153  			return err
   154  		}
   155  
   156  		m[name] = s
   157  		return nil
   158  	})
   159  
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	return m, nil
   165  }
   166  
   167  // List of imports
   168  type List []string
   169  
   170  // Len implements sort.Interface.Len
   171  func (l List) Len() int {
   172  	return len(l)
   173  }
   174  
   175  // Swap implements sort.Interface.Swap
   176  func (l List) Swap(i, j int) {
   177  	l[i], l[j] = l[j], l[i]
   178  }
   179  
   180  // Less implements sort.Interface.Less
   181  func (l List) Less(i, j int) bool {
   182  	res := strings.Compare(strings.TrimLeft(l[i], "_ "), strings.TrimLeft(l[j], "_ "))
   183  	if res <= 0 {
   184  		return true
   185  	}
   186  
   187  	return false
   188  }
   189  
   190  // NewDefaultImports returns a default Imports struct.
   191  func NewDefaultImports() Collection {
   192  	var col Collection
   193  
   194  	col.All = Set{
   195  		Standard: List{
   196  			`"database/sql"`,
   197  			`"fmt"`,
   198  			`"reflect"`,
   199  			`"strings"`,
   200  			`"sync"`,
   201  			`"time"`,
   202  		},
   203  		ThirdParty: List{
   204  			`"github.com/friendsofgo/errors"`,
   205  			`"github.com/volatiletech/sqlboiler/v4/boil"`,
   206  			`"github.com/volatiletech/sqlboiler/v4/queries"`,
   207  			`"github.com/volatiletech/sqlboiler/v4/queries/qm"`,
   208  			`"github.com/volatiletech/sqlboiler/v4/queries/qmhelper"`,
   209  			`"github.com/volatiletech/strmangle"`,
   210  		},
   211  	}
   212  
   213  	col.Singleton = Map{
   214  		"boil_queries": {
   215  			ThirdParty: List{
   216  				`"github.com/volatiletech/sqlboiler/v4/drivers"`,
   217  				`"github.com/volatiletech/sqlboiler/v4/queries"`,
   218  				`"github.com/volatiletech/sqlboiler/v4/queries/qm"`,
   219  			},
   220  		},
   221  		"boil_types": {
   222  			Standard: List{
   223  				`"strconv"`,
   224  			},
   225  			ThirdParty: List{
   226  				`"github.com/friendsofgo/errors"`,
   227  				`"github.com/volatiletech/sqlboiler/v4/boil"`,
   228  				`"github.com/volatiletech/strmangle"`,
   229  			},
   230  		},
   231  	}
   232  
   233  	col.Test = Set{
   234  		Standard: List{
   235  			`"bytes"`,
   236  			`"reflect"`,
   237  			`"testing"`,
   238  		},
   239  		ThirdParty: List{
   240  			`"github.com/volatiletech/sqlboiler/v4/boil"`,
   241  			`"github.com/volatiletech/sqlboiler/v4/queries"`,
   242  			`"github.com/volatiletech/randomize"`,
   243  			`"github.com/volatiletech/strmangle"`,
   244  		},
   245  	}
   246  
   247  	col.TestSingleton = Map{
   248  		"boil_main_test": {
   249  			Standard: List{
   250  				`"database/sql"`,
   251  				`"flag"`,
   252  				`"fmt"`,
   253  				`"math/rand"`,
   254  				`"os"`,
   255  				`"path/filepath"`,
   256  				`"strings"`,
   257  				`"testing"`,
   258  				`"time"`,
   259  			},
   260  			ThirdParty: List{
   261  				`"github.com/spf13/viper"`,
   262  				`"github.com/volatiletech/sqlboiler/v4/boil"`,
   263  			},
   264  		},
   265  		"boil_queries_test": {
   266  			Standard: List{
   267  				`"bytes"`,
   268  				`"fmt"`,
   269  				`"io"`,
   270  				`"math/rand"`,
   271  				`"regexp"`,
   272  			},
   273  			ThirdParty: List{
   274  				`"github.com/volatiletech/sqlboiler/v4/boil"`,
   275  			},
   276  		},
   277  		"boil_suites_test": {
   278  			Standard: List{
   279  				`"testing"`,
   280  			},
   281  		},
   282  	}
   283  
   284  	return col
   285  }
   286  
   287  // NullableEnumImports returns imports collection for nullable enum types.
   288  func NullableEnumImports() Collection {
   289  	var col Collection
   290  
   291  	col.Singleton = Map{
   292  		"boil_types": {
   293  			Standard: List{
   294  				`"bytes"`,
   295  				`"database/sql/driver"`,
   296  				`"encoding/json"`,
   297  			},
   298  			ThirdParty: List{
   299  				`"github.com/volatiletech/null/v8"`,
   300  				`"github.com/volatiletech/null/v8/convert"`,
   301  			},
   302  		},
   303  	}
   304  
   305  	return col
   306  }
   307  
   308  // AddTypeImports takes a set of imports 'a', a type -> import mapping 'typeMap'
   309  // and a set of column types that are currently in use and produces a new set
   310  // including both the old standard/third party, as well as the imports required
   311  // for the types in use.
   312  func AddTypeImports(a Set, typeMap map[string]Set, columnTypes []string) Set {
   313  	tmpImp := Set{
   314  		Standard:   make(List, len(a.Standard)),
   315  		ThirdParty: make(List, len(a.ThirdParty)),
   316  	}
   317  
   318  	copy(tmpImp.Standard, a.Standard)
   319  	copy(tmpImp.ThirdParty, a.ThirdParty)
   320  
   321  	for _, typ := range columnTypes {
   322  		for key, imp := range typeMap {
   323  			if typ == key {
   324  				tmpImp.Standard = append(tmpImp.Standard, imp.Standard...)
   325  				tmpImp.ThirdParty = append(tmpImp.ThirdParty, imp.ThirdParty...)
   326  			}
   327  		}
   328  	}
   329  
   330  	tmpImp.Standard = strmangle.RemoveDuplicates(tmpImp.Standard)
   331  	tmpImp.ThirdParty = strmangle.RemoveDuplicates(tmpImp.ThirdParty)
   332  
   333  	sort.Sort(tmpImp.Standard)
   334  	sort.Sort(tmpImp.ThirdParty)
   335  
   336  	return tmpImp
   337  }
   338  
   339  // Merge takes two collections and creates a new one
   340  // with the de-duplication contents of both.
   341  func Merge(a, b Collection) Collection {
   342  	var c Collection
   343  
   344  	c.All = mergeSet(a.All, b.All)
   345  	c.Test = mergeSet(a.Test, b.Test)
   346  
   347  	c.Singleton = mergeMap(a.Singleton, b.Singleton)
   348  	c.TestSingleton = mergeMap(a.TestSingleton, b.TestSingleton)
   349  
   350  	c.BasedOnType = mergeMap(a.BasedOnType, b.BasedOnType)
   351  
   352  	return c
   353  }
   354  
   355  func mergeSet(a, b Set) Set {
   356  	var c Set
   357  
   358  	c.Standard = strmangle.RemoveDuplicates(combineStringSlices(a.Standard, b.Standard))
   359  	c.ThirdParty = strmangle.RemoveDuplicates(combineStringSlices(a.ThirdParty, b.ThirdParty))
   360  
   361  	sort.Sort(c.Standard)
   362  	sort.Sort(c.ThirdParty)
   363  
   364  	return c
   365  }
   366  
   367  func mergeMap(a, b Map) Map {
   368  	m := make(Map)
   369  
   370  	for k, v := range a {
   371  		m[k] = v
   372  	}
   373  
   374  	for k, toMerge := range b {
   375  		exist, ok := m[k]
   376  		if !ok {
   377  			m[k] = toMerge
   378  		}
   379  
   380  		m[k] = mergeSet(exist, toMerge)
   381  	}
   382  
   383  	return m
   384  }
   385  
   386  func combineStringSlices(a, b []string) []string {
   387  	c := make([]string, len(a)+len(b))
   388  	if len(a) > 0 {
   389  		copy(c, a)
   390  	}
   391  	if len(b) > 0 {
   392  		copy(c[len(a):], b)
   393  	}
   394  
   395  	return c
   396  }