Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mesa
Path: blob/21.2-virgl/src/compiler/nir/nir_algebraic.py
4545 views
1
#
2
# Copyright (C) 2014 Intel Corporation
3
#
4
# Permission is hereby granted, free of charge, to any person obtaining a
5
# copy of this software and associated documentation files (the "Software"),
6
# to deal in the Software without restriction, including without limitation
7
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
# and/or sell copies of the Software, and to permit persons to whom the
9
# Software is furnished to do so, subject to the following conditions:
10
#
11
# The above copyright notice and this permission notice (including the next
12
# paragraph) shall be included in all copies or substantial portions of the
13
# Software.
14
#
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
# IN THE SOFTWARE.
22
#
23
# Authors:
24
# Jason Ekstrand ([email protected])
25
26
from __future__ import print_function
27
import ast
28
from collections import defaultdict
29
import itertools
30
import struct
31
import sys
32
import mako.template
33
import re
34
import traceback
35
36
from nir_opcodes import opcodes, type_sizes
37
38
# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39
nir_search_max_comm_ops = 8
40
41
# These opcodes are only employed by nir_search. This provides a mapping from
42
# opcode to destination type.
43
conv_opcode_types = {
44
'i2f' : 'float',
45
'u2f' : 'float',
46
'f2f' : 'float',
47
'f2u' : 'uint',
48
'f2i' : 'int',
49
'u2u' : 'uint',
50
'i2i' : 'int',
51
'b2f' : 'float',
52
'b2i' : 'int',
53
'i2b' : 'bool',
54
'f2b' : 'bool',
55
}
56
57
def get_c_opcode(op):
58
if op in conv_opcode_types:
59
return 'nir_search_op_' + op
60
else:
61
return 'nir_op_' + op
62
63
64
if sys.version_info < (3, 0):
65
integer_types = (int, long)
66
string_type = unicode
67
68
else:
69
integer_types = (int, )
70
string_type = str
71
72
_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74
def type_bits(type_str):
75
m = _type_re.match(type_str)
76
assert m.group('type')
77
78
if m.group('bits') is None:
79
return 0
80
else:
81
return int(m.group('bits'))
82
83
# Represents a set of variables, each with a unique id
84
class VarSet(object):
85
def __init__(self):
86
self.names = {}
87
self.ids = itertools.count()
88
self.immutable = False;
89
90
def __getitem__(self, name):
91
if name not in self.names:
92
assert not self.immutable, "Unknown replacement variable: " + name
93
self.names[name] = next(self.ids)
94
95
return self.names[name]
96
97
def lock(self):
98
self.immutable = True
99
100
class Value(object):
101
@staticmethod
102
def create(val, name_base, varset):
103
if isinstance(val, bytes):
104
val = val.decode('utf-8')
105
106
if isinstance(val, tuple):
107
return Expression(val, name_base, varset)
108
elif isinstance(val, Expression):
109
return val
110
elif isinstance(val, string_type):
111
return Variable(val, name_base, varset)
112
elif isinstance(val, (bool, float) + integer_types):
113
return Constant(val, name_base)
114
115
def __init__(self, val, name, type_str):
116
self.in_val = str(val)
117
self.name = name
118
self.type_str = type_str
119
120
def __str__(self):
121
return self.in_val
122
123
def get_bit_size(self):
124
"""Get the physical bit-size that has been chosen for this value, or if
125
there is none, the canonical value which currently represents this
126
bit-size class. Variables will be preferred, i.e. if there are any
127
variables in the equivalence class, the canonical value will be a
128
variable. We do this since we'll need to know which variable each value
129
is equivalent to when constructing the replacement expression. This is
130
the "find" part of the union-find algorithm.
131
"""
132
bit_size = self
133
134
while isinstance(bit_size, Value):
135
if bit_size._bit_size is None:
136
break
137
bit_size = bit_size._bit_size
138
139
if bit_size is not self:
140
self._bit_size = bit_size
141
return bit_size
142
143
def set_bit_size(self, other):
144
"""Make self.get_bit_size() return what other.get_bit_size() return
145
before calling this, or just "other" if it's a concrete bit-size. This is
146
the "union" part of the union-find algorithm.
147
"""
148
149
self_bit_size = self.get_bit_size()
150
other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152
if self_bit_size == other_bit_size:
153
return
154
155
self_bit_size._bit_size = other_bit_size
156
157
@property
158
def type_enum(self):
159
return "nir_search_value_" + self.type_str
160
161
@property
162
def c_type(self):
163
return "nir_search_" + self.type_str
164
165
def __c_name(self, cache):
166
if cache is not None and self.name in cache:
167
return cache[self.name]
168
else:
169
return self.name
170
171
def c_value_ptr(self, cache):
172
return "&{0}.value".format(self.__c_name(cache))
173
174
def c_ptr(self, cache):
175
return "&{0}".format(self.__c_name(cache))
176
177
@property
178
def c_bit_size(self):
179
bit_size = self.get_bit_size()
180
if isinstance(bit_size, int):
181
return bit_size
182
elif isinstance(bit_size, Variable):
183
return -bit_size.index - 1
184
else:
185
# If the bit-size class is neither a variable, nor an actual bit-size, then
186
# - If it's in the search expression, we don't need to check anything
187
# - If it's in the replace expression, either it's ambiguous (in which
188
# case we'd reject it), or it equals the bit-size of the search value
189
# We represent these cases with a 0 bit-size.
190
return 0
191
192
__template = mako.template.Template("""{
193
{ ${val.type_enum}, ${val.c_bit_size} },
194
% if isinstance(val, Constant):
195
${val.type()}, { ${val.hex()} /* ${val.value} */ },
196
% elif isinstance(val, Variable):
197
${val.index}, /* ${val.var_name} */
198
${'true' if val.is_constant else 'false'},
199
${val.type() or 'nir_type_invalid' },
200
${val.cond if val.cond else 'NULL'},
201
${val.swizzle()},
202
% elif isinstance(val, Expression):
203
${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
204
${val.comm_expr_idx}, ${val.comm_exprs},
205
${val.c_opcode()},
206
{ ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207
${val.cond if val.cond else 'NULL'},
208
% endif
209
};""")
210
211
def render(self, cache):
212
struct_init = self.__template.render(val=self, cache=cache,
213
Constant=Constant,
214
Variable=Variable,
215
Expression=Expression)
216
if cache is not None and struct_init in cache:
217
# If it's in the cache, register a name remap in the cache and render
218
# only a comment saying it's been remapped
219
cache[self.name] = cache[struct_init]
220
return "/* {} -> {} in the cache */\n".format(self.name,
221
cache[struct_init])
222
else:
223
if cache is not None:
224
cache[struct_init] = self.name
225
return "static const {} {} = {}\n".format(self.c_type, self.name,
226
struct_init)
227
228
_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230
class Constant(Value):
231
def __init__(self, val, name):
232
Value.__init__(self, val, name, "constant")
233
234
if isinstance(val, (str)):
235
m = _constant_re.match(val)
236
self.value = ast.literal_eval(m.group('value'))
237
self._bit_size = int(m.group('bits')) if m.group('bits') else None
238
else:
239
self.value = val
240
self._bit_size = None
241
242
if isinstance(self.value, bool):
243
assert self._bit_size is None or self._bit_size == 1
244
self._bit_size = 1
245
246
def hex(self):
247
if isinstance(self.value, (bool)):
248
return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249
if isinstance(self.value, integer_types):
250
return hex(self.value)
251
elif isinstance(self.value, float):
252
return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
253
else:
254
assert False
255
256
def type(self):
257
if isinstance(self.value, (bool)):
258
return "nir_type_bool"
259
elif isinstance(self.value, integer_types):
260
return "nir_type_int"
261
elif isinstance(self.value, float):
262
return "nir_type_float"
263
264
def equivalent(self, other):
265
"""Check that two constants are equivalent.
266
267
This is check is much weaker than equality. One generally cannot be
268
used in place of the other. Using this implementation for the __eq__
269
will break BitSizeValidator.
270
271
"""
272
if not isinstance(other, type(self)):
273
return False
274
275
return self.value == other.value
276
277
# The $ at the end forces there to be an error if any part of the string
278
# doesn't match one of the field patterns.
279
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
280
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
281
r"(?P<cond>\([^\)]+\))?"
282
r"(?P<swiz>\.[xyzw]+)?"
283
r"$")
284
285
class Variable(Value):
286
def __init__(self, val, name, varset):
287
Value.__init__(self, val, name, "variable")
288
289
m = _var_name_re.match(val)
290
assert m and m.group('name') is not None, \
291
"Malformed variable name \"{}\".".format(val)
292
293
self.var_name = m.group('name')
294
295
# Prevent common cases where someone puts quotes around a literal
296
# constant. If we want to support names that have numeric or
297
# punctuation characters, we can me the first assertion more flexible.
298
assert self.var_name.isalpha()
299
assert self.var_name != 'True'
300
assert self.var_name != 'False'
301
302
self.is_constant = m.group('const') is not None
303
self.cond = m.group('cond')
304
self.required_type = m.group('type')
305
self._bit_size = int(m.group('bits')) if m.group('bits') else None
306
self.swiz = m.group('swiz')
307
308
if self.required_type == 'bool':
309
if self._bit_size is not None:
310
assert self._bit_size in type_sizes(self.required_type)
311
else:
312
self._bit_size = 1
313
314
if self.required_type is not None:
315
assert self.required_type in ('float', 'bool', 'int', 'uint')
316
317
self.index = varset[self.var_name]
318
319
def type(self):
320
if self.required_type == 'bool':
321
return "nir_type_bool"
322
elif self.required_type in ('int', 'uint'):
323
return "nir_type_int"
324
elif self.required_type == 'float':
325
return "nir_type_float"
326
327
def equivalent(self, other):
328
"""Check that two variables are equivalent.
329
330
This is check is much weaker than equality. One generally cannot be
331
used in place of the other. Using this implementation for the __eq__
332
will break BitSizeValidator.
333
334
"""
335
if not isinstance(other, type(self)):
336
return False
337
338
return self.index == other.index
339
340
def swizzle(self):
341
if self.swiz is not None:
342
swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
343
'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
344
'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
345
'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
346
'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
347
return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
348
return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
349
350
_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
351
r"(?P<cond>\([^\)]+\))?")
352
353
class Expression(Value):
354
def __init__(self, expr, name_base, varset):
355
Value.__init__(self, expr, name_base, "expression")
356
assert isinstance(expr, tuple)
357
358
m = _opcode_re.match(expr[0])
359
assert m and m.group('opcode') is not None
360
361
self.opcode = m.group('opcode')
362
self._bit_size = int(m.group('bits')) if m.group('bits') else None
363
self.inexact = m.group('inexact') is not None
364
self.exact = m.group('exact') is not None
365
self.cond = m.group('cond')
366
367
assert not self.inexact or not self.exact, \
368
'Expression cannot be both exact and inexact.'
369
370
# "many-comm-expr" isn't really a condition. It's notification to the
371
# generator that this pattern is known to have too many commutative
372
# expressions, and an error should not be generated for this case.
373
self.many_commutative_expressions = False
374
if self.cond and self.cond.find("many-comm-expr") >= 0:
375
# Split the condition into a comma-separated list. Remove
376
# "many-comm-expr". If there is anything left, put it back together.
377
c = self.cond[1:-1].split(",")
378
c.remove("many-comm-expr")
379
380
self.cond = "({})".format(",".join(c)) if c else None
381
self.many_commutative_expressions = True
382
383
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
384
for (i, src) in enumerate(expr[1:]) ]
385
386
# nir_search_expression::srcs is hard-coded to 4
387
assert len(self.sources) <= 4
388
389
if self.opcode in conv_opcode_types:
390
assert self._bit_size is None, \
391
'Expression cannot use an unsized conversion opcode with ' \
392
'an explicit size; that\'s silly.'
393
394
self.__index_comm_exprs(0)
395
396
def equivalent(self, other):
397
"""Check that two variables are equivalent.
398
399
This is check is much weaker than equality. One generally cannot be
400
used in place of the other. Using this implementation for the __eq__
401
will break BitSizeValidator.
402
403
This implementation does not check for equivalence due to commutativity,
404
but it could.
405
406
"""
407
if not isinstance(other, type(self)):
408
return False
409
410
if len(self.sources) != len(other.sources):
411
return False
412
413
if self.opcode != other.opcode:
414
return False
415
416
return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
417
418
def __index_comm_exprs(self, base_idx):
419
"""Recursively count and index commutative expressions
420
"""
421
self.comm_exprs = 0
422
423
# A note about the explicit "len(self.sources)" check. The list of
424
# sources comes from user input, and that input might be bad. Check
425
# that the expected second source exists before accessing it. Without
426
# this check, a unit test that does "('iadd', 'a')" will crash.
427
if self.opcode not in conv_opcode_types and \
428
"2src_commutative" in opcodes[self.opcode].algebraic_properties and \
429
len(self.sources) >= 2 and \
430
not self.sources[0].equivalent(self.sources[1]):
431
self.comm_expr_idx = base_idx
432
self.comm_exprs += 1
433
else:
434
self.comm_expr_idx = -1
435
436
for s in self.sources:
437
if isinstance(s, Expression):
438
s.__index_comm_exprs(base_idx + self.comm_exprs)
439
self.comm_exprs += s.comm_exprs
440
441
return self.comm_exprs
442
443
def c_opcode(self):
444
return get_c_opcode(self.opcode)
445
446
def render(self, cache):
447
srcs = "\n".join(src.render(cache) for src in self.sources)
448
return srcs + super(Expression, self).render(cache)
449
450
class BitSizeValidator(object):
451
"""A class for validating bit sizes of expressions.
452
453
NIR supports multiple bit-sizes on expressions in order to handle things
454
such as fp64. The source and destination of every ALU operation is
455
assigned a type and that type may or may not specify a bit size. Sources
456
and destinations whose type does not specify a bit size are considered
457
"unsized" and automatically take on the bit size of the corresponding
458
register or SSA value. NIR has two simple rules for bit sizes that are
459
validated by nir_validator:
460
461
1) A given SSA def or register has a single bit size that is respected by
462
everything that reads from it or writes to it.
463
464
2) The bit sizes of all unsized inputs/outputs on any given ALU
465
instruction must match. They need not match the sized inputs or
466
outputs but they must match each other.
467
468
In order to keep nir_algebraic relatively simple and easy-to-use,
469
nir_search supports a type of bit-size inference based on the two rules
470
above. This is similar to type inference in many common programming
471
languages. If, for instance, you are constructing an add operation and you
472
know the second source is 16-bit, then you know that the other source and
473
the destination must also be 16-bit. There are, however, cases where this
474
inference can be ambiguous or contradictory. Consider, for instance, the
475
following transformation:
476
477
(('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
478
479
This transformation can potentially cause a problem because usub_borrow is
480
well-defined for any bit-size of integer. However, b2i always generates a
481
32-bit result so it could end up replacing a 64-bit expression with one
482
that takes two 64-bit values and produces a 32-bit value. As another
483
example, consider this expression:
484
485
(('bcsel', a, b, 0), ('iand', a, b))
486
487
In this case, in the search expression a must be 32-bit but b can
488
potentially have any bit size. If we had a 64-bit b value, we would end up
489
trying to and a 32-bit value with a 64-bit value which would be invalid
490
491
This class solves that problem by providing a validation layer that proves
492
that a given search-and-replace operation is 100% well-defined before we
493
generate any code. This ensures that bugs are caught at compile time
494
rather than at run time.
495
496
Each value maintains a "bit-size class", which is either an actual bit size
497
or an equivalence class with other values that must have the same bit size.
498
The validator works by combining bit-size classes with each other according
499
to the NIR rules outlined above, checking that there are no inconsistencies.
500
When doing this for the replacement expression, we make sure to never change
501
the equivalence class of any of the search values. We could make the example
502
transforms above work by doing some extra run-time checking of the search
503
expression, but we make the user specify those constraints themselves, to
504
avoid any surprises. Since the replacement bitsizes can only be connected to
505
the source bitsize via variables (variables must have the same bitsize in
506
the source and replacment expressions) or the roots of the expression (the
507
replacement expression must produce the same bit size as the search
508
expression), we prevent merging a variable with anything when processing the
509
replacement expression, or specializing the search bitsize
510
with anything. The former prevents
511
512
(('bcsel', a, b, 0), ('iand', a, b))
513
514
from being allowed, since we'd have to merge the bitsizes for a and b due to
515
the 'iand', while the latter prevents
516
517
(('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
518
519
from being allowed, since the search expression has the bit size of a and b,
520
which can't be specialized to 32 which is the bitsize of the replace
521
expression. It also prevents something like:
522
523
(('b2i', ('i2b', a)), ('ineq', a, 0))
524
525
since the bitsize of 'b2i', which can be anything, can't be specialized to
526
the bitsize of a.
527
528
After doing all this, we check that every subexpression of the replacement
529
was assigned a constant bitsize, the bitsize of a variable, or the bitsize
530
of the search expresssion, since those are the things that are known when
531
constructing the replacement expresssion. Finally, we record the bitsize
532
needed in nir_search_value so that we know what to do when building the
533
replacement expression.
534
"""
535
536
def __init__(self, varset):
537
self._var_classes = [None] * len(varset.names)
538
539
def compare_bitsizes(self, a, b):
540
"""Determines which bitsize class is a specialization of the other, or
541
whether neither is. When we merge two different bitsizes, the
542
less-specialized bitsize always points to the more-specialized one, so
543
that calling get_bit_size() always gets you the most specialized bitsize.
544
The specialization partial order is given by:
545
- Physical bitsizes are always the most specialized, and a different
546
bitsize can never specialize another.
547
- In the search expression, variables can always be specialized to each
548
other and to physical bitsizes. In the replace expression, we disallow
549
this to avoid adding extra constraints to the search expression that
550
the user didn't specify.
551
- Expressions and constants without a bitsize can always be specialized to
552
each other and variables, but not the other way around.
553
554
We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
555
and None if they are not comparable (neither a <= b nor b <= a).
556
"""
557
if isinstance(a, int):
558
if isinstance(b, int):
559
return 0 if a == b else None
560
elif isinstance(b, Variable):
561
return -1 if self.is_search else None
562
else:
563
return -1
564
elif isinstance(a, Variable):
565
if isinstance(b, int):
566
return 1 if self.is_search else None
567
elif isinstance(b, Variable):
568
return 0 if self.is_search or a.index == b.index else None
569
else:
570
return -1
571
else:
572
if isinstance(b, int):
573
return 1
574
elif isinstance(b, Variable):
575
return 1
576
else:
577
return 0
578
579
def unify_bit_size(self, a, b, error_msg):
580
"""Record that a must have the same bit-size as b. If both
581
have been assigned conflicting physical bit-sizes, call "error_msg" with
582
the bit-sizes of self and other to get a message and raise an error.
583
In the replace expression, disallow merging variables with other
584
variables and physical bit-sizes as well.
585
"""
586
a_bit_size = a.get_bit_size()
587
b_bit_size = b if isinstance(b, int) else b.get_bit_size()
588
589
cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
590
591
assert cmp_result is not None, \
592
error_msg(a_bit_size, b_bit_size)
593
594
if cmp_result < 0:
595
b_bit_size.set_bit_size(a)
596
elif not isinstance(a_bit_size, int):
597
a_bit_size.set_bit_size(b)
598
599
def merge_variables(self, val):
600
"""Perform the first part of type inference by merging all the different
601
uses of the same variable. We always do this as if we're in the search
602
expression, even if we're actually not, since otherwise we'd get errors
603
if the search expression specified some constraint but the replace
604
expression didn't, because we'd be merging a variable and a constant.
605
"""
606
if isinstance(val, Variable):
607
if self._var_classes[val.index] is None:
608
self._var_classes[val.index] = val
609
else:
610
other = self._var_classes[val.index]
611
self.unify_bit_size(other, val,
612
lambda other_bit_size, bit_size:
613
'Variable {} has conflicting bit size requirements: ' \
614
'it must have bit size {} and {}'.format(
615
val.var_name, other_bit_size, bit_size))
616
elif isinstance(val, Expression):
617
for src in val.sources:
618
self.merge_variables(src)
619
620
def validate_value(self, val):
621
"""Validate the an expression by performing classic Hindley-Milner
622
type inference on bitsizes. This will detect if there are any conflicting
623
requirements, and unify variables so that we know which variables must
624
have the same bitsize. If we're operating on the replace expression, we
625
will refuse to merge different variables together or merge a variable
626
with a constant, in order to prevent surprises due to rules unexpectedly
627
not matching at runtime.
628
"""
629
if not isinstance(val, Expression):
630
return
631
632
# Generic conversion ops are special in that they have a single unsized
633
# source and an unsized destination and the two don't have to match.
634
# This means there's no validation or unioning to do here besides the
635
# len(val.sources) check.
636
if val.opcode in conv_opcode_types:
637
assert len(val.sources) == 1, \
638
"Expression {} has {} sources, expected 1".format(
639
val, len(val.sources))
640
self.validate_value(val.sources[0])
641
return
642
643
nir_op = opcodes[val.opcode]
644
assert len(val.sources) == nir_op.num_inputs, \
645
"Expression {} has {} sources, expected {}".format(
646
val, len(val.sources), nir_op.num_inputs)
647
648
for src in val.sources:
649
self.validate_value(src)
650
651
dst_type_bits = type_bits(nir_op.output_type)
652
653
# First, unify all the sources. That way, an error coming up because two
654
# sources have an incompatible bit-size won't produce an error message
655
# involving the destination.
656
first_unsized_src = None
657
for src_type, src in zip(nir_op.input_types, val.sources):
658
src_type_bits = type_bits(src_type)
659
if src_type_bits == 0:
660
if first_unsized_src is None:
661
first_unsized_src = src
662
continue
663
664
if self.is_search:
665
self.unify_bit_size(first_unsized_src, src,
666
lambda first_unsized_src_bit_size, src_bit_size:
667
'Source {} of {} must have bit size {}, while source {} ' \
668
'must have incompatible bit size {}'.format(
669
first_unsized_src, val, first_unsized_src_bit_size,
670
src, src_bit_size))
671
else:
672
self.unify_bit_size(first_unsized_src, src,
673
lambda first_unsized_src_bit_size, src_bit_size:
674
'Sources {} (bit size of {}) and {} (bit size of {}) ' \
675
'of {} may not have the same bit size when building the ' \
676
'replacement expression.'.format(
677
first_unsized_src, first_unsized_src_bit_size, src,
678
src_bit_size, val))
679
else:
680
if self.is_search:
681
self.unify_bit_size(src, src_type_bits,
682
lambda src_bit_size, unused:
683
'{} must have {} bits, but as a source of nir_op_{} '\
684
'it must have {} bits'.format(
685
src, src_bit_size, nir_op.name, src_type_bits))
686
else:
687
self.unify_bit_size(src, src_type_bits,
688
lambda src_bit_size, unused:
689
'{} has the bit size of {}, but as a source of ' \
690
'nir_op_{} it must have {} bits, which may not be the ' \
691
'same'.format(
692
src, src_bit_size, nir_op.name, src_type_bits))
693
694
if dst_type_bits == 0:
695
if first_unsized_src is not None:
696
if self.is_search:
697
self.unify_bit_size(val, first_unsized_src,
698
lambda val_bit_size, src_bit_size:
699
'{} must have the bit size of {}, while its source {} ' \
700
'must have incompatible bit size {}'.format(
701
val, val_bit_size, first_unsized_src, src_bit_size))
702
else:
703
self.unify_bit_size(val, first_unsized_src,
704
lambda val_bit_size, src_bit_size:
705
'{} must have {} bits, but its source {} ' \
706
'(bit size of {}) may not have that bit size ' \
707
'when building the replacement.'.format(
708
val, val_bit_size, first_unsized_src, src_bit_size))
709
else:
710
self.unify_bit_size(val, dst_type_bits,
711
lambda dst_bit_size, unused:
712
'{} must have {} bits, but as a destination of nir_op_{} ' \
713
'it must have {} bits'.format(
714
val, dst_bit_size, nir_op.name, dst_type_bits))
715
716
def validate_replace(self, val, search):
717
bit_size = val.get_bit_size()
718
assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
719
bit_size == search.get_bit_size(), \
720
'Ambiguous bit size for replacement value {}: ' \
721
'it cannot be deduced from a variable, a fixed bit size ' \
722
'somewhere, or the search expression.'.format(val)
723
724
if isinstance(val, Expression):
725
for src in val.sources:
726
self.validate_replace(src, search)
727
728
def validate(self, search, replace):
729
self.is_search = True
730
self.merge_variables(search)
731
self.merge_variables(replace)
732
self.validate_value(search)
733
734
self.is_search = False
735
self.validate_value(replace)
736
737
# Check that search is always more specialized than replace. Note that
738
# we're doing this in replace mode, disallowing merging variables.
739
search_bit_size = search.get_bit_size()
740
replace_bit_size = replace.get_bit_size()
741
cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
742
743
assert cmp_result is not None and cmp_result <= 0, \
744
'The search expression bit size {} and replace expression ' \
745
'bit size {} may not be the same'.format(
746
search_bit_size, replace_bit_size)
747
748
replace.set_bit_size(search)
749
750
self.validate_replace(replace, search)
751
752
_optimization_ids = itertools.count()
753
754
condition_list = ['true']
755
756
class SearchAndReplace(object):
757
def __init__(self, transform):
758
self.id = next(_optimization_ids)
759
760
search = transform[0]
761
replace = transform[1]
762
if len(transform) > 2:
763
self.condition = transform[2]
764
else:
765
self.condition = 'true'
766
767
if self.condition not in condition_list:
768
condition_list.append(self.condition)
769
self.condition_index = condition_list.index(self.condition)
770
771
varset = VarSet()
772
if isinstance(search, Expression):
773
self.search = search
774
else:
775
self.search = Expression(search, "search{0}".format(self.id), varset)
776
777
varset.lock()
778
779
if isinstance(replace, Value):
780
self.replace = replace
781
else:
782
self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
783
784
BitSizeValidator(varset).validate(self.search, self.replace)
785
786
class TreeAutomaton(object):
787
"""This class calculates a bottom-up tree automaton to quickly search for
788
the left-hand sides of tranforms. Tree automatons are a generalization of
789
classical NFA's and DFA's, where the transition function determines the
790
state of the parent node based on the state of its children. We construct a
791
deterministic automaton to match patterns, using a similar algorithm to the
792
classical NFA to DFA construction. At the moment, it only matches opcodes
793
and constants (without checking the actual value), leaving more detailed
794
checking to the search function which actually checks the leaves. The
795
automaton acts as a quick filter for the search function, requiring only n
796
+ 1 table lookups for each n-source operation. The implementation is based
797
on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
798
In the language of that reference, this is a frontier-to-root deterministic
799
automaton using only symbol filtering. The filtering is crucial to reduce
800
both the time taken to generate the tables and the size of the tables.
801
"""
802
def __init__(self, transforms):
803
self.patterns = [t.search for t in transforms]
804
self._compute_items()
805
self._build_table()
806
#print('num items: {}'.format(len(set(self.items.values()))))
807
#print('num states: {}'.format(len(self.states)))
808
#for state, patterns in zip(self.states, self.patterns):
809
# print('{}: num patterns: {}'.format(state, len(patterns)))
810
811
class IndexMap(object):
812
"""An indexed list of objects, where one can either lookup an object by
813
index or find the index associated to an object quickly using a hash
814
table. Compared to a list, it has a constant time index(). Compared to a
815
set, it provides a stable iteration order.
816
"""
817
def __init__(self, iterable=()):
818
self.objects = []
819
self.map = {}
820
for obj in iterable:
821
self.add(obj)
822
823
def __getitem__(self, i):
824
return self.objects[i]
825
826
def __contains__(self, obj):
827
return obj in self.map
828
829
def __len__(self):
830
return len(self.objects)
831
832
def __iter__(self):
833
return iter(self.objects)
834
835
def clear(self):
836
self.objects = []
837
self.map.clear()
838
839
def index(self, obj):
840
return self.map[obj]
841
842
def add(self, obj):
843
if obj in self.map:
844
return self.map[obj]
845
else:
846
index = len(self.objects)
847
self.objects.append(obj)
848
self.map[obj] = index
849
return index
850
851
def __repr__(self):
852
return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
853
854
class Item(object):
855
"""This represents an "item" in the language of "Tree Automatons." This
856
is just a subtree of some pattern, which represents a potential partial
857
match at runtime. We deduplicate them, so that identical subtrees of
858
different patterns share the same object, and store some extra
859
information needed for the main algorithm as well.
860
"""
861
def __init__(self, opcode, children):
862
self.opcode = opcode
863
self.children = children
864
# These are the indices of patterns for which this item is the root node.
865
self.patterns = []
866
# This the set of opcodes for parents of this item. Used to speed up
867
# filtering.
868
self.parent_ops = set()
869
870
def __str__(self):
871
return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
872
873
def __repr__(self):
874
return str(self)
875
876
def _compute_items(self):
877
"""Build a set of all possible items, deduplicating them."""
878
# This is a map from (opcode, sources) to item.
879
self.items = {}
880
881
# The set of all opcodes used by the patterns. Used later to avoid
882
# building and emitting all the tables for opcodes that aren't used.
883
self.opcodes = self.IndexMap()
884
885
def get_item(opcode, children, pattern=None):
886
commutative = len(children) >= 2 \
887
and "2src_commutative" in opcodes[opcode].algebraic_properties
888
item = self.items.setdefault((opcode, children),
889
self.Item(opcode, children))
890
if commutative:
891
self.items[opcode, (children[1], children[0]) + children[2:]] = item
892
if pattern is not None:
893
item.patterns.append(pattern)
894
return item
895
896
self.wildcard = get_item("__wildcard", ())
897
self.const = get_item("__const", ())
898
899
def process_subpattern(src, pattern=None):
900
if isinstance(src, Constant):
901
# Note: we throw away the actual constant value!
902
return self.const
903
elif isinstance(src, Variable):
904
if src.is_constant:
905
return self.const
906
else:
907
# Note: we throw away which variable it is here! This special
908
# item is equivalent to nu in "Tree Automatons."
909
return self.wildcard
910
else:
911
assert isinstance(src, Expression)
912
opcode = src.opcode
913
stripped = opcode.rstrip('0123456789')
914
if stripped in conv_opcode_types:
915
# Matches that use conversion opcodes with a specific type,
916
# like f2b1, are tricky. Either we construct the automaton to
917
# match specific NIR opcodes like nir_op_f2b1, in which case we
918
# need to create separate items for each possible NIR opcode
919
# for patterns that have a generic opcode like f2b, or we
920
# construct it to match the search opcode, in which case we
921
# need to map f2b1 to f2b when constructing the automaton. Here
922
# we do the latter.
923
opcode = stripped
924
self.opcodes.add(opcode)
925
children = tuple(process_subpattern(c) for c in src.sources)
926
item = get_item(opcode, children, pattern)
927
for i, child in enumerate(children):
928
child.parent_ops.add(opcode)
929
return item
930
931
for i, pattern in enumerate(self.patterns):
932
process_subpattern(pattern, i)
933
934
def _build_table(self):
935
"""This is the core algorithm which builds up the transition table. It
936
is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
937
Comp_a and Filt_{a,i} using integers to identify match sets." It
938
simultaneously builds up a list of all possible "match sets" or
939
"states", where each match set represents the set of Item's that match a
940
given instruction, and builds up the transition table between states.
941
"""
942
# Map from opcode + filtered state indices to transitioned state.
943
self.table = defaultdict(dict)
944
# Bijection from state to index. q in the original algorithm is
945
# len(self.states)
946
self.states = self.IndexMap()
947
# List of pattern matches for each state index.
948
self.state_patterns = []
949
# Map from state index to filtered state index for each opcode.
950
self.filter = defaultdict(list)
951
# Bijections from filtered state to filtered state index for each
952
# opcode, called the "representor sets" in the original algorithm.
953
# q_{a,j} in the original algorithm is len(self.rep[op]).
954
self.rep = defaultdict(self.IndexMap)
955
956
# Everything in self.states with a index at least worklist_index is part
957
# of the worklist of newly created states. There is also a worklist of
958
# newly fitered states for each opcode, for which worklist_indices
959
# serves a similar purpose. worklist_index corresponds to p in the
960
# original algorithm, while worklist_indices is p_{a,j} (although since
961
# we only filter by opcode/symbol, it's really just p_a).
962
self.worklist_index = 0
963
worklist_indices = defaultdict(lambda: 0)
964
965
# This is the set of opcodes for which the filtered worklist is non-empty.
966
# It's used to avoid scanning opcodes for which there is nothing to
967
# process when building the transition table. It corresponds to new_a in
968
# the original algorithm.
969
new_opcodes = self.IndexMap()
970
971
# Process states on the global worklist, filtering them for each opcode,
972
# updating the filter tables, and updating the filtered worklists if any
973
# new filtered states are found. Similar to ComputeRepresenterSets() in
974
# the original algorithm, although that only processes a single state.
975
def process_new_states():
976
while self.worklist_index < len(self.states):
977
state = self.states[self.worklist_index]
978
979
# Calculate pattern matches for this state. Each pattern is
980
# assigned to a unique item, so we don't have to worry about
981
# deduplicating them here. However, we do have to sort them so
982
# that they're visited at runtime in the order they're specified
983
# in the source.
984
patterns = list(sorted(p for item in state for p in item.patterns))
985
assert len(self.state_patterns) == self.worklist_index
986
self.state_patterns.append(patterns)
987
988
# calculate filter table for this state, and update filtered
989
# worklists.
990
for op in self.opcodes:
991
filt = self.filter[op]
992
rep = self.rep[op]
993
filtered = frozenset(item for item in state if \
994
op in item.parent_ops)
995
if filtered in rep:
996
rep_index = rep.index(filtered)
997
else:
998
rep_index = rep.add(filtered)
999
new_opcodes.add(op)
1000
assert len(filt) == self.worklist_index
1001
filt.append(rep_index)
1002
self.worklist_index += 1
1003
1004
# There are two start states: one which can only match as a wildcard,
1005
# and one which can match as a wildcard or constant. These will be the
1006
# states of intrinsics/other instructions and load_const instructions,
1007
# respectively. The indices of these must match the definitions of
1008
# WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1009
# initialize things correctly.
1010
self.states.add(frozenset((self.wildcard,)))
1011
self.states.add(frozenset((self.const,self.wildcard)))
1012
process_new_states()
1013
1014
while len(new_opcodes) > 0:
1015
for op in new_opcodes:
1016
rep = self.rep[op]
1017
table = self.table[op]
1018
op_worklist_index = worklist_indices[op]
1019
if op in conv_opcode_types:
1020
num_srcs = 1
1021
else:
1022
num_srcs = opcodes[op].num_inputs
1023
1024
# Iterate over all possible source combinations where at least one
1025
# is on the worklist.
1026
for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1027
if all(src_idx < op_worklist_index for src_idx in src_indices):
1028
continue
1029
1030
srcs = tuple(rep[src_idx] for src_idx in src_indices)
1031
1032
# Try all possible pairings of source items and add the
1033
# corresponding parent items. This is Comp_a from the paper.
1034
parent = set(self.items[op, item_srcs] for item_srcs in
1035
itertools.product(*srcs) if (op, item_srcs) in self.items)
1036
1037
# We could always start matching something else with a
1038
# wildcard. This is Cl from the paper.
1039
parent.add(self.wildcard)
1040
1041
table[src_indices] = self.states.add(frozenset(parent))
1042
worklist_indices[op] = len(rep)
1043
new_opcodes.clear()
1044
process_new_states()
1045
1046
_algebraic_pass_template = mako.template.Template("""
1047
#include "nir.h"
1048
#include "nir_builder.h"
1049
#include "nir_search.h"
1050
#include "nir_search_helpers.h"
1051
1052
/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1053
* transforms:
1054
% for xform in xforms:
1055
* ${xform.search} => ${xform.replace}
1056
% endfor
1057
*/
1058
1059
<% cache = {} %>
1060
% for xform in xforms:
1061
${xform.search.render(cache)}
1062
${xform.replace.render(cache)}
1063
% endfor
1064
1065
% for state_id, state_xforms in enumerate(automaton.state_patterns):
1066
% if state_xforms: # avoid emitting a 0-length array for MSVC
1067
static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1068
% for i in state_xforms:
1069
{ ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1070
% endfor
1071
};
1072
% endif
1073
% endfor
1074
1075
static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1076
% for op in automaton.opcodes:
1077
[${get_c_opcode(op)}] = {
1078
.filter = (uint16_t []) {
1079
% for e in automaton.filter[op]:
1080
${e},
1081
% endfor
1082
},
1083
<%
1084
num_filtered = len(automaton.rep[op])
1085
%>
1086
.num_filtered_states = ${num_filtered},
1087
.table = (uint16_t []) {
1088
<%
1089
num_srcs = len(next(iter(automaton.table[op])))
1090
%>
1091
% for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1092
${automaton.table[op][indices]},
1093
% endfor
1094
},
1095
},
1096
% endfor
1097
};
1098
1099
const struct transform *${pass_name}_transforms[] = {
1100
% for i in range(len(automaton.state_patterns)):
1101
% if automaton.state_patterns[i]:
1102
${pass_name}_state${i}_xforms,
1103
% else:
1104
NULL,
1105
% endif
1106
% endfor
1107
};
1108
1109
const uint16_t ${pass_name}_transform_counts[] = {
1110
% for i in range(len(automaton.state_patterns)):
1111
% if automaton.state_patterns[i]:
1112
(uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1113
% else:
1114
0,
1115
% endif
1116
% endfor
1117
};
1118
1119
bool
1120
${pass_name}(nir_shader *shader)
1121
{
1122
bool progress = false;
1123
bool condition_flags[${len(condition_list)}];
1124
const nir_shader_compiler_options *options = shader->options;
1125
const shader_info *info = &shader->info;
1126
(void) options;
1127
(void) info;
1128
1129
% for index, condition in enumerate(condition_list):
1130
condition_flags[${index}] = ${condition};
1131
% endfor
1132
1133
nir_foreach_function(function, shader) {
1134
if (function->impl) {
1135
progress |= nir_algebraic_impl(function->impl, condition_flags,
1136
${pass_name}_transforms,
1137
${pass_name}_transform_counts,
1138
${pass_name}_table);
1139
}
1140
}
1141
1142
return progress;
1143
}
1144
""")
1145
1146
1147
class AlgebraicPass(object):
1148
def __init__(self, pass_name, transforms):
1149
self.xforms = []
1150
self.opcode_xforms = defaultdict(lambda : [])
1151
self.pass_name = pass_name
1152
1153
error = False
1154
1155
for xform in transforms:
1156
if not isinstance(xform, SearchAndReplace):
1157
try:
1158
xform = SearchAndReplace(xform)
1159
except:
1160
print("Failed to parse transformation:", file=sys.stderr)
1161
print(" " + str(xform), file=sys.stderr)
1162
traceback.print_exc(file=sys.stderr)
1163
print('', file=sys.stderr)
1164
error = True
1165
continue
1166
1167
self.xforms.append(xform)
1168
if xform.search.opcode in conv_opcode_types:
1169
dst_type = conv_opcode_types[xform.search.opcode]
1170
for size in type_sizes(dst_type):
1171
sized_opcode = xform.search.opcode + str(size)
1172
self.opcode_xforms[sized_opcode].append(xform)
1173
else:
1174
self.opcode_xforms[xform.search.opcode].append(xform)
1175
1176
# Check to make sure the search pattern does not unexpectedly contain
1177
# more commutative expressions than match_expression (nir_search.c)
1178
# can handle.
1179
comm_exprs = xform.search.comm_exprs
1180
1181
if xform.search.many_commutative_expressions:
1182
if comm_exprs <= nir_search_max_comm_ops:
1183
print("Transform expected to have too many commutative " \
1184
"expression but did not " \
1185
"({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1186
file=sys.stderr)
1187
print(" " + str(xform), file=sys.stderr)
1188
traceback.print_exc(file=sys.stderr)
1189
print('', file=sys.stderr)
1190
error = True
1191
else:
1192
if comm_exprs > nir_search_max_comm_ops:
1193
print("Transformation with too many commutative expressions " \
1194
"({} > {}). Modify pattern or annotate with " \
1195
"\"many-comm-expr\".".format(comm_exprs,
1196
nir_search_max_comm_ops),
1197
file=sys.stderr)
1198
print(" " + str(xform.search), file=sys.stderr)
1199
print("{}".format(xform.search.cond), file=sys.stderr)
1200
error = True
1201
1202
self.automaton = TreeAutomaton(self.xforms)
1203
1204
if error:
1205
sys.exit(1)
1206
1207
1208
def render(self):
1209
return _algebraic_pass_template.render(pass_name=self.pass_name,
1210
xforms=self.xforms,
1211
opcode_xforms=self.opcode_xforms,
1212
condition_list=condition_list,
1213
automaton=self.automaton,
1214
get_c_opcode=get_c_opcode,
1215
itertools=itertools)
1216
1217