Path: blob/main/cranelift/frontend/src/switch.rs
1691 views
use super::HashMap;1use crate::frontend::FunctionBuilder;2use alloc::vec::Vec;3use cranelift_codegen::ir::condcodes::IntCC;4use cranelift_codegen::ir::*;56type EntryIndex = u128;78/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.9/// They emit efficient code using branches, jump tables, or a combination of both.10///11/// # Example12///13/// ```rust14/// # use cranelift_codegen::ir::types::*;15/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};16/// # use cranelift_codegen::isa::CallConv;17/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};18/// #19/// # let mut sig = Signature::new(CallConv::SystemV);20/// # let mut fn_builder_ctx = FunctionBuilderContext::new();21/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);22/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);23/// #24/// # let entry = builder.create_block();25/// # builder.switch_to_block(entry);26/// #27/// let block0 = builder.create_block();28/// let block1 = builder.create_block();29/// let block2 = builder.create_block();30/// let fallback = builder.create_block();31///32/// let val = builder.ins().iconst(I32, 1);33///34/// let mut switch = Switch::new();35/// switch.set_entry(0, block0);36/// switch.set_entry(1, block1);37/// switch.set_entry(7, block2);38/// switch.emit(&mut builder, val, fallback);39/// ```40#[derive(Debug, Default)]41pub struct Switch {42cases: HashMap<EntryIndex, Block>,43}4445impl Switch {46/// Create a new empty switch47pub fn new() -> Self {48Self {49cases: HashMap::new(),50}51}5253/// Set a switch entry54pub fn set_entry(&mut self, index: EntryIndex, block: Block) {55let prev = self.cases.insert(index, block);56assert!(prev.is_none(), "Tried to set the same entry {index} twice");57}5859/// Get a reference to all existing entries60pub fn entries(&self) -> &HashMap<EntryIndex, Block> {61&self.cases62}6364/// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.65///66/// # Postconditions67///68/// * Every entry will be represented.69/// * The `ContiguousCaseRange`s will not overlap.70/// * Between two `ContiguousCaseRange`s there will be at least one entry index.71/// * No `ContiguousCaseRange`s will be empty.72fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {73log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);74let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();75cases.sort_by_key(|&(index, _)| index);7677let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];78let mut last_index = None;79for (index, block) in cases {80match last_index {81None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),82Some(last_index) => {83if index > last_index + 1 {84contiguous_case_ranges.push(ContiguousCaseRange::new(index));85}86}87}88contiguous_case_ranges89.last_mut()90.unwrap()91.blocks92.push(block);93last_index = Some(index);94}9596log::trace!("build_contiguous_case_ranges after: {contiguous_case_ranges:#?}");9798contiguous_case_ranges99}100101/// Binary search for the right `ContiguousCaseRange`.102fn build_search_tree<'a>(103bx: &mut FunctionBuilder,104val: Value,105otherwise: Block,106contiguous_case_ranges: &'a [ContiguousCaseRange],107) {108// If no switch cases were added to begin with, we can just emit `jump otherwise`.109if contiguous_case_ranges.is_empty() {110bx.ins().jump(otherwise, &[]);111return;112}113114// Avoid allocation in the common case115if contiguous_case_ranges.len() <= 3 {116Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);117return;118}119120let mut stack = Vec::new();121stack.push((None, contiguous_case_ranges));122123while let Some((block, contiguous_case_ranges)) = stack.pop() {124if let Some(block) = block {125bx.switch_to_block(block);126}127128if contiguous_case_ranges.len() <= 3 {129Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);130} else {131let split_point = contiguous_case_ranges.len() / 2;132let (left, right) = contiguous_case_ranges.split_at(split_point);133134let left_block = bx.create_block();135let right_block = bx.create_block();136137let first_index = right[0].first_index;138let should_take_right_side =139icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);140bx.ins()141.brif(should_take_right_side, right_block, &[], left_block, &[]);142143bx.seal_block(left_block);144bx.seal_block(right_block);145146stack.push((Some(left_block), left));147stack.push((Some(right_block), right));148}149}150}151152/// Linear search for the right `ContiguousCaseRange`.153fn build_search_branches<'a>(154bx: &mut FunctionBuilder,155val: Value,156otherwise: Block,157contiguous_case_ranges: &'a [ContiguousCaseRange],158) {159for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {160let alternate = if ix == 0 {161otherwise162} else {163bx.create_block()164};165166if range.first_index == 0 {167assert_eq!(alternate, otherwise);168169if let Some(block) = range.single_block() {170bx.ins().brif(val, otherwise, &[], block, &[]);171} else {172Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);173}174} else {175if let Some(block) = range.single_block() {176let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);177bx.ins().brif(is_good_val, block, &[], alternate, &[]);178} else {179let is_good_val = icmp_imm_u128(180bx,181IntCC::UnsignedGreaterThanOrEqual,182val,183range.first_index,184);185let jt_block = bx.create_block();186bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);187bx.seal_block(jt_block);188bx.switch_to_block(jt_block);189Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);190}191}192193if alternate != otherwise {194bx.seal_block(alternate);195bx.switch_to_block(alternate);196}197}198}199200fn build_jump_table(201bx: &mut FunctionBuilder,202val: Value,203otherwise: Block,204first_index: EntryIndex,205blocks: &[Block],206) {207// There are currently no 128bit systems supported by rustc, but once we do ensure that208// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.209assert!(210u32::try_from(blocks.len()).is_ok(),211"Jump tables bigger than 2^32-1 are not yet supported"212);213214let jt_data = JumpTableData::new(215bx.func.dfg.block_call(otherwise, &[]),216&blocks217.iter()218.map(|block| bx.func.dfg.block_call(*block, &[]))219.collect::<Vec<_>>(),220);221let jump_table = bx.create_jump_table(jt_data);222223let discr = if first_index == 0 {224val225} else {226if let Ok(first_index) = u64::try_from(first_index) {227bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())228} else {229let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);230let lsb = bx.ins().iconst(types::I64, lsb as i64);231let msb = bx.ins().iconst(types::I64, msb as i64);232let index = bx.ins().iconcat(lsb, msb);233bx.ins().isub(val, index)234}235};236237let discr = match bx.func.dfg.value_type(discr).bits() {238bits if bits > 32 => {239// Check for overflow of cast to u32. This is the max supported jump table entries.240let new_block = bx.create_block();241let bigger_than_u32 =242bx.ins()243.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);244bx.ins()245.brif(bigger_than_u32, otherwise, &[], new_block, &[]);246bx.seal_block(new_block);247bx.switch_to_block(new_block);248249// Cast to i32, as br_table is not implemented for i64/i128250bx.ins().ireduce(types::I32, discr)251}252bits if bits < 32 => bx.ins().uextend(types::I32, discr),253_ => discr,254};255256bx.ins().br_table(discr, jump_table);257}258259/// Build the switch260///261/// # Arguments262///263/// * The function builder to emit to264/// * The value to switch on265/// * The default block266pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {267// Validate that the type of `val` is sufficiently wide to address all cases.268let max = self.cases.keys().max().copied().unwrap_or(0);269let val_ty = bx.func.dfg.value_type(val);270let val_ty_max = val_ty.bounds(false).1;271if max > val_ty_max {272panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");273}274275let contiguous_case_ranges = self.collect_contiguous_case_ranges();276Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);277}278}279280fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {281if bx.func.dfg.value_type(x) != types::I128 {282assert!(u64::try_from(y).is_ok());283bx.ins().icmp_imm(cond, x, y as i64)284} else if let Ok(index) = i64::try_from(y) {285bx.ins().icmp_imm(cond, x, index)286} else {287let (lsb, msb) = (y as u64, (y >> 64) as u64);288let lsb = bx.ins().iconst(types::I64, lsb as i64);289let msb = bx.ins().iconst(types::I64, msb as i64);290let index = bx.ins().iconcat(lsb, msb);291bx.ins().icmp(cond, x, index)292}293}294295/// This represents a contiguous range of cases to switch on.296///297/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:298///299/// ```plain300/// ContiguousCaseRange {301/// first_index: 10,302/// blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]303/// }304/// ```305#[derive(Debug)]306struct ContiguousCaseRange {307/// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.308first_index: EntryIndex,309310/// The blocks to jump to sorted in ascending order of entry index.311blocks: Vec<Block>,312}313314impl ContiguousCaseRange {315fn new(first_index: EntryIndex) -> Self {316Self {317first_index,318blocks: Vec::new(),319}320}321322/// Returns `Some` block when there is only a single block in this range.323fn single_block(&self) -> Option<Block> {324if self.blocks.len() == 1 {325Some(self.blocks[0])326} else {327None328}329}330}331332#[cfg(test)]333mod tests {334use super::*;335use crate::frontend::FunctionBuilderContext;336use alloc::string::ToString;337338macro_rules! setup {339($default:expr, [$($index:expr,)*]) => {{340let mut func = Function::new();341let mut func_ctx = FunctionBuilderContext::new();342{343let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);344let block = bx.create_block();345bx.switch_to_block(block);346let val = bx.ins().iconst(types::I8, 0);347let mut switch = Switch::new();348let _ = &mut switch;349$(350let block = bx.create_block();351switch.set_entry($index, block);352)*353switch.emit(&mut bx, val, Block::with_number($default).unwrap());354}355func356.to_string()357.trim_start_matches("function u0:0() fast {\n")358.trim_end_matches("\n}\n")359.to_string()360}};361}362363#[test]364fn switch_empty() {365let func = setup!(42, []);366assert_eq_output!(367func,368"block0:369v0 = iconst.i8 0370jump block42"371);372}373374#[test]375fn switch_zero() {376let func = setup!(0, [0,]);377assert_eq_output!(378func,379"block0:380v0 = iconst.i8 0381brif v0, block0, block1 ; v0 = 0"382);383}384385#[test]386fn switch_single() {387let func = setup!(0, [1,]);388assert_eq_output!(389func,390"block0:391v0 = iconst.i8 0392v1 = icmp_imm eq v0, 1 ; v0 = 0393brif v1, block1, block0"394);395}396397#[test]398fn switch_bool() {399let func = setup!(0, [0, 1,]);400assert_eq_output!(401func,402"block0:403v0 = iconst.i8 0404v1 = uextend.i32 v0 ; v0 = 0405br_table v1, block0, [block1, block2]"406);407}408409#[test]410fn switch_two_gap() {411let func = setup!(0, [0, 2,]);412assert_eq_output!(413func,414"block0:415v0 = iconst.i8 0416v1 = icmp_imm eq v0, 2 ; v0 = 0417brif v1, block2, block3418419block3:420brif.i8 v0, block0, block1 ; v0 = 0"421);422}423424#[test]425fn switch_many() {426let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);427assert_eq_output!(428func,429"block0:430v0 = iconst.i8 0431v1 = icmp_imm uge v0, 7 ; v0 = 0432brif v1, block9, block8433434block9:435v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0436brif v2, block11, block10437438block11:439v3 = iadd_imm.i8 v0, -10 ; v0 = 0440v4 = uextend.i32 v3441br_table v4, block0, [block5, block6, block7]442443block10:444v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0445brif v5, block4, block0446447block8:448v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0449brif v6, block3, block12450451block12:452v7 = uextend.i32 v0 ; v0 = 0453br_table v7, block0, [block1, block2]"454);455}456457#[test]458fn switch_min_index_value() {459let func = setup!(0, [i8::MIN as u8 as u128, 1,]);460assert_eq_output!(461func,462"block0:463v0 = iconst.i8 0464v1 = icmp_imm eq v0, -128 ; v0 = 0465brif v1, block1, block3466467block3:468v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0469brif v2, block2, block0"470);471}472473#[test]474fn switch_max_index_value() {475let func = setup!(0, [i8::MAX as u8 as u128, 1,]);476assert_eq_output!(477func,478"block0:479v0 = iconst.i8 0480v1 = icmp_imm eq v0, 127 ; v0 = 0481brif v1, block1, block3482483block3:484v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0485brif v2, block2, block0"486)487}488489#[test]490fn switch_optimal_codegen() {491let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);492assert_eq_output!(493func,494"block0:495v0 = iconst.i8 0496v1 = icmp_imm eq v0, -1 ; v0 = 0497brif v1, block1, block4498499block4:500v2 = uextend.i32 v0 ; v0 = 0501br_table v2, block0, [block2, block3]"502);503}504505#[test]506#[should_panic(507expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"508)]509fn switch_rejects_small_inputs() {510// This is a regression test for a bug that we found where we would emit a cmp511// with a type that was not able to fully represent a large index.512//513// See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677514setup!(1, [0x4100_0000_00bf_d470,]);515}516517#[test]518fn switch_seal_generated_blocks() {519let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];520521for case in cases {522for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {523eprintln!("Testing {typ:?} with keys: {case:?}");524do_case(case, *typ);525}526}527528fn do_case(keys: &[u128], typ: Type) {529let mut func = Function::new();530let mut builder_ctx = FunctionBuilderContext::new();531let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);532533let root_block = builder.create_block();534let default_block = builder.create_block();535let mut switch = Switch::new();536537let case_blocks = keys538.iter()539.map(|key| {540let block = builder.create_block();541switch.set_entry(*key, block);542block543})544.collect::<Vec<_>>();545546builder.seal_block(root_block);547builder.switch_to_block(root_block);548549let val = builder.ins().iconst(typ, 1);550switch.emit(&mut builder, val, default_block);551552for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {553builder.seal_block(block);554builder.switch_to_block(block);555builder.ins().return_(&[]);556}557558builder.finalize(); // Will panic if some blocks are not sealed559}560}561562#[test]563fn switch_64bit() {564let mut func = Function::new();565let mut func_ctx = FunctionBuilderContext::new();566{567let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);568let block0 = bx.create_block();569bx.switch_to_block(block0);570let val = bx.ins().iconst(types::I64, 0);571let mut switch = Switch::new();572let block1 = bx.create_block();573switch.set_entry(1, block1);574let block2 = bx.create_block();575switch.set_entry(0, block2);576let block3 = bx.create_block();577switch.emit(&mut bx, val, block3);578}579let func = func580.to_string()581.trim_start_matches("function u0:0() fast {\n")582.trim_end_matches("\n}\n")583.to_string();584assert_eq_output!(585func,586"block0:587v0 = iconst.i64 0588v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0589brif v1, block3, block4590591block4:592v2 = ireduce.i32 v0 ; v0 = 0593br_table v2, block3, [block2, block1]"594);595}596597#[test]598fn switch_128bit() {599let mut func = Function::new();600let mut func_ctx = FunctionBuilderContext::new();601{602let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);603let block0 = bx.create_block();604bx.switch_to_block(block0);605let val = bx.ins().iconst(types::I64, 0);606let val = bx.ins().uextend(types::I128, val);607let mut switch = Switch::new();608let block1 = bx.create_block();609switch.set_entry(1, block1);610let block2 = bx.create_block();611switch.set_entry(0, block2);612let block3 = bx.create_block();613switch.emit(&mut bx, val, block3);614}615let func = func616.to_string()617.trim_start_matches("function u0:0() fast {\n")618.trim_end_matches("\n}\n")619.to_string();620assert_eq_output!(621func,622"block0:623v0 = iconst.i64 0624v1 = uextend.i128 v0 ; v0 = 0625v2 = icmp_imm ugt v1, 0xffff_ffff626brif v2, block3, block4627628block4:629v3 = ireduce.i32 v1630br_table v3, block3, [block2, block1]"631);632}633634#[test]635fn switch_128bit_max_u64() {636let mut func = Function::new();637let mut func_ctx = FunctionBuilderContext::new();638{639let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);640let block0 = bx.create_block();641bx.switch_to_block(block0);642let val = bx.ins().iconst(types::I64, 0);643let val = bx.ins().uextend(types::I128, val);644let mut switch = Switch::new();645let block1 = bx.create_block();646switch.set_entry(u64::MAX.into(), block1);647let block2 = bx.create_block();648switch.set_entry(0, block2);649let block3 = bx.create_block();650switch.emit(&mut bx, val, block3);651}652let func = func653.to_string()654.trim_start_matches("function u0:0() fast {\n")655.trim_end_matches("\n}\n")656.to_string();657assert_eq_output!(658func,659"block0:660v0 = iconst.i64 0661v1 = uextend.i128 v0 ; v0 = 0662v2 = iconst.i64 -1663v3 = iconst.i64 0664v4 = iconcat v2, v3 ; v2 = -1, v3 = 0665v5 = icmp eq v1, v4666brif v5, block1, block4667668block4:669brif.i128 v1, block3, block2"670);671}672}673674675