github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/importers/imports_test.go (about)

     1  package importers
     2  
     3  import (
     4  	"reflect"
     5  	"sort"
     6  	"strings"
     7  	"testing"
     8  )
     9  
    10  func TestSetFromInterface(t *testing.T) {
    11  	t.Parallel()
    12  
    13  	setIntf := map[string]interface{}{
    14  		"standard": []interface{}{
    15  			"hello",
    16  			"there",
    17  		},
    18  		"third_party": []interface{}{
    19  			"there",
    20  			"hello",
    21  		},
    22  	}
    23  
    24  	set, err := SetFromInterface(setIntf)
    25  	if err != nil {
    26  		t.Error(err)
    27  	}
    28  
    29  	if set.Standard[0] != "hello" {
    30  		t.Error("set was wrong:", set.Standard[0])
    31  	}
    32  	if set.Standard[1] != "there" {
    33  		t.Error("set was wrong:", set.Standard[1])
    34  	}
    35  	if set.ThirdParty[0] != "there" {
    36  		t.Error("set was wrong:", set.ThirdParty[0])
    37  	}
    38  	if set.ThirdParty[1] != "hello" {
    39  		t.Error("set was wrong:", set.ThirdParty[1])
    40  	}
    41  }
    42  
    43  func TestMapFromInterface(t *testing.T) {
    44  	t.Parallel()
    45  
    46  	mapIntf := map[string]interface{}{
    47  		"test_main": map[string]interface{}{
    48  			"standard": []interface{}{
    49  				"hello",
    50  				"there",
    51  			},
    52  			"third_party": []interface{}{
    53  				"there",
    54  				"hello",
    55  			},
    56  		},
    57  	}
    58  
    59  	mp, err := MapFromInterface(mapIntf)
    60  	if err != nil {
    61  		t.Error(err)
    62  	}
    63  
    64  	set, ok := mp["test_main"]
    65  	if !ok {
    66  		t.Error("could not find set 'test_main'")
    67  	}
    68  
    69  	if set.Standard[0] != "hello" {
    70  		t.Error("set was wrong:", set.Standard[0])
    71  	}
    72  	if set.Standard[1] != "there" {
    73  		t.Error("set was wrong:", set.Standard[1])
    74  	}
    75  	if set.ThirdParty[0] != "there" {
    76  		t.Error("set was wrong:", set.ThirdParty[0])
    77  	}
    78  	if set.ThirdParty[1] != "hello" {
    79  		t.Error("set was wrong:", set.ThirdParty[1])
    80  	}
    81  }
    82  
    83  func TestMapFromInterfaceAltSyntax(t *testing.T) {
    84  	t.Parallel()
    85  
    86  	mapIntf := []interface{}{
    87  		map[string]interface{}{
    88  			"name": "test_main",
    89  			"standard": []interface{}{
    90  				"hello",
    91  				"there",
    92  			},
    93  			"third_party": []interface{}{
    94  				"there",
    95  				"hello",
    96  			},
    97  		},
    98  	}
    99  
   100  	mp, err := MapFromInterface(mapIntf)
   101  	if err != nil {
   102  		t.Error(err)
   103  	}
   104  
   105  	set, ok := mp["test_main"]
   106  	if !ok {
   107  		t.Error("could not find set 'test_main'")
   108  	}
   109  
   110  	if set.Standard[0] != "hello" {
   111  		t.Error("set was wrong:", set.Standard[0])
   112  	}
   113  	if set.Standard[1] != "there" {
   114  		t.Error("set was wrong:", set.Standard[1])
   115  	}
   116  	if set.ThirdParty[0] != "there" {
   117  		t.Error("set was wrong:", set.ThirdParty[0])
   118  	}
   119  	if set.ThirdParty[1] != "hello" {
   120  		t.Error("set was wrong:", set.ThirdParty[1])
   121  	}
   122  }
   123  
   124  func TestImportsSort(t *testing.T) {
   125  	t.Parallel()
   126  
   127  	a1 := List{
   128  		`"fmt"`,
   129  		`"errors"`,
   130  	}
   131  	a2 := List{
   132  		`_ "github.com/lib/pq"`,
   133  		`_ "github.com/gorilla/n"`,
   134  		`"github.com/gorilla/mux"`,
   135  		`"github.com/gorilla/websocket"`,
   136  	}
   137  
   138  	a1Expected := List{`"errors"`, `"fmt"`}
   139  	a2Expected := List{
   140  		`"github.com/gorilla/mux"`,
   141  		`_ "github.com/gorilla/n"`,
   142  		`"github.com/gorilla/websocket"`,
   143  		`_ "github.com/lib/pq"`,
   144  	}
   145  
   146  	sort.Sort(a1)
   147  	if !reflect.DeepEqual(a1, a1Expected) {
   148  		t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
   149  	}
   150  
   151  	for i, v := range a1 {
   152  		if v != a1Expected[i] {
   153  			t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
   154  		}
   155  	}
   156  
   157  	sort.Sort(a2)
   158  	if !reflect.DeepEqual(a2, a2Expected) {
   159  		t.Errorf("Expected a2 to match a2expected, got: %v", a2)
   160  	}
   161  
   162  	for i, v := range a2 {
   163  		if v != a2Expected[i] {
   164  			t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
   165  		}
   166  	}
   167  }
   168  
   169  func TestAddTypeImports(t *testing.T) {
   170  	t.Parallel()
   171  
   172  	imports1 := Set{
   173  		Standard: List{
   174  			`"errors"`,
   175  			`"fmt"`,
   176  		},
   177  		ThirdParty: List{
   178  			`"github.com/volatiletech/sqlboiler/v4/boil"`,
   179  		},
   180  	}
   181  
   182  	importsExpected := Set{
   183  		Standard: List{
   184  			`"errors"`,
   185  			`"fmt"`,
   186  			`"time"`,
   187  		},
   188  		ThirdParty: List{
   189  			`"github.com/volatiletech/null/v8"`,
   190  			`"github.com/volatiletech/sqlboiler/v4/boil"`,
   191  		},
   192  	}
   193  
   194  	types := []string{
   195  		"null.Time",
   196  		"null.Time",
   197  		"time.Time",
   198  	}
   199  
   200  	imps := NewDefaultImports()
   201  
   202  	imps.BasedOnType = Map{
   203  		"null.Time": Set{ThirdParty: List{`"github.com/volatiletech/null/v8"`}},
   204  		"time.Time": Set{Standard: List{`"time"`}},
   205  	}
   206  
   207  	res1 := AddTypeImports(imports1, imps.BasedOnType, types)
   208  
   209  	if !reflect.DeepEqual(res1, importsExpected) {
   210  		t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
   211  	}
   212  
   213  	imports2 := Set{
   214  		Standard: List{
   215  			`"errors"`,
   216  			`"fmt"`,
   217  			`"time"`,
   218  		},
   219  		ThirdParty: List{
   220  			`"github.com/volatiletech/null/v8"`,
   221  			`"github.com/volatiletech/sqlboiler/v4/boil"`,
   222  		},
   223  	}
   224  
   225  	res2 := AddTypeImports(imports2, imps.BasedOnType, types)
   226  
   227  	if !reflect.DeepEqual(res2, importsExpected) {
   228  		t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
   229  	}
   230  }
   231  
   232  func TestMergeSet(t *testing.T) {
   233  	t.Parallel()
   234  
   235  	a := Set{
   236  		Standard:   List{"fmt"},
   237  		ThirdParty: List{"github.com/volatiletech/sqlboiler/v4", "github.com/volatiletech/null/v8"},
   238  	}
   239  	b := Set{
   240  		Standard:   List{"os"},
   241  		ThirdParty: List{"github.com/volatiletech/sqlboiler/v4"},
   242  	}
   243  
   244  	c := mergeSet(a, b)
   245  
   246  	if c.Standard[0] != "fmt" && c.Standard[1] != "os" {
   247  		t.Errorf("Wanted: fmt, os got: %#v", c.Standard)
   248  	}
   249  	if c.ThirdParty[0] != "github.com/volatiletech/null/v8" && c.ThirdParty[1] != "github.com/volatiletech/sqlboiler/v4" {
   250  		t.Errorf("Wanted: github.com/volatiletech/sqlboiler, github.com/volatiletech/null/v8 got: %#v", c.ThirdParty)
   251  	}
   252  }
   253  
   254  func TestCombineStringSlices(t *testing.T) {
   255  	t.Parallel()
   256  
   257  	var a, b []string
   258  	slice := combineStringSlices(a, b)
   259  	if ln := len(slice); ln != 0 {
   260  		t.Error("Len was wrong:", ln)
   261  	}
   262  
   263  	a = []string{"1", "2"}
   264  	slice = combineStringSlices(a, b)
   265  	if ln := len(slice); ln != 2 {
   266  		t.Error("Len was wrong:", ln)
   267  	} else if slice[0] != a[0] || slice[1] != a[1] {
   268  		t.Errorf("Slice mismatch: %#v %#v", a, slice)
   269  	}
   270  
   271  	b = a
   272  	a = nil
   273  	slice = combineStringSlices(a, b)
   274  	if ln := len(slice); ln != 2 {
   275  		t.Error("Len was wrong:", ln)
   276  	} else if slice[0] != b[0] || slice[1] != b[1] {
   277  		t.Errorf("Slice mismatch: %#v %#v", b, slice)
   278  	}
   279  
   280  	a = b
   281  	b = []string{"3", "4"}
   282  	slice = combineStringSlices(a, b)
   283  	if ln := len(slice); ln != 4 {
   284  		t.Error("Len was wrong:", ln)
   285  	} else if slice[0] != a[0] || slice[1] != a[1] || slice[2] != b[0] || slice[3] != b[1] {
   286  		t.Errorf("Slice mismatch: %#v + %#v != #%v", a, b, slice)
   287  	}
   288  }
   289  
   290  func TestMerge(t *testing.T) {
   291  	var a, b Collection
   292  
   293  	a.All = Set{Standard: List{"aa"}, ThirdParty: List{"aa"}}
   294  	a.Test = Set{Standard: List{"at"}, ThirdParty: List{"at"}}
   295  	a.Singleton = Map{
   296  		"a": {Standard: List{"as"}, ThirdParty: List{"as"}},
   297  		"c": {Standard: List{"as"}, ThirdParty: List{"as"}},
   298  	}
   299  	a.TestSingleton = Map{
   300  		"a": {Standard: List{"at"}, ThirdParty: List{"at"}},
   301  		"c": {Standard: List{"at"}, ThirdParty: List{"at"}},
   302  	}
   303  	a.BasedOnType = Map{
   304  		"a": {Standard: List{"abot"}, ThirdParty: List{"abot"}},
   305  		"c": {Standard: List{"abot"}, ThirdParty: List{"abot"}},
   306  	}
   307  
   308  	b.All = Set{Standard: List{"bb"}, ThirdParty: List{"bb"}}
   309  	b.Test = Set{Standard: List{"bt"}, ThirdParty: List{"bt"}}
   310  	b.Singleton = Map{
   311  		"b": {Standard: List{"bs"}, ThirdParty: List{"bs"}},
   312  		"c": {Standard: List{"bs"}, ThirdParty: List{"bs"}},
   313  	}
   314  	b.TestSingleton = Map{
   315  		"b": {Standard: List{"bt"}, ThirdParty: List{"bt"}},
   316  		"c": {Standard: List{"bt"}, ThirdParty: List{"bt"}},
   317  	}
   318  	b.BasedOnType = Map{
   319  		"b": {Standard: List{"bbot"}, ThirdParty: List{"bbot"}},
   320  		"c": {Standard: List{"bbot"}, ThirdParty: List{"bbot"}},
   321  	}
   322  
   323  	c := Merge(a, b)
   324  
   325  	setHas := func(s Set, a, b string) {
   326  		t.Helper()
   327  		if s.Standard[0] != a {
   328  			t.Error("standard index 0, want:", a, "got:", s.Standard[0])
   329  		}
   330  		if s.Standard[1] != b {
   331  			t.Error("standard index 1, want:", a, "got:", s.Standard[1])
   332  		}
   333  		if s.ThirdParty[0] != a {
   334  			t.Error("third party index 0, want:", a, "got:", s.ThirdParty[0])
   335  		}
   336  		if s.ThirdParty[1] != b {
   337  			t.Error("third party index 1, want:", a, "got:", s.ThirdParty[1])
   338  		}
   339  	}
   340  	mapHas := func(m Map, key, a, b string) {
   341  		t.Helper()
   342  		setHas(m[key], a, b)
   343  	}
   344  
   345  	setHas(c.All, "aa", "bb")
   346  	setHas(c.Test, "at", "bt")
   347  	mapHas(c.Singleton, "c", "as", "bs")
   348  	mapHas(c.TestSingleton, "c", "at", "bt")
   349  	mapHas(c.BasedOnType, "c", "abot", "bbot")
   350  
   351  	if t.Failed() {
   352  		t.Logf("%#v\n", c)
   353  	}
   354  }
   355  
   356  var testImportStringExpect = `import (
   357  	"fmt"
   358  
   359  	"github.com/friendsofgo/errors"
   360  )`
   361  
   362  func TestSetFormat(t *testing.T) {
   363  	t.Parallel()
   364  
   365  	s := Set{
   366  		Standard: List{
   367  			`"fmt"`,
   368  		},
   369  		ThirdParty: List{
   370  			`"github.com/friendsofgo/errors"`,
   371  		},
   372  	}
   373  
   374  	got := strings.TrimSpace(string(s.Format()))
   375  	if got != testImportStringExpect {
   376  		t.Error("want:\n", testImportStringExpect, "\ngot:\n", got)
   377  	}
   378  }