Path: blob/main_old/src/tests/test_utils/ConstantFoldingTest.h
1693 views
//1// Copyright 2016 The ANGLE Project Authors. All rights reserved.2// Use of this source code is governed by a BSD-style license that can be3// found in the LICENSE file.4//5// ConstantFoldingTest.h:6// Utilities for constant folding tests.7//89#ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_10#define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_1112#include <vector>1314#include "common/mathutil.h"15#include "compiler/translator/tree_util/FindMain.h"16#include "compiler/translator/tree_util/FindSymbolNode.h"17#include "compiler/translator/tree_util/IntermTraverse.h"18#include "tests/test_utils/ShaderCompileTreeTest.h"1920namespace sh21{2223class TranslatorESSL;2425template <typename T>26class ConstantFinder : public TIntermTraverser27{28public:29ConstantFinder(const std::vector<T> &constantVector)30: TIntermTraverser(true, false, false),31mConstantVector(constantVector),32mFaultTolerance(T()),33mFound(false)34{}3536ConstantFinder(const std::vector<T> &constantVector, const T &faultTolerance)37: TIntermTraverser(true, false, false),38mConstantVector(constantVector),39mFaultTolerance(faultTolerance),40mFound(false)41{}4243ConstantFinder(const T &value)44: TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false)45{46mConstantVector.push_back(value);47}4849void visitConstantUnion(TIntermConstantUnion *node)50{51if (node->getType().getObjectSize() == mConstantVector.size())52{53bool found = true;54for (size_t i = 0; i < mConstantVector.size(); i++)55{56if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))57{58found = false;59break;60}61}62if (found)63{64mFound = found;65}66}67}6869bool found() const { return mFound; }7071private:72bool isEqual(const TConstantUnion &node, const float &value) const73{74if (node.getType() != EbtFloat)75{76return false;77}78if (value == std::numeric_limits<float>::infinity())79{80return gl::isInf(node.getFConst()) && node.getFConst() > 0;81}82else if (value == -std::numeric_limits<float>::infinity())83{84return gl::isInf(node.getFConst()) && node.getFConst() < 0;85}86else if (gl::isNaN(value))87{88// All NaNs are treated as equal.89return gl::isNaN(node.getFConst());90}91return mFaultTolerance >= fabsf(node.getFConst() - value);92}9394bool isEqual(const TConstantUnion &node, const int &value) const95{96if (node.getType() != EbtInt)97{98return false;99}100ASSERT(mFaultTolerance < std::numeric_limits<int>::max());101// abs() returns 0 at least on some platforms when the minimum int value is passed in (it102// doesn't have a positive counterpart).103return mFaultTolerance >= abs(node.getIConst() - value) &&104(node.getIConst() - value) != std::numeric_limits<int>::min();105}106107bool isEqual(const TConstantUnion &node, const unsigned int &value) const108{109if (node.getType() != EbtUInt)110{111return false;112}113ASSERT(mFaultTolerance < static_cast<unsigned int>(std::numeric_limits<int>::max()));114return static_cast<int>(mFaultTolerance) >=115abs(static_cast<int>(node.getUConst() - value)) &&116static_cast<int>(node.getUConst() - value) != std::numeric_limits<int>::min();117}118119bool isEqual(const TConstantUnion &node, const bool &value) const120{121if (node.getType() != EbtBool)122{123return false;124}125return node.getBConst() == value;126}127128std::vector<T> mConstantVector;129T mFaultTolerance;130bool mFound;131};132133class ConstantFoldingTest : public ShaderCompileTreeTest134{135public:136ConstantFoldingTest() {}137138protected:139::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }140ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; }141142template <typename T>143bool constantFoundInAST(T constant)144{145ConstantFinder<T> finder(constant);146mASTRoot->traverse(&finder);147return finder.found();148}149150template <typename T>151bool constantVectorFoundInAST(const std::vector<T> &constantVector)152{153ConstantFinder<T> finder(constantVector);154mASTRoot->traverse(&finder);155return finder.found();156}157158template <typename T>159bool constantColumnMajorMatrixFoundInAST(const std::vector<T> &constantMatrix)160{161return constantVectorFoundInAST(constantMatrix);162}163164template <typename T>165bool constantVectorNearFoundInAST(const std::vector<T> &constantVector, const T &faultTolerance)166{167ConstantFinder<T> finder(constantVector, faultTolerance);168mASTRoot->traverse(&finder);169return finder.found();170}171172bool symbolFoundInAST(const char *symbolName)173{174return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr;175}176177bool symbolFoundInMain(const char *symbolName)178{179return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr;180}181};182183class ConstantFoldingExpressionTest : public ConstantFoldingTest184{185public:186ConstantFoldingExpressionTest() {}187188void evaluateFloat(const std::string &floatExpression);189void evaluateInt(const std::string &intExpression);190void evaluateUint(const std::string &uintExpression);191};192193} // namespace sh194195#endif // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_196197198