Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/angle
Path: blob/main_old/src/tests/test_utils/ConstantFoldingTest.h
1693 views
1
//
2
// Copyright 2016 The ANGLE Project Authors. All rights reserved.
3
// Use of this source code is governed by a BSD-style license that can be
4
// found in the LICENSE file.
5
//
6
// ConstantFoldingTest.h:
7
// Utilities for constant folding tests.
8
//
9
10
#ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
11
#define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
12
13
#include <vector>
14
15
#include "common/mathutil.h"
16
#include "compiler/translator/tree_util/FindMain.h"
17
#include "compiler/translator/tree_util/FindSymbolNode.h"
18
#include "compiler/translator/tree_util/IntermTraverse.h"
19
#include "tests/test_utils/ShaderCompileTreeTest.h"
20
21
namespace sh
22
{
23
24
class TranslatorESSL;
25
26
template <typename T>
27
class ConstantFinder : public TIntermTraverser
28
{
29
public:
30
ConstantFinder(const std::vector<T> &constantVector)
31
: TIntermTraverser(true, false, false),
32
mConstantVector(constantVector),
33
mFaultTolerance(T()),
34
mFound(false)
35
{}
36
37
ConstantFinder(const std::vector<T> &constantVector, const T &faultTolerance)
38
: TIntermTraverser(true, false, false),
39
mConstantVector(constantVector),
40
mFaultTolerance(faultTolerance),
41
mFound(false)
42
{}
43
44
ConstantFinder(const T &value)
45
: TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false)
46
{
47
mConstantVector.push_back(value);
48
}
49
50
void visitConstantUnion(TIntermConstantUnion *node)
51
{
52
if (node->getType().getObjectSize() == mConstantVector.size())
53
{
54
bool found = true;
55
for (size_t i = 0; i < mConstantVector.size(); i++)
56
{
57
if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))
58
{
59
found = false;
60
break;
61
}
62
}
63
if (found)
64
{
65
mFound = found;
66
}
67
}
68
}
69
70
bool found() const { return mFound; }
71
72
private:
73
bool isEqual(const TConstantUnion &node, const float &value) const
74
{
75
if (node.getType() != EbtFloat)
76
{
77
return false;
78
}
79
if (value == std::numeric_limits<float>::infinity())
80
{
81
return gl::isInf(node.getFConst()) && node.getFConst() > 0;
82
}
83
else if (value == -std::numeric_limits<float>::infinity())
84
{
85
return gl::isInf(node.getFConst()) && node.getFConst() < 0;
86
}
87
else if (gl::isNaN(value))
88
{
89
// All NaNs are treated as equal.
90
return gl::isNaN(node.getFConst());
91
}
92
return mFaultTolerance >= fabsf(node.getFConst() - value);
93
}
94
95
bool isEqual(const TConstantUnion &node, const int &value) const
96
{
97
if (node.getType() != EbtInt)
98
{
99
return false;
100
}
101
ASSERT(mFaultTolerance < std::numeric_limits<int>::max());
102
// abs() returns 0 at least on some platforms when the minimum int value is passed in (it
103
// doesn't have a positive counterpart).
104
return mFaultTolerance >= abs(node.getIConst() - value) &&
105
(node.getIConst() - value) != std::numeric_limits<int>::min();
106
}
107
108
bool isEqual(const TConstantUnion &node, const unsigned int &value) const
109
{
110
if (node.getType() != EbtUInt)
111
{
112
return false;
113
}
114
ASSERT(mFaultTolerance < static_cast<unsigned int>(std::numeric_limits<int>::max()));
115
return static_cast<int>(mFaultTolerance) >=
116
abs(static_cast<int>(node.getUConst() - value)) &&
117
static_cast<int>(node.getUConst() - value) != std::numeric_limits<int>::min();
118
}
119
120
bool isEqual(const TConstantUnion &node, const bool &value) const
121
{
122
if (node.getType() != EbtBool)
123
{
124
return false;
125
}
126
return node.getBConst() == value;
127
}
128
129
std::vector<T> mConstantVector;
130
T mFaultTolerance;
131
bool mFound;
132
};
133
134
class ConstantFoldingTest : public ShaderCompileTreeTest
135
{
136
public:
137
ConstantFoldingTest() {}
138
139
protected:
140
::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
141
ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; }
142
143
template <typename T>
144
bool constantFoundInAST(T constant)
145
{
146
ConstantFinder<T> finder(constant);
147
mASTRoot->traverse(&finder);
148
return finder.found();
149
}
150
151
template <typename T>
152
bool constantVectorFoundInAST(const std::vector<T> &constantVector)
153
{
154
ConstantFinder<T> finder(constantVector);
155
mASTRoot->traverse(&finder);
156
return finder.found();
157
}
158
159
template <typename T>
160
bool constantColumnMajorMatrixFoundInAST(const std::vector<T> &constantMatrix)
161
{
162
return constantVectorFoundInAST(constantMatrix);
163
}
164
165
template <typename T>
166
bool constantVectorNearFoundInAST(const std::vector<T> &constantVector, const T &faultTolerance)
167
{
168
ConstantFinder<T> finder(constantVector, faultTolerance);
169
mASTRoot->traverse(&finder);
170
return finder.found();
171
}
172
173
bool symbolFoundInAST(const char *symbolName)
174
{
175
return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr;
176
}
177
178
bool symbolFoundInMain(const char *symbolName)
179
{
180
return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr;
181
}
182
};
183
184
class ConstantFoldingExpressionTest : public ConstantFoldingTest
185
{
186
public:
187
ConstantFoldingExpressionTest() {}
188
189
void evaluateFloat(const std::string &floatExpression);
190
void evaluateInt(const std::string &intExpression);
191
void evaluateUint(const std::string &uintExpression);
192
};
193
194
} // namespace sh
195
196
#endif // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
197
198