Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/cranelift/codegen/src/egraph/cost.rs
1693 views
1
//! Cost functions for egraph representation.
2
3
use crate::ir::Opcode;
4
5
/// A cost of computing some value in the program.
6
///
7
/// Costs are measured in an arbitrary union that we represent in a
8
/// `u32`. The ordering is meant to be meaningful, but the value of a
9
/// single unit is arbitrary (and "not to scale"). We use a collection
10
/// of heuristics to try to make this approximation at least usable.
11
///
12
/// We start by defining costs for each opcode (see `pure_op_cost`
13
/// below). The cost of computing some value, initially, is the cost
14
/// of its opcode, plus the cost of computing its inputs.
15
///
16
/// We then adjust the cost according to loop nests: for each
17
/// loop-nest level, we multiply by 1024. Because we only have 32
18
/// bits, we limit this scaling to a loop-level of two (i.e., multiply
19
/// by 2^20 ~= 1M).
20
///
21
/// Arithmetic on costs is always saturating: we don't want to wrap
22
/// around and return to a tiny cost when adding the costs of two very
23
/// expensive operations. It is better to approximate and lose some
24
/// precision than to lose the ordering by wrapping.
25
///
26
/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
27
/// that means "infinite". This is separate from the finite costs and
28
/// not reachable by doing arithmetic on them (even when overflowing)
29
/// -- we saturate just *below* infinity. (This is done by the
30
/// `finite()` method.) An infinite cost is used to represent a value
31
/// that cannot be computed, or otherwise serve as a sentinel when
32
/// performing search for the lowest-cost representation of a value.
33
#[derive(Clone, Copy, PartialEq, Eq)]
34
pub(crate) struct Cost(u32);
35
36
impl core::fmt::Debug for Cost {
37
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38
if *self == Cost::infinity() {
39
write!(f, "Cost::Infinite")
40
} else {
41
f.debug_struct("Cost::Finite")
42
.field("op_cost", &self.op_cost())
43
.field("depth", &self.depth())
44
.finish()
45
}
46
}
47
}
48
49
impl Ord for Cost {
50
#[inline]
51
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
52
// We make sure that the high bits are the op cost and the low bits are
53
// the depth. This means that we can use normal integer comparison to
54
// order by op cost and then depth.
55
//
56
// We want to break op cost ties with depth (rather than the other way
57
// around). When the op cost is the same, we prefer shallow and wide
58
// expressions to narrow and deep expressions and breaking ties with
59
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
60
// to `((a + b) + c) + d`. This is beneficial because it exposes more
61
// instruction-level parallelism and shortens live ranges.
62
self.0.cmp(&other.0)
63
}
64
}
65
66
impl PartialOrd for Cost {
67
#[inline]
68
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
69
Some(self.cmp(other))
70
}
71
}
72
73
impl Cost {
74
const DEPTH_BITS: u8 = 8;
75
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
76
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
77
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;
78
79
pub(crate) fn infinity() -> Cost {
80
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
81
// only for heuristics and always saturate so this suffices!)
82
Cost(u32::MAX)
83
}
84
85
pub(crate) fn zero() -> Cost {
86
Cost(0)
87
}
88
89
/// Construct a new `Cost` from the given parts.
90
///
91
/// If the opcode cost is greater than or equal to the maximum representable
92
/// opcode cost, then the resulting `Cost` saturates to infinity.
93
fn new(opcode_cost: u32, depth: u8) -> Cost {
94
if opcode_cost >= Self::MAX_OP_COST {
95
Self::infinity()
96
} else {
97
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
98
}
99
}
100
101
fn depth(&self) -> u8 {
102
let depth = self.0 & Self::DEPTH_MASK;
103
u8::try_from(depth).unwrap()
104
}
105
106
fn op_cost(&self) -> u32 {
107
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
108
}
109
110
/// Return the cost of an opcode.
111
fn of_opcode(op: Opcode) -> Cost {
112
match op {
113
// Constants.
114
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
115
116
// Extends/reduces.
117
Opcode::Uextend
118
| Opcode::Sextend
119
| Opcode::Ireduce
120
| Opcode::Iconcat
121
| Opcode::Isplit => Cost::new(1, 0),
122
123
// "Simple" arithmetic.
124
Opcode::Iadd
125
| Opcode::Isub
126
| Opcode::Band
127
| Opcode::Bor
128
| Opcode::Bxor
129
| Opcode::Bnot
130
| Opcode::Ishl
131
| Opcode::Ushr
132
| Opcode::Sshr => Cost::new(3, 0),
133
134
// Everything else.
135
_ => {
136
let mut c = Cost::new(4, 0);
137
if op.can_trap() || op.other_side_effects() {
138
c = c + Cost::new(5, 0);
139
}
140
if op.can_load() {
141
c = c + Cost::new(10, 0);
142
}
143
if op.can_store() {
144
c = c + Cost::new(20, 0);
145
}
146
c
147
}
148
}
149
}
150
151
/// Compute the cost of the operation and its given operands.
152
///
153
/// Caller is responsible for checking that the opcode came from an instruction
154
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
155
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
156
let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
157
Cost::new(c.op_cost(), c.depth().saturating_add(1))
158
}
159
160
/// Compute the cost of an operation in the side-effectful skeleton.
161
pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {
162
Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap(), (arity != 0) as _)
163
}
164
}
165
166
impl std::iter::Sum<Cost> for Cost {
167
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
168
iter.fold(Self::zero(), |a, b| a + b)
169
}
170
}
171
172
impl std::default::Default for Cost {
173
fn default() -> Cost {
174
Cost::zero()
175
}
176
}
177
178
impl std::ops::Add<Cost> for Cost {
179
type Output = Cost;
180
181
fn add(self, other: Cost) -> Cost {
182
let op_cost = self.op_cost().saturating_add(other.op_cost());
183
let depth = std::cmp::max(self.depth(), other.depth());
184
Cost::new(op_cost, depth)
185
}
186
}
187
188
#[cfg(test)]
189
mod tests {
190
use super::*;
191
192
#[test]
193
fn add_cost() {
194
let a = Cost::new(5, 2);
195
let b = Cost::new(37, 3);
196
assert_eq!(a + b, Cost::new(42, 3));
197
assert_eq!(b + a, Cost::new(42, 3));
198
}
199
200
#[test]
201
fn add_infinity() {
202
let a = Cost::new(5, 2);
203
let b = Cost::infinity();
204
assert_eq!(a + b, Cost::infinity());
205
assert_eq!(b + a, Cost::infinity());
206
}
207
208
#[test]
209
fn op_cost_saturates_to_infinity() {
210
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
211
let b = Cost::new(11, 2);
212
assert_eq!(a + b, Cost::infinity());
213
assert_eq!(b + a, Cost::infinity());
214
}
215
216
#[test]
217
fn depth_saturates_to_max_depth() {
218
let a = Cost::new(10, u8::MAX);
219
let b = Cost::new(10, 1);
220
assert_eq!(
221
Cost::of_pure_op(Opcode::Iconst, [a, b]),
222
Cost::new(21, u8::MAX)
223
);
224
assert_eq!(
225
Cost::of_pure_op(Opcode::Iconst, [b, a]),
226
Cost::new(21, u8::MAX)
227
);
228
}
229
}
230
231