github.com/mithrandie/csvq@v1.18.1/lib/query/join.go (about) 1 package query 2 3 import ( 4 "context" 5 "math" 6 "sync" 7 8 "github.com/mithrandie/csvq/lib/parser" 9 "github.com/mithrandie/csvq/lib/value" 10 11 "github.com/mithrandie/ternary" 12 ) 13 14 func ParseJoinCondition(join parser.Join, view *View, joinView *View) (parser.QueryExpression, []parser.FieldReference, []parser.FieldReference, error) { 15 if join.Natural.IsEmpty() && join.Condition == nil { 16 return nil, nil, nil, nil 17 } 18 19 var using []parser.QueryExpression 20 21 if !join.Natural.IsEmpty() { 22 for _, field := range view.Header { 23 if field.Column == InternalIdColumn { 24 continue 25 } 26 ref := parser.FieldReference{BaseExpr: parser.NewBaseExpr(join.Natural), Column: parser.Identifier{Literal: field.Column}} 27 if _, err := joinView.Header.SearchIndex(ref); err != nil { 28 if err == errFieldAmbiguous { 29 return nil, nil, nil, NewFieldAmbiguousError(ref) 30 } 31 continue 32 } 33 using = append(using, parser.Identifier{BaseExpr: parser.NewBaseExpr(join.Natural), Literal: field.Column}) 34 } 35 } else { 36 cond := join.Condition.(parser.JoinCondition) 37 if cond.On != nil { 38 return cond.On, nil, nil, nil 39 } 40 41 using = cond.Using 42 } 43 44 if len(using) < 1 { 45 return nil, nil, nil, nil 46 } 47 48 usingFields := make([]string, len(using)) 49 for i, v := range using { 50 usingFields[i] = v.(parser.Identifier).Literal 51 } 52 53 includeFields := make([]parser.FieldReference, len(using)) 54 excludeFields := make([]parser.FieldReference, len(using)) 55 56 comps := make([]parser.Comparison, len(using)) 57 for i, v := range using { 58 var lhs parser.FieldReference 59 var rhs parser.FieldReference 60 fieldref := parser.FieldReference{BaseExpr: v.GetBaseExpr(), Column: v.(parser.Identifier)} 61 62 lhsidx, err := view.FieldIndex(fieldref) 63 if err != nil { 64 return nil, nil, nil, err 65 } 66 lhs = parser.FieldReference{BaseExpr: v.GetBaseExpr(), View: parser.Identifier{Literal: view.Header[lhsidx].View}, Column: v.(parser.Identifier)} 67 68 rhsidx, err := joinView.FieldIndex(fieldref) 69 if err != nil { 70 return nil, nil, nil, err 71 } 72 rhs = parser.FieldReference{BaseExpr: v.GetBaseExpr(), View: parser.Identifier{Literal: joinView.Header[rhsidx].View}, Column: v.(parser.Identifier)} 73 74 comps[i] = parser.Comparison{ 75 LHS: lhs, 76 RHS: rhs, 77 Operator: parser.Token{Token: parser.COMPARISON_OP, Literal: "="}, 78 } 79 80 if join.Direction.Token == parser.RIGHT { 81 includeFields[i] = rhs 82 excludeFields[i] = lhs 83 } else { 84 includeFields[i] = lhs 85 excludeFields[i] = rhs 86 } 87 } 88 89 if len(comps) == 1 { 90 return comps[0], includeFields, excludeFields, nil 91 } 92 93 logic := parser.Logic{ 94 LHS: comps[0], 95 RHS: comps[1], 96 Operator: parser.Token{Token: parser.AND, Literal: parser.TokenLiteral(parser.AND)}, 97 } 98 for i := 2; i < len(comps); i++ { 99 logic = parser.Logic{ 100 LHS: logic, 101 RHS: comps[i], 102 Operator: parser.Token{Token: parser.AND, Literal: parser.TokenLiteral(parser.AND)}, 103 } 104 } 105 return logic, includeFields, excludeFields, nil 106 } 107 108 func CrossJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View) error { 109 mergedHeader := view.Header.Merge(joinView.Header) 110 records := make(RecordSet, view.RecordLen()*joinView.RecordLen()) 111 112 if err := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU).Run(ctx, func(index int) error { 113 start := index * joinView.RecordLen() 114 for i := 0; i < joinView.RecordLen(); i++ { 115 records[start+i] = view.RecordSet[index].Merge(joinView.RecordSet[i], nil) 116 } 117 return nil 118 }); err != nil { 119 return err 120 } 121 122 view.Header = mergedHeader 123 view.RecordSet = records 124 view.FileInfo = nil 125 return nil 126 } 127 128 func InnerJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View, condition parser.QueryExpression) error { 129 if condition == nil { 130 return CrossJoin(ctx, scope, view, joinView) 131 } 132 133 var recordPool = &sync.Pool{ 134 New: func() interface{} { 135 return make(Record, view.FieldLen()+joinView.FieldLen()) 136 }, 137 } 138 139 mergedHeader := view.Header.Merge(joinView.Header) 140 141 gm := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU) 142 recordsList := make([]RecordSet, gm.Number) 143 144 var joinFn = func(thIdx int) { 145 defer func() { 146 if !gm.HasError() { 147 if panicReport := recover(); panicReport != nil { 148 gm.SetError(NewFatalError(panicReport)) 149 } 150 } 151 152 if 1 < gm.Number { 153 gm.Done() 154 } 155 }() 156 157 ctx := ctx 158 start, end := gm.RecordRange(thIdx) 159 records := make(RecordSet, 0, end-start) 160 seqScope := scope.CreateScopeForRecordEvaluation( 161 &View{ 162 Header: mergedHeader, 163 RecordSet: make(RecordSet, 1), 164 }, 165 0, 166 ) 167 168 InnerJoinLoop: 169 for i := start; i < end; i++ { 170 for j := 0; j < joinView.RecordLen(); j++ { 171 if gm.HasError() { 172 break InnerJoinLoop 173 } 174 if i&15 == 0 && ctx.Err() != nil { 175 break InnerJoinLoop 176 } 177 178 mergedRecord := view.RecordSet[i].Merge(joinView.RecordSet[j], recordPool) 179 seqScope.Records[0].view.RecordSet[0] = mergedRecord 180 181 primary, e := Evaluate(ctx, seqScope, condition) 182 if e != nil { 183 gm.SetError(e) 184 break InnerJoinLoop 185 } 186 if primary.Ternary() == ternary.TRUE { 187 records = append(records, mergedRecord) 188 } else { 189 for i := range mergedRecord { 190 mergedRecord[i] = nil 191 } 192 recordPool.Put(mergedRecord) 193 } 194 } 195 } 196 197 recordsList[thIdx] = records 198 } 199 200 if 1 < gm.Number { 201 for i := 0; i < gm.Number; i++ { 202 gm.Add() 203 go joinFn(i) 204 } 205 gm.Wait() 206 } else { 207 joinFn(0) 208 } 209 210 if gm.HasError() { 211 return gm.Err() 212 } 213 if ctx.Err() != nil { 214 return ConvertContextError(ctx.Err()) 215 } 216 217 view.Header = mergedHeader 218 view.RecordSet = MergeRecordSetList(recordsList) 219 view.FileInfo = nil 220 return nil 221 } 222 223 func OuterJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View, condition parser.QueryExpression, direction int) error { 224 if direction == parser.TokenUndefined { 225 direction = parser.LEFT 226 } 227 228 var recordPool = &sync.Pool{ 229 New: func() interface{} { 230 return make(Record, view.FieldLen()+joinView.FieldLen()) 231 }, 232 } 233 234 mergedHeader := view.Header.Merge(joinView.Header) 235 236 if direction == parser.RIGHT { 237 view, joinView = joinView, view 238 } 239 240 gm := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU) 241 242 recordsList := make([]RecordSet, gm.Number+1) 243 joinViewMatchesList := make([][]bool, gm.Number) 244 245 var joinFn = func(thIdx int) { 246 defer func() { 247 if !gm.HasError() { 248 if panicReport := recover(); panicReport != nil { 249 gm.SetError(NewFatalError(panicReport)) 250 } 251 } 252 253 if 1 < gm.Number { 254 gm.Done() 255 } 256 }() 257 258 ctx := ctx 259 start, end := gm.RecordRange(thIdx) 260 records := make(RecordSet, 0, end-start) 261 seqScope := scope.CreateScopeForRecordEvaluation( 262 &View{ 263 Header: mergedHeader, 264 RecordSet: make(RecordSet, 1), 265 }, 266 0, 267 ) 268 269 joinViewMatches := make([]bool, joinView.RecordLen()) 270 var leftViewFieldLen int 271 if direction == parser.RIGHT { 272 leftViewFieldLen = joinView.FieldLen() 273 } else { 274 leftViewFieldLen = view.FieldLen() 275 } 276 277 OuterJoinLoop: 278 for i := start; i < end; i++ { 279 match := false 280 for j := 0; j < joinView.RecordLen(); j++ { 281 if gm.HasError() { 282 break OuterJoinLoop 283 } 284 if i&15 == 0 && ctx.Err() != nil { 285 break OuterJoinLoop 286 } 287 288 var mergedRecord Record 289 switch direction { 290 case parser.RIGHT: 291 mergedRecord = joinView.RecordSet[j].Merge(view.RecordSet[i], recordPool) 292 default: 293 mergedRecord = view.RecordSet[i].Merge(joinView.RecordSet[j], recordPool) 294 } 295 seqScope.Records[0].view.RecordSet[0] = mergedRecord 296 297 primary, e := Evaluate(ctx, seqScope, condition) 298 if e != nil { 299 gm.SetError(e) 300 break OuterJoinLoop 301 } 302 if primary.Ternary() == ternary.TRUE { 303 if direction == parser.FULL && !joinViewMatches[j] { 304 joinViewMatches[j] = true 305 } 306 records = append(records, mergedRecord) 307 match = true 308 } else { 309 for i := range mergedRecord { 310 mergedRecord[i] = nil 311 } 312 recordPool.Put(mergedRecord) 313 } 314 } 315 316 if !match { 317 record := recordPool.Get().(Record) 318 switch direction { 319 case parser.RIGHT: 320 for k := 0; k < leftViewFieldLen; k++ { 321 record[k] = NewCell(value.NewNull()) 322 } 323 for k := range view.RecordSet[i] { 324 record[k+leftViewFieldLen] = view.RecordSet[i][k] 325 } 326 default: 327 for k := range view.RecordSet[i] { 328 record[k] = view.RecordSet[i][k] 329 } 330 for k := 0; k < joinView.FieldLen(); k++ { 331 record[k+leftViewFieldLen] = NewCell(value.NewNull()) 332 } 333 } 334 records = append(records, record) 335 336 } 337 } 338 339 recordsList[thIdx] = records 340 joinViewMatchesList[thIdx] = joinViewMatches 341 } 342 343 if 1 < gm.Number { 344 for i := 0; i < gm.Number; i++ { 345 gm.Add() 346 go joinFn(i) 347 } 348 gm.Wait() 349 } else { 350 joinFn(0) 351 } 352 353 if gm.HasError() { 354 return gm.Err() 355 } 356 if ctx.Err() != nil { 357 return ConvertContextError(ctx.Err()) 358 } 359 360 if direction == parser.FULL { 361 appendIndices := make([]int, 0, joinView.RecordLen()) 362 363 for i := 0; i < joinView.RecordLen(); i++ { 364 match := false 365 for _, joinViewMatches := range joinViewMatchesList { 366 if joinViewMatches[i] { 367 match = true 368 break 369 } 370 } 371 if !match { 372 appendIndices = append(appendIndices, i) 373 } 374 } 375 376 recordsListIdx := len(recordsList) - 1 377 recordsList[recordsListIdx] = make(RecordSet, len(appendIndices)) 378 viewFieldLen := view.FieldLen() 379 for i, idx := range appendIndices { 380 record := recordPool.Get().(Record) 381 for k := 0; k < viewFieldLen; k++ { 382 record[k] = NewCell(value.NewNull()) 383 } 384 for k := range joinView.RecordSet[idx] { 385 record[k+viewFieldLen] = joinView.RecordSet[idx][k] 386 } 387 recordsList[recordsListIdx][i] = record 388 389 } 390 } 391 392 if direction == parser.RIGHT { 393 view, joinView = joinView, view 394 } 395 396 view.Header = mergedHeader 397 view.RecordSet = MergeRecordSetList(recordsList) 398 view.FileInfo = nil 399 return nil 400 } 401 402 func CalcMinimumRequired(i1 int, i2 int, defaultMinimumRequired int) int { 403 if i1 < 1 || i2 < 1 { 404 return defaultMinimumRequired 405 } 406 407 p := i1 * i2 408 if p <= defaultMinimumRequired { 409 return defaultMinimumRequired 410 } 411 return int(math.Ceil(float64(i1) / math.Floor(float64(p)/float64(defaultMinimumRequired)))) 412 }