github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/casts_test.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tree 12 13 import ( 14 "encoding/csv" 15 "fmt" 16 "io" 17 "os" 18 "path/filepath" 19 "strconv" 20 "testing" 21 22 "github.com/cockroachdb/cockroach/pkg/sql/oidext" 23 "github.com/cockroachdb/cockroach/pkg/sql/types" 24 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 25 "github.com/lib/pq/oid" 26 "github.com/stretchr/testify/require" 27 ) 28 29 // TestCastsVolatilityMatchesPostgres checks that our defined casts match 30 // Postgres' casts for Volatility. 31 // 32 // Dump command below: 33 // COPY ( 34 // SELECT c.castsource, c.casttarget, p.provolatile, p.proleakproof 35 // FROM pg_cast c JOIN pg_proc p ON (c.castfunc = p.oid) 36 // ) TO STDOUT WITH CSV DELIMITER '|' HEADER; 37 func TestCastsVolatilityMatchesPostgres(t *testing.T) { 38 defer leaktest.AfterTest(t)() 39 csvPath := filepath.Join("testdata", "pg_cast_provolatile_dump.csv") 40 f, err := os.Open(csvPath) 41 require.NoError(t, err) 42 43 defer f.Close() 44 45 reader := csv.NewReader(f) 46 reader.Comma = '|' 47 48 // Read header row 49 _, err = reader.Read() 50 require.NoError(t, err) 51 52 type pgCast struct { 53 from, to oid.Oid 54 volatility Volatility 55 } 56 var pgCasts []pgCast 57 58 for { 59 line, err := reader.Read() 60 if err == io.EOF { 61 break 62 } 63 require.NoError(t, err) 64 require.Len(t, line, 4) 65 66 fromOid, err := strconv.Atoi(line[0]) 67 require.NoError(t, err) 68 69 toOid, err := strconv.Atoi(line[1]) 70 require.NoError(t, err) 71 72 provolatile := line[2] 73 require.Len(t, provolatile, 1) 74 proleakproof := line[3] 75 require.Len(t, proleakproof, 1) 76 77 v, err := VolatilityFromPostgres(provolatile, proleakproof[0] == 't') 78 require.NoError(t, err) 79 80 pgCasts = append(pgCasts, pgCast{ 81 from: oid.Oid(fromOid), 82 to: oid.Oid(toOid), 83 volatility: v, 84 }) 85 } 86 87 oidToFamily := func(o oid.Oid) (_ types.Family, ok bool) { 88 t, ok := types.OidToType[o] 89 if !ok { 90 return 0, false 91 } 92 return t.Family(), true 93 } 94 95 oidStr := func(o oid.Oid) string { 96 res, ok := oidext.TypeName(o) 97 if !ok { 98 res = fmt.Sprintf("%d", o) 99 } 100 return res 101 } 102 103 for _, c := range validCasts { 104 if c.volatility == 0 { 105 t.Errorf("cast %s::%s has no volatility set", c.from.Name(), c.to.Name()) 106 107 } 108 if c.ignoreVolatilityCheck { 109 continue 110 } 111 112 // Look through all pg casts and find any where the Oids map to these 113 // families. 114 found := false 115 for i := range pgCasts { 116 fromFamily, fromOk := oidToFamily(pgCasts[i].from) 117 toFamily, toOk := oidToFamily(pgCasts[i].to) 118 if fromOk && toOk && fromFamily == c.from && toFamily == c.to { 119 found = true 120 if c.volatility != pgCasts[i].volatility { 121 t.Errorf("cast %s::%s has volatility %s; corresponding pg cast %s::%s has volatility %s", 122 c.from.Name(), c.to.Name(), c.volatility, 123 oidStr(pgCasts[i].from), oidStr(pgCasts[i].to), pgCasts[i].volatility, 124 ) 125 } 126 } 127 } 128 if !found && testing.Verbose() { 129 t.Logf("cast %s::%s has no corresponding pg cast", c.from.Name(), c.to.Name()) 130 } 131 } 132 }