Path: blob/21.2-virgl/src/compiler/nir/nir_algebraic.py
4545 views
#1# Copyright (C) 2014 Intel Corporation2#3# Permission is hereby granted, free of charge, to any person obtaining a4# copy of this software and associated documentation files (the "Software"),5# to deal in the Software without restriction, including without limitation6# the rights to use, copy, modify, merge, publish, distribute, sublicense,7# and/or sell copies of the Software, and to permit persons to whom the8# Software is furnished to do so, subject to the following conditions:9#10# The above copyright notice and this permission notice (including the next11# paragraph) shall be included in all copies or substantial portions of the12# Software.13#14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20# IN THE SOFTWARE.21#22# Authors:23# Jason Ekstrand ([email protected])2425from __future__ import print_function26import ast27from collections import defaultdict28import itertools29import struct30import sys31import mako.template32import re33import traceback3435from nir_opcodes import opcodes, type_sizes3637# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c38nir_search_max_comm_ops = 83940# These opcodes are only employed by nir_search. This provides a mapping from41# opcode to destination type.42conv_opcode_types = {43'i2f' : 'float',44'u2f' : 'float',45'f2f' : 'float',46'f2u' : 'uint',47'f2i' : 'int',48'u2u' : 'uint',49'i2i' : 'int',50'b2f' : 'float',51'b2i' : 'int',52'i2b' : 'bool',53'f2b' : 'bool',54}5556def get_c_opcode(op):57if op in conv_opcode_types:58return 'nir_search_op_' + op59else:60return 'nir_op_' + op616263if sys.version_info < (3, 0):64integer_types = (int, long)65string_type = unicode6667else:68integer_types = (int, )69string_type = str7071_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")7273def type_bits(type_str):74m = _type_re.match(type_str)75assert m.group('type')7677if m.group('bits') is None:78return 079else:80return int(m.group('bits'))8182# Represents a set of variables, each with a unique id83class VarSet(object):84def __init__(self):85self.names = {}86self.ids = itertools.count()87self.immutable = False;8889def __getitem__(self, name):90if name not in self.names:91assert not self.immutable, "Unknown replacement variable: " + name92self.names[name] = next(self.ids)9394return self.names[name]9596def lock(self):97self.immutable = True9899class Value(object):100@staticmethod101def create(val, name_base, varset):102if isinstance(val, bytes):103val = val.decode('utf-8')104105if isinstance(val, tuple):106return Expression(val, name_base, varset)107elif isinstance(val, Expression):108return val109elif isinstance(val, string_type):110return Variable(val, name_base, varset)111elif isinstance(val, (bool, float) + integer_types):112return Constant(val, name_base)113114def __init__(self, val, name, type_str):115self.in_val = str(val)116self.name = name117self.type_str = type_str118119def __str__(self):120return self.in_val121122def get_bit_size(self):123"""Get the physical bit-size that has been chosen for this value, or if124there is none, the canonical value which currently represents this125bit-size class. Variables will be preferred, i.e. if there are any126variables in the equivalence class, the canonical value will be a127variable. We do this since we'll need to know which variable each value128is equivalent to when constructing the replacement expression. This is129the "find" part of the union-find algorithm.130"""131bit_size = self132133while isinstance(bit_size, Value):134if bit_size._bit_size is None:135break136bit_size = bit_size._bit_size137138if bit_size is not self:139self._bit_size = bit_size140return bit_size141142def set_bit_size(self, other):143"""Make self.get_bit_size() return what other.get_bit_size() return144before calling this, or just "other" if it's a concrete bit-size. This is145the "union" part of the union-find algorithm.146"""147148self_bit_size = self.get_bit_size()149other_bit_size = other if isinstance(other, int) else other.get_bit_size()150151if self_bit_size == other_bit_size:152return153154self_bit_size._bit_size = other_bit_size155156@property157def type_enum(self):158return "nir_search_value_" + self.type_str159160@property161def c_type(self):162return "nir_search_" + self.type_str163164def __c_name(self, cache):165if cache is not None and self.name in cache:166return cache[self.name]167else:168return self.name169170def c_value_ptr(self, cache):171return "&{0}.value".format(self.__c_name(cache))172173def c_ptr(self, cache):174return "&{0}".format(self.__c_name(cache))175176@property177def c_bit_size(self):178bit_size = self.get_bit_size()179if isinstance(bit_size, int):180return bit_size181elif isinstance(bit_size, Variable):182return -bit_size.index - 1183else:184# If the bit-size class is neither a variable, nor an actual bit-size, then185# - If it's in the search expression, we don't need to check anything186# - If it's in the replace expression, either it's ambiguous (in which187# case we'd reject it), or it equals the bit-size of the search value188# We represent these cases with a 0 bit-size.189return 0190191__template = mako.template.Template("""{192{ ${val.type_enum}, ${val.c_bit_size} },193% if isinstance(val, Constant):194${val.type()}, { ${val.hex()} /* ${val.value} */ },195% elif isinstance(val, Variable):196${val.index}, /* ${val.var_name} */197${'true' if val.is_constant else 'false'},198${val.type() or 'nir_type_invalid' },199${val.cond if val.cond else 'NULL'},200${val.swizzle()},201% elif isinstance(val, Expression):202${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},203${val.comm_expr_idx}, ${val.comm_exprs},204${val.c_opcode()},205{ ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },206${val.cond if val.cond else 'NULL'},207% endif208};""")209210def render(self, cache):211struct_init = self.__template.render(val=self, cache=cache,212Constant=Constant,213Variable=Variable,214Expression=Expression)215if cache is not None and struct_init in cache:216# If it's in the cache, register a name remap in the cache and render217# only a comment saying it's been remapped218cache[self.name] = cache[struct_init]219return "/* {} -> {} in the cache */\n".format(self.name,220cache[struct_init])221else:222if cache is not None:223cache[struct_init] = self.name224return "static const {} {} = {}\n".format(self.c_type, self.name,225struct_init)226227_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")228229class Constant(Value):230def __init__(self, val, name):231Value.__init__(self, val, name, "constant")232233if isinstance(val, (str)):234m = _constant_re.match(val)235self.value = ast.literal_eval(m.group('value'))236self._bit_size = int(m.group('bits')) if m.group('bits') else None237else:238self.value = val239self._bit_size = None240241if isinstance(self.value, bool):242assert self._bit_size is None or self._bit_size == 1243self._bit_size = 1244245def hex(self):246if isinstance(self.value, (bool)):247return 'NIR_TRUE' if self.value else 'NIR_FALSE'248if isinstance(self.value, integer_types):249return hex(self.value)250elif isinstance(self.value, float):251return hex(struct.unpack('Q', struct.pack('d', self.value))[0])252else:253assert False254255def type(self):256if isinstance(self.value, (bool)):257return "nir_type_bool"258elif isinstance(self.value, integer_types):259return "nir_type_int"260elif isinstance(self.value, float):261return "nir_type_float"262263def equivalent(self, other):264"""Check that two constants are equivalent.265266This is check is much weaker than equality. One generally cannot be267used in place of the other. Using this implementation for the __eq__268will break BitSizeValidator.269270"""271if not isinstance(other, type(self)):272return False273274return self.value == other.value275276# The $ at the end forces there to be an error if any part of the string277# doesn't match one of the field patterns.278_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"279r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"280r"(?P<cond>\([^\)]+\))?"281r"(?P<swiz>\.[xyzw]+)?"282r"$")283284class Variable(Value):285def __init__(self, val, name, varset):286Value.__init__(self, val, name, "variable")287288m = _var_name_re.match(val)289assert m and m.group('name') is not None, \290"Malformed variable name \"{}\".".format(val)291292self.var_name = m.group('name')293294# Prevent common cases where someone puts quotes around a literal295# constant. If we want to support names that have numeric or296# punctuation characters, we can me the first assertion more flexible.297assert self.var_name.isalpha()298assert self.var_name != 'True'299assert self.var_name != 'False'300301self.is_constant = m.group('const') is not None302self.cond = m.group('cond')303self.required_type = m.group('type')304self._bit_size = int(m.group('bits')) if m.group('bits') else None305self.swiz = m.group('swiz')306307if self.required_type == 'bool':308if self._bit_size is not None:309assert self._bit_size in type_sizes(self.required_type)310else:311self._bit_size = 1312313if self.required_type is not None:314assert self.required_type in ('float', 'bool', 'int', 'uint')315316self.index = varset[self.var_name]317318def type(self):319if self.required_type == 'bool':320return "nir_type_bool"321elif self.required_type in ('int', 'uint'):322return "nir_type_int"323elif self.required_type == 'float':324return "nir_type_float"325326def equivalent(self, other):327"""Check that two variables are equivalent.328329This is check is much weaker than equality. One generally cannot be330used in place of the other. Using this implementation for the __eq__331will break BitSizeValidator.332333"""334if not isinstance(other, type(self)):335return False336337return self.index == other.index338339def swizzle(self):340if self.swiz is not None:341swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,342'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,343'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,344'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,345'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }346return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'347return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'348349_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"350r"(?P<cond>\([^\)]+\))?")351352class Expression(Value):353def __init__(self, expr, name_base, varset):354Value.__init__(self, expr, name_base, "expression")355assert isinstance(expr, tuple)356357m = _opcode_re.match(expr[0])358assert m and m.group('opcode') is not None359360self.opcode = m.group('opcode')361self._bit_size = int(m.group('bits')) if m.group('bits') else None362self.inexact = m.group('inexact') is not None363self.exact = m.group('exact') is not None364self.cond = m.group('cond')365366assert not self.inexact or not self.exact, \367'Expression cannot be both exact and inexact.'368369# "many-comm-expr" isn't really a condition. It's notification to the370# generator that this pattern is known to have too many commutative371# expressions, and an error should not be generated for this case.372self.many_commutative_expressions = False373if self.cond and self.cond.find("many-comm-expr") >= 0:374# Split the condition into a comma-separated list. Remove375# "many-comm-expr". If there is anything left, put it back together.376c = self.cond[1:-1].split(",")377c.remove("many-comm-expr")378379self.cond = "({})".format(",".join(c)) if c else None380self.many_commutative_expressions = True381382self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)383for (i, src) in enumerate(expr[1:]) ]384385# nir_search_expression::srcs is hard-coded to 4386assert len(self.sources) <= 4387388if self.opcode in conv_opcode_types:389assert self._bit_size is None, \390'Expression cannot use an unsized conversion opcode with ' \391'an explicit size; that\'s silly.'392393self.__index_comm_exprs(0)394395def equivalent(self, other):396"""Check that two variables are equivalent.397398This is check is much weaker than equality. One generally cannot be399used in place of the other. Using this implementation for the __eq__400will break BitSizeValidator.401402This implementation does not check for equivalence due to commutativity,403but it could.404405"""406if not isinstance(other, type(self)):407return False408409if len(self.sources) != len(other.sources):410return False411412if self.opcode != other.opcode:413return False414415return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))416417def __index_comm_exprs(self, base_idx):418"""Recursively count and index commutative expressions419"""420self.comm_exprs = 0421422# A note about the explicit "len(self.sources)" check. The list of423# sources comes from user input, and that input might be bad. Check424# that the expected second source exists before accessing it. Without425# this check, a unit test that does "('iadd', 'a')" will crash.426if self.opcode not in conv_opcode_types and \427"2src_commutative" in opcodes[self.opcode].algebraic_properties and \428len(self.sources) >= 2 and \429not self.sources[0].equivalent(self.sources[1]):430self.comm_expr_idx = base_idx431self.comm_exprs += 1432else:433self.comm_expr_idx = -1434435for s in self.sources:436if isinstance(s, Expression):437s.__index_comm_exprs(base_idx + self.comm_exprs)438self.comm_exprs += s.comm_exprs439440return self.comm_exprs441442def c_opcode(self):443return get_c_opcode(self.opcode)444445def render(self, cache):446srcs = "\n".join(src.render(cache) for src in self.sources)447return srcs + super(Expression, self).render(cache)448449class BitSizeValidator(object):450"""A class for validating bit sizes of expressions.451452NIR supports multiple bit-sizes on expressions in order to handle things453such as fp64. The source and destination of every ALU operation is454assigned a type and that type may or may not specify a bit size. Sources455and destinations whose type does not specify a bit size are considered456"unsized" and automatically take on the bit size of the corresponding457register or SSA value. NIR has two simple rules for bit sizes that are458validated by nir_validator:4594601) A given SSA def or register has a single bit size that is respected by461everything that reads from it or writes to it.4624632) The bit sizes of all unsized inputs/outputs on any given ALU464instruction must match. They need not match the sized inputs or465outputs but they must match each other.466467In order to keep nir_algebraic relatively simple and easy-to-use,468nir_search supports a type of bit-size inference based on the two rules469above. This is similar to type inference in many common programming470languages. If, for instance, you are constructing an add operation and you471know the second source is 16-bit, then you know that the other source and472the destination must also be 16-bit. There are, however, cases where this473inference can be ambiguous or contradictory. Consider, for instance, the474following transformation:475476(('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))477478This transformation can potentially cause a problem because usub_borrow is479well-defined for any bit-size of integer. However, b2i always generates a48032-bit result so it could end up replacing a 64-bit expression with one481that takes two 64-bit values and produces a 32-bit value. As another482example, consider this expression:483484(('bcsel', a, b, 0), ('iand', a, b))485486In this case, in the search expression a must be 32-bit but b can487potentially have any bit size. If we had a 64-bit b value, we would end up488trying to and a 32-bit value with a 64-bit value which would be invalid489490This class solves that problem by providing a validation layer that proves491that a given search-and-replace operation is 100% well-defined before we492generate any code. This ensures that bugs are caught at compile time493rather than at run time.494495Each value maintains a "bit-size class", which is either an actual bit size496or an equivalence class with other values that must have the same bit size.497The validator works by combining bit-size classes with each other according498to the NIR rules outlined above, checking that there are no inconsistencies.499When doing this for the replacement expression, we make sure to never change500the equivalence class of any of the search values. We could make the example501transforms above work by doing some extra run-time checking of the search502expression, but we make the user specify those constraints themselves, to503avoid any surprises. Since the replacement bitsizes can only be connected to504the source bitsize via variables (variables must have the same bitsize in505the source and replacment expressions) or the roots of the expression (the506replacement expression must produce the same bit size as the search507expression), we prevent merging a variable with anything when processing the508replacement expression, or specializing the search bitsize509with anything. The former prevents510511(('bcsel', a, b, 0), ('iand', a, b))512513from being allowed, since we'd have to merge the bitsizes for a and b due to514the 'iand', while the latter prevents515516(('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))517518from being allowed, since the search expression has the bit size of a and b,519which can't be specialized to 32 which is the bitsize of the replace520expression. It also prevents something like:521522(('b2i', ('i2b', a)), ('ineq', a, 0))523524since the bitsize of 'b2i', which can be anything, can't be specialized to525the bitsize of a.526527After doing all this, we check that every subexpression of the replacement528was assigned a constant bitsize, the bitsize of a variable, or the bitsize529of the search expresssion, since those are the things that are known when530constructing the replacement expresssion. Finally, we record the bitsize531needed in nir_search_value so that we know what to do when building the532replacement expression.533"""534535def __init__(self, varset):536self._var_classes = [None] * len(varset.names)537538def compare_bitsizes(self, a, b):539"""Determines which bitsize class is a specialization of the other, or540whether neither is. When we merge two different bitsizes, the541less-specialized bitsize always points to the more-specialized one, so542that calling get_bit_size() always gets you the most specialized bitsize.543The specialization partial order is given by:544- Physical bitsizes are always the most specialized, and a different545bitsize can never specialize another.546- In the search expression, variables can always be specialized to each547other and to physical bitsizes. In the replace expression, we disallow548this to avoid adding extra constraints to the search expression that549the user didn't specify.550- Expressions and constants without a bitsize can always be specialized to551each other and variables, but not the other way around.552553We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,554and None if they are not comparable (neither a <= b nor b <= a).555"""556if isinstance(a, int):557if isinstance(b, int):558return 0 if a == b else None559elif isinstance(b, Variable):560return -1 if self.is_search else None561else:562return -1563elif isinstance(a, Variable):564if isinstance(b, int):565return 1 if self.is_search else None566elif isinstance(b, Variable):567return 0 if self.is_search or a.index == b.index else None568else:569return -1570else:571if isinstance(b, int):572return 1573elif isinstance(b, Variable):574return 1575else:576return 0577578def unify_bit_size(self, a, b, error_msg):579"""Record that a must have the same bit-size as b. If both580have been assigned conflicting physical bit-sizes, call "error_msg" with581the bit-sizes of self and other to get a message and raise an error.582In the replace expression, disallow merging variables with other583variables and physical bit-sizes as well.584"""585a_bit_size = a.get_bit_size()586b_bit_size = b if isinstance(b, int) else b.get_bit_size()587588cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)589590assert cmp_result is not None, \591error_msg(a_bit_size, b_bit_size)592593if cmp_result < 0:594b_bit_size.set_bit_size(a)595elif not isinstance(a_bit_size, int):596a_bit_size.set_bit_size(b)597598def merge_variables(self, val):599"""Perform the first part of type inference by merging all the different600uses of the same variable. We always do this as if we're in the search601expression, even if we're actually not, since otherwise we'd get errors602if the search expression specified some constraint but the replace603expression didn't, because we'd be merging a variable and a constant.604"""605if isinstance(val, Variable):606if self._var_classes[val.index] is None:607self._var_classes[val.index] = val608else:609other = self._var_classes[val.index]610self.unify_bit_size(other, val,611lambda other_bit_size, bit_size:612'Variable {} has conflicting bit size requirements: ' \613'it must have bit size {} and {}'.format(614val.var_name, other_bit_size, bit_size))615elif isinstance(val, Expression):616for src in val.sources:617self.merge_variables(src)618619def validate_value(self, val):620"""Validate the an expression by performing classic Hindley-Milner621type inference on bitsizes. This will detect if there are any conflicting622requirements, and unify variables so that we know which variables must623have the same bitsize. If we're operating on the replace expression, we624will refuse to merge different variables together or merge a variable625with a constant, in order to prevent surprises due to rules unexpectedly626not matching at runtime.627"""628if not isinstance(val, Expression):629return630631# Generic conversion ops are special in that they have a single unsized632# source and an unsized destination and the two don't have to match.633# This means there's no validation or unioning to do here besides the634# len(val.sources) check.635if val.opcode in conv_opcode_types:636assert len(val.sources) == 1, \637"Expression {} has {} sources, expected 1".format(638val, len(val.sources))639self.validate_value(val.sources[0])640return641642nir_op = opcodes[val.opcode]643assert len(val.sources) == nir_op.num_inputs, \644"Expression {} has {} sources, expected {}".format(645val, len(val.sources), nir_op.num_inputs)646647for src in val.sources:648self.validate_value(src)649650dst_type_bits = type_bits(nir_op.output_type)651652# First, unify all the sources. That way, an error coming up because two653# sources have an incompatible bit-size won't produce an error message654# involving the destination.655first_unsized_src = None656for src_type, src in zip(nir_op.input_types, val.sources):657src_type_bits = type_bits(src_type)658if src_type_bits == 0:659if first_unsized_src is None:660first_unsized_src = src661continue662663if self.is_search:664self.unify_bit_size(first_unsized_src, src,665lambda first_unsized_src_bit_size, src_bit_size:666'Source {} of {} must have bit size {}, while source {} ' \667'must have incompatible bit size {}'.format(668first_unsized_src, val, first_unsized_src_bit_size,669src, src_bit_size))670else:671self.unify_bit_size(first_unsized_src, src,672lambda first_unsized_src_bit_size, src_bit_size:673'Sources {} (bit size of {}) and {} (bit size of {}) ' \674'of {} may not have the same bit size when building the ' \675'replacement expression.'.format(676first_unsized_src, first_unsized_src_bit_size, src,677src_bit_size, val))678else:679if self.is_search:680self.unify_bit_size(src, src_type_bits,681lambda src_bit_size, unused:682'{} must have {} bits, but as a source of nir_op_{} '\683'it must have {} bits'.format(684src, src_bit_size, nir_op.name, src_type_bits))685else:686self.unify_bit_size(src, src_type_bits,687lambda src_bit_size, unused:688'{} has the bit size of {}, but as a source of ' \689'nir_op_{} it must have {} bits, which may not be the ' \690'same'.format(691src, src_bit_size, nir_op.name, src_type_bits))692693if dst_type_bits == 0:694if first_unsized_src is not None:695if self.is_search:696self.unify_bit_size(val, first_unsized_src,697lambda val_bit_size, src_bit_size:698'{} must have the bit size of {}, while its source {} ' \699'must have incompatible bit size {}'.format(700val, val_bit_size, first_unsized_src, src_bit_size))701else:702self.unify_bit_size(val, first_unsized_src,703lambda val_bit_size, src_bit_size:704'{} must have {} bits, but its source {} ' \705'(bit size of {}) may not have that bit size ' \706'when building the replacement.'.format(707val, val_bit_size, first_unsized_src, src_bit_size))708else:709self.unify_bit_size(val, dst_type_bits,710lambda dst_bit_size, unused:711'{} must have {} bits, but as a destination of nir_op_{} ' \712'it must have {} bits'.format(713val, dst_bit_size, nir_op.name, dst_type_bits))714715def validate_replace(self, val, search):716bit_size = val.get_bit_size()717assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \718bit_size == search.get_bit_size(), \719'Ambiguous bit size for replacement value {}: ' \720'it cannot be deduced from a variable, a fixed bit size ' \721'somewhere, or the search expression.'.format(val)722723if isinstance(val, Expression):724for src in val.sources:725self.validate_replace(src, search)726727def validate(self, search, replace):728self.is_search = True729self.merge_variables(search)730self.merge_variables(replace)731self.validate_value(search)732733self.is_search = False734self.validate_value(replace)735736# Check that search is always more specialized than replace. Note that737# we're doing this in replace mode, disallowing merging variables.738search_bit_size = search.get_bit_size()739replace_bit_size = replace.get_bit_size()740cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)741742assert cmp_result is not None and cmp_result <= 0, \743'The search expression bit size {} and replace expression ' \744'bit size {} may not be the same'.format(745search_bit_size, replace_bit_size)746747replace.set_bit_size(search)748749self.validate_replace(replace, search)750751_optimization_ids = itertools.count()752753condition_list = ['true']754755class SearchAndReplace(object):756def __init__(self, transform):757self.id = next(_optimization_ids)758759search = transform[0]760replace = transform[1]761if len(transform) > 2:762self.condition = transform[2]763else:764self.condition = 'true'765766if self.condition not in condition_list:767condition_list.append(self.condition)768self.condition_index = condition_list.index(self.condition)769770varset = VarSet()771if isinstance(search, Expression):772self.search = search773else:774self.search = Expression(search, "search{0}".format(self.id), varset)775776varset.lock()777778if isinstance(replace, Value):779self.replace = replace780else:781self.replace = Value.create(replace, "replace{0}".format(self.id), varset)782783BitSizeValidator(varset).validate(self.search, self.replace)784785class TreeAutomaton(object):786"""This class calculates a bottom-up tree automaton to quickly search for787the left-hand sides of tranforms. Tree automatons are a generalization of788classical NFA's and DFA's, where the transition function determines the789state of the parent node based on the state of its children. We construct a790deterministic automaton to match patterns, using a similar algorithm to the791classical NFA to DFA construction. At the moment, it only matches opcodes792and constants (without checking the actual value), leaving more detailed793checking to the search function which actually checks the leaves. The794automaton acts as a quick filter for the search function, requiring only n795+ 1 table lookups for each n-source operation. The implementation is based796on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."797In the language of that reference, this is a frontier-to-root deterministic798automaton using only symbol filtering. The filtering is crucial to reduce799both the time taken to generate the tables and the size of the tables.800"""801def __init__(self, transforms):802self.patterns = [t.search for t in transforms]803self._compute_items()804self._build_table()805#print('num items: {}'.format(len(set(self.items.values()))))806#print('num states: {}'.format(len(self.states)))807#for state, patterns in zip(self.states, self.patterns):808# print('{}: num patterns: {}'.format(state, len(patterns)))809810class IndexMap(object):811"""An indexed list of objects, where one can either lookup an object by812index or find the index associated to an object quickly using a hash813table. Compared to a list, it has a constant time index(). Compared to a814set, it provides a stable iteration order.815"""816def __init__(self, iterable=()):817self.objects = []818self.map = {}819for obj in iterable:820self.add(obj)821822def __getitem__(self, i):823return self.objects[i]824825def __contains__(self, obj):826return obj in self.map827828def __len__(self):829return len(self.objects)830831def __iter__(self):832return iter(self.objects)833834def clear(self):835self.objects = []836self.map.clear()837838def index(self, obj):839return self.map[obj]840841def add(self, obj):842if obj in self.map:843return self.map[obj]844else:845index = len(self.objects)846self.objects.append(obj)847self.map[obj] = index848return index849850def __repr__(self):851return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'852853class Item(object):854"""This represents an "item" in the language of "Tree Automatons." This855is just a subtree of some pattern, which represents a potential partial856match at runtime. We deduplicate them, so that identical subtrees of857different patterns share the same object, and store some extra858information needed for the main algorithm as well.859"""860def __init__(self, opcode, children):861self.opcode = opcode862self.children = children863# These are the indices of patterns for which this item is the root node.864self.patterns = []865# This the set of opcodes for parents of this item. Used to speed up866# filtering.867self.parent_ops = set()868869def __str__(self):870return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'871872def __repr__(self):873return str(self)874875def _compute_items(self):876"""Build a set of all possible items, deduplicating them."""877# This is a map from (opcode, sources) to item.878self.items = {}879880# The set of all opcodes used by the patterns. Used later to avoid881# building and emitting all the tables for opcodes that aren't used.882self.opcodes = self.IndexMap()883884def get_item(opcode, children, pattern=None):885commutative = len(children) >= 2 \886and "2src_commutative" in opcodes[opcode].algebraic_properties887item = self.items.setdefault((opcode, children),888self.Item(opcode, children))889if commutative:890self.items[opcode, (children[1], children[0]) + children[2:]] = item891if pattern is not None:892item.patterns.append(pattern)893return item894895self.wildcard = get_item("__wildcard", ())896self.const = get_item("__const", ())897898def process_subpattern(src, pattern=None):899if isinstance(src, Constant):900# Note: we throw away the actual constant value!901return self.const902elif isinstance(src, Variable):903if src.is_constant:904return self.const905else:906# Note: we throw away which variable it is here! This special907# item is equivalent to nu in "Tree Automatons."908return self.wildcard909else:910assert isinstance(src, Expression)911opcode = src.opcode912stripped = opcode.rstrip('0123456789')913if stripped in conv_opcode_types:914# Matches that use conversion opcodes with a specific type,915# like f2b1, are tricky. Either we construct the automaton to916# match specific NIR opcodes like nir_op_f2b1, in which case we917# need to create separate items for each possible NIR opcode918# for patterns that have a generic opcode like f2b, or we919# construct it to match the search opcode, in which case we920# need to map f2b1 to f2b when constructing the automaton. Here921# we do the latter.922opcode = stripped923self.opcodes.add(opcode)924children = tuple(process_subpattern(c) for c in src.sources)925item = get_item(opcode, children, pattern)926for i, child in enumerate(children):927child.parent_ops.add(opcode)928return item929930for i, pattern in enumerate(self.patterns):931process_subpattern(pattern, i)932933def _build_table(self):934"""This is the core algorithm which builds up the transition table. It935is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .936Comp_a and Filt_{a,i} using integers to identify match sets." It937simultaneously builds up a list of all possible "match sets" or938"states", where each match set represents the set of Item's that match a939given instruction, and builds up the transition table between states.940"""941# Map from opcode + filtered state indices to transitioned state.942self.table = defaultdict(dict)943# Bijection from state to index. q in the original algorithm is944# len(self.states)945self.states = self.IndexMap()946# List of pattern matches for each state index.947self.state_patterns = []948# Map from state index to filtered state index for each opcode.949self.filter = defaultdict(list)950# Bijections from filtered state to filtered state index for each951# opcode, called the "representor sets" in the original algorithm.952# q_{a,j} in the original algorithm is len(self.rep[op]).953self.rep = defaultdict(self.IndexMap)954955# Everything in self.states with a index at least worklist_index is part956# of the worklist of newly created states. There is also a worklist of957# newly fitered states for each opcode, for which worklist_indices958# serves a similar purpose. worklist_index corresponds to p in the959# original algorithm, while worklist_indices is p_{a,j} (although since960# we only filter by opcode/symbol, it's really just p_a).961self.worklist_index = 0962worklist_indices = defaultdict(lambda: 0)963964# This is the set of opcodes for which the filtered worklist is non-empty.965# It's used to avoid scanning opcodes for which there is nothing to966# process when building the transition table. It corresponds to new_a in967# the original algorithm.968new_opcodes = self.IndexMap()969970# Process states on the global worklist, filtering them for each opcode,971# updating the filter tables, and updating the filtered worklists if any972# new filtered states are found. Similar to ComputeRepresenterSets() in973# the original algorithm, although that only processes a single state.974def process_new_states():975while self.worklist_index < len(self.states):976state = self.states[self.worklist_index]977978# Calculate pattern matches for this state. Each pattern is979# assigned to a unique item, so we don't have to worry about980# deduplicating them here. However, we do have to sort them so981# that they're visited at runtime in the order they're specified982# in the source.983patterns = list(sorted(p for item in state for p in item.patterns))984assert len(self.state_patterns) == self.worklist_index985self.state_patterns.append(patterns)986987# calculate filter table for this state, and update filtered988# worklists.989for op in self.opcodes:990filt = self.filter[op]991rep = self.rep[op]992filtered = frozenset(item for item in state if \993op in item.parent_ops)994if filtered in rep:995rep_index = rep.index(filtered)996else:997rep_index = rep.add(filtered)998new_opcodes.add(op)999assert len(filt) == self.worklist_index1000filt.append(rep_index)1001self.worklist_index += 110021003# There are two start states: one which can only match as a wildcard,1004# and one which can match as a wildcard or constant. These will be the1005# states of intrinsics/other instructions and load_const instructions,1006# respectively. The indices of these must match the definitions of1007# WILDCARD_STATE and CONST_STATE below, so that the runtime C code can1008# initialize things correctly.1009self.states.add(frozenset((self.wildcard,)))1010self.states.add(frozenset((self.const,self.wildcard)))1011process_new_states()10121013while len(new_opcodes) > 0:1014for op in new_opcodes:1015rep = self.rep[op]1016table = self.table[op]1017op_worklist_index = worklist_indices[op]1018if op in conv_opcode_types:1019num_srcs = 11020else:1021num_srcs = opcodes[op].num_inputs10221023# Iterate over all possible source combinations where at least one1024# is on the worklist.1025for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):1026if all(src_idx < op_worklist_index for src_idx in src_indices):1027continue10281029srcs = tuple(rep[src_idx] for src_idx in src_indices)10301031# Try all possible pairings of source items and add the1032# corresponding parent items. This is Comp_a from the paper.1033parent = set(self.items[op, item_srcs] for item_srcs in1034itertools.product(*srcs) if (op, item_srcs) in self.items)10351036# We could always start matching something else with a1037# wildcard. This is Cl from the paper.1038parent.add(self.wildcard)10391040table[src_indices] = self.states.add(frozenset(parent))1041worklist_indices[op] = len(rep)1042new_opcodes.clear()1043process_new_states()10441045_algebraic_pass_template = mako.template.Template("""1046#include "nir.h"1047#include "nir_builder.h"1048#include "nir_search.h"1049#include "nir_search_helpers.h"10501051/* What follows is NIR algebraic transform code for the following ${len(xforms)}1052* transforms:1053% for xform in xforms:1054* ${xform.search} => ${xform.replace}1055% endfor1056*/10571058<% cache = {} %>1059% for xform in xforms:1060${xform.search.render(cache)}1061${xform.replace.render(cache)}1062% endfor10631064% for state_id, state_xforms in enumerate(automaton.state_patterns):1065% if state_xforms: # avoid emitting a 0-length array for MSVC1066static const struct transform ${pass_name}_state${state_id}_xforms[] = {1067% for i in state_xforms:1068{ ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },1069% endfor1070};1071% endif1072% endfor10731074static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {1075% for op in automaton.opcodes:1076[${get_c_opcode(op)}] = {1077.filter = (uint16_t []) {1078% for e in automaton.filter[op]:1079${e},1080% endfor1081},1082<%1083num_filtered = len(automaton.rep[op])1084%>1085.num_filtered_states = ${num_filtered},1086.table = (uint16_t []) {1087<%1088num_srcs = len(next(iter(automaton.table[op])))1089%>1090% for indices in itertools.product(range(num_filtered), repeat=num_srcs):1091${automaton.table[op][indices]},1092% endfor1093},1094},1095% endfor1096};10971098const struct transform *${pass_name}_transforms[] = {1099% for i in range(len(automaton.state_patterns)):1100% if automaton.state_patterns[i]:1101${pass_name}_state${i}_xforms,1102% else:1103NULL,1104% endif1105% endfor1106};11071108const uint16_t ${pass_name}_transform_counts[] = {1109% for i in range(len(automaton.state_patterns)):1110% if automaton.state_patterns[i]:1111(uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),1112% else:11130,1114% endif1115% endfor1116};11171118bool1119${pass_name}(nir_shader *shader)1120{1121bool progress = false;1122bool condition_flags[${len(condition_list)}];1123const nir_shader_compiler_options *options = shader->options;1124const shader_info *info = &shader->info;1125(void) options;1126(void) info;11271128% for index, condition in enumerate(condition_list):1129condition_flags[${index}] = ${condition};1130% endfor11311132nir_foreach_function(function, shader) {1133if (function->impl) {1134progress |= nir_algebraic_impl(function->impl, condition_flags,1135${pass_name}_transforms,1136${pass_name}_transform_counts,1137${pass_name}_table);1138}1139}11401141return progress;1142}1143""")114411451146class AlgebraicPass(object):1147def __init__(self, pass_name, transforms):1148self.xforms = []1149self.opcode_xforms = defaultdict(lambda : [])1150self.pass_name = pass_name11511152error = False11531154for xform in transforms:1155if not isinstance(xform, SearchAndReplace):1156try:1157xform = SearchAndReplace(xform)1158except:1159print("Failed to parse transformation:", file=sys.stderr)1160print(" " + str(xform), file=sys.stderr)1161traceback.print_exc(file=sys.stderr)1162print('', file=sys.stderr)1163error = True1164continue11651166self.xforms.append(xform)1167if xform.search.opcode in conv_opcode_types:1168dst_type = conv_opcode_types[xform.search.opcode]1169for size in type_sizes(dst_type):1170sized_opcode = xform.search.opcode + str(size)1171self.opcode_xforms[sized_opcode].append(xform)1172else:1173self.opcode_xforms[xform.search.opcode].append(xform)11741175# Check to make sure the search pattern does not unexpectedly contain1176# more commutative expressions than match_expression (nir_search.c)1177# can handle.1178comm_exprs = xform.search.comm_exprs11791180if xform.search.many_commutative_expressions:1181if comm_exprs <= nir_search_max_comm_ops:1182print("Transform expected to have too many commutative " \1183"expression but did not " \1184"({} <= {}).".format(comm_exprs, nir_search_max_comm_op),1185file=sys.stderr)1186print(" " + str(xform), file=sys.stderr)1187traceback.print_exc(file=sys.stderr)1188print('', file=sys.stderr)1189error = True1190else:1191if comm_exprs > nir_search_max_comm_ops:1192print("Transformation with too many commutative expressions " \1193"({} > {}). Modify pattern or annotate with " \1194"\"many-comm-expr\".".format(comm_exprs,1195nir_search_max_comm_ops),1196file=sys.stderr)1197print(" " + str(xform.search), file=sys.stderr)1198print("{}".format(xform.search.cond), file=sys.stderr)1199error = True12001201self.automaton = TreeAutomaton(self.xforms)12021203if error:1204sys.exit(1)120512061207def render(self):1208return _algebraic_pass_template.render(pass_name=self.pass_name,1209xforms=self.xforms,1210opcode_xforms=self.opcode_xforms,1211condition_list=condition_list,1212automaton=self.automaton,1213get_c_opcode=get_c_opcode,1214itertools=itertools)121512161217