github.com/apache/arrow/go/v16@v16.1.0/arrow/compute/internal/kernels/_lib/base_arithmetic.cc (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  #include <arch.h>
    18  #include <math.h>
    19  #include <stdint.h>
    20  #include <limits.h>
    21  #include "types.h"
    22  #include "vendored/safe-math.h"
    23  
    24  // Corresponds to equivalent ArithmeticOp enum in base_arithmetic.go
    25  // for passing across which operation to perform. This allows simpler
    26  // implementation at the cost of having to pass the extra int8 and
    27  // perform a switch.
    28  //
    29  // In cases of small arrays, this is completely negligible. In cases
    30  // of large arrays, the time saved by using SIMD here is significantly
    31  // worth the cost.
    32  enum class optype : int8_t {
    33      ADD,
    34      SUB,
    35      MUL,
    36      DIV,
    37      ABSOLUTE_VALUE,
    38      NEGATE,
    39      SQRT,
    40      POWER,
    41      SIN,
    42      COS,
    43      TAN,
    44      ASIN,
    45      ACOS,
    46      ATAN,
    47      ATAN2,
    48      LN,
    49      LOG10,
    50      LOG2,
    51      LOG1P,
    52      LOGB,
    53      SIGN,
    54  
    55      // this impl doesn't actually perform any overflow checks as we need
    56      // to only run overflow checks on non-null entries
    57      ADD_CHECKED,
    58      SUB_CHECKED,
    59      MUL_CHECKED,
    60      DIV_CHECKED,
    61      ABSOLUTE_VALUE_CHECKED,
    62      NEGATE_CHECKED,
    63      SQRT_CHECKED,
    64      POWER_CHECKED,
    65      SIN_CHECKED,
    66      COS_CHECKED,
    67      TAN_CHECKED,
    68      ASIN_CHECKED,
    69      ACOS_CHECKED,    
    70      LN_CHECKED,
    71      LOG10_CHECKED,
    72      LOG2_CHECKED,
    73      LOG1P_CHECKED,
    74      LOGB_CHECKED,
    75  };
    76  
    77  struct Add {
    78      template <typename T, typename Arg0, typename Arg1>
    79      static constexpr T Call(Arg0 left, Arg1 right) {
    80          if constexpr (is_arithmetic_v<T>)
    81              return left + right;
    82      }
    83  };
    84  
    85  struct Sub {
    86      template <typename T, typename Arg0, typename Arg1>
    87      static constexpr T Call(Arg0 left, Arg1 right) {
    88          if constexpr (is_arithmetic_v<T>)
    89              return left - right;
    90      }
    91  };
    92  
    93  struct AddChecked {
    94      template <typename T, typename Arg0, typename Arg1>
    95      static constexpr T Call(Arg0 left, Arg1 right) {
    96          static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, "");
    97          if constexpr(is_arithmetic_v<T>) {
    98              return left + right;
    99          }
   100      }
   101  };
   102  
   103  
   104  struct SubChecked {
   105      template <typename T, typename Arg0, typename Arg1>
   106      static constexpr T Call(Arg0 left, Arg1 right) {
   107          static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, "");
   108          if constexpr(is_arithmetic_v<T>) {
   109              return left - right;
   110          }
   111      }
   112  };
   113  
   114  template <typename T>
   115  using maybe_make_unsigned = conditional_t<is_integral_v<T> && !is_same_v<T, bool>, make_unsigned_t<T>, T>;
   116  
   117  template <typename T, typename Unsigned = maybe_make_unsigned<T>>
   118  constexpr Unsigned to_unsigned(T signed_) {
   119      return static_cast<Unsigned>(signed_);
   120  }
   121  
   122  struct Multiply {
   123      static_assert(is_same_v<decltype(int8_t() * int8_t()), int32_t>, "");
   124      static_assert(is_same_v<decltype(uint8_t() * uint8_t()), int32_t>, "");
   125      static_assert(is_same_v<decltype(int16_t() * int16_t()), int32_t>, "");
   126      static_assert(is_same_v<decltype(uint16_t() * uint16_t()), int32_t>, "");
   127      static_assert(is_same_v<decltype(int32_t() * int32_t()), int32_t>, "");
   128      static_assert(is_same_v<decltype(uint32_t() * uint32_t()), uint32_t>, "");
   129      static_assert(is_same_v<decltype(int64_t() * int64_t()), int64_t>, "");
   130      static_assert(is_same_v<decltype(uint64_t() * uint64_t()), uint64_t>, "");
   131  
   132      template <typename T, typename Arg0, typename Arg1>
   133      static constexpr T Call(Arg0 left, Arg1 right) {
   134          static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, "");
   135          if constexpr(is_floating_point_v<T>) {
   136              return left * right;
   137          } else if constexpr(is_unsigned_v<T> && !is_same_v<T, uint16_t>) {
   138              return left * right;
   139          } else if constexpr(is_signed_v<T> && !is_same_v<T, int16_t>) {
   140              return to_unsigned(left) * to_unsigned(right);
   141          } else if constexpr(is_same_v<T, int16_t> || is_same_v<T, uint16_t>) {
   142              // multiplication of 16 bit integer types implicitly promotes to
   143              // signed 32 bit integer. However, some inputs may overflow (which
   144              // triggers undefined behavior). Therefore we first cast to 32 bit
   145              // unsigned integers where overflow is well defined.
   146              return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
   147          }
   148      }
   149  };
   150  
   151  struct MultiplyChecked {
   152      template <typename T, typename Arg0, typename Arg1>
   153      static constexpr T Call(Arg0 left, Arg1 right) {
   154          static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, "");
   155          if constexpr(is_arithmetic_v<T>) {
   156              return left * right;
   157          }
   158      }
   159  };
   160  
   161  struct AbsoluteValue {
   162      template <typename T, typename Arg>
   163      static constexpr T Call(Arg input) {
   164          if constexpr(is_same_v<Arg, float>) {
   165              *(((int*)&input)+0) &= 0x7fffffff;
   166              return input;
   167          } else if constexpr(is_same_v<Arg, double>) {
   168              *(((int*)&input)+1) &= 0x7fffffff;
   169              return input;
   170          } else if constexpr(is_unsigned_v<Arg>) {
   171              return input;
   172          } else {
   173              const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1);
   174              return (input + mask) ^ mask;
   175          }
   176      }
   177  };
   178  
   179  struct AbsoluteValueChecked {
   180      template <typename T, typename Arg>
   181      static constexpr T Call(Arg input) {
   182          if constexpr(is_same_v<Arg, float>) {
   183              *(((int*)&input)+0) &= 0x7fffffff;
   184              return input;
   185          } else if constexpr(is_same_v<Arg, double>) {
   186              *(((int*)&input)+1) &= 0x7fffffff;
   187              return input;
   188          } else if constexpr(is_unsigned_v<Arg>) {
   189              return input;
   190          } else {
   191              const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1);
   192              return (input + mask) ^ mask;
   193          }
   194      }
   195  };
   196  
   197  struct Negate {
   198      template <typename T, typename Arg>
   199      static constexpr T Call(Arg input) {
   200          if constexpr(is_floating_point_v<Arg>) {
   201              return -input;
   202          } else if constexpr(is_unsigned_v<Arg>) {
   203              return ~input + 1;
   204          } else {
   205              return -input;
   206          }
   207      }
   208  };
   209  
   210  struct NegateChecked {
   211      template <typename T, typename Arg>
   212      static constexpr T Call(Arg input) {
   213          static_assert(is_same_v<T, Arg>, "");
   214          if constexpr(is_floating_point_v<Arg>) {
   215              return -input;
   216          } else if constexpr(is_unsigned_v<Arg>) {
   217              return 0;
   218          } else {
   219              return -input;
   220          }
   221      }
   222  };
   223  
   224  struct Sign {
   225      template <typename T, typename Arg>
   226      static constexpr T Call(Arg input) {
   227          if constexpr(is_floating_point_v<Arg>) {
   228              return isnan(input) ? input : ((input == 0) ? 0 : (signbit(input) ? -1 : 1));
   229          } else if constexpr(is_unsigned_v<Arg>) {
   230              return input > 0 ? 1 : 0;
   231          } else if constexpr(is_signed_v<Arg>) {
   232              return input > 0 ? 1 : (input ? -1 : 0);
   233          }
   234      }
   235  };
   236  
   237  template <typename T, typename Op, typename OutT = T>
   238  struct arithmetic_op_arr_arr_impl {
   239      static inline void exec(const void* in_left, const void* in_right, void* out, const int len) {
   240          const T* left = reinterpret_cast<const T*>(in_left);
   241          const T* right = reinterpret_cast<const T*>(in_right);
   242          OutT* output = reinterpret_cast<OutT*>(out);
   243  
   244          for (int i = 0; i < len; ++i) {
   245              output[i] = Op::template Call<OutT, T, T>(left[i], right[i]);
   246          }
   247      }
   248  };
   249  
   250  template <typename T, typename Op, typename OutT = T>
   251  struct arithmetic_op_arr_scalar_impl {
   252      static inline void exec(const void* in_left, const void* scalar_right, void* out, const int len) {
   253          const T* left = reinterpret_cast<const T*>(in_left);
   254          const T right = *reinterpret_cast<const T*>(scalar_right);
   255          OutT* output = reinterpret_cast<OutT*>(out);
   256  
   257          for (int i = 0; i < len; ++i) {
   258              output[i] = Op::template Call<OutT, T, T>(left[i], right);
   259          }
   260      }
   261  };
   262  
   263  template <typename T, typename Op, typename OutT = T>
   264  struct arithmetic_op_scalar_arr_impl {
   265      static inline void exec(const void* scalar_left, const void* in_right, void* out, const int len) {
   266          const T left = *reinterpret_cast<const T*>(scalar_left);
   267          const T* right = reinterpret_cast<const T*>(in_right);
   268          OutT* output = reinterpret_cast<OutT*>(out);
   269  
   270          for (int i = 0; i < len; ++i) {
   271              output[i] = Op::template Call<OutT, T, T>(left, right[i]);
   272          }
   273      }
   274  };
   275  
   276  template <typename T, typename Op, typename OutT = T>
   277  struct arithmetic_unary_op_impl {
   278      static inline void exec(const void* arg, void* out, const int len) {
   279          const T* input = reinterpret_cast<const T*>(arg);
   280          OutT* output = reinterpret_cast<OutT*>(out);
   281  
   282          for (int i = 0; i < len; ++i) {
   283              output[i] = Op::template Call<OutT, T>(input[i]);
   284          }
   285      }
   286  };
   287  
   288  template <typename Op, template<typename...> typename Impl>
   289  static inline void arithmetic_op(const int type, const void* in_left, const void* in_right, void* output, const int len) {
   290      const auto intype = static_cast<arrtype>(type);
   291  
   292      switch (intype) {
   293      case arrtype::UINT8:
   294          return Impl<uint8_t, Op>::exec(in_left, in_right, output, len);
   295      case arrtype::INT8:
   296          return Impl<int8_t, Op>::exec(in_left, in_right, output, len);
   297      case arrtype::UINT16:
   298          return Impl<uint16_t, Op>::exec(in_left, in_right, output, len);
   299      case arrtype::INT16:
   300          return Impl<int16_t, Op>::exec(in_left, in_right, output, len);
   301      case arrtype::UINT32:
   302          return Impl<uint32_t, Op>::exec(in_left, in_right, output, len);
   303      case arrtype::INT32:
   304          return Impl<int32_t, Op>::exec(in_left, in_right, output, len);
   305      case arrtype::UINT64:
   306          return Impl<uint64_t, Op>::exec(in_left, in_right, output, len);
   307      case arrtype::INT64:
   308          return Impl<int64_t, Op>::exec(in_left, in_right, output, len);
   309      case arrtype::FLOAT32:
   310          return Impl<float, Op>::exec(in_left, in_right, output, len);
   311      case arrtype::FLOAT64:
   312          return Impl<double, Op>::exec(in_left, in_right, output, len);
   313      default:
   314          break;
   315      }
   316  }
   317  
   318  template <typename Op, template <typename...> typename Impl, typename Input>
   319  static inline void arithmetic_op(const int otype, const void* input, void* output, const int len) {
   320      const auto outtype = static_cast<arrtype>(otype);
   321  
   322      switch (outtype) {
   323      case arrtype::UINT8:
   324          return Impl<Input, Op, uint8_t>::exec(input, output, len);
   325      case arrtype::INT8:
   326          return Impl<Input, Op, int8_t>::exec(input, output, len);
   327      case arrtype::UINT16:
   328          return Impl<Input, Op, uint16_t>::exec(input, output, len);
   329      case arrtype::INT16:
   330          return Impl<Input, Op, int16_t>::exec(input, output, len);
   331      case arrtype::UINT32:
   332          return Impl<Input, Op, uint32_t>::exec(input, output, len);
   333      case arrtype::INT32:
   334          return Impl<Input, Op, int32_t>::exec(input, output, len);
   335      case arrtype::UINT64:
   336          return Impl<Input, Op, uint64_t>::exec(input, output, len);
   337      case arrtype::INT64:
   338          return Impl<Input, Op, int64_t>::exec(input, output, len);
   339      case arrtype::FLOAT32:
   340          return Impl<Input, Op, float>::exec(input, output, len);
   341      case arrtype::FLOAT64:
   342          return Impl<Input, Op, double>::exec(input, output, len);
   343      default:
   344          break;
   345      }
   346  }
   347  
   348  
   349  template <typename Op, template <typename...> typename Impl>
   350  static inline void arithmetic_op(const int type, const void* input, void* output, const int len) {
   351      const auto intype = static_cast<arrtype>(type);
   352  
   353      switch (intype) {
   354      case arrtype::UINT8:
   355          return Impl<uint8_t, Op>::exec(input, output, len);
   356      case arrtype::INT8:
   357          return Impl<int8_t, Op>::exec(input, output, len);
   358      case arrtype::UINT16:
   359          return Impl<uint16_t, Op>::exec(input, output, len);
   360      case arrtype::INT16:
   361          return Impl<int16_t, Op>::exec(input, output, len);
   362      case arrtype::UINT32:
   363          return Impl<uint32_t, Op>::exec(input, output, len);
   364      case arrtype::INT32:
   365          return Impl<int32_t, Op>::exec(input, output, len);
   366      case arrtype::UINT64:
   367          return Impl<uint64_t, Op>::exec(input, output, len);
   368      case arrtype::INT64:
   369          return Impl<int64_t, Op>::exec(input, output, len);
   370      case arrtype::FLOAT32:
   371          return Impl<float, Op>::exec(input, output, len);
   372      case arrtype::FLOAT64:
   373          return Impl<double, Op>::exec(input, output, len);
   374      default:
   375          break;
   376      }
   377  }
   378  
   379  template <typename Op, template <typename...> typename Impl>
   380  static inline void arithmetic_op(const int itype, const int otype, const void* input, void* output, const int len) {
   381      const auto intype = static_cast<arrtype>(itype);
   382  
   383      switch (intype) {
   384      case arrtype::UINT8:
   385          return arithmetic_op<Op, Impl, uint8_t>(otype, input, output, len);
   386      case arrtype::INT8:
   387          return arithmetic_op<Op, Impl, int8_t>(otype, input, output, len);
   388      case arrtype::UINT16:
   389          return arithmetic_op<Op, Impl, uint16_t>(otype, input, output, len);
   390      case arrtype::INT16:
   391          return arithmetic_op<Op, Impl, int16_t>(otype, input, output, len);
   392      case arrtype::UINT32:
   393          return arithmetic_op<Op, Impl, uint32_t>(otype, input, output, len);
   394      case arrtype::INT32:
   395          return arithmetic_op<Op, Impl, int32_t>(otype, input, output, len);
   396      case arrtype::UINT64:
   397          return arithmetic_op<Op, Impl, uint64_t>(otype, input, output, len);
   398      case arrtype::INT64:
   399          return arithmetic_op<Op, Impl, int64_t>(otype, input, output, len);
   400      case arrtype::FLOAT32:
   401          return arithmetic_op<Op, Impl, float>(otype, input, output, len);
   402      case arrtype::FLOAT64:
   403          return arithmetic_op<Op, Impl, double>(otype, input, output, len);
   404      default:
   405          break;
   406      }
   407  }
   408  
   409  template <template <typename...> class Impl>
   410  static inline void arithmetic_unary_impl_same_types(const int type, const int8_t op, const void* input, void* output, const int len) {
   411      const auto opt = static_cast<optype>(op);
   412  
   413      switch (opt) {
   414      case optype::ABSOLUTE_VALUE:
   415          return arithmetic_op<AbsoluteValue, Impl>(type, input, output, len);
   416      case optype::ABSOLUTE_VALUE_CHECKED:
   417          return arithmetic_op<AbsoluteValueChecked, Impl>(type, input, output, len);
   418      case optype::NEGATE:
   419          return arithmetic_op<Negate, Impl>(type, input, output, len);
   420      case optype::NEGATE_CHECKED:
   421          return arithmetic_op<NegateChecked, Impl>(type, input, output, len);
   422      case optype::SIGN:
   423          return arithmetic_op<Sign, Impl>(type, input, output, len);
   424      default:
   425          break;
   426      }
   427  }
   428  
   429  
   430  template <template <typename...> class Impl>
   431  static inline void arithmetic_unary_impl(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) {
   432      const auto opt = static_cast<optype>(op);
   433  
   434      switch (opt) {
   435      case optype::SIGN:
   436          return arithmetic_op<Sign, Impl>(itype, otype, input, output, len);
   437      default:
   438          break;
   439      }
   440  }
   441  
   442  template <template <typename...> class Impl>
   443  static inline void arithmetic_binary_impl(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
   444      const auto opt = static_cast<optype>(op);
   445  
   446      switch (opt) {
   447      case optype::ADD:
   448          return arithmetic_op<Add, Impl>(type, in_left, in_right, out, len);
   449      case optype::ADD_CHECKED:
   450          return arithmetic_op<AddChecked, Impl>(type, in_left, in_right, out, len);
   451      case optype::SUB:
   452          return arithmetic_op<Sub, Impl>(type, in_left, in_right, out, len);
   453      case optype::SUB_CHECKED:
   454          return arithmetic_op<SubChecked, Impl>(type, in_left, in_right, out, len);
   455      case optype::MUL:
   456          return arithmetic_op<Multiply, Impl>(type, in_left, in_right, out, len);
   457      case optype::MUL_CHECKED:
   458          return arithmetic_op<MultiplyChecked, Impl>(type, in_left, in_right, out, len);
   459      default:
   460          // don't implement divide here as we can only divide on non-null entries
   461          // so we can avoid dividing by zero
   462          break;
   463      }
   464  }
   465  
   466  extern "C" void FULL_NAME(arithmetic_binary)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
   467      arithmetic_binary_impl<arithmetic_op_arr_arr_impl>(type, op, in_left, in_right, out, len);
   468  }
   469  
   470  extern "C" void FULL_NAME(arithmetic_arr_scalar)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
   471      arithmetic_binary_impl<arithmetic_op_arr_scalar_impl>(type, op, in_left, in_right, out, len);
   472  }
   473  
   474  extern "C" void FULL_NAME(arithmetic_scalar_arr)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
   475      arithmetic_binary_impl<arithmetic_op_scalar_arr_impl>(type, op, in_left, in_right, out, len);
   476  }
   477  
   478  extern "C" void FULL_NAME(arithmetic_unary_same_types)(const int type, const int8_t op, const void* input, void* output, const int len) {
   479      arithmetic_unary_impl_same_types<arithmetic_unary_op_impl>(type, op, input, output, len);
   480  }
   481  
   482  extern "C" void FULL_NAME(arithmetic_unary_diff_type)(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) {
   483      arithmetic_unary_impl<arithmetic_unary_op_impl>(itype, otype, op, input, output, len);
   484  }