gitee.com/quant1x/num@v0.3.2/asm/src/floats_avx.c (about) 1 // Copyright 2022 gorse Project Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <immintrin.h> 16 #include <stdint.h> 17 18 void _mm256_mul_const_add_to(float *a, float *b, float *c, int64_t n) 19 { 20 int epoch = n / 8; 21 int remain = n % 8; 22 for (int i = 0; i < epoch; i++) 23 { 24 __m256 v1 = _mm256_loadu_ps(a); 25 __m256 v2 = _mm256_broadcast_ss(b); 26 __m256 v3 = _mm256_loadu_ps(c); 27 __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3); 28 _mm256_storeu_ps(c, v); 29 a += 8; 30 c += 8; 31 } 32 for (int i = 0; i < remain; i++) 33 { 34 c[i] += a[i] * b[0]; 35 } 36 } 37 38 void _mm256_mul_const_to(float *a, float *b, float *c, int64_t n) 39 { 40 int epoch = n / 8; 41 int remain = n % 8; 42 for (int i = 0; i < epoch; i++) 43 { 44 __m256 v1 = _mm256_loadu_ps(a); 45 __m256 v2 = _mm256_broadcast_ss(b); 46 __m256 v = _mm256_mul_ps(v1, v2); 47 _mm256_storeu_ps(c, v); 48 a += 8; 49 c += 8; 50 } 51 for (int i = 0; i < remain; i++) 52 { 53 c[i] = a[i] * b[0]; 54 } 55 } 56 57 void _mm256_mul_const(float *a, float *b, int64_t n) 58 { 59 int epoch = n / 8; 60 int remain = n % 8; 61 for (int i = 0; i < epoch; i++) 62 { 63 __m256 v1 = _mm256_loadu_ps(a); 64 __m256 v2 = _mm256_broadcast_ss(b); 65 __m256 v = _mm256_mul_ps(v1, v2); 66 _mm256_storeu_ps(a, v); 67 a += 8; 68 } 69 for (int i = 0; i < remain; i++) 70 { 71 a[i] *= b[0]; 72 } 73 } 74 75 void _mm256_mul_to(float *a, float *b, float *c, int64_t n) 76 { 77 int epoch = n / 8; 78 int remain = n % 8; 79 for (int i = 0; i < epoch; i++) 80 { 81 __m256 v1 = _mm256_loadu_ps(a); 82 __m256 v2 = _mm256_loadu_ps(b); 83 __m256 v = _mm256_mul_ps(v1, v2); 84 _mm256_storeu_ps(c, v); 85 a += 8; 86 b += 8; 87 c += 8; 88 } 89 for (int i = 0; i < remain; i++) 90 { 91 c[i] = a[i] * b[i]; 92 } 93 } 94 95 void _mm256_dot(float *a, float *b, int64_t n, float *ret) 96 { 97 int epoch = n / 8; 98 int remain = n % 8; 99 __m256 s; 100 if (epoch > 0) 101 { 102 __m256 v1 = _mm256_loadu_ps(a); 103 __m256 v2 = _mm256_loadu_ps(b); 104 s = _mm256_mul_ps(v1, v2); 105 a += 8; 106 b += 8; 107 } 108 for (int i = 1; i < epoch; i++) 109 { 110 __m256 v1 = _mm256_loadu_ps(a); 111 __m256 v2 = _mm256_loadu_ps(b); 112 s = _mm256_add_ps(_mm256_mul_ps(v1, v2), s); 113 a += 8; 114 b += 8; 115 } 116 __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1); 117 __m128 s3_2_1_0 = _mm256_castps256_ps128(s); 118 __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0); 119 __m128 sxx_15_04 = s37_26_15_04; 120 __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04); 121 const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26); 122 const __m128 sxxx_0246 = sxx_1357_0246; 123 const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1); 124 __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357); 125 *ret = _mm_cvtss_f32(sxxx_01234567); 126 for (int i = 0; i < remain; i++) 127 { 128 *ret += a[i] * b[i]; 129 } 130 }