CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hrydgard

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: hrydgard/ppsspp
Path: blob/master/Common/Math/fast/fast_matrix.c
Views: 1401
1
#include "ppsspp_config.h"
2
3
#include "Common/Math/CrossSIMD.h"
4
5
#include "fast_matrix.h"
6
7
#if PPSSPP_ARCH(X86) || PPSSPP_ARCH(AMD64)
8
9
#include <emmintrin.h>
10
11
#include "fast_matrix.h"
12
13
void fast_matrix_mul_4x4_sse(float *dest, const float *a, const float *b) {
14
int i;
15
__m128 a_col_1 = _mm_loadu_ps(a);
16
__m128 a_col_2 = _mm_loadu_ps(&a[4]);
17
__m128 a_col_3 = _mm_loadu_ps(&a[8]);
18
__m128 a_col_4 = _mm_loadu_ps(&a[12]);
19
20
for (i = 0; i < 16; i += 4) {
21
__m128 r_col = _mm_mul_ps(a_col_1, _mm_set1_ps(b[i]));
22
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_2, _mm_set1_ps(b[i + 1])));
23
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_3, _mm_set1_ps(b[i + 2])));
24
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_4, _mm_set1_ps(b[i + 3])));
25
_mm_storeu_ps(&dest[i], r_col);
26
}
27
}
28
29
#elif PPSSPP_ARCH(ARM_NEON)
30
31
#if defined(_MSC_VER) && PPSSPP_ARCH(ARM64)
32
#include <arm64_neon.h>
33
#else
34
#include <arm_neon.h>
35
#endif
36
37
#if PPSSPP_ARCH(ARM)
38
static inline float32x4_t vfmaq_laneq_f32(float32x4_t _s, float32x4_t _a, float32x4_t _b, int lane) {
39
if (lane == 0) return vmlaq_lane_f32(_s, _a, vget_low_f32(_b), 0);
40
else if (lane == 1) return vmlaq_lane_f32(_s, _a, vget_low_f32(_b), 1);
41
else if (lane == 2) return vmlaq_lane_f32(_s, _a, vget_high_f32(_b), 0);
42
else if (lane == 3) return vmlaq_lane_f32(_s, _a, vget_high_f32(_b), 1);
43
else return vdupq_n_f32(0.f);
44
}
45
#endif
46
47
// From https://developer.arm.com/documentation/102467/0100/Matrix-multiplication-example
48
void fast_matrix_mul_4x4_neon(float *C, const float *A, const float *B) {
49
// these are the columns A
50
float32x4_t A0;
51
float32x4_t A1;
52
float32x4_t A2;
53
float32x4_t A3;
54
55
// these are the columns B
56
float32x4_t B0;
57
float32x4_t B1;
58
float32x4_t B2;
59
float32x4_t B3;
60
61
// these are the columns C
62
float32x4_t C0;
63
float32x4_t C1;
64
float32x4_t C2;
65
float32x4_t C3;
66
67
A0 = vld1q_f32(A);
68
A1 = vld1q_f32(A + 4);
69
A2 = vld1q_f32(A + 8);
70
A3 = vld1q_f32(A + 12);
71
72
// Multiply accumulate in 4x1 blocks, i.e. each column in C
73
B0 = vld1q_f32(B);
74
C0 = vmulq_laneq_f32(A0, B0, 0);
75
C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
76
C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
77
C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
78
vst1q_f32(C, C0);
79
80
B1 = vld1q_f32(B + 4);
81
C1 = vmulq_laneq_f32(A0, B1, 0);
82
C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
83
C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
84
C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
85
vst1q_f32(C + 4, C1);
86
87
B2 = vld1q_f32(B + 8);
88
C2 = vmulq_laneq_f32(A0, B2, 0);
89
C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
90
C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
91
C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
92
vst1q_f32(C + 8, C2);
93
94
B3 = vld1q_f32(B + 12);
95
C3 = vmulq_laneq_f32(A0, B3, 0);
96
C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
97
C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
98
C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
99
vst1q_f32(C + 12, C3);
100
}
101
102
#else
103
104
#define xx 0
105
#define xy 1
106
#define xz 2
107
#define xw 3
108
#define yx 4
109
#define yy 5
110
#define yz 6
111
#define yw 7
112
#define zx 8
113
#define zy 9
114
#define zz 10
115
#define zw 11
116
#define wx 12
117
#define wy 13
118
#define wz 14
119
#define ww 15
120
121
void fast_matrix_mul_4x4_c(float *dest, const float *a, const float *b) {
122
dest[xx] = b[xx] * a[xx] + b[xy] * a[yx] + b[xz] * a[zx] + b[xw] * a[wx];
123
dest[xy] = b[xx] * a[xy] + b[xy] * a[yy] + b[xz] * a[zy] + b[xw] * a[wy];
124
dest[xz] = b[xx] * a[xz] + b[xy] * a[yz] + b[xz] * a[zz] + b[xw] * a[wz];
125
dest[xw] = b[xx] * a[xw] + b[xy] * a[yw] + b[xz] * a[zw] + b[xw] * a[ww];
126
127
dest[yx] = b[yx] * a[xx] + b[yy] * a[yx] + b[yz] * a[zx] + b[yw] * a[wx];
128
dest[yy] = b[yx] * a[xy] + b[yy] * a[yy] + b[yz] * a[zy] + b[yw] * a[wy];
129
dest[yz] = b[yx] * a[xz] + b[yy] * a[yz] + b[yz] * a[zz] + b[yw] * a[wz];
130
dest[yw] = b[yx] * a[xw] + b[yy] * a[yw] + b[yz] * a[zw] + b[yw] * a[ww];
131
132
dest[zx] = b[zx] * a[xx] + b[zy] * a[yx] + b[zz] * a[zx] + b[zw] * a[wx];
133
dest[zy] = b[zx] * a[xy] + b[zy] * a[yy] + b[zz] * a[zy] + b[zw] * a[wy];
134
dest[zz] = b[zx] * a[xz] + b[zy] * a[yz] + b[zz] * a[zz] + b[zw] * a[wz];
135
dest[zw] = b[zx] * a[xw] + b[zy] * a[yw] + b[zz] * a[zw] + b[zw] * a[ww];
136
137
dest[wx] = b[wx] * a[xx] + b[wy] * a[yx] + b[wz] * a[zx] + b[ww] * a[wx];
138
dest[wy] = b[wx] * a[xy] + b[wy] * a[yy] + b[wz] * a[zy] + b[ww] * a[wy];
139
dest[wz] = b[wx] * a[xz] + b[wy] * a[yz] + b[wz] * a[zz] + b[ww] * a[wz];
140
dest[ww] = b[wx] * a[xw] + b[wy] * a[yw] + b[wz] * a[zw] + b[ww] * a[ww];
141
}
142
143
#endif
144
145