github.com/apache/arrow/go/v14@v14.0.1/arrow/compute/utils.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 //go:build go1.18 18 19 package compute 20 21 import ( 22 "fmt" 23 "io" 24 "math" 25 "time" 26 27 "github.com/apache/arrow/go/v14/arrow" 28 "github.com/apache/arrow/go/v14/arrow/bitutil" 29 "github.com/apache/arrow/go/v14/arrow/compute/exec" 30 "github.com/apache/arrow/go/v14/arrow/compute/internal/kernels" 31 "github.com/apache/arrow/go/v14/arrow/internal/debug" 32 "github.com/apache/arrow/go/v14/arrow/memory" 33 "golang.org/x/xerrors" 34 ) 35 36 type bufferWriteSeeker struct { 37 buf *memory.Buffer 38 pos int 39 mem memory.Allocator 40 } 41 42 func (b *bufferWriteSeeker) Reserve(nbytes int) { 43 if b.buf == nil { 44 b.buf = memory.NewResizableBuffer(b.mem) 45 } 46 newCap := int(math.Max(float64(b.buf.Cap()), 256)) 47 for newCap < b.pos+nbytes { 48 newCap = bitutil.NextPowerOf2(newCap) 49 } 50 b.buf.Reserve(newCap) 51 } 52 53 func (b *bufferWriteSeeker) Write(p []byte) (n int, err error) { 54 if len(p) == 0 { 55 return 0, nil 56 } 57 58 if b.buf == nil { 59 b.Reserve(len(p)) 60 } else if b.pos+len(p) >= b.buf.Cap() { 61 b.Reserve(len(p)) 62 } 63 64 return b.UnsafeWrite(p) 65 } 66 67 func (b *bufferWriteSeeker) UnsafeWrite(p []byte) (n int, err error) { 68 n = copy(b.buf.Buf()[b.pos:], p) 69 b.pos += len(p) 70 if b.pos > b.buf.Len() { 71 b.buf.ResizeNoShrink(b.pos) 72 } 73 return 74 } 75 76 func (b *bufferWriteSeeker) Seek(offset int64, whence int) (int64, error) { 77 newpos, offs := 0, int(offset) 78 switch whence { 79 case io.SeekStart: 80 newpos = offs 81 case io.SeekCurrent: 82 newpos = b.pos + offs 83 case io.SeekEnd: 84 newpos = b.buf.Len() + offs 85 } 86 if newpos < 0 { 87 return 0, xerrors.New("negative result pos") 88 } 89 b.pos = newpos 90 return int64(newpos), nil 91 } 92 93 // ensureDictionaryDecoded is used by DispatchBest to determine 94 // the proper types for promotion. Casting is then performed by 95 // the executor before continuing execution: see the implementation 96 // of execInternal in exec.go after calling DispatchBest. 97 // 98 // That casting is where actual decoding would be performed for 99 // the dictionary 100 func ensureDictionaryDecoded(vals ...arrow.DataType) { 101 for i, v := range vals { 102 if v.ID() == arrow.DICTIONARY { 103 vals[i] = v.(*arrow.DictionaryType).ValueType 104 } 105 } 106 } 107 108 func replaceNullWithOtherType(vals ...arrow.DataType) { 109 debug.Assert(len(vals) == 2, "should be length 2") 110 111 if vals[0].ID() == arrow.NULL { 112 vals[0] = vals[1] 113 return 114 } 115 116 if vals[1].ID() == arrow.NULL { 117 vals[1] = vals[0] 118 return 119 } 120 } 121 122 func commonTemporalResolution(vals ...arrow.DataType) (arrow.TimeUnit, bool) { 123 isTimeUnit := false 124 finestUnit := arrow.Second 125 for _, v := range vals { 126 switch dt := v.(type) { 127 case *arrow.Date32Type: 128 isTimeUnit = true 129 continue 130 case *arrow.Date64Type: 131 finestUnit = exec.Max(finestUnit, arrow.Millisecond) 132 isTimeUnit = true 133 case arrow.TemporalWithUnit: 134 finestUnit = exec.Max(finestUnit, dt.TimeUnit()) 135 isTimeUnit = true 136 default: 137 continue 138 } 139 } 140 return finestUnit, isTimeUnit 141 } 142 143 func replaceTemporalTypes(unit arrow.TimeUnit, vals ...arrow.DataType) { 144 for i, v := range vals { 145 switch dt := v.(type) { 146 case *arrow.TimestampType: 147 dt.Unit = unit 148 vals[i] = dt 149 case *arrow.Time32Type, *arrow.Time64Type: 150 if unit > arrow.Millisecond { 151 vals[i] = &arrow.Time64Type{Unit: unit} 152 } else { 153 vals[i] = &arrow.Time32Type{Unit: unit} 154 } 155 case *arrow.DurationType: 156 dt.Unit = unit 157 vals[i] = dt 158 case *arrow.Date32Type, *arrow.Date64Type: 159 vals[i] = &arrow.TimestampType{Unit: unit} 160 } 161 } 162 } 163 164 func replaceTypes(replacement arrow.DataType, vals ...arrow.DataType) { 165 for i := range vals { 166 vals[i] = replacement 167 } 168 } 169 170 func commonNumeric(vals ...arrow.DataType) arrow.DataType { 171 for _, v := range vals { 172 if !arrow.IsFloating(v.ID()) && !arrow.IsInteger(v.ID()) { 173 // a common numeric type is only possible if all are numeric 174 return nil 175 } 176 if v.ID() == arrow.FLOAT16 { 177 // float16 arithmetic is not currently supported 178 return nil 179 } 180 } 181 182 for _, v := range vals { 183 if v.ID() == arrow.FLOAT64 { 184 return arrow.PrimitiveTypes.Float64 185 } 186 } 187 188 for _, v := range vals { 189 if v.ID() == arrow.FLOAT32 { 190 return arrow.PrimitiveTypes.Float32 191 } 192 } 193 194 maxWidthSigned, maxWidthUnsigned := 0, 0 195 for _, v := range vals { 196 if arrow.IsUnsignedInteger(v.ID()) { 197 maxWidthUnsigned = exec.Max(v.(arrow.FixedWidthDataType).BitWidth(), maxWidthUnsigned) 198 } else { 199 maxWidthSigned = exec.Max(v.(arrow.FixedWidthDataType).BitWidth(), maxWidthSigned) 200 } 201 } 202 203 if maxWidthSigned == 0 { 204 switch { 205 case maxWidthUnsigned >= 64: 206 return arrow.PrimitiveTypes.Uint64 207 case maxWidthUnsigned == 32: 208 return arrow.PrimitiveTypes.Uint32 209 case maxWidthUnsigned == 16: 210 return arrow.PrimitiveTypes.Uint16 211 default: 212 debug.Assert(maxWidthUnsigned == 8, "bad maxWidthUnsigned") 213 return arrow.PrimitiveTypes.Uint8 214 } 215 } 216 217 if maxWidthSigned <= maxWidthUnsigned { 218 maxWidthSigned = bitutil.NextPowerOf2(maxWidthUnsigned + 1) 219 } 220 221 switch { 222 case maxWidthSigned >= 64: 223 return arrow.PrimitiveTypes.Int64 224 case maxWidthSigned == 32: 225 return arrow.PrimitiveTypes.Int32 226 case maxWidthSigned == 16: 227 return arrow.PrimitiveTypes.Int16 228 default: 229 debug.Assert(maxWidthSigned == 8, "bad maxWidthSigned") 230 return arrow.PrimitiveTypes.Int8 231 } 232 } 233 234 func hasDecimal(vals ...arrow.DataType) bool { 235 for _, v := range vals { 236 if arrow.IsDecimal(v.ID()) { 237 return true 238 } 239 } 240 241 return false 242 } 243 244 type decimalPromotion uint8 245 246 const ( 247 decPromoteNone decimalPromotion = iota 248 decPromoteAdd 249 decPromoteMultiply 250 decPromoteDivide 251 ) 252 253 func castBinaryDecimalArgs(promote decimalPromotion, vals ...arrow.DataType) error { 254 left, right := vals[0], vals[1] 255 debug.Assert(arrow.IsDecimal(left.ID()) || arrow.IsDecimal(right.ID()), "at least one of the types should be decimal") 256 257 // decimal + float = float 258 if arrow.IsFloating(left.ID()) { 259 vals[1] = vals[0] 260 return nil 261 } else if arrow.IsFloating(right.ID()) { 262 vals[0] = vals[1] 263 return nil 264 } 265 266 var prec1, scale1, prec2, scale2 int32 267 var err error 268 // decimal + integer = decimal 269 if arrow.IsDecimal(left.ID()) { 270 dec := left.(arrow.DecimalType) 271 prec1, scale1 = dec.GetPrecision(), dec.GetScale() 272 } else { 273 debug.Assert(arrow.IsInteger(left.ID()), "floats were already handled, this should be an int") 274 if prec1, err = kernels.MaxDecimalDigitsForInt(left.ID()); err != nil { 275 return err 276 } 277 } 278 if arrow.IsDecimal(right.ID()) { 279 dec := right.(arrow.DecimalType) 280 prec2, scale2 = dec.GetPrecision(), dec.GetScale() 281 } else { 282 debug.Assert(arrow.IsInteger(right.ID()), "float already handled, should be ints") 283 if prec2, err = kernels.MaxDecimalDigitsForInt(right.ID()); err != nil { 284 return err 285 } 286 } 287 288 if scale1 < 0 || scale2 < 0 { 289 return fmt.Errorf("%w: decimals with negative scales not supported", arrow.ErrNotImplemented) 290 } 291 292 // decimal128 + decimal256 = decimal256 293 castedID := arrow.DECIMAL128 294 if left.ID() == arrow.DECIMAL256 || right.ID() == arrow.DECIMAL256 { 295 castedID = arrow.DECIMAL256 296 } 297 298 // decimal promotion rules compatible with amazon redshift 299 // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html 300 var leftScaleup, rightScaleup int32 301 302 switch promote { 303 case decPromoteAdd: 304 leftScaleup = exec.Max(scale1, scale2) - scale1 305 rightScaleup = exec.Max(scale1, scale2) - scale2 306 case decPromoteMultiply: 307 case decPromoteDivide: 308 leftScaleup = exec.Max(4, scale1+prec2-scale2+1) + scale2 - scale1 309 default: 310 debug.Assert(false, fmt.Sprintf("invalid DecimalPromotion value %d", promote)) 311 } 312 313 vals[0], err = arrow.NewDecimalType(castedID, prec1+leftScaleup, scale1+leftScaleup) 314 if err != nil { 315 return err 316 } 317 vals[1], err = arrow.NewDecimalType(castedID, prec2+rightScaleup, scale2+rightScaleup) 318 return err 319 } 320 321 func commonTemporal(vals ...arrow.DataType) arrow.DataType { 322 var ( 323 finestUnit = arrow.Second 324 zone *string 325 loc *time.Location 326 sawDate32, sawDate64 bool 327 ) 328 329 for _, ty := range vals { 330 switch ty.ID() { 331 case arrow.DATE32: 332 // date32's unit is days, but the coarsest we have is seconds 333 sawDate32 = true 334 case arrow.DATE64: 335 finestUnit = exec.Max(finestUnit, arrow.Millisecond) 336 sawDate64 = true 337 case arrow.TIMESTAMP: 338 ts := ty.(*arrow.TimestampType) 339 if ts.TimeZone != "" { 340 tz, _ := ts.GetZone() 341 if loc != nil && loc != tz { 342 return nil 343 } 344 loc = tz 345 } 346 zone = &ts.TimeZone 347 finestUnit = exec.Max(finestUnit, ts.Unit) 348 default: 349 return nil 350 } 351 } 352 353 switch { 354 case zone != nil: 355 // at least one timestamp seen 356 return &arrow.TimestampType{Unit: finestUnit, TimeZone: *zone} 357 case sawDate64: 358 return arrow.FixedWidthTypes.Date64 359 case sawDate32: 360 return arrow.FixedWidthTypes.Date32 361 } 362 return nil 363 } 364 365 func commonBinary(vals ...arrow.DataType) arrow.DataType { 366 var ( 367 allUTF8, allOffset32, allFixedWidth = true, true, true 368 ) 369 370 for _, ty := range vals { 371 switch ty.ID() { 372 case arrow.STRING: 373 allFixedWidth = false 374 case arrow.BINARY: 375 allFixedWidth, allUTF8 = false, false 376 case arrow.FIXED_SIZE_BINARY: 377 allUTF8 = false 378 case arrow.LARGE_BINARY: 379 allOffset32, allFixedWidth, allUTF8 = false, false, false 380 case arrow.LARGE_STRING: 381 allOffset32, allFixedWidth = false, false 382 default: 383 return nil 384 } 385 } 386 387 switch { 388 case allFixedWidth: 389 // at least for the purposes of comparison, no need to cast 390 return nil 391 case allUTF8: 392 if allOffset32 { 393 return arrow.BinaryTypes.String 394 } 395 return arrow.BinaryTypes.LargeString 396 case allOffset32: 397 return arrow.BinaryTypes.Binary 398 } 399 return arrow.BinaryTypes.LargeBinary 400 }