Path: blob/main/cranelift/codegen/src/egraph/cost.rs
1693 views
//! Cost functions for egraph representation.12use crate::ir::Opcode;34/// A cost of computing some value in the program.5///6/// Costs are measured in an arbitrary union that we represent in a7/// `u32`. The ordering is meant to be meaningful, but the value of a8/// single unit is arbitrary (and "not to scale"). We use a collection9/// of heuristics to try to make this approximation at least usable.10///11/// We start by defining costs for each opcode (see `pure_op_cost`12/// below). The cost of computing some value, initially, is the cost13/// of its opcode, plus the cost of computing its inputs.14///15/// We then adjust the cost according to loop nests: for each16/// loop-nest level, we multiply by 1024. Because we only have 3217/// bits, we limit this scaling to a loop-level of two (i.e., multiply18/// by 2^20 ~= 1M).19///20/// Arithmetic on costs is always saturating: we don't want to wrap21/// around and return to a tiny cost when adding the costs of two very22/// expensive operations. It is better to approximate and lose some23/// precision than to lose the ordering by wrapping.24///25/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel26/// that means "infinite". This is separate from the finite costs and27/// not reachable by doing arithmetic on them (even when overflowing)28/// -- we saturate just *below* infinity. (This is done by the29/// `finite()` method.) An infinite cost is used to represent a value30/// that cannot be computed, or otherwise serve as a sentinel when31/// performing search for the lowest-cost representation of a value.32#[derive(Clone, Copy, PartialEq, Eq)]33pub(crate) struct Cost(u32);3435impl core::fmt::Debug for Cost {36fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {37if *self == Cost::infinity() {38write!(f, "Cost::Infinite")39} else {40f.debug_struct("Cost::Finite")41.field("op_cost", &self.op_cost())42.field("depth", &self.depth())43.finish()44}45}46}4748impl Ord for Cost {49#[inline]50fn cmp(&self, other: &Self) -> std::cmp::Ordering {51// We make sure that the high bits are the op cost and the low bits are52// the depth. This means that we can use normal integer comparison to53// order by op cost and then depth.54//55// We want to break op cost ties with depth (rather than the other way56// around). When the op cost is the same, we prefer shallow and wide57// expressions to narrow and deep expressions and breaking ties with58// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred59// to `((a + b) + c) + d`. This is beneficial because it exposes more60// instruction-level parallelism and shortens live ranges.61self.0.cmp(&other.0)62}63}6465impl PartialOrd for Cost {66#[inline]67fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {68Some(self.cmp(other))69}70}7172impl Cost {73const DEPTH_BITS: u8 = 8;74const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;75const OP_COST_MASK: u32 = !Self::DEPTH_MASK;76const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;7778pub(crate) fn infinity() -> Cost {79// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`80// only for heuristics and always saturate so this suffices!)81Cost(u32::MAX)82}8384pub(crate) fn zero() -> Cost {85Cost(0)86}8788/// Construct a new `Cost` from the given parts.89///90/// If the opcode cost is greater than or equal to the maximum representable91/// opcode cost, then the resulting `Cost` saturates to infinity.92fn new(opcode_cost: u32, depth: u8) -> Cost {93if opcode_cost >= Self::MAX_OP_COST {94Self::infinity()95} else {96Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))97}98}99100fn depth(&self) -> u8 {101let depth = self.0 & Self::DEPTH_MASK;102u8::try_from(depth).unwrap()103}104105fn op_cost(&self) -> u32 {106(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS107}108109/// Return the cost of an opcode.110fn of_opcode(op: Opcode) -> Cost {111match op {112// Constants.113Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),114115// Extends/reduces.116Opcode::Uextend117| Opcode::Sextend118| Opcode::Ireduce119| Opcode::Iconcat120| Opcode::Isplit => Cost::new(1, 0),121122// "Simple" arithmetic.123Opcode::Iadd124| Opcode::Isub125| Opcode::Band126| Opcode::Bor127| Opcode::Bxor128| Opcode::Bnot129| Opcode::Ishl130| Opcode::Ushr131| Opcode::Sshr => Cost::new(3, 0),132133// Everything else.134_ => {135let mut c = Cost::new(4, 0);136if op.can_trap() || op.other_side_effects() {137c = c + Cost::new(5, 0);138}139if op.can_load() {140c = c + Cost::new(10, 0);141}142if op.can_store() {143c = c + Cost::new(20, 0);144}145c146}147}148}149150/// Compute the cost of the operation and its given operands.151///152/// Caller is responsible for checking that the opcode came from an instruction153/// that satisfies `inst_predicates::is_pure_for_egraph()`.154pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {155let c = Self::of_opcode(op) + operand_costs.into_iter().sum();156Cost::new(c.op_cost(), c.depth().saturating_add(1))157}158159/// Compute the cost of an operation in the side-effectful skeleton.160pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {161Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap(), (arity != 0) as _)162}163}164165impl std::iter::Sum<Cost> for Cost {166fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {167iter.fold(Self::zero(), |a, b| a + b)168}169}170171impl std::default::Default for Cost {172fn default() -> Cost {173Cost::zero()174}175}176177impl std::ops::Add<Cost> for Cost {178type Output = Cost;179180fn add(self, other: Cost) -> Cost {181let op_cost = self.op_cost().saturating_add(other.op_cost());182let depth = std::cmp::max(self.depth(), other.depth());183Cost::new(op_cost, depth)184}185}186187#[cfg(test)]188mod tests {189use super::*;190191#[test]192fn add_cost() {193let a = Cost::new(5, 2);194let b = Cost::new(37, 3);195assert_eq!(a + b, Cost::new(42, 3));196assert_eq!(b + a, Cost::new(42, 3));197}198199#[test]200fn add_infinity() {201let a = Cost::new(5, 2);202let b = Cost::infinity();203assert_eq!(a + b, Cost::infinity());204assert_eq!(b + a, Cost::infinity());205}206207#[test]208fn op_cost_saturates_to_infinity() {209let a = Cost::new(Cost::MAX_OP_COST - 10, 2);210let b = Cost::new(11, 2);211assert_eq!(a + b, Cost::infinity());212assert_eq!(b + a, Cost::infinity());213}214215#[test]216fn depth_saturates_to_max_depth() {217let a = Cost::new(10, u8::MAX);218let b = Cost::new(10, 1);219assert_eq!(220Cost::of_pure_op(Opcode::Iconst, [a, b]),221Cost::new(21, u8::MAX)222);223assert_eq!(224Cost::of_pure_op(Opcode::Iconst, [b, a]),225Cost::new(21, u8::MAX)226);227}228}229230231