Path: blob/master/thirdparty/glslang/SPIRV/SPVRemapper.cpp
9903 views
//1// Copyright (C) 2015 LunarG, Inc.2//3// All rights reserved.4//5// Redistribution and use in source and binary forms, with or without6// modification, are permitted provided that the following conditions7// are met:8//9// Redistributions of source code must retain the above copyright10// notice, this list of conditions and the following disclaimer.11//12// Redistributions in binary form must reproduce the above13// copyright notice, this list of conditions and the following14// disclaimer in the documentation and/or other materials provided15// with the distribution.16//17// Neither the name of 3Dlabs Inc. Ltd. nor the names of its18// contributors may be used to endorse or promote products derived19// from this software without specific prior written permission.20//21// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS22// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT23// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS24// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE25// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,26// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,27// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;28// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER29// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT30// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN31// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE32// POSSIBILITY OF SUCH DAMAGE.33//3435#include "SPVRemapper.h"36#include "doc.h"3738#include <algorithm>39#include <cassert>40#include "../glslang/Include/Common.h"4142namespace spv {4344// By default, just abort on error. Can be overridden via RegisterErrorHandler45spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };46// By default, eat log messages. Can be overridden via RegisterLogHandler47spirvbin_t::logfn_t spirvbin_t::logHandler = [](const std::string&) { };4849// This can be overridden to provide other message behavior if needed50void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const51{52if (verbose >= minVerbosity)53logHandler(std::string(indent, ' ') + txt);54}5556// hash opcode, with special handling for OpExtInst57std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)58{59const spv::Op opCode = asOpCode(word);6061std::uint32_t offset = 0;6263switch (opCode) {64case spv::OpExtInst:65offset += asId(word + 4); break;66default:67break;68}6970return opCode * 19 + offset; // 19 = small prime71}7273spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const74{75static const int maxCount = 1<<30;7677switch (opCode) {78case spv::OpTypeFloat: // fall through...79case spv::OpTypePointer: return range_t(2, 3);80case spv::OpTypeInt: return range_t(2, 4);81// TODO: case spv::OpTypeImage:82// TODO: case spv::OpTypeSampledImage:83case spv::OpTypeSampler: return range_t(3, 8);84case spv::OpTypeVector: // fall through85case spv::OpTypeMatrix: // ...86case spv::OpTypePipe: return range_t(3, 4);87case spv::OpConstant: return range_t(3, maxCount);88default: return range_t(0, 0);89}90}9192spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const93{94static const int maxCount = 1<<30;9596if (isConstOp(opCode))97return range_t(1, 2);9899switch (opCode) {100case spv::OpTypeVector: // fall through101case spv::OpTypeMatrix: // ...102case spv::OpTypeSampler: // ...103case spv::OpTypeArray: // ...104case spv::OpTypeRuntimeArray: // ...105case spv::OpTypePipe: return range_t(2, 3);106case spv::OpTypeStruct: // fall through107case spv::OpTypeFunction: return range_t(2, maxCount);108case spv::OpTypePointer: return range_t(3, 4);109default: return range_t(0, 0);110}111}112113spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const114{115static const int maxCount = 1<<30;116117switch (opCode) {118case spv::OpTypeArray: // fall through...119case spv::OpTypeRuntimeArray: return range_t(3, 4);120case spv::OpConstantComposite: return range_t(3, maxCount);121default: return range_t(0, 0);122}123}124125// Return the size of a type in 32-bit words. This currently only126// handles ints and floats, and is only invoked by queries which must be127// integer types. If ever needed, it can be generalized.128unsigned spirvbin_t::typeSizeInWords(spv::Id id) const129{130const unsigned typeStart = idPos(id);131const spv::Op opCode = asOpCode(typeStart);132133if (errorLatch)134return 0;135136switch (opCode) {137case spv::OpTypeInt: // fall through...138case spv::OpTypeFloat: return (spv[typeStart+2]+31)/32;139default:140return 0;141}142}143144// Looks up the type of a given const or variable ID, and145// returns its size in 32-bit words.146unsigned spirvbin_t::idTypeSizeInWords(spv::Id id) const147{148const auto tid_it = idTypeSizeMap.find(id);149if (tid_it == idTypeSizeMap.end()) {150error("type size for ID not found");151return 0;152}153154return tid_it->second;155}156157// Is this an opcode we should remove when using --strip?158bool spirvbin_t::isStripOp(spv::Op opCode, unsigned start) const159{160switch (opCode) {161case spv::OpSource:162case spv::OpSourceExtension:163case spv::OpName:164case spv::OpMemberName:165case spv::OpLine :166{167const std::string name = literalString(start + 2);168169std::vector<std::string>::const_iterator it;170for (it = stripWhiteList.begin(); it < stripWhiteList.end(); it++)171{172if (name.find(*it) != std::string::npos) {173return false;174}175}176177return true;178}179default :180return false;181}182}183184// Return true if this opcode is flow control185bool spirvbin_t::isFlowCtrl(spv::Op opCode) const186{187switch (opCode) {188case spv::OpBranchConditional:189case spv::OpBranch:190case spv::OpSwitch:191case spv::OpLoopMerge:192case spv::OpSelectionMerge:193case spv::OpLabel:194case spv::OpFunction:195case spv::OpFunctionEnd: return true;196default: return false;197}198}199200// Return true if this opcode defines a type201bool spirvbin_t::isTypeOp(spv::Op opCode) const202{203switch (opCode) {204case spv::OpTypeVoid:205case spv::OpTypeBool:206case spv::OpTypeInt:207case spv::OpTypeFloat:208case spv::OpTypeVector:209case spv::OpTypeMatrix:210case spv::OpTypeImage:211case spv::OpTypeSampler:212case spv::OpTypeArray:213case spv::OpTypeRuntimeArray:214case spv::OpTypeStruct:215case spv::OpTypeOpaque:216case spv::OpTypePointer:217case spv::OpTypeFunction:218case spv::OpTypeEvent:219case spv::OpTypeDeviceEvent:220case spv::OpTypeReserveId:221case spv::OpTypeQueue:222case spv::OpTypeSampledImage:223case spv::OpTypePipe: return true;224default: return false;225}226}227228// Return true if this opcode defines a constant229bool spirvbin_t::isConstOp(spv::Op opCode) const230{231switch (opCode) {232case spv::OpConstantSampler:233error("unimplemented constant type");234return true;235236case spv::OpConstantNull:237case spv::OpConstantTrue:238case spv::OpConstantFalse:239case spv::OpConstantComposite:240case spv::OpConstant:241return true;242243default:244return false;245}246}247248const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };249const auto op_fn_nop = [](spv::Id&) { };250251// g++ doesn't like these defined in the class proper in an anonymous namespace.252// Dunno why. Also MSVC doesn't like the constexpr keyword. Also dunno why.253// Defining them externally seems to please both compilers, so, here they are.254const spv::Id spirvbin_t::unmapped = spv::Id(-10000);255const spv::Id spirvbin_t::unused = spv::Id(-10001);256const int spirvbin_t::header_size = 5;257258spv::Id spirvbin_t::nextUnusedId(spv::Id id)259{260while (isNewIdMapped(id)) // search for an unused ID261++id;262263return id;264}265266spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)267{268//assert(id != spv::NoResult && newId != spv::NoResult);269270if (id > bound()) {271error(std::string("ID out of range: ") + std::to_string(id));272return spirvbin_t::unused;273}274275if (id >= idMapL.size())276idMapL.resize(id+1, unused);277278if (newId != unmapped && newId != unused) {279if (isOldIdUnused(id)) {280error(std::string("ID unused in module: ") + std::to_string(id));281return spirvbin_t::unused;282}283284if (!isOldIdUnmapped(id)) {285error(std::string("ID already mapped: ") + std::to_string(id) + " -> "286+ std::to_string(localId(id)));287288return spirvbin_t::unused;289}290291if (isNewIdMapped(newId)) {292error(std::string("ID already used in module: ") + std::to_string(newId));293return spirvbin_t::unused;294}295296msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));297setMapped(newId);298largestNewId = std::max(largestNewId, newId);299}300301return idMapL[id] = newId;302}303304// Parse a literal string from the SPIR binary and return it as an std::string305// Due to C++11 RValue references, this doesn't copy the result string.306std::string spirvbin_t::literalString(unsigned word) const307{308std::string literal;309const spirword_t * pos = spv.data() + word;310311literal.reserve(16);312313do {314spirword_t word = *pos;315for (int i = 0; i < 4; i++) {316char c = word & 0xff;317if (c == '\0')318return literal;319literal += c;320word >>= 8;321}322pos++;323} while (true);324}325326void spirvbin_t::applyMap()327{328msg(3, 2, std::string("Applying map: "));329330// Map local IDs through the ID map331process(inst_fn_nop, // ignore instructions332[this](spv::Id& id) {333id = localId(id);334335if (errorLatch)336return;337338assert(id != unused && id != unmapped);339}340);341}342343// Find free IDs for anything we haven't mapped344void spirvbin_t::mapRemainder()345{346msg(3, 2, std::string("Remapping remainder: "));347348spv::Id unusedId = 1; // can't use 0: that's NoResult349spirword_t maxBound = 0;350351for (spv::Id id = 0; id < idMapL.size(); ++id) {352if (isOldIdUnused(id))353continue;354355// Find a new mapping for any used but unmapped IDs356if (isOldIdUnmapped(id)) {357localId(id, unusedId = nextUnusedId(unusedId));358if (errorLatch)359return;360}361362if (isOldIdUnmapped(id)) {363error(std::string("old ID not mapped: ") + std::to_string(id));364return;365}366367// Track max bound368maxBound = std::max(maxBound, localId(id) + 1);369370if (errorLatch)371return;372}373374bound(maxBound); // reset header ID bound to as big as it now needs to be375}376377// Mark debug instructions for stripping378void spirvbin_t::stripDebug()379{380// Strip instructions in the stripOp set: debug info.381process(382[&](spv::Op opCode, unsigned start) {383// remember opcodes we want to strip later384if (isStripOp(opCode, start))385stripInst(start);386return true;387},388op_fn_nop);389}390391// Mark instructions that refer to now-removed IDs for stripping392void spirvbin_t::stripDeadRefs()393{394process(395[&](spv::Op opCode, unsigned start) {396// strip opcodes pointing to removed data397switch (opCode) {398case spv::OpName:399case spv::OpMemberName:400case spv::OpDecorate:401case spv::OpMemberDecorate:402if (idPosR.find(asId(start+1)) == idPosR.end())403stripInst(start);404break;405default:406break; // leave it alone407}408409return true;410},411op_fn_nop);412413strip();414}415416// Update local maps of ID, type, etc positions417void spirvbin_t::buildLocalMaps()418{419msg(2, 2, std::string("build local maps: "));420421mapped.clear();422idMapL.clear();423// preserve nameMap, so we don't clear that.424fnPos.clear();425fnCalls.clear();426typeConstPos.clear();427idPosR.clear();428entryPoint = spv::NoResult;429largestNewId = 0;430431idMapL.resize(bound(), unused);432433int fnStart = 0;434spv::Id fnRes = spv::NoResult;435436// build local Id and name maps437process(438[&](spv::Op opCode, unsigned start) {439unsigned word = start+1;440spv::Id typeId = spv::NoResult;441442if (spv::InstructionDesc[opCode].hasType())443typeId = asId(word++);444445// If there's a result ID, remember the size of its type446if (spv::InstructionDesc[opCode].hasResult()) {447const spv::Id resultId = asId(word++);448idPosR[resultId] = start;449450if (typeId != spv::NoResult) {451const unsigned idTypeSize = typeSizeInWords(typeId);452453if (errorLatch)454return false;455456if (idTypeSize != 0)457idTypeSizeMap[resultId] = idTypeSize;458}459}460461if (opCode == spv::Op::OpName) {462const spv::Id target = asId(start+1);463const std::string name = literalString(start+2);464nameMap[name] = target;465466} else if (opCode == spv::Op::OpFunctionCall) {467++fnCalls[asId(start + 3)];468} else if (opCode == spv::Op::OpEntryPoint) {469entryPoint = asId(start + 2);470} else if (opCode == spv::Op::OpFunction) {471if (fnStart != 0) {472error("nested function found");473return false;474}475476fnStart = start;477fnRes = asId(start + 2);478} else if (opCode == spv::Op::OpFunctionEnd) {479assert(fnRes != spv::NoResult);480if (fnStart == 0) {481error("function end without function start");482return false;483}484485fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));486fnStart = 0;487} else if (isConstOp(opCode)) {488if (errorLatch)489return false;490491assert(asId(start + 2) != spv::NoResult);492typeConstPos.insert(start);493} else if (isTypeOp(opCode)) {494assert(asId(start + 1) != spv::NoResult);495typeConstPos.insert(start);496}497498return false;499},500501[this](spv::Id& id) { localId(id, unmapped); }502);503}504505// Validate the SPIR header506void spirvbin_t::validate() const507{508msg(2, 2, std::string("validating: "));509510if (spv.size() < header_size) {511error("file too short: ");512return;513}514515if (magic() != spv::MagicNumber) {516error("bad magic number");517return;518}519520// field 1 = version521// field 2 = generator magic522// field 3 = result <id> bound523524if (schemaNum() != 0) {525error("bad schema, must be 0");526return;527}528}529530int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)531{532const auto instructionStart = word;533const unsigned wordCount = asWordCount(instructionStart);534const int nextInst = word++ + wordCount;535spv::Op opCode = asOpCode(instructionStart);536537if (nextInst > int(spv.size())) {538error("spir instruction terminated too early");539return -1;540}541542// Base for computing number of operands; will be updated as more is learned543unsigned numOperands = wordCount - 1;544545if (instFn(opCode, instructionStart))546return nextInst;547548// Read type and result ID from instruction desc table549if (spv::InstructionDesc[opCode].hasType()) {550idFn(asId(word++));551--numOperands;552}553554if (spv::InstructionDesc[opCode].hasResult()) {555idFn(asId(word++));556--numOperands;557}558559// Extended instructions: currently, assume everything is an ID.560// TODO: add whatever data we need for exceptions to that561if (opCode == spv::OpExtInst) {562563idFn(asId(word)); // Instruction set is an ID that also needs to be mapped564565word += 2; // instruction set, and instruction from set566numOperands -= 2;567568for (unsigned op=0; op < numOperands; ++op)569idFn(asId(word++)); // ID570571return nextInst;572}573574// Circular buffer so we can look back at previous unmapped values during the mapping pass.575static const unsigned idBufferSize = 4;576spv::Id idBuffer[idBufferSize];577unsigned idBufferPos = 0;578579// Store IDs from instruction in our map580for (int op = 0; numOperands > 0; ++op, --numOperands) {581// SpecConstantOp is special: it includes the operands of another opcode which is582// given as a literal in the 3rd word. We will switch over to pretending that the583// opcode being processed is the literal opcode value of the SpecConstantOp. See the584// SPIRV spec for details. This way we will handle IDs and literals as appropriate for585// the embedded op.586if (opCode == spv::OpSpecConstantOp) {587if (op == 0) {588opCode = asOpCode(word++); // this is the opcode embedded in the SpecConstantOp.589--numOperands;590}591}592593switch (spv::InstructionDesc[opCode].operands.getClass(op)) {594case spv::OperandId:595case spv::OperandScope:596case spv::OperandMemorySemantics:597idBuffer[idBufferPos] = asId(word);598idBufferPos = (idBufferPos + 1) % idBufferSize;599idFn(asId(word++));600break;601602case spv::OperandVariableIds:603for (unsigned i = 0; i < numOperands; ++i)604idFn(asId(word++));605return nextInst;606607case spv::OperandVariableLiterals:608// for clarity609// if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {610// ++word;611// --numOperands;612// }613// word += numOperands;614return nextInst;615616case spv::OperandVariableLiteralId: {617if (opCode == OpSwitch) {618// word-2 is the position of the selector ID. OpSwitch Literals match its type.619// In case the IDs are currently being remapped, we get the word[-2] ID from620// the circular idBuffer.621const unsigned literalSizePos = (idBufferPos+idBufferSize-2) % idBufferSize;622const unsigned literalSize = idTypeSizeInWords(idBuffer[literalSizePos]);623const unsigned numLiteralIdPairs = (nextInst-word) / (1+literalSize);624625if (errorLatch)626return -1;627628for (unsigned arg=0; arg<numLiteralIdPairs; ++arg) {629word += literalSize; // literal630idFn(asId(word++)); // label631}632} else {633assert(0); // currentely, only OpSwitch uses OperandVariableLiteralId634}635636return nextInst;637}638639case spv::OperandLiteralString: {640const int stringWordCount = literalStringWords(literalString(word));641word += stringWordCount;642numOperands -= (stringWordCount-1); // -1 because for() header post-decrements643break;644}645646case spv::OperandVariableLiteralStrings:647return nextInst;648649// Execution mode might have extra literal operands. Skip them.650case spv::OperandExecutionMode:651return nextInst;652653// Single word operands we simply ignore, as they hold no IDs654case spv::OperandLiteralNumber:655case spv::OperandSource:656case spv::OperandExecutionModel:657case spv::OperandAddressing:658case spv::OperandMemory:659case spv::OperandStorage:660case spv::OperandDimensionality:661case spv::OperandSamplerAddressingMode:662case spv::OperandSamplerFilterMode:663case spv::OperandSamplerImageFormat:664case spv::OperandImageChannelOrder:665case spv::OperandImageChannelDataType:666case spv::OperandImageOperands:667case spv::OperandFPFastMath:668case spv::OperandFPRoundingMode:669case spv::OperandLinkageType:670case spv::OperandAccessQualifier:671case spv::OperandFuncParamAttr:672case spv::OperandDecoration:673case spv::OperandBuiltIn:674case spv::OperandSelect:675case spv::OperandLoop:676case spv::OperandFunction:677case spv::OperandMemoryAccess:678case spv::OperandGroupOperation:679case spv::OperandKernelEnqueueFlags:680case spv::OperandKernelProfilingInfo:681case spv::OperandCapability:682case spv::OperandCooperativeMatrixOperands:683++word;684break;685686default:687assert(0 && "Unhandled Operand Class");688break;689}690}691692return nextInst;693}694695// Make a pass over all the instructions and process them given appropriate functions696spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)697{698// For efficiency, reserve name map space. It can grow if needed.699nameMap.reserve(32);700701// If begin or end == 0, use defaults702begin = (begin == 0 ? header_size : begin);703end = (end == 0 ? unsigned(spv.size()) : end);704705// basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...706unsigned nextInst = unsigned(spv.size());707708for (unsigned word = begin; word < end; word = nextInst) {709nextInst = processInstruction(word, instFn, idFn);710711if (errorLatch)712return *this;713}714715return *this;716}717718// Apply global name mapping to a single module719void spirvbin_t::mapNames()720{721static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options722static const std::uint32_t firstMappedID = 3019; // offset into ID space723724for (const auto& name : nameMap) {725std::uint32_t hashval = 1911;726for (const char c : name.first)727hashval = hashval * 1009 + c;728729if (isOldIdUnmapped(name.second)) {730localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));731if (errorLatch)732return;733}734}735}736737// Map fn contents to IDs of similar functions in other modules738void spirvbin_t::mapFnBodies()739{740static const std::uint32_t softTypeIdLimit = 19071; // small prime. TODO: get from options741static const std::uint32_t firstMappedID = 6203; // offset into ID space742743// Initial approach: go through some high priority opcodes first and assign them744// hash values.745746spv::Id fnId = spv::NoResult;747std::vector<unsigned> instPos;748instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.749750// Build local table of instruction start positions751process(752[&](spv::Op, unsigned start) { instPos.push_back(start); return true; },753op_fn_nop);754755if (errorLatch)756return;757758// Window size for context-sensitive canonicalization values759// Empirical best size from a single data set. TODO: Would be a good tunable.760// We essentially perform a little convolution around each instruction,761// to capture the flavor of nearby code, to hopefully match to similar762// code in other modules.763static const unsigned windowSize = 2;764765for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {766const unsigned start = instPos[entry];767const spv::Op opCode = asOpCode(start);768769if (opCode == spv::OpFunction)770fnId = asId(start + 2);771772if (opCode == spv::OpFunctionEnd)773fnId = spv::NoResult;774775if (fnId != spv::NoResult) { // if inside a function776if (spv::InstructionDesc[opCode].hasResult()) {777const unsigned word = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);778const spv::Id resId = asId(word);779std::uint32_t hashval = fnId * 17; // small prime780781for (unsigned i = entry-1; i >= entry-windowSize; --i) {782if (asOpCode(instPos[i]) == spv::OpFunction)783break;784hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime785}786787for (unsigned i = entry; i <= entry + windowSize; ++i) {788if (asOpCode(instPos[i]) == spv::OpFunctionEnd)789break;790hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime791}792793if (isOldIdUnmapped(resId)) {794localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));795if (errorLatch)796return;797}798799}800}801}802803spv::Op thisOpCode(spv::OpNop);804std::unordered_map<int, int> opCounter;805int idCounter(0);806fnId = spv::NoResult;807808process(809[&](spv::Op opCode, unsigned start) {810switch (opCode) {811case spv::OpFunction:812// Reset counters at each function813idCounter = 0;814opCounter.clear();815fnId = asId(start + 2);816break;817818case spv::OpImageSampleImplicitLod:819case spv::OpImageSampleExplicitLod:820case spv::OpImageSampleDrefImplicitLod:821case spv::OpImageSampleDrefExplicitLod:822case spv::OpImageSampleProjImplicitLod:823case spv::OpImageSampleProjExplicitLod:824case spv::OpImageSampleProjDrefImplicitLod:825case spv::OpImageSampleProjDrefExplicitLod:826case spv::OpDot:827case spv::OpCompositeExtract:828case spv::OpCompositeInsert:829case spv::OpVectorShuffle:830case spv::OpLabel:831case spv::OpVariable:832833case spv::OpAccessChain:834case spv::OpLoad:835case spv::OpStore:836case spv::OpCompositeConstruct:837case spv::OpFunctionCall:838++opCounter[opCode];839idCounter = 0;840thisOpCode = opCode;841break;842default:843thisOpCode = spv::OpNop;844}845846return false;847},848849[&](spv::Id& id) {850if (thisOpCode != spv::OpNop) {851++idCounter;852const std::uint32_t hashval =853// Explicitly cast operands to unsigned int to avoid integer854// promotion to signed int followed by integer overflow,855// which would result in undefined behavior.856static_cast<unsigned int>(opCounter[thisOpCode])857* thisOpCode858* 50047859+ idCounter860+ static_cast<unsigned int>(fnId) * 117;861862if (isOldIdUnmapped(id))863localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));864}865});866}867868// EXPERIMENTAL: forward IO and uniform load/stores into operands869// This produces invalid Schema-0 SPIRV870void spirvbin_t::forwardLoadStores()871{872idset_t fnLocalVars; // set of function local vars873idmap_t idMap; // Map of load result IDs to what they load874875// EXPERIMENTAL: Forward input and access chain loads into consumptions876process(877[&](spv::Op opCode, unsigned start) {878// Add inputs and uniforms to the map879if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&880(spv[start+3] == spv::StorageClassUniform ||881spv[start+3] == spv::StorageClassUniformConstant ||882spv[start+3] == spv::StorageClassInput))883fnLocalVars.insert(asId(start+2));884885if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)886fnLocalVars.insert(asId(start+2));887888if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {889idMap[asId(start+2)] = asId(start+3);890stripInst(start);891}892893return false;894},895896[&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }897);898899if (errorLatch)900return;901902// EXPERIMENTAL: Implicit output stores903fnLocalVars.clear();904idMap.clear();905906process(907[&](spv::Op opCode, unsigned start) {908// Add inputs and uniforms to the map909if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&910(spv[start+3] == spv::StorageClassOutput))911fnLocalVars.insert(asId(start+2));912913if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {914idMap[asId(start+2)] = asId(start+1);915stripInst(start);916}917918return false;919},920op_fn_nop);921922if (errorLatch)923return;924925process(926inst_fn_nop,927[&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }928);929930if (errorLatch)931return;932933strip(); // strip out data we decided to eliminate934}935936// optimize loads and stores937void spirvbin_t::optLoadStore()938{939idset_t fnLocalVars; // candidates for removal (only locals)940idmap_t idMap; // Map of load result IDs to what they load941blockmap_t blockMap; // Map of IDs to blocks they first appear in942int blockNum = 0; // block count, to avoid crossing flow control943944// Find all the function local pointers stored at most once, and not via access chains945process(946[&](spv::Op opCode, unsigned start) {947const int wordCount = asWordCount(start);948949// Count blocks, so we can avoid crossing flow control950if (isFlowCtrl(opCode))951++blockNum;952953// Add local variables to the map954if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4)) {955fnLocalVars.insert(asId(start+2));956return true;957}958959// Ignore process vars referenced via access chain960if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {961fnLocalVars.erase(asId(start+3));962idMap.erase(asId(start+3));963return true;964}965966if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {967const spv::Id varId = asId(start+3);968969// Avoid loads before stores970if (idMap.find(varId) == idMap.end()) {971fnLocalVars.erase(varId);972idMap.erase(varId);973}974975// don't do for volatile references976if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {977fnLocalVars.erase(varId);978idMap.erase(varId);979}980981// Handle flow control982if (blockMap.find(varId) == blockMap.end()) {983blockMap[varId] = blockNum; // track block we found it in.984} else if (blockMap[varId] != blockNum) {985fnLocalVars.erase(varId); // Ignore if crosses flow control986idMap.erase(varId);987}988989return true;990}991992if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {993const spv::Id varId = asId(start+1);994995if (idMap.find(varId) == idMap.end()) {996idMap[varId] = asId(start+2);997} else {998// Remove if it has more than one store to the same pointer999fnLocalVars.erase(varId);1000idMap.erase(varId);1001}10021003// don't do for volatile references1004if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {1005fnLocalVars.erase(asId(start+3));1006idMap.erase(asId(start+3));1007}10081009// Handle flow control1010if (blockMap.find(varId) == blockMap.end()) {1011blockMap[varId] = blockNum; // track block we found it in.1012} else if (blockMap[varId] != blockNum) {1013fnLocalVars.erase(varId); // Ignore if crosses flow control1014idMap.erase(varId);1015}10161017return true;1018}10191020return false;1021},10221023// If local var id used anywhere else, don't eliminate1024[&](spv::Id& id) {1025if (fnLocalVars.count(id) > 0) {1026fnLocalVars.erase(id);1027idMap.erase(id);1028}1029}1030);10311032if (errorLatch)1033return;10341035process(1036[&](spv::Op opCode, unsigned start) {1037if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)1038idMap[asId(start+2)] = idMap[asId(start+3)];1039return false;1040},1041op_fn_nop);10421043if (errorLatch)1044return;10451046// Chase replacements to their origins, in case there is a chain such as:1047// 2 = store 11048// 3 = load 21049// 4 = store 31050// 5 = load 41051// We want to replace uses of 5 with 1.1052for (const auto& idPair : idMap) {1053spv::Id id = idPair.first;1054while (idMap.find(id) != idMap.end()) // Chase to end of chain1055id = idMap[id];10561057idMap[idPair.first] = id; // replace with final result1058}10591060// Remove the load/store/variables for the ones we've discovered1061process(1062[&](spv::Op opCode, unsigned start) {1063if ((opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) ||1064(opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||1065(opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {10661067stripInst(start);1068return true;1069}10701071return false;1072},10731074[&](spv::Id& id) {1075if (idMap.find(id) != idMap.end()) id = idMap[id];1076}1077);10781079if (errorLatch)1080return;10811082strip(); // strip out data we decided to eliminate1083}10841085// remove bodies of uncalled functions1086void spirvbin_t::dceFuncs()1087{1088msg(3, 2, std::string("Removing Dead Functions: "));10891090// TODO: There are more efficient ways to do this.1091bool changed = true;10921093while (changed) {1094changed = false;10951096for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {1097if (fn->first == entryPoint) { // don't DCE away the entry point!1098++fn;1099continue;1100}11011102const auto call_it = fnCalls.find(fn->first);11031104if (call_it == fnCalls.end() || call_it->second == 0) {1105changed = true;1106stripRange.push_back(fn->second);11071108// decrease counts of called functions1109process(1110[&](spv::Op opCode, unsigned start) {1111if (opCode == spv::Op::OpFunctionCall) {1112const auto call_it = fnCalls.find(asId(start + 3));1113if (call_it != fnCalls.end()) {1114if (--call_it->second <= 0)1115fnCalls.erase(call_it);1116}1117}11181119return true;1120},1121op_fn_nop,1122fn->second.first,1123fn->second.second);11241125if (errorLatch)1126return;11271128fn = fnPos.erase(fn);1129} else ++fn;1130}1131}1132}11331134// remove unused function variables + decorations1135void spirvbin_t::dceVars()1136{1137msg(3, 2, std::string("DCE Vars: "));11381139std::unordered_map<spv::Id, int> varUseCount;11401141// Count function variable use1142process(1143[&](spv::Op opCode, unsigned start) {1144if (opCode == spv::OpVariable) {1145++varUseCount[asId(start+2)];1146return true;1147} else if (opCode == spv::OpEntryPoint) {1148const int wordCount = asWordCount(start);1149for (int i = 4; i < wordCount; i++) {1150++varUseCount[asId(start+i)];1151}1152return true;1153} else1154return false;1155},11561157[&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }1158);11591160if (errorLatch)1161return;11621163// Remove single-use function variables + associated decorations and names1164process(1165[&](spv::Op opCode, unsigned start) {1166spv::Id id = spv::NoResult;1167if (opCode == spv::OpVariable)1168id = asId(start+2);1169if (opCode == spv::OpDecorate || opCode == spv::OpName)1170id = asId(start+1);11711172if (id != spv::NoResult && varUseCount[id] == 1)1173stripInst(start);11741175return true;1176},1177op_fn_nop);1178}11791180// remove unused types1181void spirvbin_t::dceTypes()1182{1183std::vector<bool> isType(bound(), false);11841185// for speed, make O(1) way to get to type query (map is log(n))1186for (const auto typeStart : typeConstPos)1187isType[asTypeConstId(typeStart)] = true;11881189std::unordered_map<spv::Id, int> typeUseCount;11901191// This is not the most efficient algorithm, but this is an offline tool, and1192// it's easy to write this way. Can be improved opportunistically if needed.1193bool changed = true;1194while (changed) {1195changed = false;1196strip();1197typeUseCount.clear();11981199// Count total type usage1200process(inst_fn_nop,1201[&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }1202);12031204if (errorLatch)1205return;12061207// Remove single reference types1208for (const auto typeStart : typeConstPos) {1209const spv::Id typeId = asTypeConstId(typeStart);1210if (typeUseCount[typeId] == 1) {1211changed = true;1212--typeUseCount[typeId];1213stripInst(typeStart);1214}1215}12161217if (errorLatch)1218return;1219}1220}12211222#ifdef NOTDEF1223bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const1224{1225// Find the local type id "lt" and global type id "gt"1226const auto lt_it = typeConstPosR.find(lt);1227if (lt_it == typeConstPosR.end())1228return false;12291230const auto typeStart = lt_it->second;12311232// Search for entry in global table1233const auto gtype = globalTypes.find(gt);1234if (gtype == globalTypes.end())1235return false;12361237const auto& gdata = gtype->second;12381239// local wordcount and opcode1240const int wordCount = asWordCount(typeStart);1241const spv::Op opCode = asOpCode(typeStart);12421243// no type match if opcodes don't match, or operand count doesn't match1244if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))1245return false;12461247const unsigned numOperands = wordCount - 2; // all types have a result12481249const auto cmpIdRange = [&](range_t range) {1250for (int x=range.first; x<std::min(range.second, wordCount); ++x)1251if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))1252return false;1253return true;1254};12551256const auto cmpConst = [&]() { return cmpIdRange(constRange(opCode)); };1257const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode)); };12581259// Compare literals in range [start,end)1260const auto cmpLiteral = [&]() {1261const auto range = literalRange(opCode);1262return std::equal(spir.begin() + typeStart + range.first,1263spir.begin() + typeStart + std::min(range.second, wordCount),1264gdata.begin() + range.first);1265};12661267assert(isTypeOp(opCode) || isConstOp(opCode));12681269switch (opCode) {1270case spv::OpTypeOpaque: // TODO: disable until we compare the literal strings.1271case spv::OpTypeQueue: return false;1272case spv::OpTypeEvent: // fall through...1273case spv::OpTypeDeviceEvent: // ...1274case spv::OpTypeReserveId: return false;1275// for samplers, we don't handle the optional parameters yet1276case spv::OpTypeSampler: return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;1277default: return cmpLiteral() && cmpConst() && cmpSubType();1278}1279}12801281// Look for an equivalent type in the globalTypes map1282spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const1283{1284// Try a recursive type match on each in turn, and return a match if we find one1285for (const auto& gt : globalTypes)1286if (matchType(globalTypes, lt, gt.first))1287return gt.first;12881289return spv::NoType;1290}1291#endif // NOTDEF12921293// Return start position in SPV of given Id. error if not found.1294unsigned spirvbin_t::idPos(spv::Id id) const1295{1296const auto tid_it = idPosR.find(id);1297if (tid_it == idPosR.end()) {1298error("ID not found");1299return 0;1300}13011302return tid_it->second;1303}13041305// Hash types to canonical values. This can return ID collisions (it's a bit1306// inevitable): it's up to the caller to handle that gracefully.1307std::uint32_t spirvbin_t::hashType(unsigned typeStart) const1308{1309const unsigned wordCount = asWordCount(typeStart);1310const spv::Op opCode = asOpCode(typeStart);13111312switch (opCode) {1313case spv::OpTypeVoid: return 0;1314case spv::OpTypeBool: return 1;1315case spv::OpTypeInt: return 3 + (spv[typeStart+3]);1316case spv::OpTypeFloat: return 5;1317case spv::OpTypeVector:1318return 6 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);1319case spv::OpTypeMatrix:1320return 30 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);1321case spv::OpTypeImage:1322return 120 + hashType(idPos(spv[typeStart+2])) +1323spv[typeStart+3] + // dimensionality1324spv[typeStart+4] * 8 * 16 + // depth1325spv[typeStart+5] * 4 * 16 + // arrayed1326spv[typeStart+6] * 2 * 16 + // multisampled1327spv[typeStart+7] * 1 * 16; // format1328case spv::OpTypeSampler:1329return 500;1330case spv::OpTypeSampledImage:1331return 502;1332case spv::OpTypeArray:1333return 501 + hashType(idPos(spv[typeStart+2])) * spv[typeStart+3];1334case spv::OpTypeRuntimeArray:1335return 5000 + hashType(idPos(spv[typeStart+2]));1336case spv::OpTypeStruct:1337{1338std::uint32_t hash = 10000;1339for (unsigned w=2; w < wordCount; ++w)1340hash += w * hashType(idPos(spv[typeStart+w]));1341return hash;1342}13431344case spv::OpTypeOpaque: return 6000 + spv[typeStart+2];1345case spv::OpTypePointer: return 100000 + hashType(idPos(spv[typeStart+3]));1346case spv::OpTypeFunction:1347{1348std::uint32_t hash = 200000;1349for (unsigned w=2; w < wordCount; ++w)1350hash += w * hashType(idPos(spv[typeStart+w]));1351return hash;1352}13531354case spv::OpTypeEvent: return 300000;1355case spv::OpTypeDeviceEvent: return 300001;1356case spv::OpTypeReserveId: return 300002;1357case spv::OpTypeQueue: return 300003;1358case spv::OpTypePipe: return 300004;1359case spv::OpConstantTrue: return 300007;1360case spv::OpConstantFalse: return 300008;1361case spv::OpConstantComposite:1362{1363std::uint32_t hash = 300011 + hashType(idPos(spv[typeStart+1]));1364for (unsigned w=3; w < wordCount; ++w)1365hash += w * hashType(idPos(spv[typeStart+w]));1366return hash;1367}1368case spv::OpConstant:1369{1370std::uint32_t hash = 400011 + hashType(idPos(spv[typeStart+1]));1371for (unsigned w=3; w < wordCount; ++w)1372hash += w * spv[typeStart+w];1373return hash;1374}1375case spv::OpConstantNull:1376{1377std::uint32_t hash = 500009 + hashType(idPos(spv[typeStart+1]));1378return hash;1379}1380case spv::OpConstantSampler:1381{1382std::uint32_t hash = 600011 + hashType(idPos(spv[typeStart+1]));1383for (unsigned w=3; w < wordCount; ++w)1384hash += w * spv[typeStart+w];1385return hash;1386}13871388default:1389error("unknown type opcode");1390return 0;1391}1392}13931394void spirvbin_t::mapTypeConst()1395{1396globaltypes_t globalTypeMap;13971398msg(3, 2, std::string("Remapping Consts & Types: "));13991400static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options1401static const std::uint32_t firstMappedID = 8; // offset into ID space14021403for (auto& typeStart : typeConstPos) {1404const spv::Id resId = asTypeConstId(typeStart);1405const std::uint32_t hashval = hashType(typeStart);14061407if (errorLatch)1408return;14091410if (isOldIdUnmapped(resId)) {1411localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));1412if (errorLatch)1413return;1414}1415}1416}14171418// Strip a single binary by removing ranges given in stripRange1419void spirvbin_t::strip()1420{1421if (stripRange.empty()) // nothing to do1422return;14231424// Sort strip ranges in order of traversal1425std::sort(stripRange.begin(), stripRange.end());14261427// Allocate a new binary big enough to hold old binary1428// We'll step this iterator through the strip ranges as we go through the binary1429auto strip_it = stripRange.begin();14301431int strippedPos = 0;1432for (unsigned word = 0; word < unsigned(spv.size()); ++word) {1433while (strip_it != stripRange.end() && word >= strip_it->second)1434++strip_it;14351436if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)1437spv[strippedPos++] = spv[word];1438}14391440spv.resize(strippedPos);1441stripRange.clear();14421443buildLocalMaps();1444}14451446// Strip a single binary by removing ranges given in stripRange1447void spirvbin_t::remap(std::uint32_t opts)1448{1449options = opts;14501451// Set up opcode tables from SpvDoc1452spv::Parameterize();14531454validate(); // validate header1455buildLocalMaps(); // build ID maps14561457msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));14581459if (options & STRIP) stripDebug();1460if (errorLatch) return;14611462strip(); // strip out data we decided to eliminate1463if (errorLatch) return;14641465if (options & OPT_LOADSTORE) optLoadStore();1466if (errorLatch) return;14671468if (options & OPT_FWD_LS) forwardLoadStores();1469if (errorLatch) return;14701471if (options & DCE_FUNCS) dceFuncs();1472if (errorLatch) return;14731474if (options & DCE_VARS) dceVars();1475if (errorLatch) return;14761477if (options & DCE_TYPES) dceTypes();1478if (errorLatch) return;14791480strip(); // strip out data we decided to eliminate1481if (errorLatch) return;14821483stripDeadRefs(); // remove references to things we DCEed1484if (errorLatch) return;14851486// after the last strip, we must clean any debug info referring to now-deleted data14871488if (options & MAP_TYPES) mapTypeConst();1489if (errorLatch) return;14901491if (options & MAP_NAMES) mapNames();1492if (errorLatch) return;14931494if (options & MAP_FUNCS) mapFnBodies();1495if (errorLatch) return;14961497if (options & MAP_ALL) {1498mapRemainder(); // map any unmapped IDs1499if (errorLatch) return;15001501applyMap(); // Now remap each shader to the new IDs we've come up with1502if (errorLatch) return;1503}1504}15051506// remap from a memory image1507void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, const std::vector<std::string>& whiteListStrings,1508std::uint32_t opts)1509{1510stripWhiteList = whiteListStrings;1511spv.swap(in_spv);1512remap(opts);1513spv.swap(in_spv);1514}15151516// remap from a memory image - legacy interface without white list1517void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)1518{1519stripWhiteList.clear();1520spv.swap(in_spv);1521remap(opts);1522spv.swap(in_spv);1523}15241525} // namespace SPV1526152715281529