Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/cranelift/frontend/src/switch.rs
1691 views
1
use super::HashMap;
2
use crate::frontend::FunctionBuilder;
3
use alloc::vec::Vec;
4
use cranelift_codegen::ir::condcodes::IntCC;
5
use cranelift_codegen::ir::*;
6
7
type EntryIndex = u128;
8
9
/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
10
/// They emit efficient code using branches, jump tables, or a combination of both.
11
///
12
/// # Example
13
///
14
/// ```rust
15
/// # use cranelift_codegen::ir::types::*;
16
/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
17
/// # use cranelift_codegen::isa::CallConv;
18
/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
19
/// #
20
/// # let mut sig = Signature::new(CallConv::SystemV);
21
/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
22
/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
23
/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
24
/// #
25
/// # let entry = builder.create_block();
26
/// # builder.switch_to_block(entry);
27
/// #
28
/// let block0 = builder.create_block();
29
/// let block1 = builder.create_block();
30
/// let block2 = builder.create_block();
31
/// let fallback = builder.create_block();
32
///
33
/// let val = builder.ins().iconst(I32, 1);
34
///
35
/// let mut switch = Switch::new();
36
/// switch.set_entry(0, block0);
37
/// switch.set_entry(1, block1);
38
/// switch.set_entry(7, block2);
39
/// switch.emit(&mut builder, val, fallback);
40
/// ```
41
#[derive(Debug, Default)]
42
pub struct Switch {
43
cases: HashMap<EntryIndex, Block>,
44
}
45
46
impl Switch {
47
/// Create a new empty switch
48
pub fn new() -> Self {
49
Self {
50
cases: HashMap::new(),
51
}
52
}
53
54
/// Set a switch entry
55
pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56
let prev = self.cases.insert(index, block);
57
assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58
}
59
60
/// Get a reference to all existing entries
61
pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62
&self.cases
63
}
64
65
/// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
66
///
67
/// # Postconditions
68
///
69
/// * Every entry will be represented.
70
/// * The `ContiguousCaseRange`s will not overlap.
71
/// * Between two `ContiguousCaseRange`s there will be at least one entry index.
72
/// * No `ContiguousCaseRange`s will be empty.
73
fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74
log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75
let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76
cases.sort_by_key(|&(index, _)| index);
77
78
let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79
let mut last_index = None;
80
for (index, block) in cases {
81
match last_index {
82
None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83
Some(last_index) => {
84
if index > last_index + 1 {
85
contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86
}
87
}
88
}
89
contiguous_case_ranges
90
.last_mut()
91
.unwrap()
92
.blocks
93
.push(block);
94
last_index = Some(index);
95
}
96
97
log::trace!("build_contiguous_case_ranges after: {contiguous_case_ranges:#?}");
98
99
contiguous_case_ranges
100
}
101
102
/// Binary search for the right `ContiguousCaseRange`.
103
fn build_search_tree<'a>(
104
bx: &mut FunctionBuilder,
105
val: Value,
106
otherwise: Block,
107
contiguous_case_ranges: &'a [ContiguousCaseRange],
108
) {
109
// If no switch cases were added to begin with, we can just emit `jump otherwise`.
110
if contiguous_case_ranges.is_empty() {
111
bx.ins().jump(otherwise, &[]);
112
return;
113
}
114
115
// Avoid allocation in the common case
116
if contiguous_case_ranges.len() <= 3 {
117
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
118
return;
119
}
120
121
let mut stack = Vec::new();
122
stack.push((None, contiguous_case_ranges));
123
124
while let Some((block, contiguous_case_ranges)) = stack.pop() {
125
if let Some(block) = block {
126
bx.switch_to_block(block);
127
}
128
129
if contiguous_case_ranges.len() <= 3 {
130
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
131
} else {
132
let split_point = contiguous_case_ranges.len() / 2;
133
let (left, right) = contiguous_case_ranges.split_at(split_point);
134
135
let left_block = bx.create_block();
136
let right_block = bx.create_block();
137
138
let first_index = right[0].first_index;
139
let should_take_right_side =
140
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
141
bx.ins()
142
.brif(should_take_right_side, right_block, &[], left_block, &[]);
143
144
bx.seal_block(left_block);
145
bx.seal_block(right_block);
146
147
stack.push((Some(left_block), left));
148
stack.push((Some(right_block), right));
149
}
150
}
151
}
152
153
/// Linear search for the right `ContiguousCaseRange`.
154
fn build_search_branches<'a>(
155
bx: &mut FunctionBuilder,
156
val: Value,
157
otherwise: Block,
158
contiguous_case_ranges: &'a [ContiguousCaseRange],
159
) {
160
for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
161
let alternate = if ix == 0 {
162
otherwise
163
} else {
164
bx.create_block()
165
};
166
167
if range.first_index == 0 {
168
assert_eq!(alternate, otherwise);
169
170
if let Some(block) = range.single_block() {
171
bx.ins().brif(val, otherwise, &[], block, &[]);
172
} else {
173
Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
174
}
175
} else {
176
if let Some(block) = range.single_block() {
177
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
178
bx.ins().brif(is_good_val, block, &[], alternate, &[]);
179
} else {
180
let is_good_val = icmp_imm_u128(
181
bx,
182
IntCC::UnsignedGreaterThanOrEqual,
183
val,
184
range.first_index,
185
);
186
let jt_block = bx.create_block();
187
bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
188
bx.seal_block(jt_block);
189
bx.switch_to_block(jt_block);
190
Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
191
}
192
}
193
194
if alternate != otherwise {
195
bx.seal_block(alternate);
196
bx.switch_to_block(alternate);
197
}
198
}
199
}
200
201
fn build_jump_table(
202
bx: &mut FunctionBuilder,
203
val: Value,
204
otherwise: Block,
205
first_index: EntryIndex,
206
blocks: &[Block],
207
) {
208
// There are currently no 128bit systems supported by rustc, but once we do ensure that
209
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
210
assert!(
211
u32::try_from(blocks.len()).is_ok(),
212
"Jump tables bigger than 2^32-1 are not yet supported"
213
);
214
215
let jt_data = JumpTableData::new(
216
bx.func.dfg.block_call(otherwise, &[]),
217
&blocks
218
.iter()
219
.map(|block| bx.func.dfg.block_call(*block, &[]))
220
.collect::<Vec<_>>(),
221
);
222
let jump_table = bx.create_jump_table(jt_data);
223
224
let discr = if first_index == 0 {
225
val
226
} else {
227
if let Ok(first_index) = u64::try_from(first_index) {
228
bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
229
} else {
230
let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
231
let lsb = bx.ins().iconst(types::I64, lsb as i64);
232
let msb = bx.ins().iconst(types::I64, msb as i64);
233
let index = bx.ins().iconcat(lsb, msb);
234
bx.ins().isub(val, index)
235
}
236
};
237
238
let discr = match bx.func.dfg.value_type(discr).bits() {
239
bits if bits > 32 => {
240
// Check for overflow of cast to u32. This is the max supported jump table entries.
241
let new_block = bx.create_block();
242
let bigger_than_u32 =
243
bx.ins()
244
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
245
bx.ins()
246
.brif(bigger_than_u32, otherwise, &[], new_block, &[]);
247
bx.seal_block(new_block);
248
bx.switch_to_block(new_block);
249
250
// Cast to i32, as br_table is not implemented for i64/i128
251
bx.ins().ireduce(types::I32, discr)
252
}
253
bits if bits < 32 => bx.ins().uextend(types::I32, discr),
254
_ => discr,
255
};
256
257
bx.ins().br_table(discr, jump_table);
258
}
259
260
/// Build the switch
261
///
262
/// # Arguments
263
///
264
/// * The function builder to emit to
265
/// * The value to switch on
266
/// * The default block
267
pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
268
// Validate that the type of `val` is sufficiently wide to address all cases.
269
let max = self.cases.keys().max().copied().unwrap_or(0);
270
let val_ty = bx.func.dfg.value_type(val);
271
let val_ty_max = val_ty.bounds(false).1;
272
if max > val_ty_max {
273
panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
274
}
275
276
let contiguous_case_ranges = self.collect_contiguous_case_ranges();
277
Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
278
}
279
}
280
281
fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
282
if bx.func.dfg.value_type(x) != types::I128 {
283
assert!(u64::try_from(y).is_ok());
284
bx.ins().icmp_imm(cond, x, y as i64)
285
} else if let Ok(index) = i64::try_from(y) {
286
bx.ins().icmp_imm(cond, x, index)
287
} else {
288
let (lsb, msb) = (y as u64, (y >> 64) as u64);
289
let lsb = bx.ins().iconst(types::I64, lsb as i64);
290
let msb = bx.ins().iconst(types::I64, msb as i64);
291
let index = bx.ins().iconcat(lsb, msb);
292
bx.ins().icmp(cond, x, index)
293
}
294
}
295
296
/// This represents a contiguous range of cases to switch on.
297
///
298
/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
299
///
300
/// ```plain
301
/// ContiguousCaseRange {
302
/// first_index: 10,
303
/// blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
304
/// }
305
/// ```
306
#[derive(Debug)]
307
struct ContiguousCaseRange {
308
/// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
309
first_index: EntryIndex,
310
311
/// The blocks to jump to sorted in ascending order of entry index.
312
blocks: Vec<Block>,
313
}
314
315
impl ContiguousCaseRange {
316
fn new(first_index: EntryIndex) -> Self {
317
Self {
318
first_index,
319
blocks: Vec::new(),
320
}
321
}
322
323
/// Returns `Some` block when there is only a single block in this range.
324
fn single_block(&self) -> Option<Block> {
325
if self.blocks.len() == 1 {
326
Some(self.blocks[0])
327
} else {
328
None
329
}
330
}
331
}
332
333
#[cfg(test)]
334
mod tests {
335
use super::*;
336
use crate::frontend::FunctionBuilderContext;
337
use alloc::string::ToString;
338
339
macro_rules! setup {
340
($default:expr, [$($index:expr,)*]) => {{
341
let mut func = Function::new();
342
let mut func_ctx = FunctionBuilderContext::new();
343
{
344
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
345
let block = bx.create_block();
346
bx.switch_to_block(block);
347
let val = bx.ins().iconst(types::I8, 0);
348
let mut switch = Switch::new();
349
let _ = &mut switch;
350
$(
351
let block = bx.create_block();
352
switch.set_entry($index, block);
353
)*
354
switch.emit(&mut bx, val, Block::with_number($default).unwrap());
355
}
356
func
357
.to_string()
358
.trim_start_matches("function u0:0() fast {\n")
359
.trim_end_matches("\n}\n")
360
.to_string()
361
}};
362
}
363
364
#[test]
365
fn switch_empty() {
366
let func = setup!(42, []);
367
assert_eq_output!(
368
func,
369
"block0:
370
v0 = iconst.i8 0
371
jump block42"
372
);
373
}
374
375
#[test]
376
fn switch_zero() {
377
let func = setup!(0, [0,]);
378
assert_eq_output!(
379
func,
380
"block0:
381
v0 = iconst.i8 0
382
brif v0, block0, block1 ; v0 = 0"
383
);
384
}
385
386
#[test]
387
fn switch_single() {
388
let func = setup!(0, [1,]);
389
assert_eq_output!(
390
func,
391
"block0:
392
v0 = iconst.i8 0
393
v1 = icmp_imm eq v0, 1 ; v0 = 0
394
brif v1, block1, block0"
395
);
396
}
397
398
#[test]
399
fn switch_bool() {
400
let func = setup!(0, [0, 1,]);
401
assert_eq_output!(
402
func,
403
"block0:
404
v0 = iconst.i8 0
405
v1 = uextend.i32 v0 ; v0 = 0
406
br_table v1, block0, [block1, block2]"
407
);
408
}
409
410
#[test]
411
fn switch_two_gap() {
412
let func = setup!(0, [0, 2,]);
413
assert_eq_output!(
414
func,
415
"block0:
416
v0 = iconst.i8 0
417
v1 = icmp_imm eq v0, 2 ; v0 = 0
418
brif v1, block2, block3
419
420
block3:
421
brif.i8 v0, block0, block1 ; v0 = 0"
422
);
423
}
424
425
#[test]
426
fn switch_many() {
427
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
428
assert_eq_output!(
429
func,
430
"block0:
431
v0 = iconst.i8 0
432
v1 = icmp_imm uge v0, 7 ; v0 = 0
433
brif v1, block9, block8
434
435
block9:
436
v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
437
brif v2, block11, block10
438
439
block11:
440
v3 = iadd_imm.i8 v0, -10 ; v0 = 0
441
v4 = uextend.i32 v3
442
br_table v4, block0, [block5, block6, block7]
443
444
block10:
445
v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
446
brif v5, block4, block0
447
448
block8:
449
v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
450
brif v6, block3, block12
451
452
block12:
453
v7 = uextend.i32 v0 ; v0 = 0
454
br_table v7, block0, [block1, block2]"
455
);
456
}
457
458
#[test]
459
fn switch_min_index_value() {
460
let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
461
assert_eq_output!(
462
func,
463
"block0:
464
v0 = iconst.i8 0
465
v1 = icmp_imm eq v0, -128 ; v0 = 0
466
brif v1, block1, block3
467
468
block3:
469
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
470
brif v2, block2, block0"
471
);
472
}
473
474
#[test]
475
fn switch_max_index_value() {
476
let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
477
assert_eq_output!(
478
func,
479
"block0:
480
v0 = iconst.i8 0
481
v1 = icmp_imm eq v0, 127 ; v0 = 0
482
brif v1, block1, block3
483
484
block3:
485
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
486
brif v2, block2, block0"
487
)
488
}
489
490
#[test]
491
fn switch_optimal_codegen() {
492
let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
493
assert_eq_output!(
494
func,
495
"block0:
496
v0 = iconst.i8 0
497
v1 = icmp_imm eq v0, -1 ; v0 = 0
498
brif v1, block1, block4
499
500
block4:
501
v2 = uextend.i32 v0 ; v0 = 0
502
br_table v2, block0, [block2, block3]"
503
);
504
}
505
506
#[test]
507
#[should_panic(
508
expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
509
)]
510
fn switch_rejects_small_inputs() {
511
// This is a regression test for a bug that we found where we would emit a cmp
512
// with a type that was not able to fully represent a large index.
513
//
514
// See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
515
setup!(1, [0x4100_0000_00bf_d470,]);
516
}
517
518
#[test]
519
fn switch_seal_generated_blocks() {
520
let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
521
522
for case in cases {
523
for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
524
eprintln!("Testing {typ:?} with keys: {case:?}");
525
do_case(case, *typ);
526
}
527
}
528
529
fn do_case(keys: &[u128], typ: Type) {
530
let mut func = Function::new();
531
let mut builder_ctx = FunctionBuilderContext::new();
532
let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
533
534
let root_block = builder.create_block();
535
let default_block = builder.create_block();
536
let mut switch = Switch::new();
537
538
let case_blocks = keys
539
.iter()
540
.map(|key| {
541
let block = builder.create_block();
542
switch.set_entry(*key, block);
543
block
544
})
545
.collect::<Vec<_>>();
546
547
builder.seal_block(root_block);
548
builder.switch_to_block(root_block);
549
550
let val = builder.ins().iconst(typ, 1);
551
switch.emit(&mut builder, val, default_block);
552
553
for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
554
builder.seal_block(block);
555
builder.switch_to_block(block);
556
builder.ins().return_(&[]);
557
}
558
559
builder.finalize(); // Will panic if some blocks are not sealed
560
}
561
}
562
563
#[test]
564
fn switch_64bit() {
565
let mut func = Function::new();
566
let mut func_ctx = FunctionBuilderContext::new();
567
{
568
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
569
let block0 = bx.create_block();
570
bx.switch_to_block(block0);
571
let val = bx.ins().iconst(types::I64, 0);
572
let mut switch = Switch::new();
573
let block1 = bx.create_block();
574
switch.set_entry(1, block1);
575
let block2 = bx.create_block();
576
switch.set_entry(0, block2);
577
let block3 = bx.create_block();
578
switch.emit(&mut bx, val, block3);
579
}
580
let func = func
581
.to_string()
582
.trim_start_matches("function u0:0() fast {\n")
583
.trim_end_matches("\n}\n")
584
.to_string();
585
assert_eq_output!(
586
func,
587
"block0:
588
v0 = iconst.i64 0
589
v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
590
brif v1, block3, block4
591
592
block4:
593
v2 = ireduce.i32 v0 ; v0 = 0
594
br_table v2, block3, [block2, block1]"
595
);
596
}
597
598
#[test]
599
fn switch_128bit() {
600
let mut func = Function::new();
601
let mut func_ctx = FunctionBuilderContext::new();
602
{
603
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
604
let block0 = bx.create_block();
605
bx.switch_to_block(block0);
606
let val = bx.ins().iconst(types::I64, 0);
607
let val = bx.ins().uextend(types::I128, val);
608
let mut switch = Switch::new();
609
let block1 = bx.create_block();
610
switch.set_entry(1, block1);
611
let block2 = bx.create_block();
612
switch.set_entry(0, block2);
613
let block3 = bx.create_block();
614
switch.emit(&mut bx, val, block3);
615
}
616
let func = func
617
.to_string()
618
.trim_start_matches("function u0:0() fast {\n")
619
.trim_end_matches("\n}\n")
620
.to_string();
621
assert_eq_output!(
622
func,
623
"block0:
624
v0 = iconst.i64 0
625
v1 = uextend.i128 v0 ; v0 = 0
626
v2 = icmp_imm ugt v1, 0xffff_ffff
627
brif v2, block3, block4
628
629
block4:
630
v3 = ireduce.i32 v1
631
br_table v3, block3, [block2, block1]"
632
);
633
}
634
635
#[test]
636
fn switch_128bit_max_u64() {
637
let mut func = Function::new();
638
let mut func_ctx = FunctionBuilderContext::new();
639
{
640
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
641
let block0 = bx.create_block();
642
bx.switch_to_block(block0);
643
let val = bx.ins().iconst(types::I64, 0);
644
let val = bx.ins().uextend(types::I128, val);
645
let mut switch = Switch::new();
646
let block1 = bx.create_block();
647
switch.set_entry(u64::MAX.into(), block1);
648
let block2 = bx.create_block();
649
switch.set_entry(0, block2);
650
let block3 = bx.create_block();
651
switch.emit(&mut bx, val, block3);
652
}
653
let func = func
654
.to_string()
655
.trim_start_matches("function u0:0() fast {\n")
656
.trim_end_matches("\n}\n")
657
.to_string();
658
assert_eq_output!(
659
func,
660
"block0:
661
v0 = iconst.i64 0
662
v1 = uextend.i128 v0 ; v0 = 0
663
v2 = iconst.i64 -1
664
v3 = iconst.i64 0
665
v4 = iconcat v2, v3 ; v2 = -1, v3 = 0
666
v5 = icmp eq v1, v4
667
brif v5, block1, block4
668
669
block4:
670
brif.i128 v1, block3, block2"
671
);
672
}
673
}
674
675