github.com/mithrandie/csvq@v1.18.1/lib/query/join.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"sync"
     7  
     8  	"github.com/mithrandie/csvq/lib/parser"
     9  	"github.com/mithrandie/csvq/lib/value"
    10  
    11  	"github.com/mithrandie/ternary"
    12  )
    13  
    14  func ParseJoinCondition(join parser.Join, view *View, joinView *View) (parser.QueryExpression, []parser.FieldReference, []parser.FieldReference, error) {
    15  	if join.Natural.IsEmpty() && join.Condition == nil {
    16  		return nil, nil, nil, nil
    17  	}
    18  
    19  	var using []parser.QueryExpression
    20  
    21  	if !join.Natural.IsEmpty() {
    22  		for _, field := range view.Header {
    23  			if field.Column == InternalIdColumn {
    24  				continue
    25  			}
    26  			ref := parser.FieldReference{BaseExpr: parser.NewBaseExpr(join.Natural), Column: parser.Identifier{Literal: field.Column}}
    27  			if _, err := joinView.Header.SearchIndex(ref); err != nil {
    28  				if err == errFieldAmbiguous {
    29  					return nil, nil, nil, NewFieldAmbiguousError(ref)
    30  				}
    31  				continue
    32  			}
    33  			using = append(using, parser.Identifier{BaseExpr: parser.NewBaseExpr(join.Natural), Literal: field.Column})
    34  		}
    35  	} else {
    36  		cond := join.Condition.(parser.JoinCondition)
    37  		if cond.On != nil {
    38  			return cond.On, nil, nil, nil
    39  		}
    40  
    41  		using = cond.Using
    42  	}
    43  
    44  	if len(using) < 1 {
    45  		return nil, nil, nil, nil
    46  	}
    47  
    48  	usingFields := make([]string, len(using))
    49  	for i, v := range using {
    50  		usingFields[i] = v.(parser.Identifier).Literal
    51  	}
    52  
    53  	includeFields := make([]parser.FieldReference, len(using))
    54  	excludeFields := make([]parser.FieldReference, len(using))
    55  
    56  	comps := make([]parser.Comparison, len(using))
    57  	for i, v := range using {
    58  		var lhs parser.FieldReference
    59  		var rhs parser.FieldReference
    60  		fieldref := parser.FieldReference{BaseExpr: v.GetBaseExpr(), Column: v.(parser.Identifier)}
    61  
    62  		lhsidx, err := view.FieldIndex(fieldref)
    63  		if err != nil {
    64  			return nil, nil, nil, err
    65  		}
    66  		lhs = parser.FieldReference{BaseExpr: v.GetBaseExpr(), View: parser.Identifier{Literal: view.Header[lhsidx].View}, Column: v.(parser.Identifier)}
    67  
    68  		rhsidx, err := joinView.FieldIndex(fieldref)
    69  		if err != nil {
    70  			return nil, nil, nil, err
    71  		}
    72  		rhs = parser.FieldReference{BaseExpr: v.GetBaseExpr(), View: parser.Identifier{Literal: joinView.Header[rhsidx].View}, Column: v.(parser.Identifier)}
    73  
    74  		comps[i] = parser.Comparison{
    75  			LHS:      lhs,
    76  			RHS:      rhs,
    77  			Operator: parser.Token{Token: parser.COMPARISON_OP, Literal: "="},
    78  		}
    79  
    80  		if join.Direction.Token == parser.RIGHT {
    81  			includeFields[i] = rhs
    82  			excludeFields[i] = lhs
    83  		} else {
    84  			includeFields[i] = lhs
    85  			excludeFields[i] = rhs
    86  		}
    87  	}
    88  
    89  	if len(comps) == 1 {
    90  		return comps[0], includeFields, excludeFields, nil
    91  	}
    92  
    93  	logic := parser.Logic{
    94  		LHS:      comps[0],
    95  		RHS:      comps[1],
    96  		Operator: parser.Token{Token: parser.AND, Literal: parser.TokenLiteral(parser.AND)},
    97  	}
    98  	for i := 2; i < len(comps); i++ {
    99  		logic = parser.Logic{
   100  			LHS:      logic,
   101  			RHS:      comps[i],
   102  			Operator: parser.Token{Token: parser.AND, Literal: parser.TokenLiteral(parser.AND)},
   103  		}
   104  	}
   105  	return logic, includeFields, excludeFields, nil
   106  }
   107  
   108  func CrossJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View) error {
   109  	mergedHeader := view.Header.Merge(joinView.Header)
   110  	records := make(RecordSet, view.RecordLen()*joinView.RecordLen())
   111  
   112  	if err := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU).Run(ctx, func(index int) error {
   113  		start := index * joinView.RecordLen()
   114  		for i := 0; i < joinView.RecordLen(); i++ {
   115  			records[start+i] = view.RecordSet[index].Merge(joinView.RecordSet[i], nil)
   116  		}
   117  		return nil
   118  	}); err != nil {
   119  		return err
   120  	}
   121  
   122  	view.Header = mergedHeader
   123  	view.RecordSet = records
   124  	view.FileInfo = nil
   125  	return nil
   126  }
   127  
   128  func InnerJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View, condition parser.QueryExpression) error {
   129  	if condition == nil {
   130  		return CrossJoin(ctx, scope, view, joinView)
   131  	}
   132  
   133  	var recordPool = &sync.Pool{
   134  		New: func() interface{} {
   135  			return make(Record, view.FieldLen()+joinView.FieldLen())
   136  		},
   137  	}
   138  
   139  	mergedHeader := view.Header.Merge(joinView.Header)
   140  
   141  	gm := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU)
   142  	recordsList := make([]RecordSet, gm.Number)
   143  
   144  	var joinFn = func(thIdx int) {
   145  		defer func() {
   146  			if !gm.HasError() {
   147  				if panicReport := recover(); panicReport != nil {
   148  					gm.SetError(NewFatalError(panicReport))
   149  				}
   150  			}
   151  
   152  			if 1 < gm.Number {
   153  				gm.Done()
   154  			}
   155  		}()
   156  
   157  		ctx := ctx
   158  		start, end := gm.RecordRange(thIdx)
   159  		records := make(RecordSet, 0, end-start)
   160  		seqScope := scope.CreateScopeForRecordEvaluation(
   161  			&View{
   162  				Header:    mergedHeader,
   163  				RecordSet: make(RecordSet, 1),
   164  			},
   165  			0,
   166  		)
   167  
   168  	InnerJoinLoop:
   169  		for i := start; i < end; i++ {
   170  			for j := 0; j < joinView.RecordLen(); j++ {
   171  				if gm.HasError() {
   172  					break InnerJoinLoop
   173  				}
   174  				if i&15 == 0 && ctx.Err() != nil {
   175  					break InnerJoinLoop
   176  				}
   177  
   178  				mergedRecord := view.RecordSet[i].Merge(joinView.RecordSet[j], recordPool)
   179  				seqScope.Records[0].view.RecordSet[0] = mergedRecord
   180  
   181  				primary, e := Evaluate(ctx, seqScope, condition)
   182  				if e != nil {
   183  					gm.SetError(e)
   184  					break InnerJoinLoop
   185  				}
   186  				if primary.Ternary() == ternary.TRUE {
   187  					records = append(records, mergedRecord)
   188  				} else {
   189  					for i := range mergedRecord {
   190  						mergedRecord[i] = nil
   191  					}
   192  					recordPool.Put(mergedRecord)
   193  				}
   194  			}
   195  		}
   196  
   197  		recordsList[thIdx] = records
   198  	}
   199  
   200  	if 1 < gm.Number {
   201  		for i := 0; i < gm.Number; i++ {
   202  			gm.Add()
   203  			go joinFn(i)
   204  		}
   205  		gm.Wait()
   206  	} else {
   207  		joinFn(0)
   208  	}
   209  
   210  	if gm.HasError() {
   211  		return gm.Err()
   212  	}
   213  	if ctx.Err() != nil {
   214  		return ConvertContextError(ctx.Err())
   215  	}
   216  
   217  	view.Header = mergedHeader
   218  	view.RecordSet = MergeRecordSetList(recordsList)
   219  	view.FileInfo = nil
   220  	return nil
   221  }
   222  
   223  func OuterJoin(ctx context.Context, scope *ReferenceScope, view *View, joinView *View, condition parser.QueryExpression, direction int) error {
   224  	if direction == parser.TokenUndefined {
   225  		direction = parser.LEFT
   226  	}
   227  
   228  	var recordPool = &sync.Pool{
   229  		New: func() interface{} {
   230  			return make(Record, view.FieldLen()+joinView.FieldLen())
   231  		},
   232  	}
   233  
   234  	mergedHeader := view.Header.Merge(joinView.Header)
   235  
   236  	if direction == parser.RIGHT {
   237  		view, joinView = joinView, view
   238  	}
   239  
   240  	gm := NewGoroutineTaskManager(view.RecordLen(), CalcMinimumRequired(view.RecordLen(), joinView.RecordLen(), MinimumRequiredPerCPUCore), scope.Tx.Flags.CPU)
   241  
   242  	recordsList := make([]RecordSet, gm.Number+1)
   243  	joinViewMatchesList := make([][]bool, gm.Number)
   244  
   245  	var joinFn = func(thIdx int) {
   246  		defer func() {
   247  			if !gm.HasError() {
   248  				if panicReport := recover(); panicReport != nil {
   249  					gm.SetError(NewFatalError(panicReport))
   250  				}
   251  			}
   252  
   253  			if 1 < gm.Number {
   254  				gm.Done()
   255  			}
   256  		}()
   257  
   258  		ctx := ctx
   259  		start, end := gm.RecordRange(thIdx)
   260  		records := make(RecordSet, 0, end-start)
   261  		seqScope := scope.CreateScopeForRecordEvaluation(
   262  			&View{
   263  				Header:    mergedHeader,
   264  				RecordSet: make(RecordSet, 1),
   265  			},
   266  			0,
   267  		)
   268  
   269  		joinViewMatches := make([]bool, joinView.RecordLen())
   270  		var leftViewFieldLen int
   271  		if direction == parser.RIGHT {
   272  			leftViewFieldLen = joinView.FieldLen()
   273  		} else {
   274  			leftViewFieldLen = view.FieldLen()
   275  		}
   276  
   277  	OuterJoinLoop:
   278  		for i := start; i < end; i++ {
   279  			match := false
   280  			for j := 0; j < joinView.RecordLen(); j++ {
   281  				if gm.HasError() {
   282  					break OuterJoinLoop
   283  				}
   284  				if i&15 == 0 && ctx.Err() != nil {
   285  					break OuterJoinLoop
   286  				}
   287  
   288  				var mergedRecord Record
   289  				switch direction {
   290  				case parser.RIGHT:
   291  					mergedRecord = joinView.RecordSet[j].Merge(view.RecordSet[i], recordPool)
   292  				default:
   293  					mergedRecord = view.RecordSet[i].Merge(joinView.RecordSet[j], recordPool)
   294  				}
   295  				seqScope.Records[0].view.RecordSet[0] = mergedRecord
   296  
   297  				primary, e := Evaluate(ctx, seqScope, condition)
   298  				if e != nil {
   299  					gm.SetError(e)
   300  					break OuterJoinLoop
   301  				}
   302  				if primary.Ternary() == ternary.TRUE {
   303  					if direction == parser.FULL && !joinViewMatches[j] {
   304  						joinViewMatches[j] = true
   305  					}
   306  					records = append(records, mergedRecord)
   307  					match = true
   308  				} else {
   309  					for i := range mergedRecord {
   310  						mergedRecord[i] = nil
   311  					}
   312  					recordPool.Put(mergedRecord)
   313  				}
   314  			}
   315  
   316  			if !match {
   317  				record := recordPool.Get().(Record)
   318  				switch direction {
   319  				case parser.RIGHT:
   320  					for k := 0; k < leftViewFieldLen; k++ {
   321  						record[k] = NewCell(value.NewNull())
   322  					}
   323  					for k := range view.RecordSet[i] {
   324  						record[k+leftViewFieldLen] = view.RecordSet[i][k]
   325  					}
   326  				default:
   327  					for k := range view.RecordSet[i] {
   328  						record[k] = view.RecordSet[i][k]
   329  					}
   330  					for k := 0; k < joinView.FieldLen(); k++ {
   331  						record[k+leftViewFieldLen] = NewCell(value.NewNull())
   332  					}
   333  				}
   334  				records = append(records, record)
   335  
   336  			}
   337  		}
   338  
   339  		recordsList[thIdx] = records
   340  		joinViewMatchesList[thIdx] = joinViewMatches
   341  	}
   342  
   343  	if 1 < gm.Number {
   344  		for i := 0; i < gm.Number; i++ {
   345  			gm.Add()
   346  			go joinFn(i)
   347  		}
   348  		gm.Wait()
   349  	} else {
   350  		joinFn(0)
   351  	}
   352  
   353  	if gm.HasError() {
   354  		return gm.Err()
   355  	}
   356  	if ctx.Err() != nil {
   357  		return ConvertContextError(ctx.Err())
   358  	}
   359  
   360  	if direction == parser.FULL {
   361  		appendIndices := make([]int, 0, joinView.RecordLen())
   362  
   363  		for i := 0; i < joinView.RecordLen(); i++ {
   364  			match := false
   365  			for _, joinViewMatches := range joinViewMatchesList {
   366  				if joinViewMatches[i] {
   367  					match = true
   368  					break
   369  				}
   370  			}
   371  			if !match {
   372  				appendIndices = append(appendIndices, i)
   373  			}
   374  		}
   375  
   376  		recordsListIdx := len(recordsList) - 1
   377  		recordsList[recordsListIdx] = make(RecordSet, len(appendIndices))
   378  		viewFieldLen := view.FieldLen()
   379  		for i, idx := range appendIndices {
   380  			record := recordPool.Get().(Record)
   381  			for k := 0; k < viewFieldLen; k++ {
   382  				record[k] = NewCell(value.NewNull())
   383  			}
   384  			for k := range joinView.RecordSet[idx] {
   385  				record[k+viewFieldLen] = joinView.RecordSet[idx][k]
   386  			}
   387  			recordsList[recordsListIdx][i] = record
   388  
   389  		}
   390  	}
   391  
   392  	if direction == parser.RIGHT {
   393  		view, joinView = joinView, view
   394  	}
   395  
   396  	view.Header = mergedHeader
   397  	view.RecordSet = MergeRecordSetList(recordsList)
   398  	view.FileInfo = nil
   399  	return nil
   400  }
   401  
   402  func CalcMinimumRequired(i1 int, i2 int, defaultMinimumRequired int) int {
   403  	if i1 < 1 || i2 < 1 {
   404  		return defaultMinimumRequired
   405  	}
   406  
   407  	p := i1 * i2
   408  	if p <= defaultMinimumRequired {
   409  		return defaultMinimumRequired
   410  	}
   411  	return int(math.Ceil(float64(i1) / math.Floor(float64(p)/float64(defaultMinimumRequired))))
   412  }