github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/migrate/validation.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package migrate
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"runtime"
    22  	"strings"
    23  	"time"
    24  	"unicode"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    28  	"github.com/dolthub/vitess/go/vt/proto/query"
    29  	"golang.org/x/sync/errgroup"
    30  
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    34  	"github.com/dolthub/dolt/go/store/types"
    35  )
    36  
    37  func validateBranchMapping(ctx context.Context, old, new *doltdb.DoltDB) error {
    38  	branches, err := old.GetBranches(ctx)
    39  	if err != nil {
    40  		return err
    41  	}
    42  
    43  	var ok bool
    44  	for _, bref := range branches {
    45  		_, ok, err = new.HasBranch(ctx, bref.GetPath())
    46  		if err != nil {
    47  			return err
    48  		}
    49  		if !ok {
    50  			return fmt.Errorf("failed to map branch %s", bref.GetPath())
    51  		}
    52  	}
    53  	return nil
    54  }
    55  
    56  func validateRootValue(ctx context.Context, oldParent, old, new doltdb.RootValue) error {
    57  	names, err := old.GetTableNames(ctx, doltdb.DefaultSchemaName)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	for _, name := range names {
    62  		o, ok, err := old.GetTable(ctx, doltdb.TableName{Name: name})
    63  		if err != nil {
    64  			return err
    65  		}
    66  		if !ok {
    67  			h, _ := old.HashOf()
    68  			return fmt.Errorf("expected to find table %s in root value (%s)", name, h.String())
    69  		}
    70  
    71  		// Skip tables that haven't changed
    72  		op, ok, err := oldParent.GetTable(ctx, doltdb.TableName{Name: name})
    73  		if err != nil {
    74  			return err
    75  		}
    76  		if ok {
    77  			oldHash, err := o.HashOf()
    78  			if err != nil {
    79  				return err
    80  			}
    81  			oldParentHash, err := op.HashOf()
    82  			if err != nil {
    83  				return err
    84  			}
    85  			if oldHash.Equal(oldParentHash) {
    86  				continue
    87  			}
    88  		}
    89  
    90  		n, ok, err := new.GetTable(ctx, doltdb.TableName{Name: name})
    91  		if err != nil {
    92  			return err
    93  		}
    94  		if !ok {
    95  			h, _ := new.HashOf()
    96  			return fmt.Errorf("expected to find table %s in root value (%s)", name, h.String())
    97  		}
    98  
    99  		if err = validateTableData(ctx, name, o, n); err != nil {
   100  			return err
   101  		}
   102  	}
   103  	return nil
   104  }
   105  
   106  func validateTableData(ctx context.Context, name string, old, new *doltdb.Table) error {
   107  	parts, err := partitionTable(ctx, old)
   108  	if err != nil {
   109  		return err
   110  	} else if len(parts) == 0 {
   111  		return nil
   112  	}
   113  
   114  	eg, ctx := errgroup.WithContext(ctx)
   115  	for i := range parts {
   116  		start, end := parts[i][0], parts[i][1]
   117  		eg.Go(func() error {
   118  			return validateTableDataPartition(ctx, name, old, new, start, end)
   119  		})
   120  	}
   121  
   122  	return eg.Wait()
   123  }
   124  
   125  func validateTableDataPartition(ctx context.Context, name string, old, new *doltdb.Table, start, end uint64) error {
   126  	sctx := sql.NewContext(ctx)
   127  	_, oldIter, err := sqle.DoltTablePartitionToRowIter(sctx, name, old, start, end)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	newSch, newIter, err := sqle.DoltTablePartitionToRowIter(sctx, name, new, start, end)
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	var o, n sql.Row
   137  	for {
   138  		o, err = oldIter.Next(sctx)
   139  		if err == io.EOF {
   140  			break
   141  		} else if err != nil {
   142  			return err
   143  		}
   144  
   145  		n, err = newIter.Next(sctx)
   146  		if err != nil {
   147  			return err
   148  		}
   149  
   150  		ok, err := equalRows(o, n, newSch)
   151  		if err != nil {
   152  			return err
   153  		} else if !ok {
   154  			return fmt.Errorf("differing rows for table %s (%s != %s)",
   155  				name, sql.FormatRow(o), sql.FormatRow(n))
   156  		}
   157  	}
   158  
   159  	// validated that newIter is also exhausted
   160  	_, err = newIter.Next(sctx)
   161  	if err != io.EOF {
   162  		return fmt.Errorf("differing number of rows for table %s", name)
   163  	}
   164  	return nil
   165  }
   166  
   167  func equalRows(old, new sql.Row, sch sql.Schema) (bool, error) {
   168  	if len(new) != len(old) || len(new) != len(sch) {
   169  		return false, nil
   170  	}
   171  
   172  	var err error
   173  	var cmp int
   174  	for i := range new {
   175  
   176  		// special case string comparisons
   177  		if s, ok := old[i].(string); ok {
   178  			old[i] = strings.TrimRightFunc(s, unicode.IsSpace)
   179  		}
   180  		if s, ok := new[i].(string); ok {
   181  			new[i] = strings.TrimRightFunc(s, unicode.IsSpace)
   182  		}
   183  
   184  		// special case time comparison to account
   185  		// for precision changes between formats
   186  		if _, ok := old[i].(time.Time); ok {
   187  			var o, n interface{}
   188  			if o, _, err = gmstypes.Int64.Convert(old[i]); err != nil {
   189  				return false, err
   190  			}
   191  			if n, _, err = gmstypes.Int64.Convert(new[i]); err != nil {
   192  				return false, err
   193  			}
   194  			if cmp, err = gmstypes.Int64.Compare(o, n); err != nil {
   195  				return false, err
   196  			}
   197  		} else {
   198  			if cmp, err = sch[i].Type.Compare(old[i], new[i]); err != nil {
   199  				return false, err
   200  			}
   201  		}
   202  		if cmp != 0 {
   203  			return false, nil
   204  		}
   205  	}
   206  	return true, nil
   207  }
   208  
   209  func validateSchema(existing schema.Schema) error {
   210  	for _, c := range existing.GetAllCols().GetColumns() {
   211  		qt := c.TypeInfo.ToSqlType().Type()
   212  		err := assertNomsKind(c.Kind, nomsKindsFromQueryTypes(qt)...)
   213  		if err != nil {
   214  			return err
   215  		}
   216  	}
   217  	return nil
   218  }
   219  
   220  func nomsKindsFromQueryTypes(qt query.Type) []types.NomsKind {
   221  	switch qt {
   222  	case query.Type_UINT8:
   223  		return []types.NomsKind{types.UintKind, types.BoolKind}
   224  
   225  	case query.Type_UINT16, query.Type_UINT24,
   226  		query.Type_UINT32, query.Type_UINT64:
   227  		return []types.NomsKind{types.UintKind}
   228  
   229  	case query.Type_INT8:
   230  		return []types.NomsKind{types.IntKind, types.BoolKind}
   231  
   232  	case query.Type_INT16, query.Type_INT24,
   233  		query.Type_INT32, query.Type_INT64:
   234  		return []types.NomsKind{types.IntKind}
   235  
   236  	case query.Type_YEAR, query.Type_TIME:
   237  		return []types.NomsKind{types.IntKind}
   238  
   239  	case query.Type_FLOAT32, query.Type_FLOAT64:
   240  		return []types.NomsKind{types.FloatKind}
   241  
   242  	case query.Type_TIMESTAMP, query.Type_DATE, query.Type_DATETIME:
   243  		return []types.NomsKind{types.TimestampKind}
   244  
   245  	case query.Type_DECIMAL:
   246  		return []types.NomsKind{types.DecimalKind}
   247  
   248  	case query.Type_TEXT, query.Type_BLOB:
   249  		return []types.NomsKind{
   250  			types.BlobKind,
   251  			types.StringKind,
   252  		}
   253  
   254  	case query.Type_VARCHAR, query.Type_CHAR:
   255  		return []types.NomsKind{types.StringKind}
   256  
   257  	case query.Type_VARBINARY, query.Type_BINARY:
   258  		return []types.NomsKind{types.InlineBlobKind}
   259  
   260  	case query.Type_BIT, query.Type_ENUM, query.Type_SET:
   261  		return []types.NomsKind{types.UintKind}
   262  
   263  	case query.Type_GEOMETRY:
   264  		return []types.NomsKind{
   265  			types.GeometryKind,
   266  			types.PointKind,
   267  			types.LineStringKind,
   268  			types.PolygonKind,
   269  			types.MultiPointKind,
   270  			types.MultiLineStringKind,
   271  			types.MultiPolygonKind,
   272  			types.GeometryCollectionKind,
   273  		}
   274  
   275  	case query.Type_JSON:
   276  		return []types.NomsKind{types.JSONKind}
   277  
   278  	default:
   279  		panic(fmt.Sprintf("unexpect query.Type %s", qt.String()))
   280  	}
   281  }
   282  
   283  func assertNomsKind(kind types.NomsKind, candidates ...types.NomsKind) error {
   284  	for _, c := range candidates {
   285  		if kind == c {
   286  			return nil
   287  		}
   288  	}
   289  
   290  	cs := make([]string, len(candidates))
   291  	for i, c := range candidates {
   292  		cs[i] = types.KindToString[c]
   293  	}
   294  	return fmt.Errorf("expected NomsKind to be one of (%s), got NomsKind (%s)",
   295  		strings.Join(cs, ", "), types.KindToString[kind])
   296  }
   297  
   298  func partitionTable(ctx context.Context, tbl *doltdb.Table) ([][2]uint64, error) {
   299  	idx, err := tbl.GetRowData(ctx)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	c, err := idx.Count()
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	if c == 0 {
   309  		return nil, nil
   310  	}
   311  	n := runtime.NumCPU() * 2
   312  	szc, err := idx.Count()
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  	sz := int(szc) / n
   317  
   318  	parts := make([][2]uint64, n)
   319  
   320  	parts[0][0] = 0
   321  	parts[n-1][1], err = idx.Count()
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	for i := 1; i < len(parts); i++ {
   327  		parts[i-1][1] = uint64(i * sz)
   328  		parts[i][0] = uint64(i * sz)
   329  	}
   330  
   331  	return parts, nil
   332  }
   333  
   334  func assertTrue(b bool) {
   335  	if !b {
   336  		panic("expected true")
   337  	}
   338  }