go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tests/testutil_test.go (about)

     1  package tests
     2  
     3  import (
     4  	"fmt"
     5  	"go-ml.dev/pkg/base/fu"
     6  	"go-ml.dev/pkg/base/tables"
     7  	"gotest.tools/assert"
     8  	"gotest.tools/assert/cmp"
     9  	"strings"
    10  	"testing"
    11  )
    12  
    13  type TR struct {
    14  	Name string
    15  	Age  int
    16  	Rate float32
    17  }
    18  
    19  func PrepareTable(t *testing.T) *tables.Table {
    20  	q := tables.New([]TR{
    21  		{"Ivanov", 32, 1.2},
    22  		{"Petrov", 44, 1.5}})
    23  	assert.DeepEqual(t, q.Names(), []string{"Name", "Age", "Rate"})
    24  	assert.Assert(t, q.Len() == 2)
    25  	assert.DeepEqual(t, fu.MapInterface(q.Row(0)),
    26  		map[string]interface{}{
    27  			"Name": "Ivanov",
    28  			"Age":  32,
    29  			"Rate": float32(1.2),
    30  		})
    31  	assert.DeepEqual(t, fu.MapInterface(q.Row(1)),
    32  		map[string]interface{}{
    33  			"Name": "Petrov",
    34  			"Age":  44,
    35  			"Rate": float32(1.5),
    36  		})
    37  
    38  	return q
    39  }
    40  
    41  var trList = []TR{
    42  	{"Ivanov", 32, 1.2},
    43  	{"Petrov", 44, 1.5},
    44  	{"Sidorov", 55, 1.8},
    45  	{"Gavrilov", 20, 0.9},
    46  	{"Popova", 28, 1.0},
    47  	{"Kozlov", 42, 1.3},
    48  }
    49  
    50  func TrTable() *tables.Table {
    51  	return tables.New(trList)
    52  }
    53  
    54  func assertTrData(t *testing.T, q *tables.Table) {
    55  	assert.Assert(t, q.Len() == len(trList))
    56  	for i, r := range trList {
    57  		assert.DeepEqual(t, fu.MapInterface(q.Row(i)),
    58  			map[string]interface{}{
    59  				"Name": r.Name,
    60  				"Age":  r.Age,
    61  				"Rate": r.Rate,
    62  			})
    63  	}
    64  }
    65  
    66  func PanicWith(text string, f func()) cmp.Comparison {
    67  	return func() (result cmp.Result) {
    68  		defer func() {
    69  			if err := recover(); err != nil {
    70  				s := fmt.Sprint(err)
    71  				if strings.Index(s, text) >= 0 {
    72  					result = cmp.ResultSuccess
    73  					return
    74  				}
    75  				result = cmp.ResultFailure("panic `" + s + "` does not contain `" + text + "`")
    76  			}
    77  		}()
    78  		f()
    79  		return cmp.ResultFailure("did not panic")
    80  	}
    81  
    82  }