github.com/apache/arrow/go/v16@v16.1.0/arrow/compute/internal/kernels/_lib/base_arithmetic.cc (about) 1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 #include <arch.h> 18 #include <math.h> 19 #include <stdint.h> 20 #include <limits.h> 21 #include "types.h" 22 #include "vendored/safe-math.h" 23 24 // Corresponds to equivalent ArithmeticOp enum in base_arithmetic.go 25 // for passing across which operation to perform. This allows simpler 26 // implementation at the cost of having to pass the extra int8 and 27 // perform a switch. 28 // 29 // In cases of small arrays, this is completely negligible. In cases 30 // of large arrays, the time saved by using SIMD here is significantly 31 // worth the cost. 32 enum class optype : int8_t { 33 ADD, 34 SUB, 35 MUL, 36 DIV, 37 ABSOLUTE_VALUE, 38 NEGATE, 39 SQRT, 40 POWER, 41 SIN, 42 COS, 43 TAN, 44 ASIN, 45 ACOS, 46 ATAN, 47 ATAN2, 48 LN, 49 LOG10, 50 LOG2, 51 LOG1P, 52 LOGB, 53 SIGN, 54 55 // this impl doesn't actually perform any overflow checks as we need 56 // to only run overflow checks on non-null entries 57 ADD_CHECKED, 58 SUB_CHECKED, 59 MUL_CHECKED, 60 DIV_CHECKED, 61 ABSOLUTE_VALUE_CHECKED, 62 NEGATE_CHECKED, 63 SQRT_CHECKED, 64 POWER_CHECKED, 65 SIN_CHECKED, 66 COS_CHECKED, 67 TAN_CHECKED, 68 ASIN_CHECKED, 69 ACOS_CHECKED, 70 LN_CHECKED, 71 LOG10_CHECKED, 72 LOG2_CHECKED, 73 LOG1P_CHECKED, 74 LOGB_CHECKED, 75 }; 76 77 struct Add { 78 template <typename T, typename Arg0, typename Arg1> 79 static constexpr T Call(Arg0 left, Arg1 right) { 80 if constexpr (is_arithmetic_v<T>) 81 return left + right; 82 } 83 }; 84 85 struct Sub { 86 template <typename T, typename Arg0, typename Arg1> 87 static constexpr T Call(Arg0 left, Arg1 right) { 88 if constexpr (is_arithmetic_v<T>) 89 return left - right; 90 } 91 }; 92 93 struct AddChecked { 94 template <typename T, typename Arg0, typename Arg1> 95 static constexpr T Call(Arg0 left, Arg1 right) { 96 static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, ""); 97 if constexpr(is_arithmetic_v<T>) { 98 return left + right; 99 } 100 } 101 }; 102 103 104 struct SubChecked { 105 template <typename T, typename Arg0, typename Arg1> 106 static constexpr T Call(Arg0 left, Arg1 right) { 107 static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, ""); 108 if constexpr(is_arithmetic_v<T>) { 109 return left - right; 110 } 111 } 112 }; 113 114 template <typename T> 115 using maybe_make_unsigned = conditional_t<is_integral_v<T> && !is_same_v<T, bool>, make_unsigned_t<T>, T>; 116 117 template <typename T, typename Unsigned = maybe_make_unsigned<T>> 118 constexpr Unsigned to_unsigned(T signed_) { 119 return static_cast<Unsigned>(signed_); 120 } 121 122 struct Multiply { 123 static_assert(is_same_v<decltype(int8_t() * int8_t()), int32_t>, ""); 124 static_assert(is_same_v<decltype(uint8_t() * uint8_t()), int32_t>, ""); 125 static_assert(is_same_v<decltype(int16_t() * int16_t()), int32_t>, ""); 126 static_assert(is_same_v<decltype(uint16_t() * uint16_t()), int32_t>, ""); 127 static_assert(is_same_v<decltype(int32_t() * int32_t()), int32_t>, ""); 128 static_assert(is_same_v<decltype(uint32_t() * uint32_t()), uint32_t>, ""); 129 static_assert(is_same_v<decltype(int64_t() * int64_t()), int64_t>, ""); 130 static_assert(is_same_v<decltype(uint64_t() * uint64_t()), uint64_t>, ""); 131 132 template <typename T, typename Arg0, typename Arg1> 133 static constexpr T Call(Arg0 left, Arg1 right) { 134 static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, ""); 135 if constexpr(is_floating_point_v<T>) { 136 return left * right; 137 } else if constexpr(is_unsigned_v<T> && !is_same_v<T, uint16_t>) { 138 return left * right; 139 } else if constexpr(is_signed_v<T> && !is_same_v<T, int16_t>) { 140 return to_unsigned(left) * to_unsigned(right); 141 } else if constexpr(is_same_v<T, int16_t> || is_same_v<T, uint16_t>) { 142 // multiplication of 16 bit integer types implicitly promotes to 143 // signed 32 bit integer. However, some inputs may overflow (which 144 // triggers undefined behavior). Therefore we first cast to 32 bit 145 // unsigned integers where overflow is well defined. 146 return static_cast<uint32_t>(left) * static_cast<uint32_t>(right); 147 } 148 } 149 }; 150 151 struct MultiplyChecked { 152 template <typename T, typename Arg0, typename Arg1> 153 static constexpr T Call(Arg0 left, Arg1 right) { 154 static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, ""); 155 if constexpr(is_arithmetic_v<T>) { 156 return left * right; 157 } 158 } 159 }; 160 161 struct AbsoluteValue { 162 template <typename T, typename Arg> 163 static constexpr T Call(Arg input) { 164 if constexpr(is_same_v<Arg, float>) { 165 *(((int*)&input)+0) &= 0x7fffffff; 166 return input; 167 } else if constexpr(is_same_v<Arg, double>) { 168 *(((int*)&input)+1) &= 0x7fffffff; 169 return input; 170 } else if constexpr(is_unsigned_v<Arg>) { 171 return input; 172 } else { 173 const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1); 174 return (input + mask) ^ mask; 175 } 176 } 177 }; 178 179 struct AbsoluteValueChecked { 180 template <typename T, typename Arg> 181 static constexpr T Call(Arg input) { 182 if constexpr(is_same_v<Arg, float>) { 183 *(((int*)&input)+0) &= 0x7fffffff; 184 return input; 185 } else if constexpr(is_same_v<Arg, double>) { 186 *(((int*)&input)+1) &= 0x7fffffff; 187 return input; 188 } else if constexpr(is_unsigned_v<Arg>) { 189 return input; 190 } else { 191 const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1); 192 return (input + mask) ^ mask; 193 } 194 } 195 }; 196 197 struct Negate { 198 template <typename T, typename Arg> 199 static constexpr T Call(Arg input) { 200 if constexpr(is_floating_point_v<Arg>) { 201 return -input; 202 } else if constexpr(is_unsigned_v<Arg>) { 203 return ~input + 1; 204 } else { 205 return -input; 206 } 207 } 208 }; 209 210 struct NegateChecked { 211 template <typename T, typename Arg> 212 static constexpr T Call(Arg input) { 213 static_assert(is_same_v<T, Arg>, ""); 214 if constexpr(is_floating_point_v<Arg>) { 215 return -input; 216 } else if constexpr(is_unsigned_v<Arg>) { 217 return 0; 218 } else { 219 return -input; 220 } 221 } 222 }; 223 224 struct Sign { 225 template <typename T, typename Arg> 226 static constexpr T Call(Arg input) { 227 if constexpr(is_floating_point_v<Arg>) { 228 return isnan(input) ? input : ((input == 0) ? 0 : (signbit(input) ? -1 : 1)); 229 } else if constexpr(is_unsigned_v<Arg>) { 230 return input > 0 ? 1 : 0; 231 } else if constexpr(is_signed_v<Arg>) { 232 return input > 0 ? 1 : (input ? -1 : 0); 233 } 234 } 235 }; 236 237 template <typename T, typename Op, typename OutT = T> 238 struct arithmetic_op_arr_arr_impl { 239 static inline void exec(const void* in_left, const void* in_right, void* out, const int len) { 240 const T* left = reinterpret_cast<const T*>(in_left); 241 const T* right = reinterpret_cast<const T*>(in_right); 242 OutT* output = reinterpret_cast<OutT*>(out); 243 244 for (int i = 0; i < len; ++i) { 245 output[i] = Op::template Call<OutT, T, T>(left[i], right[i]); 246 } 247 } 248 }; 249 250 template <typename T, typename Op, typename OutT = T> 251 struct arithmetic_op_arr_scalar_impl { 252 static inline void exec(const void* in_left, const void* scalar_right, void* out, const int len) { 253 const T* left = reinterpret_cast<const T*>(in_left); 254 const T right = *reinterpret_cast<const T*>(scalar_right); 255 OutT* output = reinterpret_cast<OutT*>(out); 256 257 for (int i = 0; i < len; ++i) { 258 output[i] = Op::template Call<OutT, T, T>(left[i], right); 259 } 260 } 261 }; 262 263 template <typename T, typename Op, typename OutT = T> 264 struct arithmetic_op_scalar_arr_impl { 265 static inline void exec(const void* scalar_left, const void* in_right, void* out, const int len) { 266 const T left = *reinterpret_cast<const T*>(scalar_left); 267 const T* right = reinterpret_cast<const T*>(in_right); 268 OutT* output = reinterpret_cast<OutT*>(out); 269 270 for (int i = 0; i < len; ++i) { 271 output[i] = Op::template Call<OutT, T, T>(left, right[i]); 272 } 273 } 274 }; 275 276 template <typename T, typename Op, typename OutT = T> 277 struct arithmetic_unary_op_impl { 278 static inline void exec(const void* arg, void* out, const int len) { 279 const T* input = reinterpret_cast<const T*>(arg); 280 OutT* output = reinterpret_cast<OutT*>(out); 281 282 for (int i = 0; i < len; ++i) { 283 output[i] = Op::template Call<OutT, T>(input[i]); 284 } 285 } 286 }; 287 288 template <typename Op, template<typename...> typename Impl> 289 static inline void arithmetic_op(const int type, const void* in_left, const void* in_right, void* output, const int len) { 290 const auto intype = static_cast<arrtype>(type); 291 292 switch (intype) { 293 case arrtype::UINT8: 294 return Impl<uint8_t, Op>::exec(in_left, in_right, output, len); 295 case arrtype::INT8: 296 return Impl<int8_t, Op>::exec(in_left, in_right, output, len); 297 case arrtype::UINT16: 298 return Impl<uint16_t, Op>::exec(in_left, in_right, output, len); 299 case arrtype::INT16: 300 return Impl<int16_t, Op>::exec(in_left, in_right, output, len); 301 case arrtype::UINT32: 302 return Impl<uint32_t, Op>::exec(in_left, in_right, output, len); 303 case arrtype::INT32: 304 return Impl<int32_t, Op>::exec(in_left, in_right, output, len); 305 case arrtype::UINT64: 306 return Impl<uint64_t, Op>::exec(in_left, in_right, output, len); 307 case arrtype::INT64: 308 return Impl<int64_t, Op>::exec(in_left, in_right, output, len); 309 case arrtype::FLOAT32: 310 return Impl<float, Op>::exec(in_left, in_right, output, len); 311 case arrtype::FLOAT64: 312 return Impl<double, Op>::exec(in_left, in_right, output, len); 313 default: 314 break; 315 } 316 } 317 318 template <typename Op, template <typename...> typename Impl, typename Input> 319 static inline void arithmetic_op(const int otype, const void* input, void* output, const int len) { 320 const auto outtype = static_cast<arrtype>(otype); 321 322 switch (outtype) { 323 case arrtype::UINT8: 324 return Impl<Input, Op, uint8_t>::exec(input, output, len); 325 case arrtype::INT8: 326 return Impl<Input, Op, int8_t>::exec(input, output, len); 327 case arrtype::UINT16: 328 return Impl<Input, Op, uint16_t>::exec(input, output, len); 329 case arrtype::INT16: 330 return Impl<Input, Op, int16_t>::exec(input, output, len); 331 case arrtype::UINT32: 332 return Impl<Input, Op, uint32_t>::exec(input, output, len); 333 case arrtype::INT32: 334 return Impl<Input, Op, int32_t>::exec(input, output, len); 335 case arrtype::UINT64: 336 return Impl<Input, Op, uint64_t>::exec(input, output, len); 337 case arrtype::INT64: 338 return Impl<Input, Op, int64_t>::exec(input, output, len); 339 case arrtype::FLOAT32: 340 return Impl<Input, Op, float>::exec(input, output, len); 341 case arrtype::FLOAT64: 342 return Impl<Input, Op, double>::exec(input, output, len); 343 default: 344 break; 345 } 346 } 347 348 349 template <typename Op, template <typename...> typename Impl> 350 static inline void arithmetic_op(const int type, const void* input, void* output, const int len) { 351 const auto intype = static_cast<arrtype>(type); 352 353 switch (intype) { 354 case arrtype::UINT8: 355 return Impl<uint8_t, Op>::exec(input, output, len); 356 case arrtype::INT8: 357 return Impl<int8_t, Op>::exec(input, output, len); 358 case arrtype::UINT16: 359 return Impl<uint16_t, Op>::exec(input, output, len); 360 case arrtype::INT16: 361 return Impl<int16_t, Op>::exec(input, output, len); 362 case arrtype::UINT32: 363 return Impl<uint32_t, Op>::exec(input, output, len); 364 case arrtype::INT32: 365 return Impl<int32_t, Op>::exec(input, output, len); 366 case arrtype::UINT64: 367 return Impl<uint64_t, Op>::exec(input, output, len); 368 case arrtype::INT64: 369 return Impl<int64_t, Op>::exec(input, output, len); 370 case arrtype::FLOAT32: 371 return Impl<float, Op>::exec(input, output, len); 372 case arrtype::FLOAT64: 373 return Impl<double, Op>::exec(input, output, len); 374 default: 375 break; 376 } 377 } 378 379 template <typename Op, template <typename...> typename Impl> 380 static inline void arithmetic_op(const int itype, const int otype, const void* input, void* output, const int len) { 381 const auto intype = static_cast<arrtype>(itype); 382 383 switch (intype) { 384 case arrtype::UINT8: 385 return arithmetic_op<Op, Impl, uint8_t>(otype, input, output, len); 386 case arrtype::INT8: 387 return arithmetic_op<Op, Impl, int8_t>(otype, input, output, len); 388 case arrtype::UINT16: 389 return arithmetic_op<Op, Impl, uint16_t>(otype, input, output, len); 390 case arrtype::INT16: 391 return arithmetic_op<Op, Impl, int16_t>(otype, input, output, len); 392 case arrtype::UINT32: 393 return arithmetic_op<Op, Impl, uint32_t>(otype, input, output, len); 394 case arrtype::INT32: 395 return arithmetic_op<Op, Impl, int32_t>(otype, input, output, len); 396 case arrtype::UINT64: 397 return arithmetic_op<Op, Impl, uint64_t>(otype, input, output, len); 398 case arrtype::INT64: 399 return arithmetic_op<Op, Impl, int64_t>(otype, input, output, len); 400 case arrtype::FLOAT32: 401 return arithmetic_op<Op, Impl, float>(otype, input, output, len); 402 case arrtype::FLOAT64: 403 return arithmetic_op<Op, Impl, double>(otype, input, output, len); 404 default: 405 break; 406 } 407 } 408 409 template <template <typename...> class Impl> 410 static inline void arithmetic_unary_impl_same_types(const int type, const int8_t op, const void* input, void* output, const int len) { 411 const auto opt = static_cast<optype>(op); 412 413 switch (opt) { 414 case optype::ABSOLUTE_VALUE: 415 return arithmetic_op<AbsoluteValue, Impl>(type, input, output, len); 416 case optype::ABSOLUTE_VALUE_CHECKED: 417 return arithmetic_op<AbsoluteValueChecked, Impl>(type, input, output, len); 418 case optype::NEGATE: 419 return arithmetic_op<Negate, Impl>(type, input, output, len); 420 case optype::NEGATE_CHECKED: 421 return arithmetic_op<NegateChecked, Impl>(type, input, output, len); 422 case optype::SIGN: 423 return arithmetic_op<Sign, Impl>(type, input, output, len); 424 default: 425 break; 426 } 427 } 428 429 430 template <template <typename...> class Impl> 431 static inline void arithmetic_unary_impl(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) { 432 const auto opt = static_cast<optype>(op); 433 434 switch (opt) { 435 case optype::SIGN: 436 return arithmetic_op<Sign, Impl>(itype, otype, input, output, len); 437 default: 438 break; 439 } 440 } 441 442 template <template <typename...> class Impl> 443 static inline void arithmetic_binary_impl(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) { 444 const auto opt = static_cast<optype>(op); 445 446 switch (opt) { 447 case optype::ADD: 448 return arithmetic_op<Add, Impl>(type, in_left, in_right, out, len); 449 case optype::ADD_CHECKED: 450 return arithmetic_op<AddChecked, Impl>(type, in_left, in_right, out, len); 451 case optype::SUB: 452 return arithmetic_op<Sub, Impl>(type, in_left, in_right, out, len); 453 case optype::SUB_CHECKED: 454 return arithmetic_op<SubChecked, Impl>(type, in_left, in_right, out, len); 455 case optype::MUL: 456 return arithmetic_op<Multiply, Impl>(type, in_left, in_right, out, len); 457 case optype::MUL_CHECKED: 458 return arithmetic_op<MultiplyChecked, Impl>(type, in_left, in_right, out, len); 459 default: 460 // don't implement divide here as we can only divide on non-null entries 461 // so we can avoid dividing by zero 462 break; 463 } 464 } 465 466 extern "C" void FULL_NAME(arithmetic_binary)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) { 467 arithmetic_binary_impl<arithmetic_op_arr_arr_impl>(type, op, in_left, in_right, out, len); 468 } 469 470 extern "C" void FULL_NAME(arithmetic_arr_scalar)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) { 471 arithmetic_binary_impl<arithmetic_op_arr_scalar_impl>(type, op, in_left, in_right, out, len); 472 } 473 474 extern "C" void FULL_NAME(arithmetic_scalar_arr)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) { 475 arithmetic_binary_impl<arithmetic_op_scalar_arr_impl>(type, op, in_left, in_right, out, len); 476 } 477 478 extern "C" void FULL_NAME(arithmetic_unary_same_types)(const int type, const int8_t op, const void* input, void* output, const int len) { 479 arithmetic_unary_impl_same_types<arithmetic_unary_op_impl>(type, op, input, output, len); 480 } 481 482 extern "C" void FULL_NAME(arithmetic_unary_diff_type)(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) { 483 arithmetic_unary_impl<arithmetic_unary_op_impl>(itype, otype, op, input, output, len); 484 }