github.com/dolthub/go-mysql-server@v0.18.0/sql/func_deps_test.go (about) 1 package sql 2 3 import ( 4 "fmt" 5 "sort" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestFuncDeps_Project(t *testing.T) { 12 t.Run("project const via equiv", func(t *testing.T) { 13 { 14 // a == b, a const, proj(b) => maintian const(b) 15 fds := &FuncDepSet{all: cols(1, 2, 3)} 16 fds.AddConstants(cols(1)) 17 fds.AddEquiv(1, 2) 18 proj := NewProjectFDs(fds, cols(2), false) 19 assert.Equal(t, "(2)", proj.Constants().String()) 20 } 21 }) 22 t.Run("project pk via equiv", func(t *testing.T) { 23 { 24 // pk(a,b), a == c, proj(b,c) => maintain pk(c,b) 25 fds := &FuncDepSet{all: cols(1, 2, 3)} 26 fds.AddEquiv(1, 3) 27 fds.AddStrictKey(cols(1, 2)) 28 proj := NewProjectFDs(fds, cols(2, 3), false) 29 assert.Equal(t, "key(2,3)", proj.String()) 30 } 31 32 }) 33 t.Run("distinct project adds strict key", func(t *testing.T) { 34 fds := &FuncDepSet{all: cols(1, 2, 3)} 35 fds.AddLaxKey(cols(1, 2, 3)) 36 proj := NewProjectFDs(fds, cols(1, 2, 3), true) 37 assert.Equal(t, "key(1-3)", proj.String()) 38 }) 39 t.Run("columns preserved", func(t *testing.T) { 40 // a == b, b == c, proj(a,c) maintains a == c 41 fds := &FuncDepSet{all: cols(1, 2, 3)} 42 fds.AddEquivSet(cols(1, 2, 3)) 43 proj := NewProjectFDs(fds, cols(1, 3), false) 44 assert.Equal(t, "equiv(1,3)", proj.String()) 45 }) 46 t.Run("remove strict determinant constant", func(t *testing.T) { 47 fds := &FuncDepSet{all: cols(1, 2, 3)} 48 fds.AddConstants(cols(1)) 49 fds.AddStrictKey(cols(1, 2)) 50 proj := NewProjectFDs(fds, cols(2), false) 51 assert.Equal(t, "key(2)", proj.String()) 52 }) 53 t.Run("remove lax determinant constant", func(t *testing.T) { 54 fds := &FuncDepSet{all: cols(1, 2, 3)} 55 fds.AddConstants(cols(1)) 56 fds.AddLaxKey(cols(1, 2)) 57 proj := NewProjectFDs(fds, cols(2), false) 58 assert.Equal(t, "lax-key(2)", proj.String()) 59 }) 60 } 61 62 func TestFuncDeps_CrossJoin(t *testing.T) { 63 // create table abcde (a primary key, b int, c int not null, d int not null, e int not null) 64 // create table mnpq (m primary key, n int, p int not null, q int not null) 65 t.Run("cross product", func(t *testing.T) { 66 abcde := &FuncDepSet{} 67 abcde.AddNotNullable(cols(1)) 68 abcde.AddStrictKey(cols(1)) 69 abcde.AddLaxKey(cols(2, 3)) 70 71 mnpq := &FuncDepSet{} 72 mnpq.AddNotNullable(cols(6, 7)) 73 mnpq.AddStrictKey(cols(6, 7)) 74 75 join := NewCrossJoinFDs(abcde, mnpq) 76 assert.Equal(t, "key(1,6,7); fd(1); lax-fd(2,3); fd(6,7)", join.String()) 77 }) 78 t.Run("cross product one-sided equiv", func(t *testing.T) { 79 abcde := &FuncDepSet{} 80 abcde.AddNotNullable(cols(1)) 81 abcde.AddEquiv(1, 5) 82 abcde.AddStrictKey(cols(1)) 83 abcde.AddLaxKey(cols(2, 3)) 84 85 mnpq := &FuncDepSet{} 86 mnpq.AddNotNullable(cols(6, 7)) 87 mnpq.AddStrictKey(cols(6, 7)) 88 89 join := NewCrossJoinFDs(abcde, mnpq) 90 assert.Equal(t, "key(1,6,7); equiv(1,5); fd(1); lax-fd(2,3); fd(6,7)", join.String()) 91 }) 92 } 93 94 func TestFuncDeps_InnerJoin(t *testing.T) { 95 t.Run("abcde X mnpq", func(t *testing.T) { 96 // abcde JOIN mnpq ON a = m WHERE n = 2 97 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 98 abcde.AddNotNullable(cols(1)) 99 abcde.AddStrictKey(cols(1)) 100 abcde.AddLaxKey(cols(2, 3)) 101 102 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 103 mnpq.AddNotNullable(cols(6, 7)) 104 mnpq.AddConstants(cols(7)) 105 mnpq.AddStrictKey(cols(6, 7)) 106 107 join := NewInnerJoinFDs(abcde, mnpq, [][2]ColumnId{{1, 6}}) 108 assert.Equal(t, "key(6); constant(7); equiv(1,6); fd(1)/(1-5); lax-fd(2,3)/(1-5); fd(6)/(6-9)", join.String()) 109 }) 110 111 t.Run("ware X cust", func(t *testing.T) { 112 // create table customer (id primary key, did not null, wid int not null, first varchar(10), last varchar(10)) 113 // create table warehouse (id primary key) 114 115 // c.wid = w.id 116 // SELECT * from cust join ware 117 // ON c_w_id = w_id AND 118 // WHERE w_id = 1 AND c_d_id = 2 AND c_id = 2327 119 120 cust := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 121 cust.AddNotNullable(cols(1, 2, 3)) 122 cust.AddConstants(cols(1, 2)) 123 cust.AddStrictKey(cols(3, 2, 1)) 124 cust.AddLaxKey(cols(3, 2, 4, 5)) 125 126 ware := &FuncDepSet{all: cols(6)} 127 ware.AddNotNullable(cols(6)) 128 ware.AddConstants(cols(6)) 129 ware.AddStrictKey(cols(6)) 130 131 join := NewInnerJoinFDs(cust, ware, [][2]ColumnId{{3, 6}}) 132 assert.Equal(t, "key(); constant(1-3,6); equiv(3,6); fd(3)/(1-5); fd()/(6)", join.String()) 133 }) 134 t.Run("equiv on both sides inner join", func(t *testing.T) { 135 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 136 abcde.AddNotNullable(cols(1)) 137 abcde.AddEquivSet(cols(2, 3, 4)) 138 abcde.AddStrictKey(cols(1)) 139 abcde.AddLaxKey(cols(2, 3)) 140 141 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 142 mnpq.AddNotNullable(cols(6, 7)) 143 mnpq.AddEquivSet(cols(6, 8, 9)) 144 mnpq.AddStrictKey(cols(6, 7)) 145 146 join := NewInnerJoinFDs(mnpq, abcde, [][2]ColumnId{}) 147 assert.Equal(t, "key(1,6,7); equiv(6,8,9); equiv(2-4); fd(6,7)/(6-9); fd(1)/(1-5); lax-fd(3)/(1-5)", join.String()) 148 }) 149 t.Run("max1Row inner join", func(t *testing.T) { 150 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 151 abcde.AddNotNullable(cols(1, 2, 3)) 152 abcde.AddConstants(cols(3)) 153 abcde.AddEquiv(2, 3) 154 abcde.AddStrictKey(cols(1)) 155 abcde.AddLaxKey(cols(2, 3)) 156 157 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 158 mnpq.AddNotNullable(cols(6, 7)) 159 mnpq.AddConstants(cols(6, 7)) 160 mnpq.AddStrictKey(cols(6, 7)) 161 162 join := NewInnerJoinFDs(mnpq, abcde, [][2]ColumnId{{1, 6}, {1, 2}}) 163 assert.Equal(t, "key(); constant(1-9); equiv(1-3,6)", join.String()) 164 }) 165 t.Run("infer constants from max1Row", func(t *testing.T) { 166 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 167 abcde.AddNotNullable(cols(1, 2, 3)) 168 abcde.AddConstants(cols(1)) 169 abcde.AddStrictKey(cols(1)) 170 171 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 172 mnpq.AddNotNullable(cols(6)) 173 mnpq.AddStrictKey(cols(6)) 174 175 join := NewInnerJoinFDs(mnpq, abcde, [][2]ColumnId{{1, 7}}) 176 assert.Equal(t, "key(6); constant(1-5,7); equiv(1,7); fd(6)/(6-9); fd()/(1-5)", join.String()) 177 }) 178 t.Run("simplify cols on join", func(t *testing.T) { 179 // create table t1 (id int primary key, value int) 180 // create table t2 (id int primary key, value int) 181 182 // SELECT * FROM t1 JOIN t2 ON t1.value = t2.id; 183 184 t1 := &FuncDepSet{all: cols(1, 2)} 185 t1.AddNotNullable(cols(1)) 186 t1.AddStrictKey(cols(1)) 187 188 t2 := &FuncDepSet{all: cols(3, 4)} 189 t2.AddNotNullable(cols(3)) 190 t2.AddStrictKey(cols(3)) 191 192 join := NewInnerJoinFDs(t1, t2, [][2]ColumnId{{2, 3}}) 193 assert.Equal(t, "key(1); equiv(2,3); fd(1)/(1,2); fd(3)/(3,4)", join.String()) 194 }) 195 t.Run("simplify cols on join on primary keys", func(t *testing.T) { 196 // create table t1 (id int primary key, value int) 197 // create table t2 (id int primary key, value int) 198 // create table t3 (id int primary key, value int) 199 // create table t4 (id int primary key, value int) 200 201 // SELECT * FROM t1 JOIN t2 ON t1.id = t2.id JOIN t3 ON t2.id = t3.id JOIN t3.id = t4.id; 202 203 t1 := &FuncDepSet{all: cols(1, 2)} 204 t1.AddNotNullable(cols(1)) 205 t1.AddStrictKey(cols(1)) 206 207 t2 := &FuncDepSet{all: cols(3, 4)} 208 t2.AddNotNullable(cols(3)) 209 t2.AddStrictKey(cols(3)) 210 211 t3 := &FuncDepSet{all: cols(5, 6)} 212 t3.AddNotNullable(cols(5)) 213 t3.AddStrictKey(cols(5)) 214 215 t4 := &FuncDepSet{all: cols(7, 8)} 216 t4.AddNotNullable(cols(7)) 217 t4.AddStrictKey(cols(7)) 218 219 join12 := NewInnerJoinFDs(t1, t2, [][2]ColumnId{{1, 3}}) 220 join123 := NewInnerJoinFDs(join12, t3, [][2]ColumnId{{3, 5}}) 221 join1234 := NewInnerJoinFDs(join123, t4, [][2]ColumnId{{5, 7}}) 222 assert.Equal(t, "key(7); equiv(1,3,5,7); fd(5)/(1-6); fd(3)/(1-4); fd(1)/(1,2); fd(3)/(3,4); fd(5)/(5,6); fd(7)/(7,8)", join1234.String()) 223 }) 224 t.Run("simplify cols on bushy join", func(t *testing.T) { 225 // create table t1 (id int primary key, value int) 226 // create table t2 (id int primary key, value int) 227 // create table t3 (id int primary key, value int) 228 // create table t4 (id int primary key, value int) 229 230 // SELECT * FROM (t2 JOIN t3 ON t2.value = t3.id) JOIN (t4 JOIN t1) ON t1.value = t2.id AND t3.value = t4.id; 231 232 t1 := &FuncDepSet{all: cols(1, 2)} 233 t1.AddNotNullable(cols(1)) 234 t1.AddStrictKey(cols(1)) 235 236 t2 := &FuncDepSet{all: cols(3, 4)} 237 t2.AddNotNullable(cols(3)) 238 t2.AddStrictKey(cols(3)) 239 240 t3 := &FuncDepSet{all: cols(5, 6)} 241 t3.AddNotNullable(cols(5)) 242 t3.AddStrictKey(cols(5)) 243 244 t4 := &FuncDepSet{all: cols(7, 8)} 245 t4.AddNotNullable(cols(7)) 246 t4.AddStrictKey(cols(7)) 247 248 join23 := NewInnerJoinFDs(t2, t3, [][2]ColumnId{{4, 5}}) 249 join14 := NewCrossJoinFDs(t1, t4) 250 join := NewInnerJoinFDs(join23, join14, [][2]ColumnId{{2, 3}, {6, 7}}) 251 assert.Equal(t, "key(1); equiv(4,5); equiv(2,3); equiv(6,7); fd(3)/(3-6); fd(3)/(3,4); fd(5)/(5,6); fd(1,7)/(1,2,7,8); fd(1)/(1,2); fd(7)/(7,8)", join.String()) 252 }) 253 t.Run("simplify cols on nested join", func(t *testing.T) { 254 // create table t1 (id int primary key, value int) 255 // create table t2 (id int primary key, value int) 256 // create table t3 (id int primary key, value int) 257 258 // SELECT * FROM (t1 JOIN t3) JOIN t2 ON t1.value = t2.id AND t2.value = t3.id; 259 260 t1 := &FuncDepSet{all: cols(1, 2)} 261 t1.AddNotNullable(cols(1)) 262 t1.AddStrictKey(cols(1)) 263 264 t2 := &FuncDepSet{all: cols(3, 4)} 265 t2.AddNotNullable(cols(3)) 266 t2.AddStrictKey(cols(3)) 267 268 t3 := &FuncDepSet{all: cols(5, 6)} 269 t3.AddNotNullable(cols(5)) 270 t3.AddStrictKey(cols(5)) 271 272 join13 := NewCrossJoinFDs(t1, t3) 273 join := NewInnerJoinFDs(join13, t2, [][2]ColumnId{{2, 3}, {4, 5}}) 274 assert.Equal(t, "key(1); equiv(2,3); equiv(4,5); fd(1,5)/(1,2,5,6); fd(1)/(1,2); fd(5)/(5,6); fd(3)/(3,4)", join.String()) 275 }) 276 } 277 278 func TestFuncDeps_LeftJoin(t *testing.T) { 279 t.Run("preserved-side constants kept", func(t *testing.T) { 280 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 281 abcde.AddNotNullable(cols(1)) 282 abcde.AddStrictKey(cols(1)) 283 abcde.AddLaxKey(cols(2, 3)) 284 285 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 286 mnpq.AddNotNullable(cols(6, 7)) 287 mnpq.AddConstants(cols(8, 9)) 288 mnpq.AddStrictKey(cols(6, 7)) 289 290 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{}) 291 assert.Equal(t, "key(1,6,7); constant(8,9); lax-fd(2,3)/(1-5)", join.String()) 292 }) 293 t.Run("preserved-side equiv constants kept", func(t *testing.T) { 294 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 295 abcde.AddNotNullable(cols(1)) 296 abcde.AddStrictKey(cols(1)) 297 abcde.AddLaxKey(cols(2, 3)) 298 299 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 300 mnpq.AddNotNullable(cols(6, 7)) 301 mnpq.AddEquiv(8, 9) 302 mnpq.AddConstants(cols(8)) 303 mnpq.AddStrictKey(cols(6, 7)) 304 305 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{}) 306 assert.Equal(t, "key(1,6,7); constant(8,9); equiv(8,9); lax-fd(2,3)/(1-5)", join.String()) 307 }) 308 t.Run("preserved-side key constants kept", func(t *testing.T) { 309 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 310 abcde.AddNotNullable(cols(1)) 311 abcde.AddStrictKey(cols(1)) 312 abcde.AddLaxKey(cols(2, 3)) 313 314 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 315 mnpq.AddNotNullable(cols(6, 7)) 316 mnpq.AddConstants(cols(6, 8, 9)) 317 mnpq.AddStrictKey(cols(6, 7)) 318 319 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{}) 320 assert.Equal(t, "key(1,7); constant(6,8,9); lax-fd(2,3)/(1-5)", join.String()) 321 }) 322 t.Run("null-project side constants removed", func(t *testing.T) { 323 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 324 abcde.AddNotNullable(cols(1)) 325 abcde.AddStrictKey(cols(1)) 326 abcde.AddConstants(cols(3, 4)) 327 abcde.AddLaxKey(cols(2, 3)) 328 329 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 330 mnpq.AddNotNullable(cols(6, 7)) 331 mnpq.AddStrictKey(cols(6, 7)) 332 333 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{}) 334 assert.Equal(t, "key(1,6,7); lax-fd(2)/(1-5)", join.String()) 335 }) 336 t.Run("equiv on both sides left join", func(t *testing.T) { 337 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 338 abcde.AddNotNullable(cols(1)) 339 abcde.AddEquivSet(cols(2, 3, 4)) 340 abcde.AddStrictKey(cols(1)) 341 abcde.AddLaxKey(cols(2, 3)) 342 343 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 344 mnpq.AddNotNullable(cols(6, 7)) 345 mnpq.AddEquivSet(cols(6, 8, 9)) 346 mnpq.AddStrictKey(cols(6, 7)) 347 348 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{}) 349 assert.Equal(t, "key(1,6,7); equiv(6,8,9); lax-fd(3)/(1-5)", join.String()) 350 }) 351 t.Run("join filter equiv", func(t *testing.T) { 352 // SELECT * FROM abcde RIGHT OUTER JOIN mnpq ON a=m 353 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 354 abcde.AddNotNullable(cols(1)) 355 abcde.AddStrictKey(cols(1)) 356 abcde.AddLaxKey(cols(2, 3)) 357 358 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 359 mnpq.AddNotNullable(cols(6, 7)) 360 mnpq.AddStrictKey(cols(6, 7)) 361 362 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{{1, 6}}) 363 assert.Equal(t, "key(6,7); fd(1)/(1-5); lax-fd(2,3)/(1-5)", join.String()) 364 }) 365 t.Run("join filter equiv and null-side rel equiv", func(t *testing.T) { 366 // SELECT * FROM abcde RIGHT OUTER JOIN mnpq ON a=m AND a=b 367 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 368 abcde.AddNotNullable(cols(1)) 369 abcde.AddStrictKey(cols(1)) 370 abcde.AddLaxKey(cols(2, 3)) 371 372 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 373 mnpq.AddNotNullable(cols(6, 7)) 374 mnpq.AddStrictKey(cols(6, 7)) 375 376 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{{1, 6}, {1, 2}}) 377 assert.Equal(t, "key(6,7); fd(1)/(1-5); lax-fd(2,3)/(1-5)", join.String()) 378 }) 379 t.Run("max1Row left join", func(t *testing.T) { 380 abcde := &FuncDepSet{all: cols(1, 2, 3, 4, 5)} 381 abcde.AddNotNullable(cols(1, 2, 3)) 382 abcde.AddConstants(cols(3)) 383 abcde.AddEquiv(2, 3) 384 abcde.AddStrictKey(cols(1)) 385 abcde.AddLaxKey(cols(2, 3)) 386 387 mnpq := &FuncDepSet{all: cols(6, 7, 8, 9)} 388 mnpq.AddNotNullable(cols(6, 7)) 389 mnpq.AddConstants(cols(6, 7)) 390 mnpq.AddStrictKey(cols(6, 7)) 391 392 join := NewLeftJoinFDs(mnpq, abcde, [][2]ColumnId{{1, 6}, {1, 2}}) 393 assert.Equal(t, "key(); constant(1,6,7)", join.String()) 394 }) 395 } 396 397 func TestEquivSets(t *testing.T) { 398 tests := []struct { 399 name string 400 sets []ColSet 401 exp EquivSets 402 }{ 403 { 404 name: "all overlap", 405 sets: []ColSet{ 406 cols(1, 2), 407 cols(2, 3), 408 cols(3, 4), 409 }, 410 exp: EquivSets{sets: []ColSet{cols(1, 2, 3, 4)}}, 411 }, 412 { 413 name: "no overlap", 414 sets: []ColSet{ 415 cols(1, 2), 416 cols(3, 4), 417 cols(5, 6), 418 }, 419 exp: EquivSets{sets: []ColSet{cols(1, 2), cols(3, 4), cols(5, 6)}}, 420 }, 421 { 422 name: "add merges two previous sets", 423 sets: []ColSet{ 424 cols(1, 2), 425 cols(3, 4), 426 cols(2, 3), 427 }, 428 exp: EquivSets{sets: []ColSet{cols(1, 2, 3, 4)}}, 429 }, 430 { 431 name: "add merges one previous set", 432 sets: []ColSet{ 433 cols(1, 2), 434 cols(3, 4), 435 cols(2, 6), 436 }, 437 exp: EquivSets{sets: []ColSet{cols(1, 2, 6), cols(3, 4)}}, 438 }, 439 } 440 for _, tt := range tests { 441 t.Run(tt.name, func(t *testing.T) { 442 equiv := EquivSets{} 443 for _, set := range tt.sets { 444 equiv.Add(set) 445 } 446 sort.Slice(equiv.sets, func(i, j int) bool { 447 return equiv.sets[i].set.String() < equiv.sets[j].set.String() 448 }) 449 assert.Equal(t, tt.exp, equiv, fmt.Sprintf("exp != found:\n [%s]\n [%s]", tt.exp.String(), equiv.String())) 450 }) 451 } 452 } 453 454 func cols(vals ...ColumnId) ColSet { 455 return NewColSet(vals...) 456 }