github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/migrate/validation.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package migrate 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "runtime" 22 "strings" 23 "time" 24 "unicode" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 28 "github.com/dolthub/vitess/go/vt/proto/query" 29 "golang.org/x/sync/errgroup" 30 31 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 32 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 33 "github.com/dolthub/dolt/go/libraries/doltcore/sqle" 34 "github.com/dolthub/dolt/go/store/types" 35 ) 36 37 func validateBranchMapping(ctx context.Context, old, new *doltdb.DoltDB) error { 38 branches, err := old.GetBranches(ctx) 39 if err != nil { 40 return err 41 } 42 43 var ok bool 44 for _, bref := range branches { 45 _, ok, err = new.HasBranch(ctx, bref.GetPath()) 46 if err != nil { 47 return err 48 } 49 if !ok { 50 return fmt.Errorf("failed to map branch %s", bref.GetPath()) 51 } 52 } 53 return nil 54 } 55 56 func validateRootValue(ctx context.Context, oldParent, old, new doltdb.RootValue) error { 57 names, err := old.GetTableNames(ctx, doltdb.DefaultSchemaName) 58 if err != nil { 59 return err 60 } 61 for _, name := range names { 62 o, ok, err := old.GetTable(ctx, doltdb.TableName{Name: name}) 63 if err != nil { 64 return err 65 } 66 if !ok { 67 h, _ := old.HashOf() 68 return fmt.Errorf("expected to find table %s in root value (%s)", name, h.String()) 69 } 70 71 // Skip tables that haven't changed 72 op, ok, err := oldParent.GetTable(ctx, doltdb.TableName{Name: name}) 73 if err != nil { 74 return err 75 } 76 if ok { 77 oldHash, err := o.HashOf() 78 if err != nil { 79 return err 80 } 81 oldParentHash, err := op.HashOf() 82 if err != nil { 83 return err 84 } 85 if oldHash.Equal(oldParentHash) { 86 continue 87 } 88 } 89 90 n, ok, err := new.GetTable(ctx, doltdb.TableName{Name: name}) 91 if err != nil { 92 return err 93 } 94 if !ok { 95 h, _ := new.HashOf() 96 return fmt.Errorf("expected to find table %s in root value (%s)", name, h.String()) 97 } 98 99 if err = validateTableData(ctx, name, o, n); err != nil { 100 return err 101 } 102 } 103 return nil 104 } 105 106 func validateTableData(ctx context.Context, name string, old, new *doltdb.Table) error { 107 parts, err := partitionTable(ctx, old) 108 if err != nil { 109 return err 110 } else if len(parts) == 0 { 111 return nil 112 } 113 114 eg, ctx := errgroup.WithContext(ctx) 115 for i := range parts { 116 start, end := parts[i][0], parts[i][1] 117 eg.Go(func() error { 118 return validateTableDataPartition(ctx, name, old, new, start, end) 119 }) 120 } 121 122 return eg.Wait() 123 } 124 125 func validateTableDataPartition(ctx context.Context, name string, old, new *doltdb.Table, start, end uint64) error { 126 sctx := sql.NewContext(ctx) 127 _, oldIter, err := sqle.DoltTablePartitionToRowIter(sctx, name, old, start, end) 128 if err != nil { 129 return err 130 } 131 newSch, newIter, err := sqle.DoltTablePartitionToRowIter(sctx, name, new, start, end) 132 if err != nil { 133 return err 134 } 135 136 var o, n sql.Row 137 for { 138 o, err = oldIter.Next(sctx) 139 if err == io.EOF { 140 break 141 } else if err != nil { 142 return err 143 } 144 145 n, err = newIter.Next(sctx) 146 if err != nil { 147 return err 148 } 149 150 ok, err := equalRows(o, n, newSch) 151 if err != nil { 152 return err 153 } else if !ok { 154 return fmt.Errorf("differing rows for table %s (%s != %s)", 155 name, sql.FormatRow(o), sql.FormatRow(n)) 156 } 157 } 158 159 // validated that newIter is also exhausted 160 _, err = newIter.Next(sctx) 161 if err != io.EOF { 162 return fmt.Errorf("differing number of rows for table %s", name) 163 } 164 return nil 165 } 166 167 func equalRows(old, new sql.Row, sch sql.Schema) (bool, error) { 168 if len(new) != len(old) || len(new) != len(sch) { 169 return false, nil 170 } 171 172 var err error 173 var cmp int 174 for i := range new { 175 176 // special case string comparisons 177 if s, ok := old[i].(string); ok { 178 old[i] = strings.TrimRightFunc(s, unicode.IsSpace) 179 } 180 if s, ok := new[i].(string); ok { 181 new[i] = strings.TrimRightFunc(s, unicode.IsSpace) 182 } 183 184 // special case time comparison to account 185 // for precision changes between formats 186 if _, ok := old[i].(time.Time); ok { 187 var o, n interface{} 188 if o, _, err = gmstypes.Int64.Convert(old[i]); err != nil { 189 return false, err 190 } 191 if n, _, err = gmstypes.Int64.Convert(new[i]); err != nil { 192 return false, err 193 } 194 if cmp, err = gmstypes.Int64.Compare(o, n); err != nil { 195 return false, err 196 } 197 } else { 198 if cmp, err = sch[i].Type.Compare(old[i], new[i]); err != nil { 199 return false, err 200 } 201 } 202 if cmp != 0 { 203 return false, nil 204 } 205 } 206 return true, nil 207 } 208 209 func validateSchema(existing schema.Schema) error { 210 for _, c := range existing.GetAllCols().GetColumns() { 211 qt := c.TypeInfo.ToSqlType().Type() 212 err := assertNomsKind(c.Kind, nomsKindsFromQueryTypes(qt)...) 213 if err != nil { 214 return err 215 } 216 } 217 return nil 218 } 219 220 func nomsKindsFromQueryTypes(qt query.Type) []types.NomsKind { 221 switch qt { 222 case query.Type_UINT8: 223 return []types.NomsKind{types.UintKind, types.BoolKind} 224 225 case query.Type_UINT16, query.Type_UINT24, 226 query.Type_UINT32, query.Type_UINT64: 227 return []types.NomsKind{types.UintKind} 228 229 case query.Type_INT8: 230 return []types.NomsKind{types.IntKind, types.BoolKind} 231 232 case query.Type_INT16, query.Type_INT24, 233 query.Type_INT32, query.Type_INT64: 234 return []types.NomsKind{types.IntKind} 235 236 case query.Type_YEAR, query.Type_TIME: 237 return []types.NomsKind{types.IntKind} 238 239 case query.Type_FLOAT32, query.Type_FLOAT64: 240 return []types.NomsKind{types.FloatKind} 241 242 case query.Type_TIMESTAMP, query.Type_DATE, query.Type_DATETIME: 243 return []types.NomsKind{types.TimestampKind} 244 245 case query.Type_DECIMAL: 246 return []types.NomsKind{types.DecimalKind} 247 248 case query.Type_TEXT, query.Type_BLOB: 249 return []types.NomsKind{ 250 types.BlobKind, 251 types.StringKind, 252 } 253 254 case query.Type_VARCHAR, query.Type_CHAR: 255 return []types.NomsKind{types.StringKind} 256 257 case query.Type_VARBINARY, query.Type_BINARY: 258 return []types.NomsKind{types.InlineBlobKind} 259 260 case query.Type_BIT, query.Type_ENUM, query.Type_SET: 261 return []types.NomsKind{types.UintKind} 262 263 case query.Type_GEOMETRY: 264 return []types.NomsKind{ 265 types.GeometryKind, 266 types.PointKind, 267 types.LineStringKind, 268 types.PolygonKind, 269 types.MultiPointKind, 270 types.MultiLineStringKind, 271 types.MultiPolygonKind, 272 types.GeometryCollectionKind, 273 } 274 275 case query.Type_JSON: 276 return []types.NomsKind{types.JSONKind} 277 278 default: 279 panic(fmt.Sprintf("unexpect query.Type %s", qt.String())) 280 } 281 } 282 283 func assertNomsKind(kind types.NomsKind, candidates ...types.NomsKind) error { 284 for _, c := range candidates { 285 if kind == c { 286 return nil 287 } 288 } 289 290 cs := make([]string, len(candidates)) 291 for i, c := range candidates { 292 cs[i] = types.KindToString[c] 293 } 294 return fmt.Errorf("expected NomsKind to be one of (%s), got NomsKind (%s)", 295 strings.Join(cs, ", "), types.KindToString[kind]) 296 } 297 298 func partitionTable(ctx context.Context, tbl *doltdb.Table) ([][2]uint64, error) { 299 idx, err := tbl.GetRowData(ctx) 300 if err != nil { 301 return nil, err 302 } 303 304 c, err := idx.Count() 305 if err != nil { 306 return nil, err 307 } 308 if c == 0 { 309 return nil, nil 310 } 311 n := runtime.NumCPU() * 2 312 szc, err := idx.Count() 313 if err != nil { 314 return nil, err 315 } 316 sz := int(szc) / n 317 318 parts := make([][2]uint64, n) 319 320 parts[0][0] = 0 321 parts[n-1][1], err = idx.Count() 322 if err != nil { 323 return nil, err 324 } 325 326 for i := 1; i < len(parts); i++ { 327 parts[i-1][1] = uint64(i * sz) 328 parts[i][0] = uint64(i * sz) 329 } 330 331 return parts, nil 332 } 333 334 func assertTrue(b bool) { 335 if !b { 336 panic("expected true") 337 } 338 }