Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hrydgard
GitHub Repository: hrydgard/ppsspp
Path: blob/master/Common/Math/fast/fast_matrix.h
5693 views
1
#pragma once
2
3
#include "ppsspp_config.h"
4
5
#include "Common/Math/SIMDHeaders.h"
6
#include "Common/Common.h"
7
8
#if PPSSPP_ARCH(SSE2)
9
10
inline void fast_matrix_mul_4x4(float *dest, const float *a, const float *b) {
11
int i;
12
__m128 a_col_1 = _mm_loadu_ps(a);
13
__m128 a_col_2 = _mm_loadu_ps(&a[4]);
14
__m128 a_col_3 = _mm_loadu_ps(&a[8]);
15
__m128 a_col_4 = _mm_loadu_ps(&a[12]);
16
17
for (i = 0; i < 16; i += 4) {
18
__m128 r_col = _mm_mul_ps(a_col_1, _mm_set1_ps(b[i]));
19
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_2, _mm_set1_ps(b[i + 1])));
20
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_3, _mm_set1_ps(b[i + 2])));
21
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_4, _mm_set1_ps(b[i + 3])));
22
_mm_storeu_ps(&dest[i], r_col);
23
}
24
}
25
26
#elif PPSSPP_ARCH(LOONGARCH64_LSX)
27
28
inline __m128 __lsx_vreplfr2vr_s(float val) {
29
typedef union {
30
int32_t i;
31
float f;
32
} FloatInt;
33
FloatInt tmpval = {.f = val};
34
return (__m128)__lsx_vreplgr2vr_w(tmpval.i);
35
}
36
37
inline void fast_matrix_mul_4x4(float *dest, const float *a, const float *b) {
38
__m128 a_col_1 = (__m128)__lsx_vld(a, 0);
39
__m128 a_col_2 = (__m128)__lsx_vld(a + 4, 0);
40
__m128 a_col_3 = (__m128)__lsx_vld(a + 8, 0);
41
__m128 a_col_4 = (__m128)__lsx_vld(a + 12, 0);
42
43
for (int i = 0; i < 16; i += 4) {
44
45
__m128 b1 = __lsx_vreplfr2vr_s(b[i]);
46
__m128 b2 = __lsx_vreplfr2vr_s(b[i + 1]);
47
__m128 b3 = __lsx_vreplfr2vr_s(b[i + 2]);
48
__m128 b4 = __lsx_vreplfr2vr_s(b[i + 3]);
49
50
__m128 result = __lsx_vfmul_s(a_col_1, b1);
51
result = __lsx_vfmadd_s(a_col_2, b2, result);
52
result = __lsx_vfmadd_s(a_col_3, b3, result);
53
result = __lsx_vfmadd_s(a_col_4, b4, result);
54
55
__lsx_vst(result, &dest[i], 0);
56
}
57
}
58
59
#elif PPSSPP_ARCH(ARM_NEON)
60
61
#ifdef B0
62
#undef B0
63
#endif
64
65
// From https://developer.arm.com/documentation/102467/0100/Matrix-multiplication-example
66
inline void fast_matrix_mul_4x4(float *C, const float *A, const float *B) {
67
// these are the columns A
68
float32x4_t A0;
69
float32x4_t A1;
70
float32x4_t A2;
71
float32x4_t A3;
72
73
// these are the columns B
74
float32x4_t B0;
75
float32x4_t B1;
76
float32x4_t B2;
77
float32x4_t B3;
78
79
// these are the columns C
80
float32x4_t C0;
81
float32x4_t C1;
82
float32x4_t C2;
83
float32x4_t C3;
84
85
A0 = vld1q_f32(A);
86
A1 = vld1q_f32(A + 4);
87
A2 = vld1q_f32(A + 8);
88
A3 = vld1q_f32(A + 12);
89
90
// Multiply accumulate in 4x1 blocks, i.e. each column in C
91
B0 = vld1q_f32(B);
92
B1 = vld1q_f32(B + 4);
93
B2 = vld1q_f32(B + 8);
94
B3 = vld1q_f32(B + 12);
95
96
C0 = vmulq_laneq_f32(A0, B0, 0);
97
C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
98
C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
99
C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
100
vst1q_f32(C, C0);
101
102
C1 = vmulq_laneq_f32(A0, B1, 0);
103
C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
104
C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
105
C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
106
vst1q_f32(C + 4, C1);
107
108
C2 = vmulq_laneq_f32(A0, B2, 0);
109
C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
110
C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
111
C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
112
vst1q_f32(C + 8, C2);
113
114
C3 = vmulq_laneq_f32(A0, B3, 0);
115
C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
116
C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
117
C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
118
vst1q_f32(C + 12, C3);
119
}
120
121
#else
122
123
inline void fast_matrix_mul_4x4(float * RESTRICT dest, const float * RESTRICT a, const float * RESTRICT b) {
124
dest[0] = b[0] * a[0] + b[1] * a[4] + b[2] * a[8] + b[3] * a[12];
125
dest[1] = b[0] * a[1] + b[1] * a[5] + b[2] * a[9] + b[3] * a[13];
126
dest[2] = b[0] * a[2] + b[1] * a[6] + b[2] * a[10] + b[3] * a[14];
127
dest[3] = b[0] * a[3] + b[1] * a[7] + b[2] * a[11] + b[3] * a[15];
128
129
dest[4] = b[4] * a[0] + b[5] * a[4] + b[6] * a[8] + b[7] * a[12];
130
dest[5] = b[4] * a[1] + b[5] * a[5] + b[6] * a[9] + b[7] * a[13];
131
dest[6] = b[4] * a[2] + b[5] * a[6] + b[6] * a[10] + b[7] * a[14];
132
dest[7] = b[4] * a[3] + b[5] * a[7] + b[6] * a[11] + b[7] * a[15];
133
134
dest[8] = b[8] * a[0] + b[9] * a[4] + b[10] * a[8] + b[11] * a[12];
135
dest[9] = b[8] * a[1] + b[9] * a[5] + b[10] * a[9] + b[11] * a[13];
136
dest[10] = b[8] * a[2] + b[9] * a[6] + b[10] * a[10] + b[11] * a[14];
137
dest[11] = b[8] * a[3] + b[9] * a[7] + b[10] * a[11] + b[11] * a[15];
138
139
dest[12] = b[12] * a[0] + b[13] * a[4] + b[14] * a[8] + b[15] * a[12];
140
dest[13] = b[12] * a[1] + b[13] * a[5] + b[14] * a[9] + b[15] * a[13];
141
dest[14] = b[12] * a[2] + b[13] * a[6] + b[14] * a[10] + b[15] * a[14];
142
dest[15] = b[12] * a[3] + b[13] * a[7] + b[14] * a[11] + b[15] * a[15];
143
}
144
145
#endif
146
147