Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
godotengine
GitHub Repository: godotengine/godot
Path: blob/master/thirdparty/basis_universal/encoder/basisu_math.h
9903 views
1
// File: basisu_math.h
2
#pragma once
3
4
// TODO: Would prefer this in the basisu namespace, but to avoid collisions with the existing vec/matrix classes I'm placing this in "bu_math".
5
namespace bu_math
6
{
7
// Cross-platform 1.0f/sqrtf(x) approximation. See https://en.wikipedia.org/wiki/Fast_inverse_square_root#cite_note-37.
8
// Would prefer using SSE1 etc. but that would require implementing multiple versions and platform divergence (needing more testing).
9
BASISU_FORCE_INLINE float inv_sqrt(float v)
10
{
11
union
12
{
13
float flt;
14
uint32_t ui;
15
} un;
16
17
un.flt = v;
18
un.ui = 0x5F1FFFF9UL - (un.ui >> 1);
19
20
return 0.703952253f * un.flt * (2.38924456f - v * (un.flt * un.flt));
21
}
22
23
inline float smoothstep(float edge0, float edge1, float x)
24
{
25
assert(edge1 != edge0);
26
27
// Scale, and clamp x to 0..1 range
28
x = basisu::saturate((x - edge0) / (edge1 - edge0));
29
30
return x * x * (3.0f - 2.0f * x);
31
}
32
33
template <uint32_t N, typename T>
34
class vec : public basisu::rel_ops<vec<N, T> >
35
{
36
public:
37
typedef T scalar_type;
38
enum
39
{
40
num_elements = N
41
};
42
43
inline vec()
44
{
45
}
46
47
inline vec(basisu::eClear)
48
{
49
clear();
50
}
51
52
inline vec(const vec& other)
53
{
54
for (uint32_t i = 0; i < N; i++)
55
m_s[i] = other.m_s[i];
56
}
57
58
template <uint32_t O, typename U>
59
inline vec(const vec<O, U>& other)
60
{
61
set(other);
62
}
63
64
template <uint32_t O, typename U>
65
inline vec(const vec<O, U>& other, T w)
66
{
67
*this = other;
68
m_s[N - 1] = w;
69
}
70
71
template <typename... Args>
72
inline explicit vec(Args... args)
73
{
74
static_assert(sizeof...(args) <= N);
75
set(args...);
76
}
77
78
inline void clear()
79
{
80
if (N > 4)
81
memset(m_s, 0, sizeof(m_s));
82
else
83
{
84
for (uint32_t i = 0; i < N; i++)
85
m_s[i] = 0;
86
}
87
}
88
89
template <uint32_t ON, typename OT>
90
inline vec& set(const vec<ON, OT>& other)
91
{
92
if ((void*)this == (void*)&other)
93
return *this;
94
const uint32_t m = basisu::minimum(N, ON);
95
uint32_t i;
96
for (i = 0; i < m; i++)
97
m_s[i] = static_cast<T>(other[i]);
98
for (; i < N; i++)
99
m_s[i] = 0;
100
return *this;
101
}
102
103
inline vec& set_component(uint32_t index, T val)
104
{
105
assert(index < N);
106
m_s[index] = val;
107
return *this;
108
}
109
110
inline vec& set_all(T val)
111
{
112
for (uint32_t i = 0; i < N; i++)
113
m_s[i] = val;
114
return *this;
115
}
116
117
template <typename... Args>
118
inline vec& set(Args... args)
119
{
120
static_assert(sizeof...(args) <= N);
121
122
// Initialize using parameter pack expansion
123
T values[] = { static_cast<T>(args)... };
124
125
// Special case if setting with a scalar
126
if (sizeof...(args) == 1)
127
{
128
set_all(values[0]);
129
}
130
else
131
{
132
// Copy the values into the vector
133
for (std::size_t i = 0; i < sizeof...(args); ++i)
134
{
135
m_s[i] = values[i];
136
}
137
138
// Zero-initialize the remaining elements (if any)
139
if (sizeof...(args) < N)
140
{
141
std::fill(m_s + sizeof...(args), m_s + N, T{});
142
}
143
}
144
145
return *this;
146
}
147
148
inline vec& set(const T* pValues)
149
{
150
for (uint32_t i = 0; i < N; i++)
151
m_s[i] = pValues[i];
152
return *this;
153
}
154
155
template <uint32_t ON, typename OT>
156
inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i)
157
{
158
return set(static_cast<T>(other[i]));
159
}
160
161
template <uint32_t ON, typename OT>
162
inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j)
163
{
164
return set(static_cast<T>(other[i]), static_cast<T>(other[j]));
165
}
166
167
template <uint32_t ON, typename OT>
168
inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j, uint32_t k)
169
{
170
return set(static_cast<T>(other[i]), static_cast<T>(other[j]), static_cast<T>(other[k]));
171
}
172
173
template <uint32_t ON, typename OT>
174
inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j, uint32_t k, uint32_t l)
175
{
176
return set(static_cast<T>(other[i]), static_cast<T>(other[j]), static_cast<T>(other[k]), static_cast<T>(other[l]));
177
}
178
179
inline vec& operator=(const vec& rhs)
180
{
181
if (this != &rhs)
182
{
183
for (uint32_t i = 0; i < N; i++)
184
m_s[i] = rhs.m_s[i];
185
}
186
return *this;
187
}
188
189
template <uint32_t O, typename U>
190
inline vec& operator=(const vec<O, U>& other)
191
{
192
if ((void*)this == (void*)&other)
193
return *this;
194
195
uint32_t s = basisu::minimum(N, O);
196
197
uint32_t i;
198
for (i = 0; i < s; i++)
199
m_s[i] = static_cast<T>(other[i]);
200
201
for (; i < N; i++)
202
m_s[i] = 0;
203
204
return *this;
205
}
206
207
inline bool operator==(const vec& rhs) const
208
{
209
for (uint32_t i = 0; i < N; i++)
210
if (!(m_s[i] == rhs.m_s[i]))
211
return false;
212
return true;
213
}
214
215
inline bool operator<(const vec& rhs) const
216
{
217
for (uint32_t i = 0; i < N; i++)
218
{
219
if (m_s[i] < rhs.m_s[i])
220
return true;
221
else if (!(m_s[i] == rhs.m_s[i]))
222
return false;
223
}
224
225
return false;
226
}
227
228
inline T operator[](uint32_t i) const
229
{
230
assert(i < N);
231
return m_s[i];
232
}
233
234
inline T& operator[](uint32_t i)
235
{
236
assert(i < N);
237
return m_s[i];
238
}
239
240
template <uint32_t index>
241
inline uint64_t get_component_bits_as_uint() const
242
{
243
static_assert(index < N);
244
static_assert((sizeof(T) == sizeof(uint16_t)) || (sizeof(T) == sizeof(uint32_t)) || (sizeof(T) == sizeof(uint64_t)), "Unsupported type");
245
246
if (sizeof(T) == sizeof(uint16_t))
247
return *reinterpret_cast<const uint16_t*>(&m_s[index]);
248
else if (sizeof(T) == sizeof(uint32_t))
249
return *reinterpret_cast<const uint32_t*>(&m_s[index]);
250
else if (sizeof(T) == sizeof(uint64_t))
251
return *reinterpret_cast<const uint64_t*>(&m_s[index]);
252
else
253
{
254
assert(0);
255
return 0;
256
}
257
}
258
259
inline T get_x(void) const
260
{
261
return m_s[0];
262
}
263
inline T get_y(void) const
264
{
265
static_assert(N >= 2);
266
return m_s[1];
267
}
268
inline T get_z(void) const
269
{
270
static_assert(N >= 3);
271
return m_s[2];
272
}
273
inline T get_w(void) const
274
{
275
static_assert(N >= 4);
276
return m_s[3];
277
}
278
279
inline vec get_x_vector() const
280
{
281
return broadcast<0>();
282
}
283
inline vec get_y_vector() const
284
{
285
return broadcast<1>();
286
}
287
inline vec get_z_vector() const
288
{
289
return broadcast<2>();
290
}
291
inline vec get_w_vector() const
292
{
293
return broadcast<3>();
294
}
295
296
inline T get_component(uint32_t i) const
297
{
298
return (*this)[i];
299
}
300
301
inline vec& set_x(T v)
302
{
303
m_s[0] = v;
304
return *this;
305
}
306
inline vec& set_y(T v)
307
{
308
static_assert(N >= 2);
309
m_s[1] = v;
310
return *this;
311
}
312
inline vec& set_z(T v)
313
{
314
static_assert(N >= 3);
315
m_s[2] = v;
316
return *this;
317
}
318
inline vec& set_w(T v)
319
{
320
static_assert(N >= 4);
321
m_s[3] = v;
322
return *this;
323
}
324
325
inline const T* get_ptr() const
326
{
327
return reinterpret_cast<const T*>(&m_s[0]);
328
}
329
inline T* get_ptr()
330
{
331
return reinterpret_cast<T*>(&m_s[0]);
332
}
333
334
inline vec as_point() const
335
{
336
vec result(*this);
337
result[N - 1] = 1;
338
return result;
339
}
340
341
inline vec as_dir() const
342
{
343
vec result(*this);
344
result[N - 1] = 0;
345
return result;
346
}
347
348
inline vec<2, T> select2(uint32_t i, uint32_t j) const
349
{
350
assert((i < N) && (j < N));
351
return vec<2, T>(m_s[i], m_s[j]);
352
}
353
354
inline vec<3, T> select3(uint32_t i, uint32_t j, uint32_t k) const
355
{
356
assert((i < N) && (j < N) && (k < N));
357
return vec<3, T>(m_s[i], m_s[j], m_s[k]);
358
}
359
360
inline vec<4, T> select4(uint32_t i, uint32_t j, uint32_t k, uint32_t l) const
361
{
362
assert((i < N) && (j < N) && (k < N) && (l < N));
363
return vec<4, T>(m_s[i], m_s[j], m_s[k], m_s[l]);
364
}
365
366
inline bool is_dir() const
367
{
368
return m_s[N - 1] == 0;
369
}
370
inline bool is_vector() const
371
{
372
return is_dir();
373
}
374
inline bool is_point() const
375
{
376
return m_s[N - 1] == 1;
377
}
378
379
inline vec project() const
380
{
381
vec result(*this);
382
if (result[N - 1])
383
result /= result[N - 1];
384
return result;
385
}
386
387
inline vec broadcast(unsigned i) const
388
{
389
return vec((*this)[i]);
390
}
391
392
template <uint32_t i>
393
inline vec broadcast() const
394
{
395
return vec((*this)[i]);
396
}
397
398
inline vec swizzle(uint32_t i, uint32_t j) const
399
{
400
return vec((*this)[i], (*this)[j]);
401
}
402
403
inline vec swizzle(uint32_t i, uint32_t j, uint32_t k) const
404
{
405
return vec((*this)[i], (*this)[j], (*this)[k]);
406
}
407
408
inline vec swizzle(uint32_t i, uint32_t j, uint32_t k, uint32_t l) const
409
{
410
return vec((*this)[i], (*this)[j], (*this)[k], (*this)[l]);
411
}
412
413
inline vec operator-() const
414
{
415
vec result;
416
for (uint32_t i = 0; i < N; i++)
417
result.m_s[i] = -m_s[i];
418
return result;
419
}
420
421
inline vec operator+() const
422
{
423
return *this;
424
}
425
426
inline vec& operator+=(const vec& other)
427
{
428
for (uint32_t i = 0; i < N; i++)
429
m_s[i] += other.m_s[i];
430
return *this;
431
}
432
433
inline vec& operator-=(const vec& other)
434
{
435
for (uint32_t i = 0; i < N; i++)
436
m_s[i] -= other.m_s[i];
437
return *this;
438
}
439
440
inline vec& operator*=(const vec& other)
441
{
442
for (uint32_t i = 0; i < N; i++)
443
m_s[i] *= other.m_s[i];
444
return *this;
445
}
446
447
inline vec& operator/=(const vec& other)
448
{
449
for (uint32_t i = 0; i < N; i++)
450
m_s[i] /= other.m_s[i];
451
return *this;
452
}
453
454
inline vec& operator*=(T s)
455
{
456
for (uint32_t i = 0; i < N; i++)
457
m_s[i] *= s;
458
return *this;
459
}
460
461
inline vec& operator/=(T s)
462
{
463
for (uint32_t i = 0; i < N; i++)
464
m_s[i] /= s;
465
return *this;
466
}
467
468
friend inline vec operator*(const vec& lhs, T val)
469
{
470
vec result;
471
for (uint32_t i = 0; i < N; i++)
472
result.m_s[i] = lhs.m_s[i] * val;
473
return result;
474
}
475
476
friend inline vec operator*(T val, const vec& rhs)
477
{
478
vec result;
479
for (uint32_t i = 0; i < N; i++)
480
result.m_s[i] = val * rhs.m_s[i];
481
return result;
482
}
483
484
friend inline vec operator/(const vec& lhs, const vec& rhs)
485
{
486
vec result;
487
for (uint32_t i = 0; i < N; i++)
488
result.m_s[i] = lhs.m_s[i] / rhs.m_s[i];
489
return result;
490
}
491
492
friend inline vec operator/(const vec& lhs, T val)
493
{
494
vec result;
495
for (uint32_t i = 0; i < N; i++)
496
result.m_s[i] = lhs.m_s[i] / val;
497
return result;
498
}
499
500
friend inline vec operator+(const vec& lhs, const vec& rhs)
501
{
502
vec result;
503
for (uint32_t i = 0; i < N; i++)
504
result.m_s[i] = lhs.m_s[i] + rhs.m_s[i];
505
return result;
506
}
507
508
friend inline vec operator-(const vec& lhs, const vec& rhs)
509
{
510
vec result;
511
for (uint32_t i = 0; i < N; i++)
512
result.m_s[i] = lhs.m_s[i] - rhs.m_s[i];
513
return result;
514
}
515
516
static inline vec<3, T> cross2(const vec& a, const vec& b)
517
{
518
static_assert(N >= 2);
519
return vec<3, T>(0, 0, a[0] * b[1] - a[1] * b[0]);
520
}
521
522
inline vec<3, T> cross2(const vec& b) const
523
{
524
return cross2(*this, b);
525
}
526
527
static inline vec<3, T> cross3(const vec& a, const vec& b)
528
{
529
static_assert(N >= 3);
530
return vec<3, T>(a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]);
531
}
532
533
inline vec<3, T> cross3(const vec& b) const
534
{
535
return cross3(*this, b);
536
}
537
538
static inline vec<3, T> cross(const vec& a, const vec& b)
539
{
540
static_assert(N >= 2);
541
542
if (N == 2)
543
return cross2(a, b);
544
else
545
return cross3(a, b);
546
}
547
548
inline vec<3, T> cross(const vec& b) const
549
{
550
static_assert(N >= 2);
551
return cross(*this, b);
552
}
553
554
inline T dot(const vec& rhs) const
555
{
556
return dot(*this, rhs);
557
}
558
559
inline vec dot_vector(const vec& rhs) const
560
{
561
return vec(dot(*this, rhs));
562
}
563
564
static inline T dot(const vec& lhs, const vec& rhs)
565
{
566
T result = lhs.m_s[0] * rhs.m_s[0];
567
for (uint32_t i = 1; i < N; i++)
568
result += lhs.m_s[i] * rhs.m_s[i];
569
return result;
570
}
571
572
inline T dot2(const vec& rhs) const
573
{
574
static_assert(N >= 2);
575
return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1];
576
}
577
578
inline T dot3(const vec& rhs) const
579
{
580
static_assert(N >= 3);
581
return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1] + m_s[2] * rhs.m_s[2];
582
}
583
584
inline T dot4(const vec& rhs) const
585
{
586
static_assert(N >= 4);
587
return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1] + m_s[2] * rhs.m_s[2] + m_s[3] * rhs.m_s[3];
588
}
589
590
inline T norm(void) const
591
{
592
T sum = m_s[0] * m_s[0];
593
for (uint32_t i = 1; i < N; i++)
594
sum += m_s[i] * m_s[i];
595
return sum;
596
}
597
598
inline T length(void) const
599
{
600
return sqrt(norm());
601
}
602
603
inline T squared_distance(const vec& rhs) const
604
{
605
T dist2 = 0;
606
for (uint32_t i = 0; i < N; i++)
607
{
608
T d = m_s[i] - rhs.m_s[i];
609
dist2 += d * d;
610
}
611
return dist2;
612
}
613
614
inline T squared_distance(const vec& rhs, T early_out) const
615
{
616
T dist2 = 0;
617
for (uint32_t i = 0; i < N; i++)
618
{
619
T d = m_s[i] - rhs.m_s[i];
620
dist2 += d * d;
621
if (dist2 > early_out)
622
break;
623
}
624
return dist2;
625
}
626
627
inline T distance(const vec& rhs) const
628
{
629
T dist2 = 0;
630
for (uint32_t i = 0; i < N; i++)
631
{
632
T d = m_s[i] - rhs.m_s[i];
633
dist2 += d * d;
634
}
635
return sqrt(dist2);
636
}
637
638
inline vec inverse() const
639
{
640
vec result;
641
for (uint32_t i = 0; i < N; i++)
642
result[i] = m_s[i] ? (1.0f / m_s[i]) : 0;
643
return result;
644
}
645
646
// returns squared length (norm)
647
inline double normalize(const vec* pDefaultVec = NULL)
648
{
649
double n = m_s[0] * m_s[0];
650
for (uint32_t i = 1; i < N; i++)
651
n += m_s[i] * m_s[i];
652
653
if (n != 0)
654
*this *= static_cast<T>(1.0f / sqrt(n));
655
else if (pDefaultVec)
656
*this = *pDefaultVec;
657
return n;
658
}
659
660
inline double normalize3(const vec* pDefaultVec = NULL)
661
{
662
static_assert(N >= 3);
663
664
double n = m_s[0] * m_s[0] + m_s[1] * m_s[1] + m_s[2] * m_s[2];
665
666
if (n != 0)
667
*this *= static_cast<T>((1.0f / sqrt(n)));
668
else if (pDefaultVec)
669
*this = *pDefaultVec;
670
return n;
671
}
672
673
inline vec& normalize_in_place(const vec* pDefaultVec = NULL)
674
{
675
normalize(pDefaultVec);
676
return *this;
677
}
678
679
inline vec& normalize3_in_place(const vec* pDefaultVec = NULL)
680
{
681
normalize3(pDefaultVec);
682
return *this;
683
}
684
685
inline vec get_normalized(const vec* pDefaultVec = NULL) const
686
{
687
vec result(*this);
688
result.normalize(pDefaultVec);
689
return result;
690
}
691
692
inline vec get_normalized3(const vec* pDefaultVec = NULL) const
693
{
694
vec result(*this);
695
result.normalize3(pDefaultVec);
696
return result;
697
}
698
699
inline vec& clamp(T l, T h)
700
{
701
for (uint32_t i = 0; i < N; i++)
702
m_s[i] = static_cast<T>(basisu::clamp(m_s[i], l, h));
703
return *this;
704
}
705
706
inline vec& saturate()
707
{
708
return clamp(0.0f, 1.0f);
709
}
710
711
inline vec& clamp(const vec& l, const vec& h)
712
{
713
for (uint32_t i = 0; i < N; i++)
714
m_s[i] = static_cast<T>(basisu::clamp(m_s[i], l[i], h[i]));
715
return *this;
716
}
717
718
inline bool is_within_bounds(const vec& l, const vec& h) const
719
{
720
for (uint32_t i = 0; i < N; i++)
721
if ((m_s[i] < l[i]) || (m_s[i] > h[i]))
722
return false;
723
724
return true;
725
}
726
727
inline bool is_within_bounds(T l, T h) const
728
{
729
for (uint32_t i = 0; i < N; i++)
730
if ((m_s[i] < l) || (m_s[i] > h))
731
return false;
732
733
return true;
734
}
735
736
inline uint32_t get_major_axis(void) const
737
{
738
T m = fabs(m_s[0]);
739
uint32_t r = 0;
740
for (uint32_t i = 1; i < N; i++)
741
{
742
const T c = fabs(m_s[i]);
743
if (c > m)
744
{
745
m = c;
746
r = i;
747
}
748
}
749
return r;
750
}
751
752
inline uint32_t get_minor_axis(void) const
753
{
754
T m = fabs(m_s[0]);
755
uint32_t r = 0;
756
for (uint32_t i = 1; i < N; i++)
757
{
758
const T c = fabs(m_s[i]);
759
if (c < m)
760
{
761
m = c;
762
r = i;
763
}
764
}
765
return r;
766
}
767
768
inline void get_projection_axes(uint32_t& u, uint32_t& v) const
769
{
770
const int axis = get_major_axis();
771
if (m_s[axis] < 0.0f)
772
{
773
v = basisu::next_wrap<uint32_t>(axis, N);
774
u = basisu::next_wrap<uint32_t>(v, N);
775
}
776
else
777
{
778
u = basisu::next_wrap<uint32_t>(axis, N);
779
v = basisu::next_wrap<uint32_t>(u, N);
780
}
781
}
782
783
inline T get_absolute_minimum(void) const
784
{
785
T result = fabs(m_s[0]);
786
for (uint32_t i = 1; i < N; i++)
787
result = basisu::minimum(result, fabs(m_s[i]));
788
return result;
789
}
790
791
inline T get_absolute_maximum(void) const
792
{
793
T result = fabs(m_s[0]);
794
for (uint32_t i = 1; i < N; i++)
795
result = basisu::maximum(result, fabs(m_s[i]));
796
return result;
797
}
798
799
inline T get_minimum(void) const
800
{
801
T result = m_s[0];
802
for (uint32_t i = 1; i < N; i++)
803
result = basisu::minimum(result, m_s[i]);
804
return result;
805
}
806
807
inline T get_maximum(void) const
808
{
809
T result = m_s[0];
810
for (uint32_t i = 1; i < N; i++)
811
result = basisu::maximum(result, m_s[i]);
812
return result;
813
}
814
815
inline vec& remove_unit_direction(const vec& dir)
816
{
817
*this -= (dot(dir) * dir);
818
return *this;
819
}
820
821
inline vec get_remove_unit_direction(const vec& dir) const
822
{
823
return *this - (dot(dir) * dir);
824
}
825
826
inline bool all_less(const vec& b) const
827
{
828
for (uint32_t i = 0; i < N; i++)
829
if (m_s[i] >= b.m_s[i])
830
return false;
831
return true;
832
}
833
834
inline bool all_less_equal(const vec& b) const
835
{
836
for (uint32_t i = 0; i < N; i++)
837
if (m_s[i] > b.m_s[i])
838
return false;
839
return true;
840
}
841
842
inline bool all_greater(const vec& b) const
843
{
844
for (uint32_t i = 0; i < N; i++)
845
if (m_s[i] <= b.m_s[i])
846
return false;
847
return true;
848
}
849
850
inline bool all_greater_equal(const vec& b) const
851
{
852
for (uint32_t i = 0; i < N; i++)
853
if (m_s[i] < b.m_s[i])
854
return false;
855
return true;
856
}
857
858
inline vec negate_xyz() const
859
{
860
vec ret;
861
862
ret[0] = -m_s[0];
863
if (N >= 2)
864
ret[1] = -m_s[1];
865
if (N >= 3)
866
ret[2] = -m_s[2];
867
868
for (uint32_t i = 3; i < N; i++)
869
ret[i] = m_s[i];
870
871
return ret;
872
}
873
874
inline vec& invert()
875
{
876
for (uint32_t i = 0; i < N; i++)
877
if (m_s[i] != 0.0f)
878
m_s[i] = 1.0f / m_s[i];
879
return *this;
880
}
881
882
inline scalar_type perp_dot(const vec& b) const
883
{
884
static_assert(N == 2);
885
return m_s[0] * b.m_s[1] - m_s[1] * b.m_s[0];
886
}
887
888
inline vec perp() const
889
{
890
static_assert(N == 2);
891
return vec(-m_s[1], m_s[0]);
892
}
893
894
inline vec get_floor() const
895
{
896
vec result;
897
for (uint32_t i = 0; i < N; i++)
898
result[i] = floor(m_s[i]);
899
return result;
900
}
901
902
inline vec get_ceil() const
903
{
904
vec result;
905
for (uint32_t i = 0; i < N; i++)
906
result[i] = ceil(m_s[i]);
907
return result;
908
}
909
910
inline T get_total() const
911
{
912
T res = m_s[0];
913
for (uint32_t i = 1; i < N; i++)
914
res += m_s[i];
915
return res;
916
}
917
918
// static helper methods
919
920
static inline vec mul_components(const vec& lhs, const vec& rhs)
921
{
922
vec result;
923
for (uint32_t i = 0; i < N; i++)
924
result[i] = lhs.m_s[i] * rhs.m_s[i];
925
return result;
926
}
927
928
static inline vec mul_add_components(const vec& a, const vec& b, const vec& c)
929
{
930
vec result;
931
for (uint32_t i = 0; i < N; i++)
932
result[i] = a.m_s[i] * b.m_s[i] + c.m_s[i];
933
return result;
934
}
935
936
static inline vec make_axis(uint32_t i)
937
{
938
vec result;
939
result.clear();
940
result[i] = 1;
941
return result;
942
}
943
944
static inline vec equals_mask(const vec& a, const vec& b)
945
{
946
vec ret;
947
for (uint32_t i = 0; i < N; i++)
948
ret[i] = (a[i] == b[i]);
949
return ret;
950
}
951
952
static inline vec not_equals_mask(const vec& a, const vec& b)
953
{
954
vec ret;
955
for (uint32_t i = 0; i < N; i++)
956
ret[i] = (a[i] != b[i]);
957
return ret;
958
}
959
960
static inline vec less_mask(const vec& a, const vec& b)
961
{
962
vec ret;
963
for (uint32_t i = 0; i < N; i++)
964
ret[i] = (a[i] < b[i]);
965
return ret;
966
}
967
968
static inline vec less_equals_mask(const vec& a, const vec& b)
969
{
970
vec ret;
971
for (uint32_t i = 0; i < N; i++)
972
ret[i] = (a[i] <= b[i]);
973
return ret;
974
}
975
976
static inline vec greater_equals_mask(const vec& a, const vec& b)
977
{
978
vec ret;
979
for (uint32_t i = 0; i < N; i++)
980
ret[i] = (a[i] >= b[i]);
981
return ret;
982
}
983
984
static inline vec greater_mask(const vec& a, const vec& b)
985
{
986
vec ret;
987
for (uint32_t i = 0; i < N; i++)
988
ret[i] = (a[i] > b[i]);
989
return ret;
990
}
991
992
static inline vec component_max(const vec& a, const vec& b)
993
{
994
vec ret;
995
for (uint32_t i = 0; i < N; i++)
996
ret.m_s[i] = basisu::maximum(a.m_s[i], b.m_s[i]);
997
return ret;
998
}
999
1000
static inline vec component_min(const vec& a, const vec& b)
1001
{
1002
vec ret;
1003
for (uint32_t i = 0; i < N; i++)
1004
ret.m_s[i] = basisu::minimum(a.m_s[i], b.m_s[i]);
1005
return ret;
1006
}
1007
1008
static inline vec lerp(const vec& a, const vec& b, float t)
1009
{
1010
vec ret;
1011
for (uint32_t i = 0; i < N; i++)
1012
ret.m_s[i] = a.m_s[i] + (b.m_s[i] - a.m_s[i]) * t;
1013
return ret;
1014
}
1015
1016
static inline bool equal_tol(const vec& a, const vec& b, float t)
1017
{
1018
for (uint32_t i = 0; i < N; i++)
1019
if (!basisu::equal_tol(a.m_s[i], b.m_s[i], t))
1020
return false;
1021
return true;
1022
}
1023
1024
inline bool equal_tol(const vec& b, float t) const
1025
{
1026
return equal_tol(*this, b, t);
1027
}
1028
1029
static inline vec make_random(basisu::rand& r, float l, float h)
1030
{
1031
vec result;
1032
for (uint32_t i = 0; i < N; i++)
1033
result[i] = r.frand(l, h);
1034
return result;
1035
}
1036
1037
static inline vec make_random(basisu::rand& r, const vec& l, const vec& h)
1038
{
1039
vec result;
1040
for (uint32_t i = 0; i < N; i++)
1041
result[i] = r.frand(l[i], h[i]);
1042
return result;
1043
}
1044
1045
void print() const
1046
{
1047
for (uint32_t c = 0; c < N; c++)
1048
printf("%3.3f ", (*this)[c]);
1049
printf("\n");
1050
}
1051
1052
protected:
1053
T m_s[N];
1054
};
1055
1056
typedef vec<1, double> vec1D;
1057
typedef vec<2, double> vec2D;
1058
typedef vec<3, double> vec3D;
1059
typedef vec<4, double> vec4D;
1060
1061
typedef vec<1, float> vec1F;
1062
1063
typedef vec<2, float> vec2F;
1064
typedef basisu::vector<vec2F> vec2F_array;
1065
1066
typedef vec<3, float> vec3F;
1067
typedef basisu::vector<vec3F> vec3F_array;
1068
1069
typedef vec<4, float> vec4F;
1070
typedef basisu::vector<vec4F> vec4F_array;
1071
1072
typedef vec<2, uint32_t> vec2U;
1073
typedef vec<3, uint32_t> vec3U;
1074
typedef vec<2, int> vec2I;
1075
typedef vec<3, int> vec3I;
1076
typedef vec<4, int> vec4I;
1077
1078
typedef vec<2, int16_t> vec2I16;
1079
typedef vec<3, int16_t> vec3I16;
1080
1081
inline vec2F rotate_point_2D(const vec2F& p, float rad)
1082
{
1083
float c = cosf(rad);
1084
float s = sinf(rad);
1085
1086
float x = p[0];
1087
float y = p[1];
1088
1089
return vec2F(x * c - y * s, x * s + y * c);
1090
}
1091
1092
//--------------------------------------------------------------
1093
1094
// Matrix/vector cheat sheet, because confusingly, depending on how matrices are stored in memory people can use opposite definitions of "rows", "cols", etc.
1095
// See http://www.mindcontrol.org/~hplus/graphics/matrix-layout.html
1096
//
1097
// So in this simple row-major general matrix class:
1098
// matrix=[NumRows][NumCols] or [R][C], i.e. a 3x3 matrix stored in memory will appear as: R0C0, R0C1, R0C2, R1C0, R1C1, R1C2, etc.
1099
// Matrix multiplication: [R0,C0]*[R1,C1]=[R0,C1], C0 must equal R1
1100
//
1101
// In this class:
1102
// A "row vector" type is a vector of size # of matrix cols, 1xC. It's the vector type that is used to store the matrix rows.
1103
// A "col vector" type is a vector of size # of matrix rows, Rx1. It's a vector type large enough to hold each matrix column.
1104
//
1105
// Subrow/col vectors: last component is assumed to be either 0 (a "vector") or 1 (a "point")
1106
// "subrow vector": vector/point of size # cols-1, 1x(C-1)
1107
// "subcol vector": vector/point of size # rows-1, (R-1)x1
1108
//
1109
// D3D style:
1110
// vec*matrix, row vector on left (vec dotted against columns)
1111
// [1,4]*[4,4]=[1,4]
1112
// abcd * A B C D
1113
// A B C D
1114
// A B C D
1115
// A B C D
1116
// = e f g h
1117
//
1118
// Now confusingly, in the matrix transform method for vec*matrix below the vector's type is "col_vec", because col_vec will have the proper size for non-square matrices. But the vector on the left is written as row vector, argh.
1119
//
1120
//
1121
// OGL style:
1122
// matrix*vec, col vector on right (vec dotted against rows):
1123
// [4,4]*[4,1]=[4,1]
1124
//
1125
// A B C D * e = e
1126
// A B C D f f
1127
// A B C D g g
1128
// A B C D h h
1129
1130
template <class X, class Y, class Z>
1131
Z& matrix_mul_helper(Z& result, const X& lhs, const Y& rhs)
1132
{
1133
static_assert((int)Z::num_rows == (int)X::num_rows);
1134
static_assert((int)Z::num_cols == (int)Y::num_cols);
1135
static_assert((int)X::num_cols == (int)Y::num_rows);
1136
assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));
1137
for (int r = 0; r < X::num_rows; r++)
1138
for (int c = 0; c < Y::num_cols; c++)
1139
{
1140
typename Z::scalar_type s = lhs(r, 0) * rhs(0, c);
1141
for (uint32_t i = 1; i < X::num_cols; i++)
1142
s += lhs(r, i) * rhs(i, c);
1143
result(r, c) = s;
1144
}
1145
return result;
1146
}
1147
1148
template <class X, class Y, class Z>
1149
Z& matrix_mul_helper_transpose_lhs(Z& result, const X& lhs, const Y& rhs)
1150
{
1151
static_assert((int)Z::num_rows == (int)X::num_cols);
1152
static_assert((int)Z::num_cols == (int)Y::num_cols);
1153
static_assert((int)X::num_rows == (int)Y::num_rows);
1154
assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));
1155
for (int r = 0; r < X::num_cols; r++)
1156
for (int c = 0; c < Y::num_cols; c++)
1157
{
1158
typename Z::scalar_type s = lhs(0, r) * rhs(0, c);
1159
for (uint32_t i = 1; i < X::num_rows; i++)
1160
s += lhs(i, r) * rhs(i, c);
1161
result(r, c) = s;
1162
}
1163
return result;
1164
}
1165
1166
template <class X, class Y, class Z>
1167
Z& matrix_mul_helper_transpose_rhs(Z& result, const X& lhs, const Y& rhs)
1168
{
1169
static_assert((int)Z::num_rows == (int)X::num_rows);
1170
static_assert((int)Z::num_cols == (int)Y::num_rows);
1171
static_assert((int)X::num_cols == (int)Y::num_cols);
1172
assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));
1173
for (int r = 0; r < X::num_rows; r++)
1174
for (int c = 0; c < Y::num_rows; c++)
1175
{
1176
typename Z::scalar_type s = lhs(r, 0) * rhs(c, 0);
1177
for (uint32_t i = 1; i < X::num_cols; i++)
1178
s += lhs(r, i) * rhs(c, i);
1179
result(r, c) = s;
1180
}
1181
return result;
1182
}
1183
1184
template <uint32_t R, uint32_t C, typename T>
1185
class matrix
1186
{
1187
public:
1188
typedef T scalar_type;
1189
enum
1190
{
1191
num_rows = R,
1192
num_cols = C
1193
};
1194
1195
typedef vec<R, T> col_vec;
1196
typedef vec < (R > 1) ? (R - 1) : 0, T > subcol_vec;
1197
1198
typedef vec<C, T> row_vec;
1199
typedef vec < (C > 1) ? (C - 1) : 0, T > subrow_vec;
1200
1201
inline matrix()
1202
{
1203
}
1204
1205
inline matrix(basisu::eClear)
1206
{
1207
clear();
1208
}
1209
1210
inline matrix(basisu::eIdentity)
1211
{
1212
set_identity_matrix();
1213
}
1214
1215
inline matrix(const T* p)
1216
{
1217
set(p);
1218
}
1219
1220
inline matrix(const matrix& other)
1221
{
1222
for (uint32_t i = 0; i < R; i++)
1223
m_rows[i] = other.m_rows[i];
1224
}
1225
1226
inline matrix& operator=(const matrix& rhs)
1227
{
1228
if (this != &rhs)
1229
for (uint32_t i = 0; i < R; i++)
1230
m_rows[i] = rhs.m_rows[i];
1231
return *this;
1232
}
1233
1234
inline matrix(T val00, T val01,
1235
T val10, T val11)
1236
{
1237
set(val00, val01, val10, val11);
1238
}
1239
1240
inline matrix(T val00, T val01,
1241
T val10, T val11,
1242
T val20, T val21)
1243
{
1244
set(val00, val01, val10, val11, val20, val21);
1245
}
1246
1247
inline matrix(T val00, T val01, T val02,
1248
T val10, T val11, T val12,
1249
T val20, T val21, T val22)
1250
{
1251
set(val00, val01, val02, val10, val11, val12, val20, val21, val22);
1252
}
1253
1254
inline matrix(T val00, T val01, T val02, T val03,
1255
T val10, T val11, T val12, T val13,
1256
T val20, T val21, T val22, T val23,
1257
T val30, T val31, T val32, T val33)
1258
{
1259
set(val00, val01, val02, val03, val10, val11, val12, val13, val20, val21, val22, val23, val30, val31, val32, val33);
1260
}
1261
1262
inline matrix(T val00, T val01, T val02, T val03,
1263
T val10, T val11, T val12, T val13,
1264
T val20, T val21, T val22, T val23)
1265
{
1266
set(val00, val01, val02, val03, val10, val11, val12, val13, val20, val21, val22, val23);
1267
}
1268
1269
inline void set(const float* p)
1270
{
1271
for (uint32_t i = 0; i < R; i++)
1272
{
1273
m_rows[i].set(p);
1274
p += C;
1275
}
1276
}
1277
1278
inline void set(T val00, T val01,
1279
T val10, T val11)
1280
{
1281
m_rows[0].set(val00, val01);
1282
if (R >= 2)
1283
{
1284
m_rows[1].set(val10, val11);
1285
1286
for (uint32_t i = 2; i < R; i++)
1287
m_rows[i].clear();
1288
}
1289
}
1290
1291
inline void set(T val00, T val01,
1292
T val10, T val11,
1293
T val20, T val21)
1294
{
1295
m_rows[0].set(val00, val01);
1296
if (R >= 2)
1297
{
1298
m_rows[1].set(val10, val11);
1299
1300
if (R >= 3)
1301
{
1302
m_rows[2].set(val20, val21);
1303
1304
for (uint32_t i = 3; i < R; i++)
1305
m_rows[i].clear();
1306
}
1307
}
1308
}
1309
1310
inline void set(T val00, T val01, T val02,
1311
T val10, T val11, T val12,
1312
T val20, T val21, T val22)
1313
{
1314
m_rows[0].set(val00, val01, val02);
1315
if (R >= 2)
1316
{
1317
m_rows[1].set(val10, val11, val12);
1318
if (R >= 3)
1319
{
1320
m_rows[2].set(val20, val21, val22);
1321
1322
for (uint32_t i = 3; i < R; i++)
1323
m_rows[i].clear();
1324
}
1325
}
1326
}
1327
1328
inline void set(T val00, T val01, T val02, T val03,
1329
T val10, T val11, T val12, T val13,
1330
T val20, T val21, T val22, T val23,
1331
T val30, T val31, T val32, T val33)
1332
{
1333
m_rows[0].set(val00, val01, val02, val03);
1334
if (R >= 2)
1335
{
1336
m_rows[1].set(val10, val11, val12, val13);
1337
if (R >= 3)
1338
{
1339
m_rows[2].set(val20, val21, val22, val23);
1340
1341
if (R >= 4)
1342
{
1343
m_rows[3].set(val30, val31, val32, val33);
1344
1345
for (uint32_t i = 4; i < R; i++)
1346
m_rows[i].clear();
1347
}
1348
}
1349
}
1350
}
1351
1352
inline void set(T val00, T val01, T val02, T val03,
1353
T val10, T val11, T val12, T val13,
1354
T val20, T val21, T val22, T val23)
1355
{
1356
m_rows[0].set(val00, val01, val02, val03);
1357
if (R >= 2)
1358
{
1359
m_rows[1].set(val10, val11, val12, val13);
1360
if (R >= 3)
1361
{
1362
m_rows[2].set(val20, val21, val22, val23);
1363
1364
for (uint32_t i = 3; i < R; i++)
1365
m_rows[i].clear();
1366
}
1367
}
1368
}
1369
1370
inline uint32_t get_num_rows() const
1371
{
1372
return num_rows;
1373
}
1374
1375
inline uint32_t get_num_cols() const
1376
{
1377
return num_cols;
1378
}
1379
1380
inline uint32_t get_total_elements() const
1381
{
1382
return num_rows * num_cols;
1383
}
1384
1385
inline T operator()(uint32_t r, uint32_t c) const
1386
{
1387
assert((r < R) && (c < C));
1388
return m_rows[r][c];
1389
}
1390
1391
inline T& operator()(uint32_t r, uint32_t c)
1392
{
1393
assert((r < R) && (c < C));
1394
return m_rows[r][c];
1395
}
1396
1397
inline const row_vec& operator[](uint32_t r) const
1398
{
1399
assert(r < R);
1400
return m_rows[r];
1401
}
1402
1403
inline row_vec& operator[](uint32_t r)
1404
{
1405
assert(r < R);
1406
return m_rows[r];
1407
}
1408
1409
inline const row_vec& get_row(uint32_t r) const
1410
{
1411
return (*this)[r];
1412
}
1413
1414
inline row_vec& get_row(uint32_t r)
1415
{
1416
return (*this)[r];
1417
}
1418
1419
inline void set_row(uint32_t r, const row_vec& v)
1420
{
1421
(*this)[r] = v;
1422
}
1423
1424
inline col_vec get_col(uint32_t c) const
1425
{
1426
assert(c < C);
1427
col_vec result;
1428
for (uint32_t i = 0; i < R; i++)
1429
result[i] = m_rows[i][c];
1430
return result;
1431
}
1432
1433
inline void set_col(uint32_t c, const col_vec& col)
1434
{
1435
assert(c < C);
1436
for (uint32_t i = 0; i < R; i++)
1437
m_rows[i][c] = col[i];
1438
}
1439
1440
inline void set_col(uint32_t c, const subcol_vec& col)
1441
{
1442
assert(c < C);
1443
for (uint32_t i = 0; i < (R - 1); i++)
1444
m_rows[i][c] = col[i];
1445
1446
m_rows[R - 1][c] = 0.0f;
1447
}
1448
1449
inline const row_vec& get_translate() const
1450
{
1451
return m_rows[R - 1];
1452
}
1453
1454
inline matrix& set_translate(const row_vec& r)
1455
{
1456
m_rows[R - 1] = r;
1457
return *this;
1458
}
1459
1460
inline matrix& set_translate(const subrow_vec& r)
1461
{
1462
m_rows[R - 1] = row_vec(r).as_point();
1463
return *this;
1464
}
1465
1466
inline const T* get_ptr() const
1467
{
1468
return reinterpret_cast<const T*>(&m_rows[0]);
1469
}
1470
inline T* get_ptr()
1471
{
1472
return reinterpret_cast<T*>(&m_rows[0]);
1473
}
1474
1475
inline matrix& operator+=(const matrix& other)
1476
{
1477
for (uint32_t i = 0; i < R; i++)
1478
m_rows[i] += other.m_rows[i];
1479
return *this;
1480
}
1481
1482
inline matrix& operator-=(const matrix& other)
1483
{
1484
for (uint32_t i = 0; i < R; i++)
1485
m_rows[i] -= other.m_rows[i];
1486
return *this;
1487
}
1488
1489
inline matrix& operator*=(T val)
1490
{
1491
for (uint32_t i = 0; i < R; i++)
1492
m_rows[i] *= val;
1493
return *this;
1494
}
1495
1496
inline matrix& operator/=(T val)
1497
{
1498
for (uint32_t i = 0; i < R; i++)
1499
m_rows[i] /= val;
1500
return *this;
1501
}
1502
1503
inline matrix& operator*=(const matrix& other)
1504
{
1505
matrix result;
1506
matrix_mul_helper(result, *this, other);
1507
*this = result;
1508
return *this;
1509
}
1510
1511
friend inline matrix operator+(const matrix& lhs, const matrix& rhs)
1512
{
1513
matrix result;
1514
for (uint32_t i = 0; i < R; i++)
1515
result[i] = lhs.m_rows[i] + rhs.m_rows[i];
1516
return result;
1517
}
1518
1519
friend inline matrix operator-(const matrix& lhs, const matrix& rhs)
1520
{
1521
matrix result;
1522
for (uint32_t i = 0; i < R; i++)
1523
result[i] = lhs.m_rows[i] - rhs.m_rows[i];
1524
return result;
1525
}
1526
1527
friend inline matrix operator*(const matrix& lhs, T val)
1528
{
1529
matrix result;
1530
for (uint32_t i = 0; i < R; i++)
1531
result[i] = lhs.m_rows[i] * val;
1532
return result;
1533
}
1534
1535
friend inline matrix operator/(const matrix& lhs, T val)
1536
{
1537
matrix result;
1538
for (uint32_t i = 0; i < R; i++)
1539
result[i] = lhs.m_rows[i] / val;
1540
return result;
1541
}
1542
1543
friend inline matrix operator*(T val, const matrix& rhs)
1544
{
1545
matrix result;
1546
for (uint32_t i = 0; i < R; i++)
1547
result[i] = val * rhs.m_rows[i];
1548
return result;
1549
}
1550
1551
#if 0
1552
template<uint32_t R0, uint32_t C0, uint32_t R1, uint32_t C1, typename T>
1553
friend inline matrix operator*(const matrix<R0, C0, T>& lhs, const matrix<R1, C1, T>& rhs)
1554
{
1555
matrix<R0, C1, T> result;
1556
return matrix_mul_helper(result, lhs, rhs);
1557
}
1558
#endif
1559
friend inline matrix operator*(const matrix& lhs, const matrix& rhs)
1560
{
1561
matrix result;
1562
return matrix_mul_helper(result, lhs, rhs);
1563
}
1564
1565
friend inline row_vec operator*(const col_vec& a, const matrix& b)
1566
{
1567
return transform(a, b);
1568
}
1569
1570
inline matrix operator+() const
1571
{
1572
return *this;
1573
}
1574
1575
inline matrix operator-() const
1576
{
1577
matrix result;
1578
for (uint32_t i = 0; i < R; i++)
1579
result[i] = -m_rows[i];
1580
return result;
1581
}
1582
1583
inline matrix& clear()
1584
{
1585
for (uint32_t i = 0; i < R; i++)
1586
m_rows[i].clear();
1587
return *this;
1588
}
1589
1590
inline matrix& set_zero_matrix()
1591
{
1592
clear();
1593
return *this;
1594
}
1595
1596
inline matrix& set_identity_matrix()
1597
{
1598
for (uint32_t i = 0; i < R; i++)
1599
{
1600
m_rows[i].clear();
1601
m_rows[i][i] = 1.0f;
1602
}
1603
return *this;
1604
}
1605
1606
inline matrix& set_scale_matrix(float s)
1607
{
1608
clear();
1609
for (int i = 0; i < (R - 1); i++)
1610
m_rows[i][i] = s;
1611
m_rows[R - 1][C - 1] = 1.0f;
1612
return *this;
1613
}
1614
1615
inline matrix& set_scale_matrix(const row_vec& s)
1616
{
1617
clear();
1618
for (uint32_t i = 0; i < R; i++)
1619
m_rows[i][i] = s[i];
1620
return *this;
1621
}
1622
1623
inline matrix& set_scale_matrix(float x, float y)
1624
{
1625
set_identity_matrix();
1626
m_rows[0].set_x(x);
1627
m_rows[1].set_y(y);
1628
return *this;
1629
}
1630
1631
inline matrix& set_scale_matrix(float x, float y, float z)
1632
{
1633
set_identity_matrix();
1634
m_rows[0].set_x(x);
1635
m_rows[1].set_y(y);
1636
m_rows[2].set_z(z);
1637
return *this;
1638
}
1639
1640
inline matrix& set_translate_matrix(const row_vec& s)
1641
{
1642
set_identity_matrix();
1643
set_translate(s);
1644
return *this;
1645
}
1646
1647
inline matrix& set_translate_matrix(float x, float y)
1648
{
1649
set_identity_matrix();
1650
set_translate(row_vec(x, y).as_point());
1651
return *this;
1652
}
1653
1654
inline matrix& set_translate_matrix(float x, float y, float z)
1655
{
1656
set_identity_matrix();
1657
set_translate(row_vec(x, y, z).as_point());
1658
return *this;
1659
}
1660
1661
inline matrix get_transposed() const
1662
{
1663
static_assert(R == C);
1664
1665
matrix result;
1666
for (uint32_t i = 0; i < R; i++)
1667
for (uint32_t j = 0; j < C; j++)
1668
result.m_rows[i][j] = m_rows[j][i];
1669
return result;
1670
}
1671
1672
inline matrix<C, R, T> get_transposed_nonsquare() const
1673
{
1674
matrix<C, R, T> result;
1675
for (uint32_t i = 0; i < R; i++)
1676
for (uint32_t j = 0; j < C; j++)
1677
result[j][i] = m_rows[i][j];
1678
return result;
1679
}
1680
1681
inline matrix& transpose_in_place()
1682
{
1683
matrix result;
1684
for (uint32_t i = 0; i < R; i++)
1685
for (uint32_t j = 0; j < C; j++)
1686
result.m_rows[i][j] = m_rows[j][i];
1687
*this = result;
1688
return *this;
1689
}
1690
1691
// Frobenius Norm
1692
T get_norm() const
1693
{
1694
T result = 0;
1695
1696
for (uint32_t i = 0; i < R; i++)
1697
for (uint32_t j = 0; j < C; j++)
1698
result += m_rows[i][j] * m_rows[i][j];
1699
1700
return static_cast<T>(sqrt(result));
1701
}
1702
1703
inline matrix get_power(T p) const
1704
{
1705
matrix result;
1706
1707
for (uint32_t i = 0; i < R; i++)
1708
for (uint32_t j = 0; j < C; j++)
1709
result[i][j] = static_cast<T>(pow(m_rows[i][j], p));
1710
1711
return result;
1712
}
1713
1714
inline matrix<1, R, T> numpy_dot(const matrix<1, C, T>& b) const
1715
{
1716
matrix<1, R, T> result;
1717
1718
for (uint32_t r = 0; r < R; r++)
1719
{
1720
T sum = 0;
1721
for (uint32_t c = 0; c < C; c++)
1722
sum += m_rows[r][c] * b[0][c];
1723
1724
result[0][r] = static_cast<T>(sum);
1725
}
1726
1727
return result;
1728
}
1729
1730
bool invert(matrix& result) const
1731
{
1732
static_assert(R == C);
1733
1734
result.set_identity_matrix();
1735
1736
matrix mat(*this);
1737
1738
for (uint32_t c = 0; c < C; c++)
1739
{
1740
uint32_t max_r = c;
1741
for (uint32_t r = c + 1; r < R; r++)
1742
if (fabs(mat[r][c]) > fabs(mat[max_r][c]))
1743
max_r = r;
1744
1745
if (mat[max_r][c] == 0.0f)
1746
{
1747
result.set_identity_matrix();
1748
return false;
1749
}
1750
1751
std::swap(mat[c], mat[max_r]);
1752
std::swap(result[c], result[max_r]);
1753
1754
result[c] /= mat[c][c];
1755
mat[c] /= mat[c][c];
1756
1757
for (uint32_t row = 0; row < R; row++)
1758
{
1759
if (row != c)
1760
{
1761
const row_vec temp(mat[row][c]);
1762
mat[row] -= row_vec::mul_components(mat[c], temp);
1763
result[row] -= row_vec::mul_components(result[c], temp);
1764
}
1765
}
1766
}
1767
1768
return true;
1769
}
1770
1771
matrix& invert_in_place()
1772
{
1773
matrix result;
1774
invert(result);
1775
*this = result;
1776
return *this;
1777
}
1778
1779
matrix get_inverse() const
1780
{
1781
matrix result;
1782
invert(result);
1783
return result;
1784
}
1785
1786
T get_det() const
1787
{
1788
static_assert(R == C);
1789
return det_helper(*this, R);
1790
}
1791
1792
bool equal_tol(const matrix& b, float tol) const
1793
{
1794
for (uint32_t r = 0; r < R; r++)
1795
if (!row_vec::equal_tol(m_rows[r], b.m_rows[r], tol))
1796
return false;
1797
return true;
1798
}
1799
1800
bool is_square() const
1801
{
1802
return R == C;
1803
}
1804
1805
double get_trace() const
1806
{
1807
static_assert(is_square());
1808
1809
T total = 0;
1810
for (uint32_t i = 0; i < R; i++)
1811
total += (*this)(i, i);
1812
1813
return total;
1814
}
1815
1816
void print() const
1817
{
1818
for (uint32_t r = 0; r < R; r++)
1819
{
1820
for (uint32_t c = 0; c < C; c++)
1821
printf("%3.7f ", (*this)(r, c));
1822
printf("\n");
1823
}
1824
}
1825
1826
// This method transforms a vec by a matrix (D3D-style: row vector on left).
1827
// Confusingly, note that the data type is named "col_vec", but mathematically it's actually written as a row vector (of size equal to the # matrix rows, which is why it's called a "col_vec" in this class).
1828
// 1xR * RxC = 1xC
1829
// This dots against the matrix columns.
1830
static inline row_vec transform(const col_vec& a, const matrix& b)
1831
{
1832
row_vec result(b[0] * a[0]);
1833
for (uint32_t r = 1; r < R; r++)
1834
result += b[r] * a[r];
1835
return result;
1836
}
1837
1838
// This method transforms a vec by a matrix (D3D-style: row vector on left).
1839
// Last component of vec is assumed to be 1.
1840
static inline row_vec transform_point(const col_vec& a, const matrix& b)
1841
{
1842
row_vec result(0);
1843
for (int r = 0; r < (R - 1); r++)
1844
result += b[r] * a[r];
1845
result += b[R - 1];
1846
return result;
1847
}
1848
1849
// This method transforms a vec by a matrix (D3D-style: row vector on left).
1850
// Last component of vec is assumed to be 0.
1851
static inline row_vec transform_vector(const col_vec& a, const matrix& b)
1852
{
1853
row_vec result(0);
1854
for (int r = 0; r < (R - 1); r++)
1855
result += b[r] * a[r];
1856
return result;
1857
}
1858
1859
// This method transforms a vec by a matrix (D3D-style: row vector on left).
1860
// Last component of vec is assumed to be 1.
1861
static inline subcol_vec transform_point(const subcol_vec& a, const matrix& b)
1862
{
1863
subcol_vec result(0);
1864
for (int r = 0; r < static_cast<int>(R); r++)
1865
{
1866
const T s = (r < subcol_vec::num_elements) ? a[r] : 1.0f;
1867
for (int c = 0; c < static_cast<int>(C - 1); c++)
1868
result[c] += b[r][c] * s;
1869
}
1870
return result;
1871
}
1872
1873
// This method transforms a vec by a matrix (D3D-style: row vector on left).
1874
// Last component of vec is assumed to be 0.
1875
static inline subcol_vec transform_vector(const subcol_vec& a, const matrix& b)
1876
{
1877
subcol_vec result(0);
1878
for (int r = 0; r < static_cast<int>(R - 1); r++)
1879
{
1880
const T s = a[r];
1881
for (int c = 0; c < static_cast<int>(C - 1); c++)
1882
result[c] += b[r][c] * s;
1883
}
1884
return result;
1885
}
1886
1887
// Like transform() above, but the matrix is effectively transposed before the multiply.
1888
static inline col_vec transform_transposed(const col_vec& a, const matrix& b)
1889
{
1890
static_assert(R == C);
1891
col_vec result;
1892
for (uint32_t r = 0; r < R; r++)
1893
result[r] = b[r].dot(a);
1894
return result;
1895
}
1896
1897
// Like transform() above, but the matrix is effectively transposed before the multiply.
1898
// Last component of vec is assumed to be 0.
1899
static inline col_vec transform_vector_transposed(const col_vec& a, const matrix& b)
1900
{
1901
static_assert(R == C);
1902
col_vec result;
1903
for (uint32_t r = 0; r < R; r++)
1904
{
1905
T s = 0;
1906
for (uint32_t c = 0; c < (C - 1); c++)
1907
s += b[r][c] * a[c];
1908
1909
result[r] = s;
1910
}
1911
return result;
1912
}
1913
1914
// This method transforms a vec by a matrix (D3D-style: row vector on left), but the matrix is effectively transposed before the multiply.
1915
// Last component of vec is assumed to be 1.
1916
static inline subcol_vec transform_point_transposed(const subcol_vec& a, const matrix& b)
1917
{
1918
static_assert(R == C);
1919
subcol_vec result(0);
1920
for (int r = 0; r < R; r++)
1921
{
1922
const T s = (r < subcol_vec::num_elements) ? a[r] : 1.0f;
1923
for (int c = 0; c < (C - 1); c++)
1924
result[c] += b[c][r] * s;
1925
}
1926
return result;
1927
}
1928
1929
// This method transforms a vec by a matrix (D3D-style: row vector on left), but the matrix is effectively transposed before the multiply.
1930
// Last component of vec is assumed to be 0.
1931
static inline subcol_vec transform_vector_transposed(const subcol_vec& a, const matrix& b)
1932
{
1933
static_assert(R == C);
1934
subcol_vec result(0);
1935
for (int r = 0; r < static_cast<int>(R - 1); r++)
1936
{
1937
const T s = a[r];
1938
for (int c = 0; c < static_cast<int>(C - 1); c++)
1939
result[c] += b[c][r] * s;
1940
}
1941
return result;
1942
}
1943
1944
// This method transforms a matrix by a vector (OGL style, col vector on the right).
1945
// Note that the data type is named "row_vec", but mathematically it's actually written as a column vector (of size equal to the # matrix cols).
1946
// RxC * Cx1 = Rx1
1947
// This dots against the matrix rows.
1948
static inline col_vec transform(const matrix& b, const row_vec& a)
1949
{
1950
col_vec result;
1951
for (int r = 0; r < static_cast<int>(R); r++)
1952
result[r] = b[r].dot(a);
1953
return result;
1954
}
1955
1956
// This method transforms a matrix by a vector (OGL style, col vector on the right), except the matrix is effectively transposed before the multiply.
1957
// Note that the data type is named "row_vec", but mathematically it's actually written as a column vector (of size equal to the # matrix cols).
1958
// RxC * Cx1 = Rx1
1959
// This dots against the matrix cols.
1960
static inline col_vec transform_transposed(const matrix& b, const row_vec& a)
1961
{
1962
static_assert(R == C);
1963
row_vec result(b[0] * a[0]);
1964
for (int r = 1; r < static_cast<int>(R); r++)
1965
result += b[r] * a[r];
1966
return col_vec(result);
1967
}
1968
1969
static inline matrix& mul_components(matrix& result, const matrix& lhs, const matrix& rhs)
1970
{
1971
for (uint32_t r = 0; r < R; r++)
1972
result[r] = row_vec::mul_components(lhs[r], rhs[r]);
1973
return result;
1974
}
1975
1976
static inline matrix& concat(matrix& lhs, const matrix& rhs)
1977
{
1978
return matrix_mul_helper(lhs, matrix(lhs), rhs);
1979
}
1980
1981
inline matrix& concat_in_place(const matrix& rhs)
1982
{
1983
return concat(*this, rhs);
1984
}
1985
1986
static inline matrix& multiply(matrix& result, const matrix& lhs, const matrix& rhs)
1987
{
1988
matrix temp;
1989
matrix* pResult = ((&result == &lhs) || (&result == &rhs)) ? &temp : &result;
1990
1991
matrix_mul_helper(*pResult, lhs, rhs);
1992
if (pResult != &result)
1993
result = *pResult;
1994
1995
return result;
1996
}
1997
1998
static matrix make_zero_matrix()
1999
{
2000
matrix result;
2001
result.clear();
2002
return result;
2003
}
2004
2005
static matrix make_identity_matrix()
2006
{
2007
matrix result;
2008
result.set_identity_matrix();
2009
return result;
2010
}
2011
2012
static matrix make_translate_matrix(const row_vec& t)
2013
{
2014
return matrix(basisu::cIdentity).set_translate(t);
2015
}
2016
2017
static matrix make_translate_matrix(float x, float y)
2018
{
2019
return matrix(basisu::cIdentity).set_translate_matrix(x, y);
2020
}
2021
2022
static matrix make_translate_matrix(float x, float y, float z)
2023
{
2024
return matrix(basisu::cIdentity).set_translate_matrix(x, y, z);
2025
}
2026
2027
static inline matrix make_scale_matrix(float s)
2028
{
2029
return matrix().set_scale_matrix(s);
2030
}
2031
2032
static inline matrix make_scale_matrix(const row_vec& s)
2033
{
2034
return matrix().set_scale_matrix(s);
2035
}
2036
2037
static inline matrix make_scale_matrix(float x, float y)
2038
{
2039
static_assert(R >= 3 && C >= 3);
2040
matrix result;
2041
result.set_identity_matrix();
2042
result.m_rows[0][0] = x;
2043
result.m_rows[1][1] = y;
2044
return result;
2045
}
2046
2047
static inline matrix make_scale_matrix(float x, float y, float z)
2048
{
2049
static_assert(R >= 4 && C >= 4);
2050
matrix result;
2051
result.set_identity_matrix();
2052
result.m_rows[0][0] = x;
2053
result.m_rows[1][1] = y;
2054
result.m_rows[2][2] = z;
2055
return result;
2056
}
2057
2058
// Helpers derived from Graphics Gems 1 and 2 (Matrices and Transformations, Ronald N. Goldman)
2059
static matrix make_rotate_matrix(const vec<3, T>& axis, T ang)
2060
{
2061
static_assert(R >= 3 && C >= 3);
2062
2063
vec<3, T> norm_axis(axis.get_normalized());
2064
2065
double cos_a = cos(ang);
2066
double inv_cos_a = 1.0f - cos_a;
2067
2068
double sin_a = sin(ang);
2069
2070
const T x = norm_axis[0];
2071
const T y = norm_axis[1];
2072
const T z = norm_axis[2];
2073
2074
const double x2 = norm_axis[0] * norm_axis[0];
2075
const double y2 = norm_axis[1] * norm_axis[1];
2076
const double z2 = norm_axis[2] * norm_axis[2];
2077
2078
matrix result;
2079
result.set_identity_matrix();
2080
2081
result[0][0] = (T)((inv_cos_a * x2) + cos_a);
2082
result[1][0] = (T)((inv_cos_a * x * y) + (sin_a * z));
2083
result[2][0] = (T)((inv_cos_a * x * z) - (sin_a * y));
2084
2085
result[0][1] = (T)((inv_cos_a * x * y) - (sin_a * z));
2086
result[1][1] = (T)((inv_cos_a * y2) + cos_a);
2087
result[2][1] = (T)((inv_cos_a * y * z) + (sin_a * x));
2088
2089
result[0][2] = (T)((inv_cos_a * x * z) + (sin_a * y));
2090
result[1][2] = (T)((inv_cos_a * y * z) - (sin_a * x));
2091
result[2][2] = (T)((inv_cos_a * z2) + cos_a);
2092
2093
return result;
2094
}
2095
2096
static inline matrix make_rotate_matrix(T ang)
2097
{
2098
static_assert(R >= 2 && C >= 2);
2099
2100
matrix ret(basisu::cIdentity);
2101
2102
const T sin_a = static_cast<T>(sin(ang));
2103
const T cos_a = static_cast<T>(cos(ang));
2104
2105
ret[0][0] = +cos_a;
2106
ret[0][1] = -sin_a;
2107
ret[1][0] = +sin_a;
2108
ret[1][1] = +cos_a;
2109
2110
return ret;
2111
}
2112
2113
static inline matrix make_rotate_matrix(uint32_t axis, T ang)
2114
{
2115
vec<3, T> axis_vec;
2116
axis_vec.clear();
2117
axis_vec[axis] = 1.0f;
2118
return make_rotate_matrix(axis_vec, ang);
2119
}
2120
2121
static inline matrix make_cross_product_matrix(const vec<3, scalar_type>& c)
2122
{
2123
static_assert((num_rows >= 3) && (num_cols >= 3));
2124
matrix ret(basisu::cClear);
2125
ret[0][1] = c[2];
2126
ret[0][2] = -c[1];
2127
ret[1][0] = -c[2];
2128
ret[1][2] = c[0];
2129
ret[2][0] = c[1];
2130
ret[2][1] = -c[0];
2131
return ret;
2132
}
2133
2134
static inline matrix make_reflection_matrix(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q)
2135
{
2136
static_assert((num_rows == 4) && (num_cols == 4));
2137
matrix ret;
2138
assert(n.is_vector() && q.is_vector());
2139
ret = make_identity_matrix() - 2.0f * make_tensor_product_matrix(n, n);
2140
ret.set_translate((2.0f * q.dot(n) * n).as_point());
2141
return ret;
2142
}
2143
2144
static inline matrix make_tensor_product_matrix(const row_vec& v, const row_vec& w)
2145
{
2146
matrix ret;
2147
for (int r = 0; r < num_rows; r++)
2148
ret[r] = row_vec::mul_components(v.broadcast(r), w);
2149
return ret;
2150
}
2151
2152
static inline matrix make_uniform_scaling_matrix(const vec<4, scalar_type>& q, scalar_type c)
2153
{
2154
static_assert((num_rows == 4) && (num_cols == 4));
2155
assert(q.is_vector());
2156
matrix ret;
2157
ret = c * make_identity_matrix();
2158
ret.set_translate(((1.0f - c) * q).as_point());
2159
return ret;
2160
}
2161
2162
static inline matrix make_nonuniform_scaling_matrix(const vec<4, scalar_type>& q, scalar_type c, const vec<4, scalar_type>& w)
2163
{
2164
static_assert((num_rows == 4) && (num_cols == 4));
2165
assert(q.is_vector() && w.is_vector());
2166
matrix ret;
2167
ret = make_identity_matrix() - (1.0f - c) * make_tensor_product_matrix(w, w);
2168
ret.set_translate(((1.0f - c) * q.dot(w) * w).as_point());
2169
return ret;
2170
}
2171
2172
// n = normal of plane, q = point on plane
2173
static inline matrix make_ortho_projection_matrix(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q)
2174
{
2175
assert(n.is_vector() && q.is_vector());
2176
matrix ret;
2177
ret = make_identity_matrix() - make_tensor_product_matrix(n, n);
2178
ret.set_translate((q.dot(n) * n).as_point());
2179
return ret;
2180
}
2181
2182
static inline matrix make_parallel_projection(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q, const vec<4, scalar_type>& w)
2183
{
2184
assert(n.is_vector() && q.is_vector() && w.is_vector());
2185
matrix ret;
2186
ret = make_identity_matrix() - (make_tensor_product_matrix(n, w) / (w.dot(n)));
2187
ret.set_translate(((q.dot(n) / w.dot(n)) * w).as_point());
2188
return ret;
2189
}
2190
2191
protected:
2192
row_vec m_rows[R];
2193
2194
static T det_helper(const matrix& a, uint32_t n)
2195
{
2196
// Algorithm ported from Numerical Recipes in C.
2197
T d;
2198
matrix m;
2199
if (n == 2)
2200
d = a(0, 0) * a(1, 1) - a(1, 0) * a(0, 1);
2201
else
2202
{
2203
d = 0;
2204
for (uint32_t j1 = 1; j1 <= n; j1++)
2205
{
2206
for (uint32_t i = 2; i <= n; i++)
2207
{
2208
int j2 = 1;
2209
for (uint32_t j = 1; j <= n; j++)
2210
{
2211
if (j != j1)
2212
{
2213
m(i - 2, j2 - 1) = a(i - 1, j - 1);
2214
j2++;
2215
}
2216
}
2217
}
2218
d += (((1 + j1) & 1) ? -1.0f : 1.0f) * a(1 - 1, j1 - 1) * det_helper(m, n - 1);
2219
}
2220
}
2221
return d;
2222
}
2223
};
2224
2225
typedef matrix<2, 2, float> matrix22F;
2226
typedef matrix<2, 2, double> matrix22D;
2227
2228
typedef matrix<3, 3, float> matrix33F;
2229
typedef matrix<3, 3, double> matrix33D;
2230
2231
typedef matrix<4, 4, float> matrix44F;
2232
typedef matrix<4, 4, double> matrix44D;
2233
2234
typedef matrix<8, 8, float> matrix88F;
2235
2236
// These helpers create good old D3D-style matrices.
2237
inline matrix44F matrix44F_make_perspective_offcenter_lh(float l, float r, float b, float t, float nz, float fz)
2238
{
2239
float two_nz = 2.0f * nz;
2240
float one_over_width = 1.0f / (r - l);
2241
float one_over_height = 1.0f / (t - b);
2242
2243
matrix44F view_to_proj;
2244
view_to_proj[0].set(two_nz * one_over_width, 0.0f, 0.0f, 0.0f);
2245
view_to_proj[1].set(0.0f, two_nz * one_over_height, 0.0f, 0.0f);
2246
view_to_proj[2].set(-(l + r) * one_over_width, -(t + b) * one_over_height, fz / (fz - nz), 1.0f);
2247
view_to_proj[3].set(0.0f, 0.0f, -view_to_proj[2][2] * nz, 0.0f);
2248
return view_to_proj;
2249
}
2250
2251
// fov_y: full Y field of view (radians)
2252
// aspect: viewspace width/height
2253
inline matrix44F matrix44F_make_perspective_fov_lh(float fov_y, float aspect, float nz, float fz)
2254
{
2255
double sin_fov = sin(0.5f * fov_y);
2256
double cos_fov = cos(0.5f * fov_y);
2257
2258
float y_scale = static_cast<float>(cos_fov / sin_fov);
2259
float x_scale = static_cast<float>(y_scale / aspect);
2260
2261
matrix44F view_to_proj;
2262
view_to_proj[0].set(x_scale, 0, 0, 0);
2263
view_to_proj[1].set(0, y_scale, 0, 0);
2264
view_to_proj[2].set(0, 0, fz / (fz - nz), 1);
2265
view_to_proj[3].set(0, 0, -nz * fz / (fz - nz), 0);
2266
return view_to_proj;
2267
}
2268
2269
inline matrix44F matrix44F_make_ortho_offcenter_lh(float l, float r, float b, float t, float nz, float fz)
2270
{
2271
matrix44F view_to_proj;
2272
view_to_proj[0].set(2.0f / (r - l), 0.0f, 0.0f, 0.0f);
2273
view_to_proj[1].set(0.0f, 2.0f / (t - b), 0.0f, 0.0f);
2274
view_to_proj[2].set(0.0f, 0.0f, 1.0f / (fz - nz), 0.0f);
2275
view_to_proj[3].set((l + r) / (l - r), (t + b) / (b - t), nz / (nz - fz), 1.0f);
2276
return view_to_proj;
2277
}
2278
2279
inline matrix44F matrix44F_make_ortho_lh(float w, float h, float nz, float fz)
2280
{
2281
return matrix44F_make_ortho_offcenter_lh(-w * .5f, w * .5f, -h * .5f, h * .5f, nz, fz);
2282
}
2283
2284
inline matrix44F matrix44F_make_projection_to_screen_d3d(int x, int y, int w, int h, float min_z, float max_z)
2285
{
2286
matrix44F proj_to_screen;
2287
proj_to_screen[0].set(w * .5f, 0.0f, 0.0f, 0.0f);
2288
proj_to_screen[1].set(0, h * -.5f, 0.0f, 0.0f);
2289
proj_to_screen[2].set(0, 0.0f, max_z - min_z, 0.0f);
2290
proj_to_screen[3].set(x + w * .5f, y + h * .5f, min_z, 1.0f);
2291
return proj_to_screen;
2292
}
2293
2294
inline matrix44F matrix44F_make_lookat_lh(const vec3F& camera_pos, const vec3F& look_at, const vec3F& camera_up, float camera_roll_ang_in_radians)
2295
{
2296
vec4F col2(look_at - camera_pos);
2297
assert(col2.is_vector());
2298
if (col2.normalize() == 0.0f)
2299
col2.set(0, 0, 1, 0);
2300
2301
vec4F col1(camera_up);
2302
assert(col1.is_vector());
2303
if (!col2[0] && !col2[2])
2304
col1.set(-1.0f, 0.0f, 0.0f, 0.0f);
2305
2306
if ((col1.dot(col2)) > .9999f)
2307
col1.set(0.0f, 1.0f, 0.0f, 0.0f);
2308
2309
vec4F col0(vec4F::cross3(col1, col2).normalize_in_place());
2310
col1 = vec4F::cross3(col2, col0).normalize_in_place();
2311
2312
matrix44F rotm(matrix44F::make_identity_matrix());
2313
rotm.set_col(0, col0);
2314
rotm.set_col(1, col1);
2315
rotm.set_col(2, col2);
2316
return matrix44F::make_translate_matrix(-camera_pos[0], -camera_pos[1], -camera_pos[2]) * rotm * matrix44F::make_rotate_matrix(2, camera_roll_ang_in_radians);
2317
}
2318
2319
template<typename R> R matrix_NxN_create_DCT()
2320
{
2321
assert(R::num_rows == R::num_cols);
2322
2323
const uint32_t N = R::num_cols;
2324
2325
R result;
2326
for (uint32_t k = 0; k < N; k++)
2327
{
2328
for (uint32_t n = 0; n < N; n++)
2329
{
2330
double f;
2331
2332
if (!k)
2333
f = 1.0f / sqrt(float(N));
2334
else
2335
f = sqrt(2.0f / float(N)) * cos((basisu::cPiD * (2.0f * float(n) + 1.0f) * float(k)) / (2.0f * float(N)));
2336
2337
result(k, n) = static_cast<typename R::scalar_type>(f);
2338
}
2339
}
2340
2341
return result;
2342
}
2343
2344
template<typename R> R matrix_NxN_DCT(const R& a, const R& dct)
2345
{
2346
R temp;
2347
matrix_mul_helper<R, R, R>(temp, dct, a);
2348
R result;
2349
matrix_mul_helper_transpose_rhs<R, R, R>(result, temp, dct);
2350
return result;
2351
}
2352
2353
template<typename R> R matrix_NxN_IDCT(const R& b, const R& dct)
2354
{
2355
R temp;
2356
matrix_mul_helper_transpose_lhs<R, R, R>(temp, dct, b);
2357
R result;
2358
matrix_mul_helper<R, R, R>(result, temp, dct);
2359
return result;
2360
}
2361
2362
template<typename X, typename Y> matrix<X::num_rows* Y::num_rows, X::num_cols* Y::num_cols, typename X::scalar_type> matrix_kronecker_product(const X& a, const Y& b)
2363
{
2364
matrix<X::num_rows* Y::num_rows, X::num_cols* Y::num_cols, typename X::scalar_type> result;
2365
2366
for (uint32_t r = 0; r < X::num_rows; r++)
2367
{
2368
for (uint32_t c = 0; c < X::num_cols; c++)
2369
{
2370
for (uint32_t i = 0; i < Y::num_rows; i++)
2371
for (uint32_t j = 0; j < Y::num_cols; j++)
2372
result(r * Y::num_rows + i, c * Y::num_cols + j) = a(r, c) * b(i, j);
2373
}
2374
}
2375
2376
return result;
2377
}
2378
2379
template<typename X, typename Y> matrix<X::num_rows + Y::num_rows, X::num_cols, typename X::scalar_type> matrix_combine_vertically(const X& a, const Y& b)
2380
{
2381
matrix<X::num_rows + Y::num_rows, X::num_cols, typename X::scalar_type> result;
2382
2383
for (uint32_t r = 0; r < X::num_rows; r++)
2384
for (uint32_t c = 0; c < X::num_cols; c++)
2385
result(r, c) = a(r, c);
2386
2387
for (uint32_t r = 0; r < Y::num_rows; r++)
2388
for (uint32_t c = 0; c < Y::num_cols; c++)
2389
result(r + X::num_rows, c) = b(r, c);
2390
2391
return result;
2392
}
2393
2394
inline matrix88F get_haar8()
2395
{
2396
matrix22F haar2(
2397
1, 1,
2398
1, -1);
2399
matrix22F i2(
2400
1, 0,
2401
0, 1);
2402
matrix44F i4(
2403
1, 0, 0, 0,
2404
0, 1, 0, 0,
2405
0, 0, 1, 0,
2406
0, 0, 0, 1);
2407
2408
matrix<1, 2, float> b0; b0(0, 0) = 1; b0(0, 1) = 1;
2409
matrix<1, 2, float> b1; b1(0, 0) = 1.0f; b1(0, 1) = -1.0f;
2410
2411
matrix<2, 4, float> haar4_0 = matrix_kronecker_product(haar2, b0);
2412
matrix<2, 4, float> haar4_1 = matrix_kronecker_product(i2, b1);
2413
2414
matrix<4, 4, float> haar4 = matrix_combine_vertically(haar4_0, haar4_1);
2415
2416
matrix<4, 8, float> haar8_0 = matrix_kronecker_product(haar4, b0);
2417
matrix<4, 8, float> haar8_1 = matrix_kronecker_product(i4, b1);
2418
2419
haar8_0[2] *= sqrtf(2);
2420
haar8_0[3] *= sqrtf(2);
2421
haar8_1 *= 2.0f;
2422
2423
matrix<8, 8, float> haar8 = matrix_combine_vertically(haar8_0, haar8_1);
2424
2425
return haar8;
2426
}
2427
2428
inline matrix44F get_haar4()
2429
{
2430
const float sqrt2 = 1.4142135623730951f;
2431
2432
return matrix44F(
2433
.5f * 1, .5f * 1, .5f * 1, .5f * 1,
2434
.5f * 1, .5f * 1, .5f * -1, .5f * -1,
2435
.5f * sqrt2, .5f * -sqrt2, 0, 0,
2436
0, 0, .5f * sqrt2, .5f * -sqrt2);
2437
}
2438
2439
template<typename T>
2440
inline matrix<2, 2, T> get_inverse_2x2(const matrix<2, 2, T>& m)
2441
{
2442
double a = m[0][0];
2443
double b = m[0][1];
2444
double c = m[1][0];
2445
double d = m[1][1];
2446
2447
double det = a * d - b * c;
2448
if (det != 0.0f)
2449
det = 1.0f / det;
2450
2451
matrix<2, 2, T> result;
2452
result[0][0] = static_cast<T>(d * det);
2453
result[0][1] = static_cast<T>(-b * det);
2454
result[1][0] = static_cast<T>(-c * det);
2455
result[1][1] = static_cast<T>(a * det);
2456
return result;
2457
}
2458
2459
} // namespace bu_math
2460
2461
namespace basisu
2462
{
2463
class tracked_stat
2464
{
2465
public:
2466
tracked_stat() { clear(); }
2467
2468
inline void clear() { m_num = 0; m_total = 0; m_total2 = 0; }
2469
2470
inline void update(int32_t val) { m_num++; m_total += val; m_total2 += val * val; }
2471
2472
inline tracked_stat& operator += (uint32_t val) { update(val); return *this; }
2473
2474
inline uint32_t get_number_of_values() { return m_num; }
2475
inline uint64_t get_total() const { return m_total; }
2476
inline uint64_t get_total2() const { return m_total2; }
2477
2478
inline float get_average() const { return m_num ? (float)m_total / m_num : 0.0f; };
2479
inline float get_std_dev() const { return m_num ? sqrtf((float)(m_num * m_total2 - m_total * m_total)) / m_num : 0.0f; }
2480
inline float get_variance() const { float s = get_std_dev(); return s * s; }
2481
2482
private:
2483
uint32_t m_num;
2484
int64_t m_total;
2485
int64_t m_total2;
2486
};
2487
2488
class tracked_stat_dbl
2489
{
2490
public:
2491
tracked_stat_dbl() { clear(); }
2492
2493
inline void clear() { m_num = 0; m_total = 0; m_total2 = 0; }
2494
2495
inline void update(double val) { m_num++; m_total += val; m_total2 += val * val; }
2496
2497
inline tracked_stat_dbl& operator += (double val) { update(val); return *this; }
2498
2499
inline uint64_t get_number_of_values() { return m_num; }
2500
inline double get_total() const { return m_total; }
2501
inline double get_total2() const { return m_total2; }
2502
2503
inline double get_average() const { return m_num ? m_total / (double)m_num : 0.0f; };
2504
inline double get_std_dev() const { return m_num ? sqrt((double)(m_num * m_total2 - m_total * m_total)) / m_num : 0.0f; }
2505
inline double get_variance() const { double s = get_std_dev(); return s * s; }
2506
2507
private:
2508
uint64_t m_num;
2509
double m_total;
2510
double m_total2;
2511
};
2512
2513
template<typename FloatType>
2514
struct stats
2515
{
2516
uint32_t m_n;
2517
FloatType m_total, m_total_sq; // total, total of squares values
2518
FloatType m_avg, m_avg_sq; // mean, mean of the squared values
2519
FloatType m_rms; // sqrt(m_avg_sq)
2520
FloatType m_std_dev, m_var; // population standard deviation and variance
2521
FloatType m_mad; // mean absolute deviation
2522
FloatType m_min, m_max, m_range; // min and max values, and max-min
2523
FloatType m_len; // length of values as a vector (Euclidean norm or L2 norm)
2524
FloatType m_coeff_of_var; // coefficient of variation (std_dev/mean), High CV: Indicates greater variability relative to the mean, meaning the data values are more spread out,
2525
// Low CV : Indicates less variability relative to the mean, meaning the data values are more consistent.
2526
2527
FloatType m_skewness; // Skewness = 0: The data is perfectly symmetric around the mean,
2528
// Skewness > 0: The data is positively skewed (right-skewed),
2529
// Skewness < 0: The data is negatively skewed (left-skewed)
2530
// 0-.5 approx. symmetry, .5-1 moderate skew, >= 1 highly skewed
2531
2532
FloatType m_kurtosis; // Excess Kurtosis: Kurtosis = 0: The distribution has normal kurtosis (mesokurtic)
2533
// Kurtosis > 0: The distribution is leptokurtic, with heavy tails and a sharp peak
2534
// Kurtosis < 0: The distribution is platykurtic, with light tails and a flatter peak
2535
2536
bool m_any_zero;
2537
2538
FloatType m_median;
2539
uint32_t m_median_index;
2540
2541
stats()
2542
{
2543
clear();
2544
}
2545
2546
void clear()
2547
{
2548
m_n = 0;
2549
m_total = 0, m_total_sq = 0;
2550
m_avg = 0, m_avg_sq = 0;
2551
m_rms = 0;
2552
m_std_dev = 0, m_var = 0;
2553
m_mad = 0;
2554
m_min = BIG_FLOAT_VAL, m_max = -BIG_FLOAT_VAL; m_range = 0.0f;
2555
m_len = 0;
2556
m_coeff_of_var = 0;
2557
m_skewness = 0;
2558
m_kurtosis = 0;
2559
m_any_zero = false;
2560
2561
m_median = 0;
2562
m_median_index = 0;
2563
}
2564
2565
template<typename T>
2566
void calc_median(uint32_t n, const T* pVals, uint32_t stride = 1)
2567
{
2568
m_median = 0;
2569
m_median_index = 0;
2570
2571
if (!n)
2572
return;
2573
2574
basisu::vector< std::pair<T, uint32_t> > vals(n);
2575
2576
for (uint32_t i = 0; i < n; i++)
2577
{
2578
vals[i].first = pVals[i * stride];
2579
vals[i].second = i;
2580
}
2581
2582
std::sort(vals.begin(), vals.end(), [](const std::pair<T, uint32_t>& a, const std::pair<T, uint32_t>& b) {
2583
return a.first < b.first;
2584
});
2585
2586
m_median = vals[n / 2].first;
2587
if ((n & 1) == 0)
2588
m_median = (m_median + vals[(n / 2) - 1].first) * .5f;
2589
2590
m_median_index = vals[n / 2].second;
2591
}
2592
2593
template<typename T>
2594
void calc(uint32_t n, const T* pVals, uint32_t stride = 1, bool calc_median_flag = false)
2595
{
2596
clear();
2597
2598
if (!n)
2599
return;
2600
2601
if (calc_median_flag)
2602
calc_median(n, pVals, stride);
2603
2604
m_n = n;
2605
2606
for (uint32_t i = 0; i < n; i++)
2607
{
2608
FloatType v = (FloatType)pVals[i * stride];
2609
2610
if (v == 0.0f)
2611
m_any_zero = true;
2612
2613
m_total += v;
2614
m_total_sq += v * v;
2615
2616
if (!i)
2617
{
2618
m_min = v;
2619
m_max = v;
2620
}
2621
else
2622
{
2623
m_min = minimum(m_min, v);
2624
m_max = maximum(m_max, v);
2625
}
2626
}
2627
2628
m_range = m_max - m_min;
2629
2630
m_len = sqrt(m_total_sq);
2631
2632
const FloatType nd = (FloatType)n;
2633
2634
m_avg = m_total / nd;
2635
m_avg_sq = m_total_sq / nd;
2636
m_rms = sqrt(m_avg_sq);
2637
2638
for (uint32_t i = 0; i < n; i++)
2639
{
2640
FloatType v = (FloatType)pVals[i * stride];
2641
FloatType d = v - m_avg;
2642
2643
const FloatType d2 = d * d;
2644
const FloatType d3 = d2 * d;
2645
const FloatType d4 = d3 * d;
2646
2647
m_var += d2;
2648
m_mad += fabs(d);
2649
m_skewness += d3;
2650
m_kurtosis += d4;
2651
}
2652
2653
m_var /= nd;
2654
m_mad /= nd;
2655
2656
m_std_dev = sqrt(m_var);
2657
2658
m_coeff_of_var = (m_avg != 0.0f) ? (m_std_dev / fabs(m_avg)) : 0.0f;
2659
2660
FloatType k3 = m_std_dev * m_std_dev * m_std_dev;
2661
FloatType k4 = k3 * m_std_dev;
2662
m_skewness = (k3 != 0.0f) ? ((m_skewness / nd) / k3) : 0.0f;
2663
m_kurtosis = (k4 != 0.0f) ? (((m_kurtosis / nd) / k4) - 3.0f) : 0.0f;
2664
}
2665
2666
// Only compute average, variance and standard deviation.
2667
template<typename T>
2668
void calc_simplified(uint32_t n, const T* pVals, uint32_t stride = 1)
2669
{
2670
clear();
2671
2672
if (!n)
2673
return;
2674
2675
m_n = n;
2676
2677
for (uint32_t i = 0; i < n; i++)
2678
{
2679
FloatType v = (FloatType)pVals[i * stride];
2680
2681
m_total += v;
2682
}
2683
2684
const FloatType nd = (FloatType)n;
2685
2686
m_avg = m_total / nd;
2687
2688
for (uint32_t i = 0; i < n; i++)
2689
{
2690
FloatType v = (FloatType)pVals[i * stride];
2691
FloatType d = v - m_avg;
2692
2693
const FloatType d2 = d * d;
2694
2695
m_var += d2;
2696
}
2697
2698
m_var /= nd;
2699
m_std_dev = sqrt(m_var);
2700
}
2701
};
2702
2703
template<typename FloatType>
2704
struct comparative_stats
2705
{
2706
FloatType m_cov; // covariance
2707
FloatType m_pearson; // Pearson Correlation Coefficient (r) [-1,1]
2708
FloatType m_mse; // mean squared error
2709
FloatType m_rmse; // root mean squared error
2710
FloatType m_mae; // mean abs error
2711
FloatType m_rmsle; // root mean squared log error
2712
FloatType m_euclidean_dist; // euclidean distance between values as vectors
2713
FloatType m_cosine_sim; // normalized dot products of values as vectors
2714
FloatType m_min_diff, m_max_diff; // minimum/maximum abs difference between values
2715
2716
comparative_stats()
2717
{
2718
clear();
2719
}
2720
2721
void clear()
2722
{
2723
m_cov = 0;
2724
m_pearson = 0;
2725
m_mse = 0;
2726
m_rmse = 0;
2727
m_mae = 0;
2728
m_rmsle = 0;
2729
m_euclidean_dist = 0;
2730
m_cosine_sim = 0;
2731
m_min_diff = 0;
2732
m_max_diff = 0;
2733
}
2734
2735
template<typename T>
2736
void calc(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType> *pA_stats = nullptr, const stats<FloatType> *pB_stats = nullptr)
2737
{
2738
clear();
2739
if (!n)
2740
return;
2741
2742
stats<FloatType> temp_a_stats;
2743
if (!pA_stats)
2744
{
2745
pA_stats = &temp_a_stats;
2746
temp_a_stats.calc(n, pA, a_stride);
2747
}
2748
2749
stats<FloatType> temp_b_stats;
2750
if (!pB_stats)
2751
{
2752
pB_stats = &temp_b_stats;
2753
temp_b_stats.calc(n, pB, b_stride);
2754
}
2755
2756
for (uint32_t i = 0; i < n; i++)
2757
{
2758
const FloatType fa = (FloatType)pA[i * a_stride];
2759
const FloatType fb = (FloatType)pB[i * b_stride];
2760
2761
if ((pA_stats->m_min >= 0.0f) && (pB_stats->m_min >= 0.0f))
2762
{
2763
const FloatType ld = log(fa + 1.0f) - log(fb + 1.0f);
2764
m_rmsle += ld * ld;
2765
}
2766
2767
const FloatType diff = fa - fb;
2768
const FloatType abs_diff = fabs(diff);
2769
2770
m_mse += diff * diff;
2771
m_mae += abs_diff;
2772
2773
m_min_diff = i ? minimum(m_min_diff, abs_diff) : abs_diff;
2774
m_max_diff = maximum(m_max_diff, abs_diff);
2775
2776
const FloatType da = fa - pA_stats->m_avg;
2777
const FloatType db = fb - pB_stats->m_avg;
2778
m_cov += da * db;
2779
2780
m_cosine_sim += fa * fb;
2781
}
2782
2783
const FloatType nd = (FloatType)n;
2784
2785
m_euclidean_dist = sqrt(m_mse);
2786
2787
m_mse /= nd;
2788
m_rmse = sqrt(m_mse);
2789
2790
m_mae /= nd;
2791
2792
m_cov /= nd;
2793
2794
FloatType dv = (pA_stats->m_std_dev * pB_stats->m_std_dev);
2795
if (dv != 0.0f)
2796
m_pearson = m_cov / dv;
2797
2798
if ((pA_stats->m_min >= 0.0) && (pB_stats->m_min >= 0.0f))
2799
m_rmsle = sqrt(m_rmsle / nd);
2800
2801
FloatType c = pA_stats->m_len * pB_stats->m_len;
2802
if (c != 0.0f)
2803
m_cosine_sim /= c;
2804
else
2805
m_cosine_sim = 0.0f;
2806
}
2807
2808
// Only computes Pearson, cov, mse, rmse, Euclidean distance
2809
template<typename T>
2810
void calc_pearson(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)
2811
{
2812
clear();
2813
if (!n)
2814
return;
2815
2816
stats<FloatType> temp_a_stats;
2817
if (!pA_stats)
2818
{
2819
pA_stats = &temp_a_stats;
2820
temp_a_stats.calc(n, pA, a_stride);
2821
}
2822
2823
stats<FloatType> temp_b_stats;
2824
if (!pB_stats)
2825
{
2826
pB_stats = &temp_b_stats;
2827
temp_b_stats.calc(n, pB, b_stride);
2828
}
2829
2830
for (uint32_t i = 0; i < n; i++)
2831
{
2832
const FloatType fa = (FloatType)pA[i * a_stride];
2833
const FloatType fb = (FloatType)pB[i * b_stride];
2834
2835
const FloatType diff = fa - fb;
2836
2837
m_mse += diff * diff;
2838
2839
const FloatType da = fa - pA_stats->m_avg;
2840
const FloatType db = fb - pB_stats->m_avg;
2841
m_cov += da * db;
2842
}
2843
2844
const FloatType nd = (FloatType)n;
2845
2846
m_euclidean_dist = sqrt(m_mse);
2847
2848
m_mse /= nd;
2849
m_rmse = sqrt(m_mse);
2850
2851
m_cov /= nd;
2852
2853
FloatType dv = (pA_stats->m_std_dev * pB_stats->m_std_dev);
2854
if (dv != 0.0f)
2855
m_pearson = m_cov / dv;
2856
}
2857
2858
// Only computes MSE, RMSE, eclidiean distance, and covariance.
2859
template<typename T>
2860
void calc_simplified(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)
2861
{
2862
clear();
2863
if (!n)
2864
return;
2865
2866
stats<FloatType> temp_a_stats;
2867
if (!pA_stats)
2868
{
2869
pA_stats = &temp_a_stats;
2870
temp_a_stats.calc(n, pA, a_stride);
2871
}
2872
2873
stats<FloatType> temp_b_stats;
2874
if (!pB_stats)
2875
{
2876
pB_stats = &temp_b_stats;
2877
temp_b_stats.calc(n, pB, b_stride);
2878
}
2879
2880
for (uint32_t i = 0; i < n; i++)
2881
{
2882
const FloatType fa = (FloatType)pA[i * a_stride];
2883
const FloatType fb = (FloatType)pB[i * b_stride];
2884
2885
const FloatType diff = fa - fb;
2886
2887
m_mse += diff * diff;
2888
2889
const FloatType da = fa - pA_stats->m_avg;
2890
const FloatType db = fb - pB_stats->m_avg;
2891
m_cov += da * db;
2892
}
2893
2894
const FloatType nd = (FloatType)n;
2895
2896
m_euclidean_dist = sqrt(m_mse);
2897
2898
m_mse /= nd;
2899
m_rmse = sqrt(m_mse);
2900
2901
m_cov /= nd;
2902
}
2903
2904
// Only computes covariance.
2905
template<typename T>
2906
void calc_cov(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)
2907
{
2908
clear();
2909
if (!n)
2910
return;
2911
2912
stats<FloatType> temp_a_stats;
2913
if (!pA_stats)
2914
{
2915
pA_stats = &temp_a_stats;
2916
temp_a_stats.calc(n, pA, a_stride);
2917
}
2918
2919
stats<FloatType> temp_b_stats;
2920
if (!pB_stats)
2921
{
2922
pB_stats = &temp_b_stats;
2923
temp_b_stats.calc(n, pB, b_stride);
2924
}
2925
2926
for (uint32_t i = 0; i < n; i++)
2927
{
2928
const FloatType fa = (FloatType)pA[i * a_stride];
2929
const FloatType fb = (FloatType)pB[i * b_stride];
2930
2931
const FloatType da = fa - pA_stats->m_avg;
2932
const FloatType db = fb - pB_stats->m_avg;
2933
m_cov += da * db;
2934
}
2935
2936
const FloatType nd = (FloatType)n;
2937
2938
m_cov /= nd;
2939
}
2940
};
2941
2942
class stat_history
2943
{
2944
public:
2945
stat_history(uint32_t size)
2946
{
2947
init(size);
2948
}
2949
2950
void init(uint32_t size)
2951
{
2952
clear();
2953
2954
m_samples.reserve(size);
2955
m_samples.resize(0);
2956
m_max_samples = size;
2957
}
2958
2959
inline void clear()
2960
{
2961
m_samples.resize(0);
2962
m_max_samples = 0;
2963
}
2964
2965
inline void update(double val)
2966
{
2967
m_samples.push_back(val);
2968
2969
if (m_samples.size() > m_max_samples)
2970
m_samples.erase_index(0);
2971
}
2972
2973
inline size_t size()
2974
{
2975
return m_samples.size();
2976
}
2977
2978
struct stats
2979
{
2980
double m_avg = 0;
2981
double m_std_dev = 0;
2982
double m_var = 0;
2983
double m_mad = 0;
2984
double m_min_val = 0;
2985
double m_max_val = 0;
2986
2987
void clear()
2988
{
2989
basisu::clear_obj(*this);
2990
}
2991
};
2992
2993
inline void get_stats(stats& s)
2994
{
2995
s.clear();
2996
2997
if (m_samples.empty())
2998
return;
2999
3000
double total = 0, total2 = 0;
3001
3002
for (size_t i = 0; i < m_samples.size(); i++)
3003
{
3004
const double v = m_samples[i];
3005
3006
total += v;
3007
total2 += v * v;
3008
3009
if (!i)
3010
{
3011
s.m_min_val = v;
3012
s.m_max_val = v;
3013
}
3014
else
3015
{
3016
s.m_min_val = basisu::minimum<double>(s.m_min_val, v);
3017
s.m_max_val = basisu::maximum<double>(s.m_max_val, v);
3018
}
3019
}
3020
3021
const double n = (double)m_samples.size();
3022
3023
s.m_avg = total / n;
3024
s.m_std_dev = sqrt((n * total2 - total * total)) / n;
3025
s.m_var = (n * total2 - total * total) / (n * n);
3026
3027
double sc = 0;
3028
for (size_t i = 0; i < m_samples.size(); i++)
3029
{
3030
const double v = m_samples[i];
3031
s.m_mad += fabs(v - s.m_avg);
3032
3033
sc += basisu::square(v - s.m_avg);
3034
}
3035
sc = sqrt(sc / n);
3036
3037
s.m_mad /= n;
3038
}
3039
3040
private:
3041
uint32_t m_max_samples;
3042
basisu::vector<double> m_samples;
3043
};
3044
3045
// bfloat16 helpers, see:
3046
// https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
3047
3048
typedef union
3049
{
3050
uint32_t u;
3051
float f;
3052
} float32_union;
3053
3054
typedef uint16_t bfloat16;
3055
3056
inline float bfloat16_to_float(bfloat16 bfloat16)
3057
{
3058
float32_union float_union;
3059
float_union.u = ((uint32_t)bfloat16) << 16;
3060
return float_union.f;
3061
}
3062
3063
inline bfloat16 float_to_bfloat16(float input, bool round_flag = true)
3064
{
3065
float32_union float_union;
3066
float_union.f = input;
3067
3068
uint32_t exponent = (float_union.u >> 23) & 0xFF;
3069
3070
// Check if the number is denormalized in float32 (exponent == 0)
3071
if (exponent == 0)
3072
{
3073
// Handle denormalized float32 as zero in bfloat16
3074
return 0x0000;
3075
}
3076
3077
// Extract the top 16 bits (sign, exponent, and 7 most significant bits of the mantissa)
3078
uint32_t upperBits = float_union.u >> 16;
3079
3080
if (round_flag)
3081
{
3082
// Check the most significant bit of the lower 16 bits for rounding
3083
uint32_t lowerBits = float_union.u & 0xFFFF;
3084
3085
// Round to nearest or even
3086
if ((lowerBits & 0x8000) &&
3087
((lowerBits > 0x8000) || ((lowerBits == 0x8000) && (upperBits & 1)))
3088
)
3089
{
3090
// Round up
3091
upperBits += 1;
3092
3093
// Check for overflow in the exponent after rounding up
3094
if (((upperBits & 0x7F80) == 0x7F80) && ((upperBits & 0x007F) == 0))
3095
{
3096
// Exponent overflow (the upper bits became all 1s)
3097
// Set the result to infinity
3098
upperBits = (upperBits & 0x8000) | 0x7F80; // Preserve the sign bit, set exponent to 0xFF, and mantissa to 0
3099
}
3100
}
3101
}
3102
3103
return (bfloat16)upperBits;
3104
}
3105
3106
inline int bfloat16_get_exp(bfloat16 v)
3107
{
3108
return (int)((v >> 7) & 0xFF) - 127;
3109
}
3110
3111
inline int bfloat16_get_mantissa(bfloat16 v)
3112
{
3113
return (v & 0x7F);
3114
}
3115
3116
inline int bfloat16_get_sign(bfloat16 v)
3117
{
3118
return (v & 0x8000) ? -1 : 1;
3119
}
3120
3121
inline bool bfloat16_is_nan_or_inf(bfloat16 v)
3122
{
3123
return ((v >> 7) & 0xFF) == 0xFF;
3124
}
3125
3126
inline bool bfloat16_is_zero(bfloat16 v)
3127
{
3128
return (v & 0x7FFF) == 0;
3129
}
3130
3131
inline bfloat16 bfloat16_init(int sign, int exp, int mant)
3132
{
3133
uint16_t res = (sign < 0) ? 0x8000 : 0;
3134
3135
assert((exp >= -126) && (res <= 127));
3136
res |= ((exp + 127) << 7);
3137
3138
assert((mant >= 0) && (mant < 128));
3139
res |= mant;
3140
3141
return res;
3142
}
3143
3144
3145
} // namespace basisu
3146
3147
3148