github.com/apache/arrow/go/v14@v14.0.1/arrow/compute/utils.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  //go:build go1.18
    18  
    19  package compute
    20  
    21  import (
    22  	"fmt"
    23  	"io"
    24  	"math"
    25  	"time"
    26  
    27  	"github.com/apache/arrow/go/v14/arrow"
    28  	"github.com/apache/arrow/go/v14/arrow/bitutil"
    29  	"github.com/apache/arrow/go/v14/arrow/compute/exec"
    30  	"github.com/apache/arrow/go/v14/arrow/compute/internal/kernels"
    31  	"github.com/apache/arrow/go/v14/arrow/internal/debug"
    32  	"github.com/apache/arrow/go/v14/arrow/memory"
    33  	"golang.org/x/xerrors"
    34  )
    35  
    36  type bufferWriteSeeker struct {
    37  	buf *memory.Buffer
    38  	pos int
    39  	mem memory.Allocator
    40  }
    41  
    42  func (b *bufferWriteSeeker) Reserve(nbytes int) {
    43  	if b.buf == nil {
    44  		b.buf = memory.NewResizableBuffer(b.mem)
    45  	}
    46  	newCap := int(math.Max(float64(b.buf.Cap()), 256))
    47  	for newCap < b.pos+nbytes {
    48  		newCap = bitutil.NextPowerOf2(newCap)
    49  	}
    50  	b.buf.Reserve(newCap)
    51  }
    52  
    53  func (b *bufferWriteSeeker) Write(p []byte) (n int, err error) {
    54  	if len(p) == 0 {
    55  		return 0, nil
    56  	}
    57  
    58  	if b.buf == nil {
    59  		b.Reserve(len(p))
    60  	} else if b.pos+len(p) >= b.buf.Cap() {
    61  		b.Reserve(len(p))
    62  	}
    63  
    64  	return b.UnsafeWrite(p)
    65  }
    66  
    67  func (b *bufferWriteSeeker) UnsafeWrite(p []byte) (n int, err error) {
    68  	n = copy(b.buf.Buf()[b.pos:], p)
    69  	b.pos += len(p)
    70  	if b.pos > b.buf.Len() {
    71  		b.buf.ResizeNoShrink(b.pos)
    72  	}
    73  	return
    74  }
    75  
    76  func (b *bufferWriteSeeker) Seek(offset int64, whence int) (int64, error) {
    77  	newpos, offs := 0, int(offset)
    78  	switch whence {
    79  	case io.SeekStart:
    80  		newpos = offs
    81  	case io.SeekCurrent:
    82  		newpos = b.pos + offs
    83  	case io.SeekEnd:
    84  		newpos = b.buf.Len() + offs
    85  	}
    86  	if newpos < 0 {
    87  		return 0, xerrors.New("negative result pos")
    88  	}
    89  	b.pos = newpos
    90  	return int64(newpos), nil
    91  }
    92  
    93  // ensureDictionaryDecoded is used by DispatchBest to determine
    94  // the proper types for promotion. Casting is then performed by
    95  // the executor before continuing execution: see the implementation
    96  // of execInternal in exec.go after calling DispatchBest.
    97  //
    98  // That casting is where actual decoding would be performed for
    99  // the dictionary
   100  func ensureDictionaryDecoded(vals ...arrow.DataType) {
   101  	for i, v := range vals {
   102  		if v.ID() == arrow.DICTIONARY {
   103  			vals[i] = v.(*arrow.DictionaryType).ValueType
   104  		}
   105  	}
   106  }
   107  
   108  func replaceNullWithOtherType(vals ...arrow.DataType) {
   109  	debug.Assert(len(vals) == 2, "should be length 2")
   110  
   111  	if vals[0].ID() == arrow.NULL {
   112  		vals[0] = vals[1]
   113  		return
   114  	}
   115  
   116  	if vals[1].ID() == arrow.NULL {
   117  		vals[1] = vals[0]
   118  		return
   119  	}
   120  }
   121  
   122  func commonTemporalResolution(vals ...arrow.DataType) (arrow.TimeUnit, bool) {
   123  	isTimeUnit := false
   124  	finestUnit := arrow.Second
   125  	for _, v := range vals {
   126  		switch dt := v.(type) {
   127  		case *arrow.Date32Type:
   128  			isTimeUnit = true
   129  			continue
   130  		case *arrow.Date64Type:
   131  			finestUnit = exec.Max(finestUnit, arrow.Millisecond)
   132  			isTimeUnit = true
   133  		case arrow.TemporalWithUnit:
   134  			finestUnit = exec.Max(finestUnit, dt.TimeUnit())
   135  			isTimeUnit = true
   136  		default:
   137  			continue
   138  		}
   139  	}
   140  	return finestUnit, isTimeUnit
   141  }
   142  
   143  func replaceTemporalTypes(unit arrow.TimeUnit, vals ...arrow.DataType) {
   144  	for i, v := range vals {
   145  		switch dt := v.(type) {
   146  		case *arrow.TimestampType:
   147  			dt.Unit = unit
   148  			vals[i] = dt
   149  		case *arrow.Time32Type, *arrow.Time64Type:
   150  			if unit > arrow.Millisecond {
   151  				vals[i] = &arrow.Time64Type{Unit: unit}
   152  			} else {
   153  				vals[i] = &arrow.Time32Type{Unit: unit}
   154  			}
   155  		case *arrow.DurationType:
   156  			dt.Unit = unit
   157  			vals[i] = dt
   158  		case *arrow.Date32Type, *arrow.Date64Type:
   159  			vals[i] = &arrow.TimestampType{Unit: unit}
   160  		}
   161  	}
   162  }
   163  
   164  func replaceTypes(replacement arrow.DataType, vals ...arrow.DataType) {
   165  	for i := range vals {
   166  		vals[i] = replacement
   167  	}
   168  }
   169  
   170  func commonNumeric(vals ...arrow.DataType) arrow.DataType {
   171  	for _, v := range vals {
   172  		if !arrow.IsFloating(v.ID()) && !arrow.IsInteger(v.ID()) {
   173  			// a common numeric type is only possible if all are numeric
   174  			return nil
   175  		}
   176  		if v.ID() == arrow.FLOAT16 {
   177  			// float16 arithmetic is not currently supported
   178  			return nil
   179  		}
   180  	}
   181  
   182  	for _, v := range vals {
   183  		if v.ID() == arrow.FLOAT64 {
   184  			return arrow.PrimitiveTypes.Float64
   185  		}
   186  	}
   187  
   188  	for _, v := range vals {
   189  		if v.ID() == arrow.FLOAT32 {
   190  			return arrow.PrimitiveTypes.Float32
   191  		}
   192  	}
   193  
   194  	maxWidthSigned, maxWidthUnsigned := 0, 0
   195  	for _, v := range vals {
   196  		if arrow.IsUnsignedInteger(v.ID()) {
   197  			maxWidthUnsigned = exec.Max(v.(arrow.FixedWidthDataType).BitWidth(), maxWidthUnsigned)
   198  		} else {
   199  			maxWidthSigned = exec.Max(v.(arrow.FixedWidthDataType).BitWidth(), maxWidthSigned)
   200  		}
   201  	}
   202  
   203  	if maxWidthSigned == 0 {
   204  		switch {
   205  		case maxWidthUnsigned >= 64:
   206  			return arrow.PrimitiveTypes.Uint64
   207  		case maxWidthUnsigned == 32:
   208  			return arrow.PrimitiveTypes.Uint32
   209  		case maxWidthUnsigned == 16:
   210  			return arrow.PrimitiveTypes.Uint16
   211  		default:
   212  			debug.Assert(maxWidthUnsigned == 8, "bad maxWidthUnsigned")
   213  			return arrow.PrimitiveTypes.Uint8
   214  		}
   215  	}
   216  
   217  	if maxWidthSigned <= maxWidthUnsigned {
   218  		maxWidthSigned = bitutil.NextPowerOf2(maxWidthUnsigned + 1)
   219  	}
   220  
   221  	switch {
   222  	case maxWidthSigned >= 64:
   223  		return arrow.PrimitiveTypes.Int64
   224  	case maxWidthSigned == 32:
   225  		return arrow.PrimitiveTypes.Int32
   226  	case maxWidthSigned == 16:
   227  		return arrow.PrimitiveTypes.Int16
   228  	default:
   229  		debug.Assert(maxWidthSigned == 8, "bad maxWidthSigned")
   230  		return arrow.PrimitiveTypes.Int8
   231  	}
   232  }
   233  
   234  func hasDecimal(vals ...arrow.DataType) bool {
   235  	for _, v := range vals {
   236  		if arrow.IsDecimal(v.ID()) {
   237  			return true
   238  		}
   239  	}
   240  
   241  	return false
   242  }
   243  
   244  type decimalPromotion uint8
   245  
   246  const (
   247  	decPromoteNone decimalPromotion = iota
   248  	decPromoteAdd
   249  	decPromoteMultiply
   250  	decPromoteDivide
   251  )
   252  
   253  func castBinaryDecimalArgs(promote decimalPromotion, vals ...arrow.DataType) error {
   254  	left, right := vals[0], vals[1]
   255  	debug.Assert(arrow.IsDecimal(left.ID()) || arrow.IsDecimal(right.ID()), "at least one of the types should be decimal")
   256  
   257  	// decimal + float = float
   258  	if arrow.IsFloating(left.ID()) {
   259  		vals[1] = vals[0]
   260  		return nil
   261  	} else if arrow.IsFloating(right.ID()) {
   262  		vals[0] = vals[1]
   263  		return nil
   264  	}
   265  
   266  	var prec1, scale1, prec2, scale2 int32
   267  	var err error
   268  	// decimal + integer = decimal
   269  	if arrow.IsDecimal(left.ID()) {
   270  		dec := left.(arrow.DecimalType)
   271  		prec1, scale1 = dec.GetPrecision(), dec.GetScale()
   272  	} else {
   273  		debug.Assert(arrow.IsInteger(left.ID()), "floats were already handled, this should be an int")
   274  		if prec1, err = kernels.MaxDecimalDigitsForInt(left.ID()); err != nil {
   275  			return err
   276  		}
   277  	}
   278  	if arrow.IsDecimal(right.ID()) {
   279  		dec := right.(arrow.DecimalType)
   280  		prec2, scale2 = dec.GetPrecision(), dec.GetScale()
   281  	} else {
   282  		debug.Assert(arrow.IsInteger(right.ID()), "float already handled, should be ints")
   283  		if prec2, err = kernels.MaxDecimalDigitsForInt(right.ID()); err != nil {
   284  			return err
   285  		}
   286  	}
   287  
   288  	if scale1 < 0 || scale2 < 0 {
   289  		return fmt.Errorf("%w: decimals with negative scales not supported", arrow.ErrNotImplemented)
   290  	}
   291  
   292  	// decimal128 + decimal256 = decimal256
   293  	castedID := arrow.DECIMAL128
   294  	if left.ID() == arrow.DECIMAL256 || right.ID() == arrow.DECIMAL256 {
   295  		castedID = arrow.DECIMAL256
   296  	}
   297  
   298  	// decimal promotion rules compatible with amazon redshift
   299  	// https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
   300  	var leftScaleup, rightScaleup int32
   301  
   302  	switch promote {
   303  	case decPromoteAdd:
   304  		leftScaleup = exec.Max(scale1, scale2) - scale1
   305  		rightScaleup = exec.Max(scale1, scale2) - scale2
   306  	case decPromoteMultiply:
   307  	case decPromoteDivide:
   308  		leftScaleup = exec.Max(4, scale1+prec2-scale2+1) + scale2 - scale1
   309  	default:
   310  		debug.Assert(false, fmt.Sprintf("invalid DecimalPromotion value %d", promote))
   311  	}
   312  
   313  	vals[0], err = arrow.NewDecimalType(castedID, prec1+leftScaleup, scale1+leftScaleup)
   314  	if err != nil {
   315  		return err
   316  	}
   317  	vals[1], err = arrow.NewDecimalType(castedID, prec2+rightScaleup, scale2+rightScaleup)
   318  	return err
   319  }
   320  
   321  func commonTemporal(vals ...arrow.DataType) arrow.DataType {
   322  	var (
   323  		finestUnit           = arrow.Second
   324  		zone                 *string
   325  		loc                  *time.Location
   326  		sawDate32, sawDate64 bool
   327  	)
   328  
   329  	for _, ty := range vals {
   330  		switch ty.ID() {
   331  		case arrow.DATE32:
   332  			// date32's unit is days, but the coarsest we have is seconds
   333  			sawDate32 = true
   334  		case arrow.DATE64:
   335  			finestUnit = exec.Max(finestUnit, arrow.Millisecond)
   336  			sawDate64 = true
   337  		case arrow.TIMESTAMP:
   338  			ts := ty.(*arrow.TimestampType)
   339  			if ts.TimeZone != "" {
   340  				tz, _ := ts.GetZone()
   341  				if loc != nil && loc != tz {
   342  					return nil
   343  				}
   344  				loc = tz
   345  			}
   346  			zone = &ts.TimeZone
   347  			finestUnit = exec.Max(finestUnit, ts.Unit)
   348  		default:
   349  			return nil
   350  		}
   351  	}
   352  
   353  	switch {
   354  	case zone != nil:
   355  		// at least one timestamp seen
   356  		return &arrow.TimestampType{Unit: finestUnit, TimeZone: *zone}
   357  	case sawDate64:
   358  		return arrow.FixedWidthTypes.Date64
   359  	case sawDate32:
   360  		return arrow.FixedWidthTypes.Date32
   361  	}
   362  	return nil
   363  }
   364  
   365  func commonBinary(vals ...arrow.DataType) arrow.DataType {
   366  	var (
   367  		allUTF8, allOffset32, allFixedWidth = true, true, true
   368  	)
   369  
   370  	for _, ty := range vals {
   371  		switch ty.ID() {
   372  		case arrow.STRING:
   373  			allFixedWidth = false
   374  		case arrow.BINARY:
   375  			allFixedWidth, allUTF8 = false, false
   376  		case arrow.FIXED_SIZE_BINARY:
   377  			allUTF8 = false
   378  		case arrow.LARGE_BINARY:
   379  			allOffset32, allFixedWidth, allUTF8 = false, false, false
   380  		case arrow.LARGE_STRING:
   381  			allOffset32, allFixedWidth = false, false
   382  		default:
   383  			return nil
   384  		}
   385  	}
   386  
   387  	switch {
   388  	case allFixedWidth:
   389  		// at least for the purposes of comparison, no need to cast
   390  		return nil
   391  	case allUTF8:
   392  		if allOffset32 {
   393  			return arrow.BinaryTypes.String
   394  		}
   395  		return arrow.BinaryTypes.LargeString
   396  	case allOffset32:
   397  		return arrow.BinaryTypes.Binary
   398  	}
   399  	return arrow.BinaryTypes.LargeBinary
   400  }