gitee.com/quant1x/num@v0.3.2/asm/src/floats_avx512.c (about) 1 // Copyright 2022 gorse Project Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <immintrin.h> 16 #include <stdint.h> 17 18 void _mm512_mul_const_add_to(float *a, float *b, float *c, int64_t n) 19 { 20 int epoch = n / 16; 21 int remain = n % 16; 22 for (int i = 0; i < epoch; i++) 23 { 24 __m512 v1 = _mm512_loadu_ps(a); 25 __m512 v2 = _mm512_set1_ps(*b); 26 __m512 v3 = _mm512_loadu_ps(c); 27 __m512 v = _mm512_fmadd_ps(v1, v2, v3); 28 _mm512_storeu_ps(c, v); 29 a += 16; 30 c += 16; 31 } 32 if (remain >= 8) 33 { 34 __m256 v1 = _mm256_loadu_ps(a); 35 __m256 v2 = _mm256_broadcast_ss(b); 36 __m256 v3 = _mm256_loadu_ps(c); 37 __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); 38 _mm256_storeu_ps(c, v); 39 a += 8; 40 c += 8; 41 remain -= 8; 42 } 43 for (int i = 0; i < remain; i++) 44 { 45 c[i] += a[i] * b[0]; 46 } 47 } 48 49 void _mm512_mul_const_to(float *a, float *b, float *c, int64_t n) 50 { 51 int epoch = n / 16; 52 int remain = n % 16; 53 for (int i = 0; i < epoch; i++) 54 { 55 __m512 v1 = _mm512_loadu_ps(a); 56 __m512 v2 = _mm512_set1_ps(*b); 57 __m512 v = _mm512_mul_ps(v1, v2); 58 _mm512_storeu_ps(c, v); 59 a += 16; 60 c += 16; 61 } 62 if (remain >= 8) 63 { 64 __m256 v1 = _mm256_loadu_ps(a); 65 __m256 v2 = _mm256_broadcast_ss(b); 66 __m256 v = _mm256_mul_ps(v1, v2); 67 _mm256_storeu_ps(c, v); 68 a += 8; 69 c += 8; 70 remain -= 8; 71 } 72 for (int i = 0; i < remain; i++) 73 { 74 c[i] = a[i] * b[0]; 75 } 76 } 77 78 void _mm512_mul_const(float *a, float *b, int64_t n) 79 { 80 int epoch = n / 16; 81 int remain = n % 16; 82 for (int i = 0; i < epoch; i++) 83 { 84 __m512 v1 = _mm512_loadu_ps(a); 85 __m512 v2 = _mm512_set1_ps(*b); 86 __m512 v = _mm512_mul_ps(v1, v2); 87 _mm512_storeu_ps(a, v); 88 a += 16; 89 } 90 if (remain >= 8) 91 { 92 __m256 v1 = _mm256_loadu_ps(a); 93 __m256 v2 = _mm256_broadcast_ss(b); 94 __m256 v = _mm256_mul_ps(v1, v2); 95 _mm256_storeu_ps(a, v); 96 a += 8; 97 remain -= 8; 98 } 99 for (int i = 0; i < remain; i++) 100 { 101 a[i] *= b[0]; 102 } 103 } 104 105 void _mm512_mul_to(float *a, float *b, float *c, int64_t n) 106 { 107 int epoch = n / 16; 108 int remain = n % 16; 109 for (int i = 0; i < epoch; i++) 110 { 111 __m512 v1 = _mm512_loadu_ps(a); 112 __m512 v2 = _mm512_loadu_ps(b); 113 __m512 v = _mm512_mul_ps(v1, v2); 114 _mm512_storeu_ps(c, v); 115 a += 16; 116 b += 16; 117 c += 16; 118 } 119 if (remain >= 8) 120 { 121 __m256 v1 = _mm256_loadu_ps(a); 122 __m256 v2 = _mm256_loadu_ps(b); 123 __m256 v = _mm256_mul_ps(v1, v2); 124 _mm256_storeu_ps(c, v); 125 a += 8; 126 b += 8; 127 c += 8; 128 remain -= 8; 129 } 130 for (int i = 0; i < remain; i++) 131 { 132 c[i] = a[i] * b[i]; 133 } 134 } 135 136 void _mm512_dot(float *a, float *b, int64_t n, float *ret) 137 { 138 int epoch = n / 16; 139 int remain = n % 16; 140 __m512 s; 141 if (epoch > 0) 142 { 143 __m512 v1 = _mm512_loadu_ps(a); 144 __m512 v2 = _mm512_loadu_ps(b); 145 s = _mm512_mul_ps(v1, v2); 146 a += 16; 147 b += 16; 148 } 149 for (int i = 1; i < epoch; i++) 150 { 151 __m512 v1 = _mm512_loadu_ps(a); 152 __m512 v2 = _mm512_loadu_ps(b); 153 s = _mm512_fmadd_ps(v1, v2, s); 154 a += 16; 155 b += 16; 156 } 157 __m256 sf_e_d_c_b_a_9_8 = _mm512_extractf32x8_ps(s, 1); 158 __m256 s7_6_5_4_3_2_1_0 = _mm512_castps512_ps256(s); 159 __m256 s7f_6e_5d_4c_3b_2a_19_08 = _mm256_add_ps(sf_e_d_c_b_a_9_8, s7_6_5_4_3_2_1_0); 160 __m128 s7f_6e_5d_4c = _mm256_extractf128_ps(s7f_6e_5d_4c_3b_2a_19_08, 1); 161 __m128 s3b_2a_19_08 = _mm256_castps256_ps128(s7f_6e_5d_4c_3b_2a_19_08); 162 __m128 s37bf_26ae_159d_048c = _mm_add_ps(s7f_6e_5d_4c, s3b_2a_19_08); 163 __m128 sxx_159d_048c = s37bf_26ae_159d_048c; 164 __m128 sxx_37bf_26ae = _mm_movehl_ps(sxx_159d_048c, s37bf_26ae_159d_048c); 165 const __m128 sxx_13579bdf_02468ace = _mm_add_ps(sxx_159d_048c, sxx_37bf_26ae); 166 const __m128 sxxx_02468ace = sxx_13579bdf_02468ace; 167 const __m128 sxxx_13579bdf = _mm_shuffle_ps(sxx_13579bdf_02468ace, sxx_13579bdf_02468ace, 0x1); 168 __m128 sxxx_0123456789abcdef = _mm_add_ss(sxxx_02468ace, sxxx_13579bdf); 169 *ret = _mm_cvtss_f32(sxxx_0123456789abcdef); 170 171 if (remain >= 8) 172 { 173 __m256 s; 174 __m256 v1 = _mm256_loadu_ps(a); 175 __m256 v2 = _mm256_loadu_ps(b); 176 s = _mm256_mul_ps(v1, v2); 177 a += 8; 178 b += 8; 179 __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1); 180 __m128 s3_2_1_0 = _mm256_castps256_ps128(s); 181 __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); 182 __m128 sxx_15_04 = s37_26_15_04; 183 __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); 184 const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); 185 const __m128 sxxx_0246 = sxx_1357_0246; 186 const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); 187 __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); 188 *ret += _mm_cvtss_f32(sxxx_01234567); 189 remain -= 8; 190 } 191 192 for (int i = 0; i < remain; i++) 193 { 194 *ret += a[i] * b[i]; 195 } 196 }