github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/casts_test.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree
    12  
    13  import (
    14  	"encoding/csv"
    15  	"fmt"
    16  	"io"
    17  	"os"
    18  	"path/filepath"
    19  	"strconv"
    20  	"testing"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/sql/oidext"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    24  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    25  	"github.com/lib/pq/oid"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  // TestCastsVolatilityMatchesPostgres checks that our defined casts match
    30  // Postgres' casts for Volatility.
    31  //
    32  // Dump command below:
    33  // COPY (
    34  //   SELECT c.castsource, c.casttarget, p.provolatile, p.proleakproof
    35  //   FROM pg_cast c JOIN pg_proc p ON (c.castfunc = p.oid)
    36  // ) TO STDOUT WITH CSV DELIMITER '|' HEADER;
    37  func TestCastsVolatilityMatchesPostgres(t *testing.T) {
    38  	defer leaktest.AfterTest(t)()
    39  	csvPath := filepath.Join("testdata", "pg_cast_provolatile_dump.csv")
    40  	f, err := os.Open(csvPath)
    41  	require.NoError(t, err)
    42  
    43  	defer f.Close()
    44  
    45  	reader := csv.NewReader(f)
    46  	reader.Comma = '|'
    47  
    48  	// Read header row
    49  	_, err = reader.Read()
    50  	require.NoError(t, err)
    51  
    52  	type pgCast struct {
    53  		from, to   oid.Oid
    54  		volatility Volatility
    55  	}
    56  	var pgCasts []pgCast
    57  
    58  	for {
    59  		line, err := reader.Read()
    60  		if err == io.EOF {
    61  			break
    62  		}
    63  		require.NoError(t, err)
    64  		require.Len(t, line, 4)
    65  
    66  		fromOid, err := strconv.Atoi(line[0])
    67  		require.NoError(t, err)
    68  
    69  		toOid, err := strconv.Atoi(line[1])
    70  		require.NoError(t, err)
    71  
    72  		provolatile := line[2]
    73  		require.Len(t, provolatile, 1)
    74  		proleakproof := line[3]
    75  		require.Len(t, proleakproof, 1)
    76  
    77  		v, err := VolatilityFromPostgres(provolatile, proleakproof[0] == 't')
    78  		require.NoError(t, err)
    79  
    80  		pgCasts = append(pgCasts, pgCast{
    81  			from:       oid.Oid(fromOid),
    82  			to:         oid.Oid(toOid),
    83  			volatility: v,
    84  		})
    85  	}
    86  
    87  	oidToFamily := func(o oid.Oid) (_ types.Family, ok bool) {
    88  		t, ok := types.OidToType[o]
    89  		if !ok {
    90  			return 0, false
    91  		}
    92  		return t.Family(), true
    93  	}
    94  
    95  	oidStr := func(o oid.Oid) string {
    96  		res, ok := oidext.TypeName(o)
    97  		if !ok {
    98  			res = fmt.Sprintf("%d", o)
    99  		}
   100  		return res
   101  	}
   102  
   103  	for _, c := range validCasts {
   104  		if c.volatility == 0 {
   105  			t.Errorf("cast %s::%s has no volatility set", c.from.Name(), c.to.Name())
   106  
   107  		}
   108  		if c.ignoreVolatilityCheck {
   109  			continue
   110  		}
   111  
   112  		// Look through all pg casts and find any where the Oids map to these
   113  		// families.
   114  		found := false
   115  		for i := range pgCasts {
   116  			fromFamily, fromOk := oidToFamily(pgCasts[i].from)
   117  			toFamily, toOk := oidToFamily(pgCasts[i].to)
   118  			if fromOk && toOk && fromFamily == c.from && toFamily == c.to {
   119  				found = true
   120  				if c.volatility != pgCasts[i].volatility {
   121  					t.Errorf("cast %s::%s has volatility %s; corresponding pg cast %s::%s has volatility %s",
   122  						c.from.Name(), c.to.Name(), c.volatility,
   123  						oidStr(pgCasts[i].from), oidStr(pgCasts[i].to), pgCasts[i].volatility,
   124  					)
   125  				}
   126  			}
   127  		}
   128  		if !found && testing.Verbose() {
   129  			t.Logf("cast %s::%s has no corresponding pg cast", c.from.Name(), c.to.Name())
   130  		}
   131  	}
   132  }