github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/range_heap_iter.go (about) 1 package rowexec 2 3 import ( 4 "container/heap" 5 "errors" 6 "io" 7 "reflect" 8 9 "go.opentelemetry.io/otel/attribute" 10 "go.opentelemetry.io/otel/trace" 11 12 "github.com/dolthub/go-mysql-server/sql" 13 "github.com/dolthub/go-mysql-server/sql/plan" 14 ) 15 16 func newRangeHeapJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { 17 var leftName, rightName string 18 if leftTable, ok := j.Left().(sql.Nameable); ok { 19 leftName = leftTable.Name() 20 } else { 21 leftName = reflect.TypeOf(j.Left()).String() 22 } 23 24 if rightTable, ok := j.Right().(sql.Nameable); ok { 25 rightName = rightTable.Name() 26 } else { 27 rightName = reflect.TypeOf(j.Right()).String() 28 } 29 30 span, ctx := ctx.Span("plan.rangeHeapJoinIter", trace.WithAttributes( 31 attribute.String("left", leftName), 32 attribute.String("right", rightName), 33 )) 34 35 l, err := b.Build(ctx, j.Left(), row) 36 if err != nil { 37 span.End() 38 return nil, err 39 } 40 41 rhp, ok := j.Right().(*plan.RangeHeap) 42 if !ok { 43 return nil, errors.New("right side of join must be a range heap") 44 } 45 46 return sql.NewSpanIter(span, &rangeHeapJoinIter{ 47 parentRow: row, 48 primary: l, 49 cond: j.Filter, 50 joinType: j.Op, 51 rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()), 52 scopeLen: j.ScopeLen, 53 b: b, 54 rangeHeapPlan: rhp, 55 }), nil 56 } 57 58 // joinIter is an iterator that iterates over every row in the primary table and performs an index lookup in 59 // the secondary table for each value 60 type rangeHeapJoinIter struct { 61 parentRow sql.Row 62 primary sql.RowIter 63 primaryRow sql.Row 64 secondary sql.RowIter 65 cond sql.Expression 66 joinType plan.JoinType 67 68 foundMatch bool 69 rowSize int 70 scopeLen int 71 b sql.NodeExecBuilder 72 73 rangeHeapPlan *plan.RangeHeap 74 childRowIter sql.RowIter 75 pendingRow sql.Row 76 77 activeRanges []sql.Row 78 err error 79 } 80 81 func (iter *rangeHeapJoinIter) loadPrimary(ctx *sql.Context) error { 82 if iter.primaryRow == nil { 83 r, err := iter.primary.Next(ctx) 84 if err != nil { 85 return err 86 } 87 88 iter.primaryRow = iter.parentRow.Append(r) 89 iter.foundMatch = false 90 91 err = iter.initializeHeap(ctx, iter.b, iter.primaryRow) 92 if err != nil { 93 return err 94 } 95 } 96 97 return nil 98 } 99 100 func (iter *rangeHeapJoinIter) loadSecondary(ctx *sql.Context) (sql.Row, error) { 101 if iter.secondary == nil { 102 rowIter, err := iter.getActiveRanges(ctx, iter.b, iter.primaryRow) 103 104 if err != nil { 105 return nil, err 106 } 107 if plan.IsEmptyIter(rowIter) { 108 return nil, plan.ErrEmptyCachedResult 109 } 110 iter.secondary = rowIter 111 } 112 113 secondaryRow, err := iter.secondary.Next(ctx) 114 if err != nil { 115 if err == io.EOF { 116 err = iter.secondary.Close(ctx) 117 iter.secondary = nil 118 if err != nil { 119 return nil, err 120 } 121 iter.primaryRow = nil 122 return nil, io.EOF 123 } 124 return nil, err 125 } 126 127 return secondaryRow, nil 128 } 129 130 func (iter *rangeHeapJoinIter) Next(ctx *sql.Context) (sql.Row, error) { 131 for { 132 if err := iter.loadPrimary(ctx); err != nil { 133 return nil, err 134 } 135 136 primary := iter.primaryRow 137 secondary, err := iter.loadSecondary(ctx) 138 if err != nil { 139 if errors.Is(err, io.EOF) { 140 if !iter.foundMatch && iter.joinType.IsLeftOuter() { 141 iter.primaryRow = nil 142 row := iter.buildRow(primary, nil) 143 return iter.removeParentRow(row), nil 144 } 145 continue 146 } else if errors.Is(err, plan.ErrEmptyCachedResult) { 147 if !iter.foundMatch && iter.joinType.IsLeftOuter() { 148 iter.primaryRow = nil 149 row := iter.buildRow(primary, nil) 150 return iter.removeParentRow(row), nil 151 } 152 153 return nil, io.EOF 154 } 155 return nil, err 156 } 157 158 row := iter.buildRow(primary, secondary) 159 res, err := iter.cond.Eval(ctx, row) 160 matches := res == true 161 if err != nil { 162 return nil, err 163 } 164 165 if res == nil && iter.joinType.IsExcludeNulls() { 166 err = iter.secondary.Close(ctx) 167 iter.secondary = nil 168 if err != nil { 169 return nil, err 170 } 171 iter.primaryRow = nil 172 continue 173 } 174 175 if !matches { 176 continue 177 } 178 179 iter.foundMatch = true 180 return iter.removeParentRow(row), nil 181 } 182 } 183 184 func (iter *rangeHeapJoinIter) removeParentRow(r sql.Row) sql.Row { 185 copy(r[iter.scopeLen:], r[len(iter.parentRow):]) 186 r = r[:len(r)-len(iter.parentRow)+iter.scopeLen] 187 return r 188 } 189 190 // buildRow builds the result set row using the rows from the primary and secondary tables 191 func (iter *rangeHeapJoinIter) buildRow(primary, secondary sql.Row) sql.Row { 192 row := make(sql.Row, iter.rowSize) 193 194 copy(row, primary) 195 copy(row[len(primary):], secondary) 196 197 return row 198 } 199 200 func (iter *rangeHeapJoinIter) Close(ctx *sql.Context) (err error) { 201 if iter.primary != nil { 202 if err = iter.primary.Close(ctx); err != nil { 203 if iter.secondary != nil { 204 _ = iter.secondary.Close(ctx) 205 } 206 return err 207 } 208 } 209 210 if iter.secondary != nil { 211 err = iter.secondary.Close(ctx) 212 iter.secondary = nil 213 } 214 215 return err 216 } 217 218 func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.NodeExecBuilder, primaryRow sql.Row) (err error) { 219 iter.childRowIter, err = builder.Build(ctx, iter.rangeHeapPlan.Child, primaryRow) 220 if err != nil { 221 return err 222 } 223 iter.activeRanges = nil 224 iter.rangeHeapPlan.ComparisonType = iter.rangeHeapPlan.Schema()[iter.rangeHeapPlan.MaxColumnIndex].Type 225 226 iter.pendingRow, err = iter.childRowIter.Next(ctx) 227 if err == io.EOF { 228 iter.pendingRow = nil 229 return nil 230 } 231 return err 232 } 233 234 func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) { 235 // Remove rows from the heap if we've advanced beyond their max value. 236 for iter.Len() > 0 { 237 maxValue := iter.Peek() 238 compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], maxValue) 239 if err != nil { 240 return nil, err 241 } 242 if (iter.rangeHeapPlan.RangeIsClosedAbove && compareResult > 0) || (!iter.rangeHeapPlan.RangeIsClosedAbove && compareResult >= 0) { 243 heap.Pop(iter) 244 if iter.err != nil { 245 err = iter.err 246 iter.err = nil 247 return nil, err 248 } 249 } else { 250 break 251 } 252 } 253 254 // Advance the child iterator until we encounter a row whose min value is beyond the range. 255 for iter.pendingRow != nil { 256 minValue := iter.pendingRow[iter.rangeHeapPlan.MinColumnIndex] 257 compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], minValue) 258 if err != nil { 259 return nil, err 260 } 261 262 if (iter.rangeHeapPlan.RangeIsClosedBelow && compareResult < 0) || (!iter.rangeHeapPlan.RangeIsClosedBelow && compareResult <= 0) { 263 break 264 } else { 265 heap.Push(iter, iter.pendingRow) 266 if iter.err != nil { 267 err = iter.err 268 iter.err = nil 269 return nil, err 270 } 271 } 272 273 iter.pendingRow, err = iter.childRowIter.Next(ctx) 274 if err != nil { 275 if errors.Is(err, io.EOF) { 276 // We've already imported every range into the priority queue. 277 iter.pendingRow = nil 278 break 279 } 280 return nil, err 281 } 282 } 283 284 // Every active row must match the accepted row. 285 return sql.RowsToRowIter(iter.activeRanges...), nil 286 } 287 288 // When managing the heap, consider all NULLs to come before any non-NULLS. 289 // This is consistent with the order received if either child node is an index. 290 // Note: We could get the same behavior by simply excluding values and ranges containing NULL, 291 // but this is forward compatible if we ever want to convert joins with null-safe conditions into RangeHeapJoins. 292 func compareNullsFirst(comparisonType sql.Type, a, b interface{}) (int, error) { 293 if a == nil { 294 if b == nil { 295 return 0, nil 296 } else { 297 return -1, nil 298 } 299 } 300 if b == nil { 301 return 1, nil 302 } 303 return comparisonType.Compare(a, b) 304 } 305 306 func (iter rangeHeapJoinIter) Len() int { return len(iter.activeRanges) } 307 308 func (iter *rangeHeapJoinIter) Less(i, j int) bool { 309 lhs := iter.activeRanges[i][iter.rangeHeapPlan.MaxColumnIndex] 310 rhs := iter.activeRanges[j][iter.rangeHeapPlan.MaxColumnIndex] 311 // compareResult will be 0 if lhs==rhs, -1 if lhs < rhs, and +1 if lhs > rhs. 312 compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, lhs, rhs) 313 if iter.err == nil && err != nil { 314 iter.err = err 315 } 316 return compareResult < 0 317 } 318 319 func (iter *rangeHeapJoinIter) Swap(i, j int) { 320 iter.activeRanges[i], iter.activeRanges[j] = iter.activeRanges[j], iter.activeRanges[i] 321 } 322 323 func (iter *rangeHeapJoinIter) Push(x any) { 324 item := x.(sql.Row) 325 iter.activeRanges = append(iter.activeRanges, item) 326 } 327 328 func (iter *rangeHeapJoinIter) Pop() any { 329 n := len(iter.activeRanges) 330 x := iter.activeRanges[n-1] 331 iter.activeRanges = iter.activeRanges[0 : n-1] 332 return x 333 } 334 335 func (iter *rangeHeapJoinIter) Peek() interface{} { 336 n := len(iter.activeRanges) 337 return iter.activeRanges[n-1][iter.rangeHeapPlan.MaxColumnIndex] 338 }