Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/sys/contrib/zstd/doc/educational_decoder/zstd_decompress.c
48378 views
1
/*
2
* Copyright (c) Facebook, Inc.
3
* All rights reserved.
4
*
5
* This source code is licensed under both the BSD-style license (found in the
6
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
7
* in the COPYING file in the root directory of this source tree).
8
* You may select, at your option, one of the above-listed licenses.
9
*/
10
11
/// Zstandard educational decoder implementation
12
/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
13
14
#include <stdint.h> // uint8_t, etc.
15
#include <stdlib.h> // malloc, free, exit
16
#include <stdio.h> // fprintf
17
#include <string.h> // memset, memcpy
18
#include "zstd_decompress.h"
19
20
21
/******* IMPORTANT CONSTANTS *********************************************/
22
23
// Zstandard frame
24
// "Magic_Number
25
// 4 Bytes, little-endian format. Value : 0xFD2FB528"
26
#define ZSTD_MAGIC_NUMBER 0xFD2FB528U
27
28
// The size of `Block_Content` is limited by `Block_Maximum_Size`,
29
#define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
30
31
// literal blocks can't be larger than their block
32
#define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
33
34
35
/******* UTILITY MACROS AND TYPES *********************************************/
36
#define MAX(a, b) ((a) > (b) ? (a) : (b))
37
#define MIN(a, b) ((a) < (b) ? (a) : (b))
38
39
#if defined(ZDEC_NO_MESSAGE)
40
#define MESSAGE(...)
41
#else
42
#define MESSAGE(...) fprintf(stderr, "" __VA_ARGS__)
43
#endif
44
45
/// This decoder calls exit(1) when it encounters an error, however a production
46
/// library should propagate error codes
47
#define ERROR(s) \
48
do { \
49
MESSAGE("Error: %s\n", s); \
50
exit(1); \
51
} while (0)
52
#define INP_SIZE() \
53
ERROR("Input buffer smaller than it should be or input is " \
54
"corrupted")
55
#define OUT_SIZE() ERROR("Output buffer too small for output")
56
#define CORRUPTION() ERROR("Corruption detected while decompressing")
57
#define BAD_ALLOC() ERROR("Memory allocation error")
58
#define IMPOSSIBLE() ERROR("An impossibility has occurred")
59
60
typedef uint8_t u8;
61
typedef uint16_t u16;
62
typedef uint32_t u32;
63
typedef uint64_t u64;
64
65
typedef int8_t i8;
66
typedef int16_t i16;
67
typedef int32_t i32;
68
typedef int64_t i64;
69
/******* END UTILITY MACROS AND TYPES *****************************************/
70
71
/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
72
/// The implementations for these functions can be found at the bottom of this
73
/// file. They implement low-level functionality needed for the higher level
74
/// decompression functions.
75
76
/*** IO STREAM OPERATIONS *************/
77
78
/// ostream_t/istream_t are used to wrap the pointers/length data passed into
79
/// ZSTD_decompress, so that all IO operations are safely bounds checked
80
/// They are written/read forward, and reads are treated as little-endian
81
/// They should be used opaquely to ensure safety
82
typedef struct {
83
u8 *ptr;
84
size_t len;
85
} ostream_t;
86
87
typedef struct {
88
const u8 *ptr;
89
size_t len;
90
91
// Input often reads a few bits at a time, so maintain an internal offset
92
int bit_offset;
93
} istream_t;
94
95
/// The following two functions are the only ones that allow the istream to be
96
/// non-byte aligned
97
98
/// Reads `num` bits from a bitstream, and updates the internal offset
99
static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
100
/// Backs-up the stream by `num` bits so they can be read again
101
static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
102
/// If the remaining bits in a byte will be unused, advance to the end of the
103
/// byte
104
static inline void IO_align_stream(istream_t *const in);
105
106
/// Write the given byte into the output stream
107
static inline void IO_write_byte(ostream_t *const out, u8 symb);
108
109
/// Returns the number of bytes left to be read in this stream. The stream must
110
/// be byte aligned.
111
static inline size_t IO_istream_len(const istream_t *const in);
112
113
/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
114
/// was skipped. The stream must be byte aligned.
115
static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
116
/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
117
/// was skipped so it can be written to.
118
static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
119
120
/// Advance the inner state by `len` bytes. The stream must be byte aligned.
121
static inline void IO_advance_input(istream_t *const in, size_t len);
122
123
/// Returns an `ostream_t` constructed from the given pointer and length.
124
static inline ostream_t IO_make_ostream(u8 *out, size_t len);
125
/// Returns an `istream_t` constructed from the given pointer and length.
126
static inline istream_t IO_make_istream(const u8 *in, size_t len);
127
128
/// Returns an `istream_t` with the same base as `in`, and length `len`.
129
/// Then, advance `in` to account for the consumed bytes.
130
/// `in` must be byte aligned.
131
static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
132
/*** END IO STREAM OPERATIONS *********/
133
134
/*** BITSTREAM OPERATIONS *************/
135
/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
136
/// and return them interpreted as a little-endian unsigned integer.
137
static inline u64 read_bits_LE(const u8 *src, const int num_bits,
138
const size_t offset);
139
140
/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
141
/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
142
/// `src + offset`. If the offset becomes negative, the extra bits at the
143
/// bottom are filled in with `0` bits instead of reading from before `src`.
144
static inline u64 STREAM_read_bits(const u8 *src, const int bits,
145
i64 *const offset);
146
/*** END BITSTREAM OPERATIONS *********/
147
148
/*** BIT COUNTING OPERATIONS **********/
149
/// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
150
static inline int highest_set_bit(const u64 num);
151
/*** END BIT COUNTING OPERATIONS ******/
152
153
/*** HUFFMAN PRIMITIVES ***************/
154
// Table decode method uses exponential memory, so we need to limit depth
155
#define HUF_MAX_BITS (16)
156
157
// Limit the maximum number of symbols to 256 so we can store a symbol in a byte
158
#define HUF_MAX_SYMBS (256)
159
160
/// Structure containing all tables necessary for efficient Huffman decoding
161
typedef struct {
162
u8 *symbols;
163
u8 *num_bits;
164
int max_bits;
165
} HUF_dtable;
166
167
/// Decode a single symbol and read in enough bits to refresh the state
168
static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
169
u16 *const state, const u8 *const src,
170
i64 *const offset);
171
/// Read in a full state's worth of bits to initialize it
172
static inline void HUF_init_state(const HUF_dtable *const dtable,
173
u16 *const state, const u8 *const src,
174
i64 *const offset);
175
176
/// Decompresses a single Huffman stream, returns the number of bytes decoded.
177
/// `src_len` must be the exact length of the Huffman-coded block.
178
static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
179
ostream_t *const out, istream_t *const in);
180
/// Same as previous but decodes 4 streams, formatted as in the Zstandard
181
/// specification.
182
/// `src_len` must be the exact length of the Huffman-coded block.
183
static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
184
ostream_t *const out, istream_t *const in);
185
186
/// Initialize a Huffman decoding table using the table of bit counts provided
187
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
188
const int num_symbs);
189
/// Initialize a Huffman decoding table using the table of weights provided
190
/// Weights follow the definition provided in the Zstandard specification
191
static void HUF_init_dtable_usingweights(HUF_dtable *const table,
192
const u8 *const weights,
193
const int num_symbs);
194
195
/// Free the malloc'ed parts of a decoding table
196
static void HUF_free_dtable(HUF_dtable *const dtable);
197
/*** END HUFFMAN PRIMITIVES ***********/
198
199
/*** FSE PRIMITIVES *******************/
200
/// For more description of FSE see
201
/// https://github.com/Cyan4973/FiniteStateEntropy/
202
203
// FSE table decoding uses exponential memory, so limit the maximum accuracy
204
#define FSE_MAX_ACCURACY_LOG (15)
205
// Limit the maximum number of symbols so they can be stored in a single byte
206
#define FSE_MAX_SYMBS (256)
207
208
/// The tables needed to decode FSE encoded streams
209
typedef struct {
210
u8 *symbols;
211
u8 *num_bits;
212
u16 *new_state_base;
213
int accuracy_log;
214
} FSE_dtable;
215
216
/// Return the symbol for the current state
217
static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
218
const u16 state);
219
/// Read the number of bits necessary to update state, update, and shift offset
220
/// back to reflect the bits read
221
static inline void FSE_update_state(const FSE_dtable *const dtable,
222
u16 *const state, const u8 *const src,
223
i64 *const offset);
224
225
/// Combine peek and update: decode a symbol and update the state
226
static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
227
u16 *const state, const u8 *const src,
228
i64 *const offset);
229
230
/// Read bits from the stream to initialize the state and shift offset back
231
static inline void FSE_init_state(const FSE_dtable *const dtable,
232
u16 *const state, const u8 *const src,
233
i64 *const offset);
234
235
/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
236
/// using an FSE decoding table. `src_len` must be the exact length of the
237
/// block.
238
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
239
ostream_t *const out,
240
istream_t *const in);
241
242
/// Initialize a decoding table using normalized frequencies.
243
static void FSE_init_dtable(FSE_dtable *const dtable,
244
const i16 *const norm_freqs, const int num_symbs,
245
const int accuracy_log);
246
247
/// Decode an FSE header as defined in the Zstandard format specification and
248
/// use the decoded frequencies to initialize a decoding table.
249
static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
250
const int max_accuracy_log);
251
252
/// Initialize an FSE table that will always return the same symbol and consume
253
/// 0 bits per symbol, to be used for RLE mode in sequence commands
254
static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
255
256
/// Free the malloc'ed parts of a decoding table
257
static void FSE_free_dtable(FSE_dtable *const dtable);
258
/*** END FSE PRIMITIVES ***************/
259
260
/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
261
262
/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
263
264
/// A small structure that can be reused in various places that need to access
265
/// frame header information
266
typedef struct {
267
// The size of window that we need to be able to contiguously store for
268
// references
269
size_t window_size;
270
// The total output size of this compressed frame
271
size_t frame_content_size;
272
273
// The dictionary id if this frame uses one
274
u32 dictionary_id;
275
276
// Whether or not the content of this frame has a checksum
277
int content_checksum_flag;
278
// Whether or not the output for this frame is in a single segment
279
int single_segment_flag;
280
} frame_header_t;
281
282
/// The context needed to decode blocks in a frame
283
typedef struct {
284
frame_header_t header;
285
286
// The total amount of data available for backreferences, to determine if an
287
// offset too large to be correct
288
size_t current_total_output;
289
290
const u8 *dict_content;
291
size_t dict_content_len;
292
293
// Entropy encoding tables so they can be repeated by future blocks instead
294
// of retransmitting
295
HUF_dtable literals_dtable;
296
FSE_dtable ll_dtable;
297
FSE_dtable ml_dtable;
298
FSE_dtable of_dtable;
299
300
// The last 3 offsets for the special "repeat offsets".
301
u64 previous_offsets[3];
302
} frame_context_t;
303
304
/// The decoded contents of a dictionary so that it doesn't have to be repeated
305
/// for each frame that uses it
306
struct dictionary_s {
307
// Entropy tables
308
HUF_dtable literals_dtable;
309
FSE_dtable ll_dtable;
310
FSE_dtable ml_dtable;
311
FSE_dtable of_dtable;
312
313
// Raw content for backreferences
314
u8 *content;
315
size_t content_size;
316
317
// Offset history to prepopulate the frame's history
318
u64 previous_offsets[3];
319
320
u32 dictionary_id;
321
};
322
323
/// A tuple containing the parts necessary to decode and execute a ZSTD sequence
324
/// command
325
typedef struct {
326
u32 literal_length;
327
u32 match_length;
328
u32 offset;
329
} sequence_command_t;
330
331
/// The decoder works top-down, starting at the high level like Zstd frames, and
332
/// working down to lower more technical levels such as blocks, literals, and
333
/// sequences. The high-level functions roughly follow the outline of the
334
/// format specification:
335
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
336
337
/// Before the implementation of each high-level function declared here, the
338
/// prototypes for their helper functions are defined and explained
339
340
/// Decode a single Zstd frame, or error if the input is not a valid frame.
341
/// Accepts a dict argument, which may be NULL indicating no dictionary.
342
/// See
343
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
344
static void decode_frame(ostream_t *const out, istream_t *const in,
345
const dictionary_t *const dict);
346
347
// Decode data in a compressed block
348
static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
349
istream_t *const in);
350
351
// Decode the literals section of a block
352
static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
353
u8 **const literals);
354
355
// Decode the sequences part of a block
356
static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
357
sequence_command_t **const sequences);
358
359
// Execute the decoded sequences on the literals block
360
static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
361
const u8 *const literals,
362
const size_t literals_len,
363
const sequence_command_t *const sequences,
364
const size_t num_sequences);
365
366
// Copies literals and returns the total literal length that was copied
367
static u32 copy_literals(const size_t seq, istream_t *litstream,
368
ostream_t *const out);
369
370
// Given an offset code from a sequence command (either an actual offset value
371
// or an index for previous offset), computes the correct offset and updates
372
// the offset history
373
static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
374
375
// Given an offset, match length, and total output, as well as the frame
376
// context for the dictionary, determines if the dictionary is used and
377
// executes the copy operation
378
static void execute_match_copy(frame_context_t *const ctx, size_t offset,
379
size_t match_length, size_t total_output,
380
ostream_t *const out);
381
382
/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
383
384
size_t ZSTD_decompress(void *const dst, const size_t dst_len,
385
const void *const src, const size_t src_len) {
386
dictionary_t* const uninit_dict = create_dictionary();
387
size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
388
src_len, uninit_dict);
389
free_dictionary(uninit_dict);
390
return decomp_size;
391
}
392
393
size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
394
const void *const src, const size_t src_len,
395
dictionary_t* parsed_dict) {
396
397
istream_t in = IO_make_istream(src, src_len);
398
ostream_t out = IO_make_ostream(dst, dst_len);
399
400
// "A content compressed by Zstandard is transformed into a Zstandard frame.
401
// Multiple frames can be appended into a single file or stream. A frame is
402
// totally independent, has a defined beginning and end, and a set of
403
// parameters which tells the decoder how to decompress it."
404
405
/* this decoder assumes decompression of a single frame */
406
decode_frame(&out, &in, parsed_dict);
407
408
return (size_t)(out.ptr - (u8 *)dst);
409
}
410
411
/******* FRAME DECODING ******************************************************/
412
413
static void decode_data_frame(ostream_t *const out, istream_t *const in,
414
const dictionary_t *const dict);
415
static void init_frame_context(frame_context_t *const context,
416
istream_t *const in,
417
const dictionary_t *const dict);
418
static void free_frame_context(frame_context_t *const context);
419
static void parse_frame_header(frame_header_t *const header,
420
istream_t *const in);
421
static void frame_context_apply_dict(frame_context_t *const ctx,
422
const dictionary_t *const dict);
423
424
static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
425
istream_t *const in);
426
427
static void decode_frame(ostream_t *const out, istream_t *const in,
428
const dictionary_t *const dict) {
429
const u32 magic_number = (u32)IO_read_bits(in, 32);
430
if (magic_number == ZSTD_MAGIC_NUMBER) {
431
// ZSTD frame
432
decode_data_frame(out, in, dict);
433
434
return;
435
}
436
437
// not a real frame or a skippable frame
438
ERROR("Tried to decode non-ZSTD frame");
439
}
440
441
/// Decode a frame that contains compressed data. Not all frames do as there
442
/// are skippable frames.
443
/// See
444
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
445
static void decode_data_frame(ostream_t *const out, istream_t *const in,
446
const dictionary_t *const dict) {
447
frame_context_t ctx;
448
449
// Initialize the context that needs to be carried from block to block
450
init_frame_context(&ctx, in, dict);
451
452
if (ctx.header.frame_content_size != 0 &&
453
ctx.header.frame_content_size > out->len) {
454
OUT_SIZE();
455
}
456
457
decompress_data(&ctx, out, in);
458
459
free_frame_context(&ctx);
460
}
461
462
/// Takes the information provided in the header and dictionary, and initializes
463
/// the context for this frame
464
static void init_frame_context(frame_context_t *const context,
465
istream_t *const in,
466
const dictionary_t *const dict) {
467
// Most fields in context are correct when initialized to 0
468
memset(context, 0, sizeof(frame_context_t));
469
470
// Parse data from the frame header
471
parse_frame_header(&context->header, in);
472
473
// Set up the offset history for the repeat offset commands
474
context->previous_offsets[0] = 1;
475
context->previous_offsets[1] = 4;
476
context->previous_offsets[2] = 8;
477
478
// Apply details from the dict if it exists
479
frame_context_apply_dict(context, dict);
480
}
481
482
static void free_frame_context(frame_context_t *const context) {
483
HUF_free_dtable(&context->literals_dtable);
484
485
FSE_free_dtable(&context->ll_dtable);
486
FSE_free_dtable(&context->ml_dtable);
487
FSE_free_dtable(&context->of_dtable);
488
489
memset(context, 0, sizeof(frame_context_t));
490
}
491
492
static void parse_frame_header(frame_header_t *const header,
493
istream_t *const in) {
494
// "The first header's byte is called the Frame_Header_Descriptor. It tells
495
// which other fields are present. Decoding this byte is enough to tell the
496
// size of Frame_Header.
497
//
498
// Bit number Field name
499
// 7-6 Frame_Content_Size_flag
500
// 5 Single_Segment_flag
501
// 4 Unused_bit
502
// 3 Reserved_bit
503
// 2 Content_Checksum_flag
504
// 1-0 Dictionary_ID_flag"
505
const u8 descriptor = (u8)IO_read_bits(in, 8);
506
507
// decode frame header descriptor into flags
508
const u8 frame_content_size_flag = descriptor >> 6;
509
const u8 single_segment_flag = (descriptor >> 5) & 1;
510
const u8 reserved_bit = (descriptor >> 3) & 1;
511
const u8 content_checksum_flag = (descriptor >> 2) & 1;
512
const u8 dictionary_id_flag = descriptor & 3;
513
514
if (reserved_bit != 0) {
515
CORRUPTION();
516
}
517
518
header->single_segment_flag = single_segment_flag;
519
header->content_checksum_flag = content_checksum_flag;
520
521
// decode window size
522
if (!single_segment_flag) {
523
// "Provides guarantees on maximum back-reference distance that will be
524
// used within compressed data. This information is important for
525
// decoders to allocate enough memory.
526
//
527
// Bit numbers 7-3 2-0
528
// Field name Exponent Mantissa"
529
u8 window_descriptor = (u8)IO_read_bits(in, 8);
530
u8 exponent = window_descriptor >> 3;
531
u8 mantissa = window_descriptor & 7;
532
533
// Use the algorithm from the specification to compute window size
534
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
535
size_t window_base = (size_t)1 << (10 + exponent);
536
size_t window_add = (window_base / 8) * mantissa;
537
header->window_size = window_base + window_add;
538
}
539
540
// decode dictionary id if it exists
541
if (dictionary_id_flag) {
542
// "This is a variable size field, which contains the ID of the
543
// dictionary required to properly decode the frame. Note that this
544
// field is optional. When it's not present, it's up to the caller to
545
// make sure it uses the correct dictionary. Format is little-endian."
546
const int bytes_array[] = {0, 1, 2, 4};
547
const int bytes = bytes_array[dictionary_id_flag];
548
549
header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
550
} else {
551
header->dictionary_id = 0;
552
}
553
554
// decode frame content size if it exists
555
if (single_segment_flag || frame_content_size_flag) {
556
// "This is the original (uncompressed) size. This information is
557
// optional. The Field_Size is provided according to value of
558
// Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
559
// present), 1, 2, 4 or 8 bytes. Format is little-endian."
560
//
561
// if frame_content_size_flag == 0 but single_segment_flag is set, we
562
// still have a 1 byte field
563
const int bytes_array[] = {1, 2, 4, 8};
564
const int bytes = bytes_array[frame_content_size_flag];
565
566
header->frame_content_size = IO_read_bits(in, bytes * 8);
567
if (bytes == 2) {
568
// "When Field_Size is 2, the offset of 256 is added."
569
header->frame_content_size += 256;
570
}
571
} else {
572
header->frame_content_size = 0;
573
}
574
575
if (single_segment_flag) {
576
// "The Window_Descriptor byte is optional. It is absent when
577
// Single_Segment_flag is set. In this case, the maximum back-reference
578
// distance is the content size itself, which can be any value from 1 to
579
// 2^64-1 bytes (16 EB)."
580
header->window_size = header->frame_content_size;
581
}
582
}
583
584
/// Decompress the data from a frame block by block
585
static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
586
istream_t *const in) {
587
// "A frame encapsulates one or multiple blocks. Each block can be
588
// compressed or not, and has a guaranteed maximum content size, which
589
// depends on frame parameters. Unlike frames, each block depends on
590
// previous blocks for proper decoding. However, each block can be
591
// decompressed without waiting for its successor, allowing streaming
592
// operations."
593
int last_block = 0;
594
do {
595
// "Last_Block
596
//
597
// The lowest bit signals if this block is the last one. Frame ends
598
// right after this block.
599
//
600
// Block_Type and Block_Size
601
//
602
// The next 2 bits represent the Block_Type, while the remaining 21 bits
603
// represent the Block_Size. Format is little-endian."
604
last_block = (int)IO_read_bits(in, 1);
605
const int block_type = (int)IO_read_bits(in, 2);
606
const size_t block_len = IO_read_bits(in, 21);
607
608
switch (block_type) {
609
case 0: {
610
// "Raw_Block - this is an uncompressed block. Block_Size is the
611
// number of bytes to read and copy."
612
const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
613
u8 *const write_ptr = IO_get_write_ptr(out, block_len);
614
615
// Copy the raw data into the output
616
memcpy(write_ptr, read_ptr, block_len);
617
618
ctx->current_total_output += block_len;
619
break;
620
}
621
case 1: {
622
// "RLE_Block - this is a single byte, repeated N times. In which
623
// case, Block_Size is the size to regenerate, while the
624
// "compressed" block is just 1 byte (the byte to repeat)."
625
const u8 *const read_ptr = IO_get_read_ptr(in, 1);
626
u8 *const write_ptr = IO_get_write_ptr(out, block_len);
627
628
// Copy `block_len` copies of `read_ptr[0]` to the output
629
memset(write_ptr, read_ptr[0], block_len);
630
631
ctx->current_total_output += block_len;
632
break;
633
}
634
case 2: {
635
// "Compressed_Block - this is a Zstandard compressed block,
636
// detailed in another section of this specification. Block_Size is
637
// the compressed size.
638
639
// Create a sub-stream for the block
640
istream_t block_stream = IO_make_sub_istream(in, block_len);
641
decompress_block(ctx, out, &block_stream);
642
break;
643
}
644
case 3:
645
// "Reserved - this is not a block. This value cannot be used with
646
// current version of this specification."
647
CORRUPTION();
648
break;
649
default:
650
IMPOSSIBLE();
651
}
652
} while (!last_block);
653
654
if (ctx->header.content_checksum_flag) {
655
// This program does not support checking the checksum, so skip over it
656
// if it's present
657
IO_advance_input(in, 4);
658
}
659
}
660
/******* END FRAME DECODING ***************************************************/
661
662
/******* BLOCK DECOMPRESSION **************************************************/
663
static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
664
istream_t *const in) {
665
// "A compressed block consists of 2 sections :
666
//
667
// Literals_Section
668
// Sequences_Section"
669
670
671
// Part 1: decode the literals block
672
u8 *literals = NULL;
673
const size_t literals_size = decode_literals(ctx, in, &literals);
674
675
// Part 2: decode the sequences block
676
sequence_command_t *sequences = NULL;
677
const size_t num_sequences =
678
decode_sequences(ctx, in, &sequences);
679
680
// Part 3: combine literals and sequence commands to generate output
681
execute_sequences(ctx, out, literals, literals_size, sequences,
682
num_sequences);
683
free(literals);
684
free(sequences);
685
}
686
/******* END BLOCK DECOMPRESSION **********************************************/
687
688
/******* LITERALS DECODING ****************************************************/
689
static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
690
const int block_type,
691
const int size_format);
692
static size_t decode_literals_compressed(frame_context_t *const ctx,
693
istream_t *const in,
694
u8 **const literals,
695
const int block_type,
696
const int size_format);
697
static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
698
static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
699
int *const num_symbs);
700
701
static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
702
u8 **const literals) {
703
// "Literals can be stored uncompressed or compressed using Huffman prefix
704
// codes. When compressed, an optional tree description can be present,
705
// followed by 1 or 4 streams."
706
//
707
// "Literals_Section_Header
708
//
709
// Header is in charge of describing how literals are packed. It's a
710
// byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
711
// little-endian convention."
712
//
713
// "Literals_Block_Type
714
//
715
// This field uses 2 lowest bits of first byte, describing 4 different block
716
// types"
717
//
718
// size_format takes between 1 and 2 bits
719
int block_type = (int)IO_read_bits(in, 2);
720
int size_format = (int)IO_read_bits(in, 2);
721
722
if (block_type <= 1) {
723
// Raw or RLE literals block
724
return decode_literals_simple(in, literals, block_type,
725
size_format);
726
} else {
727
// Huffman compressed literals
728
return decode_literals_compressed(ctx, in, literals, block_type,
729
size_format);
730
}
731
}
732
733
/// Decodes literals blocks in raw or RLE form
734
static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
735
const int block_type,
736
const int size_format) {
737
size_t size;
738
switch (size_format) {
739
// These cases are in the form ?0
740
// In this case, the ? bit is actually part of the size field
741
case 0:
742
case 2:
743
// "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
744
IO_rewind_bits(in, 1);
745
size = IO_read_bits(in, 5);
746
break;
747
case 1:
748
// "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
749
size = IO_read_bits(in, 12);
750
break;
751
case 3:
752
// "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
753
size = IO_read_bits(in, 20);
754
break;
755
default:
756
// Size format is in range 0-3
757
IMPOSSIBLE();
758
}
759
760
if (size > MAX_LITERALS_SIZE) {
761
CORRUPTION();
762
}
763
764
*literals = malloc(size);
765
if (!*literals) {
766
BAD_ALLOC();
767
}
768
769
switch (block_type) {
770
case 0: {
771
// "Raw_Literals_Block - Literals are stored uncompressed."
772
const u8 *const read_ptr = IO_get_read_ptr(in, size);
773
memcpy(*literals, read_ptr, size);
774
break;
775
}
776
case 1: {
777
// "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
778
const u8 *const read_ptr = IO_get_read_ptr(in, 1);
779
memset(*literals, read_ptr[0], size);
780
break;
781
}
782
default:
783
IMPOSSIBLE();
784
}
785
786
return size;
787
}
788
789
/// Decodes Huffman compressed literals
790
static size_t decode_literals_compressed(frame_context_t *const ctx,
791
istream_t *const in,
792
u8 **const literals,
793
const int block_type,
794
const int size_format) {
795
size_t regenerated_size, compressed_size;
796
// Only size_format=0 has 1 stream, so default to 4
797
int num_streams = 4;
798
switch (size_format) {
799
case 0:
800
// "A single stream. Both Compressed_Size and Regenerated_Size use 10
801
// bits (0-1023)."
802
num_streams = 1;
803
// Fall through as it has the same size format
804
/* fallthrough */
805
case 1:
806
// "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
807
// (0-1023)."
808
regenerated_size = IO_read_bits(in, 10);
809
compressed_size = IO_read_bits(in, 10);
810
break;
811
case 2:
812
// "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
813
// (0-16383)."
814
regenerated_size = IO_read_bits(in, 14);
815
compressed_size = IO_read_bits(in, 14);
816
break;
817
case 3:
818
// "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
819
// (0-262143)."
820
regenerated_size = IO_read_bits(in, 18);
821
compressed_size = IO_read_bits(in, 18);
822
break;
823
default:
824
// Impossible
825
IMPOSSIBLE();
826
}
827
if (regenerated_size > MAX_LITERALS_SIZE) {
828
CORRUPTION();
829
}
830
831
*literals = malloc(regenerated_size);
832
if (!*literals) {
833
BAD_ALLOC();
834
}
835
836
ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
837
istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
838
839
if (block_type == 2) {
840
// Decode the provided Huffman table
841
// "This section is only present when Literals_Block_Type type is
842
// Compressed_Literals_Block (2)."
843
844
HUF_free_dtable(&ctx->literals_dtable);
845
decode_huf_table(&ctx->literals_dtable, &huf_stream);
846
} else {
847
// If the previous Huffman table is being repeated, ensure it exists
848
if (!ctx->literals_dtable.symbols) {
849
CORRUPTION();
850
}
851
}
852
853
size_t symbols_decoded;
854
if (num_streams == 1) {
855
symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
856
} else {
857
symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
858
}
859
860
if (symbols_decoded != regenerated_size) {
861
CORRUPTION();
862
}
863
864
return regenerated_size;
865
}
866
867
// Decode the Huffman table description
868
static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
869
// "All literal values from zero (included) to last present one (excluded)
870
// are represented by Weight with values from 0 to Max_Number_of_Bits."
871
872
// "This is a single byte value (0-255), which describes how to decode the list of weights."
873
const u8 header = IO_read_bits(in, 8);
874
875
u8 weights[HUF_MAX_SYMBS];
876
memset(weights, 0, sizeof(weights));
877
878
int num_symbs;
879
880
if (header >= 128) {
881
// "This is a direct representation, where each Weight is written
882
// directly as a 4 bits field (0-15). The full representation occupies
883
// ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
884
// even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
885
// 127"
886
num_symbs = header - 127;
887
const size_t bytes = (num_symbs + 1) / 2;
888
889
const u8 *const weight_src = IO_get_read_ptr(in, bytes);
890
891
for (int i = 0; i < num_symbs; i++) {
892
// "They are encoded forward, 2
893
// weights to a byte with the first weight taking the top four bits
894
// and the second taking the bottom four (e.g. the following
895
// operations could be used to read the weights: Weight[0] =
896
// (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
897
if (i % 2 == 0) {
898
weights[i] = weight_src[i / 2] >> 4;
899
} else {
900
weights[i] = weight_src[i / 2] & 0xf;
901
}
902
}
903
} else {
904
// The weights are FSE encoded, decode them before we can construct the
905
// table
906
istream_t fse_stream = IO_make_sub_istream(in, header);
907
ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
908
fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
909
}
910
911
// Construct the table using the decoded weights
912
HUF_init_dtable_usingweights(dtable, weights, num_symbs);
913
}
914
915
static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
916
int *const num_symbs) {
917
const int MAX_ACCURACY_LOG = 7;
918
919
FSE_dtable dtable;
920
921
// "An FSE bitstream starts by a header, describing probabilities
922
// distribution. It will create a Decoding Table. For a list of Huffman
923
// weights, maximum accuracy is 7 bits."
924
FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
925
926
// Decode the weights
927
*num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
928
929
FSE_free_dtable(&dtable);
930
}
931
/******* END LITERALS DECODING ************************************************/
932
933
/******* SEQUENCE DECODING ****************************************************/
934
/// The combination of FSE states needed to decode sequences
935
typedef struct {
936
FSE_dtable ll_table;
937
FSE_dtable of_table;
938
FSE_dtable ml_table;
939
940
u16 ll_state;
941
u16 of_state;
942
u16 ml_state;
943
} sequence_states_t;
944
945
/// Different modes to signal to decode_seq_tables what to do
946
typedef enum {
947
seq_literal_length = 0,
948
seq_offset = 1,
949
seq_match_length = 2,
950
} seq_part_t;
951
952
typedef enum {
953
seq_predefined = 0,
954
seq_rle = 1,
955
seq_fse = 2,
956
seq_repeat = 3,
957
} seq_mode_t;
958
959
/// The predefined FSE distribution tables for `seq_predefined` mode
960
static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
961
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2,
962
2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
963
static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
964
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
965
1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
966
static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
967
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
968
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
969
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
970
971
/// The sequence decoding baseline and number of additional bits to read/add
972
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
973
static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
974
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
975
12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40,
976
48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
977
static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
978
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
979
1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
980
981
static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
982
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
983
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
984
31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
985
99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
986
static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
987
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
988
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
989
2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
990
991
/// Offset decoding is simpler so we just need a maximum code value
992
static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
993
994
static void decompress_sequences(frame_context_t *const ctx,
995
istream_t *const in,
996
sequence_command_t *const sequences,
997
const size_t num_sequences);
998
static sequence_command_t decode_sequence(sequence_states_t *const state,
999
const u8 *const src,
1000
i64 *const offset);
1001
static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1002
const seq_part_t type, const seq_mode_t mode);
1003
1004
static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1005
sequence_command_t **const sequences) {
1006
// "A compressed block is a succession of sequences . A sequence is a
1007
// literal copy command, followed by a match copy command. A literal copy
1008
// command specifies a length. It is the number of bytes to be copied (or
1009
// extracted) from the literal section. A match copy command specifies an
1010
// offset and a length. The offset gives the position to copy from, which
1011
// can be within a previous block."
1012
1013
size_t num_sequences;
1014
1015
// "Number_of_Sequences
1016
//
1017
// This is a variable size field using between 1 and 3 bytes. Let's call its
1018
// first byte byte0."
1019
u8 header = IO_read_bits(in, 8);
1020
if (header == 0) {
1021
// "There are no sequences. The sequence section stops there.
1022
// Regenerated content is defined entirely by literals section."
1023
*sequences = NULL;
1024
return 0;
1025
} else if (header < 128) {
1026
// "Number_of_Sequences = byte0 . Uses 1 byte."
1027
num_sequences = header;
1028
} else if (header < 255) {
1029
// "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1030
num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1031
} else {
1032
// "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1033
num_sequences = IO_read_bits(in, 16) + 0x7F00;
1034
}
1035
1036
*sequences = malloc(num_sequences * sizeof(sequence_command_t));
1037
if (!*sequences) {
1038
BAD_ALLOC();
1039
}
1040
1041
decompress_sequences(ctx, in, *sequences, num_sequences);
1042
return num_sequences;
1043
}
1044
1045
/// Decompress the FSE encoded sequence commands
1046
static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1047
sequence_command_t *const sequences,
1048
const size_t num_sequences) {
1049
// "The Sequences_Section regroup all symbols required to decode commands.
1050
// There are 3 symbol types : literals lengths, offsets and match lengths.
1051
// They are encoded together, interleaved, in a single bitstream."
1052
1053
// "Symbol compression modes
1054
//
1055
// This is a single byte, defining the compression mode of each symbol
1056
// type."
1057
//
1058
// Bit number : Field name
1059
// 7-6 : Literals_Lengths_Mode
1060
// 5-4 : Offsets_Mode
1061
// 3-2 : Match_Lengths_Mode
1062
// 1-0 : Reserved
1063
u8 compression_modes = IO_read_bits(in, 8);
1064
1065
if ((compression_modes & 3) != 0) {
1066
// Reserved bits set
1067
CORRUPTION();
1068
}
1069
1070
// "Following the header, up to 3 distribution tables can be described. When
1071
// present, they are in this order :
1072
//
1073
// Literals lengths
1074
// Offsets
1075
// Match Lengths"
1076
// Update the tables we have stored in the context
1077
decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1078
(compression_modes >> 6) & 3);
1079
1080
decode_seq_table(&ctx->of_dtable, in, seq_offset,
1081
(compression_modes >> 4) & 3);
1082
1083
decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1084
(compression_modes >> 2) & 3);
1085
1086
1087
sequence_states_t states;
1088
1089
// Initialize the decoding tables
1090
{
1091
states.ll_table = ctx->ll_dtable;
1092
states.of_table = ctx->of_dtable;
1093
states.ml_table = ctx->ml_dtable;
1094
}
1095
1096
const size_t len = IO_istream_len(in);
1097
const u8 *const src = IO_get_read_ptr(in, len);
1098
1099
// "After writing the last bit containing information, the compressor writes
1100
// a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1101
const int padding = 8 - highest_set_bit(src[len - 1]);
1102
// The offset starts at the end because FSE streams are read backwards
1103
i64 bit_offset = (i64)(len * 8 - (size_t)padding);
1104
1105
// "The bitstream starts with initial state values, each using the required
1106
// number of bits in their respective accuracy, decoded previously from
1107
// their normalized distribution.
1108
//
1109
// It starts by Literals_Length_State, followed by Offset_State, and finally
1110
// Match_Length_State."
1111
FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1112
FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1113
FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1114
1115
for (size_t i = 0; i < num_sequences; i++) {
1116
// Decode sequences one by one
1117
sequences[i] = decode_sequence(&states, src, &bit_offset);
1118
}
1119
1120
if (bit_offset != 0) {
1121
CORRUPTION();
1122
}
1123
}
1124
1125
// Decode a single sequence and update the state
1126
static sequence_command_t decode_sequence(sequence_states_t *const states,
1127
const u8 *const src,
1128
i64 *const offset) {
1129
// "Each symbol is a code in its own context, which specifies Baseline and
1130
// Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1131
// additional bits in the same bitstream."
1132
1133
// Decode symbols, but don't update states
1134
const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1135
const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1136
const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1137
1138
// Offset doesn't need a max value as it's not decoded using a table
1139
if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1140
ml_code > SEQ_MAX_CODES[seq_match_length]) {
1141
CORRUPTION();
1142
}
1143
1144
// Read the interleaved bits
1145
sequence_command_t seq;
1146
// "Decoding starts by reading the Number_of_Bits required to decode Offset.
1147
// It then does the same for Match_Length, and then for Literals_Length."
1148
seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1149
1150
seq.match_length =
1151
SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1152
STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1153
1154
seq.literal_length =
1155
SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1156
STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1157
1158
// "If it is not the last sequence in the block, the next operation is to
1159
// update states. Using the rules pre-calculated in the decoding tables,
1160
// Literals_Length_State is updated, followed by Match_Length_State, and
1161
// then Offset_State."
1162
// If the stream is complete don't read bits to update state
1163
if (*offset != 0) {
1164
FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1165
FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1166
FSE_update_state(&states->of_table, &states->of_state, src, offset);
1167
}
1168
1169
return seq;
1170
}
1171
1172
/// Given a sequence part and table mode, decode the FSE distribution
1173
/// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1174
static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1175
const seq_part_t type, const seq_mode_t mode) {
1176
// Constant arrays indexed by seq_part_t
1177
const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1178
SEQ_OFFSET_DEFAULT_DIST,
1179
SEQ_MATCH_LENGTH_DEFAULT_DIST};
1180
const size_t default_distribution_lengths[] = {36, 29, 53};
1181
const size_t default_distribution_accuracies[] = {6, 5, 6};
1182
1183
const size_t max_accuracies[] = {9, 8, 9};
1184
1185
if (mode != seq_repeat) {
1186
// Free old one before overwriting
1187
FSE_free_dtable(table);
1188
}
1189
1190
switch (mode) {
1191
case seq_predefined: {
1192
// "Predefined_Mode : uses a predefined distribution table."
1193
const i16 *distribution = default_distributions[type];
1194
const size_t symbs = default_distribution_lengths[type];
1195
const size_t accuracy_log = default_distribution_accuracies[type];
1196
1197
FSE_init_dtable(table, distribution, symbs, accuracy_log);
1198
break;
1199
}
1200
case seq_rle: {
1201
// "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1202
const u8 symb = IO_get_read_ptr(in, 1)[0];
1203
FSE_init_dtable_rle(table, symb);
1204
break;
1205
}
1206
case seq_fse: {
1207
// "FSE_Compressed_Mode : standard FSE compression. A distribution table
1208
// will be present "
1209
FSE_decode_header(table, in, max_accuracies[type]);
1210
break;
1211
}
1212
case seq_repeat:
1213
// "Repeat_Mode : re-use distribution table from previous compressed
1214
// block."
1215
// Nothing to do here, table will be unchanged
1216
if (!table->symbols) {
1217
// This mode is invalid if we don't already have a table
1218
CORRUPTION();
1219
}
1220
break;
1221
default:
1222
// Impossible, as mode is from 0-3
1223
IMPOSSIBLE();
1224
break;
1225
}
1226
1227
}
1228
/******* END SEQUENCE DECODING ************************************************/
1229
1230
/******* SEQUENCE EXECUTION ***************************************************/
1231
static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1232
const u8 *const literals,
1233
const size_t literals_len,
1234
const sequence_command_t *const sequences,
1235
const size_t num_sequences) {
1236
istream_t litstream = IO_make_istream(literals, literals_len);
1237
1238
u64 *const offset_hist = ctx->previous_offsets;
1239
size_t total_output = ctx->current_total_output;
1240
1241
for (size_t i = 0; i < num_sequences; i++) {
1242
const sequence_command_t seq = sequences[i];
1243
{
1244
const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1245
total_output += literals_size;
1246
}
1247
1248
size_t const offset = compute_offset(seq, offset_hist);
1249
1250
size_t const match_length = seq.match_length;
1251
1252
execute_match_copy(ctx, offset, match_length, total_output, out);
1253
1254
total_output += match_length;
1255
}
1256
1257
// Copy any leftover literals
1258
{
1259
size_t len = IO_istream_len(&litstream);
1260
copy_literals(len, &litstream, out);
1261
total_output += len;
1262
}
1263
1264
ctx->current_total_output = total_output;
1265
}
1266
1267
static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1268
ostream_t *const out) {
1269
// If the sequence asks for more literals than are left, the
1270
// sequence must be corrupted
1271
if (literal_length > IO_istream_len(litstream)) {
1272
CORRUPTION();
1273
}
1274
1275
u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1276
const u8 *const read_ptr =
1277
IO_get_read_ptr(litstream, literal_length);
1278
// Copy literals to output
1279
memcpy(write_ptr, read_ptr, literal_length);
1280
1281
return literal_length;
1282
}
1283
1284
static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1285
size_t offset;
1286
// Offsets are special, we need to handle the repeat offsets
1287
if (seq.offset <= 3) {
1288
// "The first 3 values define a repeated offset and we will call
1289
// them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1290
// They are sorted in recency order, with Repeated_Offset1 meaning
1291
// 'most recent one'".
1292
1293
// Use 0 indexing for the array
1294
u32 idx = seq.offset - 1;
1295
if (seq.literal_length == 0) {
1296
// "There is an exception though, when current sequence's
1297
// literals length is 0. In this case, repeated offsets are
1298
// shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1299
// Repeated_Offset2 becomes Repeated_Offset3, and
1300
// Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1301
idx++;
1302
}
1303
1304
if (idx == 0) {
1305
offset = offset_hist[0];
1306
} else {
1307
// If idx == 3 then literal length was 0 and the offset was 3,
1308
// as per the exception listed above
1309
offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1310
1311
// If idx == 1 we don't need to modify offset_hist[2], since
1312
// we're using the second-most recent code
1313
if (idx > 1) {
1314
offset_hist[2] = offset_hist[1];
1315
}
1316
offset_hist[1] = offset_hist[0];
1317
offset_hist[0] = offset;
1318
}
1319
} else {
1320
// When it's not a repeat offset:
1321
// "if (Offset_Value > 3) offset = Offset_Value - 3;"
1322
offset = seq.offset - 3;
1323
1324
// Shift back history
1325
offset_hist[2] = offset_hist[1];
1326
offset_hist[1] = offset_hist[0];
1327
offset_hist[0] = offset;
1328
}
1329
return offset;
1330
}
1331
1332
static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1333
size_t match_length, size_t total_output,
1334
ostream_t *const out) {
1335
u8 *write_ptr = IO_get_write_ptr(out, match_length);
1336
if (total_output <= ctx->header.window_size) {
1337
// In this case offset might go back into the dictionary
1338
if (offset > total_output + ctx->dict_content_len) {
1339
// The offset goes beyond even the dictionary
1340
CORRUPTION();
1341
}
1342
1343
if (offset > total_output) {
1344
// "The rest of the dictionary is its content. The content act
1345
// as a "past" in front of data to compress or decompress, so it
1346
// can be referenced in sequence commands."
1347
const size_t dict_copy =
1348
MIN(offset - total_output, match_length);
1349
const size_t dict_offset =
1350
ctx->dict_content_len - (offset - total_output);
1351
1352
memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1353
write_ptr += dict_copy;
1354
match_length -= dict_copy;
1355
}
1356
} else if (offset > ctx->header.window_size) {
1357
CORRUPTION();
1358
}
1359
1360
// We must copy byte by byte because the match length might be larger
1361
// than the offset
1362
// ex: if the output so far was "abc", a command with offset=3 and
1363
// match_length=6 would produce "abcabcabc" as the new output
1364
for (size_t j = 0; j < match_length; j++) {
1365
*write_ptr = *(write_ptr - offset);
1366
write_ptr++;
1367
}
1368
}
1369
/******* END SEQUENCE EXECUTION ***********************************************/
1370
1371
/******* OUTPUT SIZE COUNTING *************************************************/
1372
/// Get the decompressed size of an input stream so memory can be allocated in
1373
/// advance.
1374
/// This implementation assumes `src` points to a single ZSTD-compressed frame
1375
size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1376
istream_t in = IO_make_istream(src, src_len);
1377
1378
// get decompressed size from ZSTD frame header
1379
{
1380
const u32 magic_number = (u32)IO_read_bits(&in, 32);
1381
1382
if (magic_number == ZSTD_MAGIC_NUMBER) {
1383
// ZSTD frame
1384
frame_header_t header;
1385
parse_frame_header(&header, &in);
1386
1387
if (header.frame_content_size == 0 && !header.single_segment_flag) {
1388
// Content size not provided, we can't tell
1389
return (size_t)-1;
1390
}
1391
1392
return header.frame_content_size;
1393
} else {
1394
// not a real frame or skippable frame
1395
ERROR("ZSTD frame magic number did not match");
1396
}
1397
}
1398
}
1399
/******* END OUTPUT SIZE COUNTING *********************************************/
1400
1401
/******* DICTIONARY PARSING ***************************************************/
1402
dictionary_t* create_dictionary() {
1403
dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
1404
if (!dict) {
1405
BAD_ALLOC();
1406
}
1407
return dict;
1408
}
1409
1410
/// Free an allocated dictionary
1411
void free_dictionary(dictionary_t *const dict) {
1412
HUF_free_dtable(&dict->literals_dtable);
1413
FSE_free_dtable(&dict->ll_dtable);
1414
FSE_free_dtable(&dict->of_dtable);
1415
FSE_free_dtable(&dict->ml_dtable);
1416
1417
free(dict->content);
1418
1419
memset(dict, 0, sizeof(dictionary_t));
1420
1421
free(dict);
1422
}
1423
1424
1425
#if !defined(ZDEC_NO_DICTIONARY)
1426
#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1427
#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1428
1429
static void init_dictionary_content(dictionary_t *const dict,
1430
istream_t *const in);
1431
1432
void parse_dictionary(dictionary_t *const dict, const void *src,
1433
size_t src_len) {
1434
const u8 *byte_src = (const u8 *)src;
1435
memset(dict, 0, sizeof(dictionary_t));
1436
if (src == NULL) { /* cannot initialize dictionary with null src */
1437
NULL_SRC();
1438
}
1439
if (src_len < 8) {
1440
DICT_SIZE_ERROR();
1441
}
1442
1443
istream_t in = IO_make_istream(byte_src, src_len);
1444
1445
const u32 magic_number = IO_read_bits(&in, 32);
1446
if (magic_number != 0xEC30A437) {
1447
// raw content dict
1448
IO_rewind_bits(&in, 32);
1449
init_dictionary_content(dict, &in);
1450
return;
1451
}
1452
1453
dict->dictionary_id = IO_read_bits(&in, 32);
1454
1455
// "Entropy_Tables : following the same format as the tables in compressed
1456
// blocks. They are stored in following order : Huffman tables for literals,
1457
// FSE table for offsets, FSE table for match lengths, and FSE table for
1458
// literals lengths. It's finally followed by 3 offset values, populating
1459
// recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1460
// little-endian each, for a total of 12 bytes. Each recent offset must have
1461
// a value < dictionary size."
1462
decode_huf_table(&dict->literals_dtable, &in);
1463
decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1464
decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1465
decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1466
1467
// Read in the previous offset history
1468
dict->previous_offsets[0] = IO_read_bits(&in, 32);
1469
dict->previous_offsets[1] = IO_read_bits(&in, 32);
1470
dict->previous_offsets[2] = IO_read_bits(&in, 32);
1471
1472
// Ensure the provided offsets aren't too large
1473
// "Each recent offset must have a value < dictionary size."
1474
for (int i = 0; i < 3; i++) {
1475
if (dict->previous_offsets[i] > src_len) {
1476
ERROR("Dictionary corrupted");
1477
}
1478
}
1479
1480
// "Content : The rest of the dictionary is its content. The content act as
1481
// a "past" in front of data to compress or decompress, so it can be
1482
// referenced in sequence commands."
1483
init_dictionary_content(dict, &in);
1484
}
1485
1486
static void init_dictionary_content(dictionary_t *const dict,
1487
istream_t *const in) {
1488
// Copy in the content
1489
dict->content_size = IO_istream_len(in);
1490
dict->content = malloc(dict->content_size);
1491
if (!dict->content) {
1492
BAD_ALLOC();
1493
}
1494
1495
const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1496
1497
memcpy(dict->content, content, dict->content_size);
1498
}
1499
1500
static void HUF_copy_dtable(HUF_dtable *const dst,
1501
const HUF_dtable *const src) {
1502
if (src->max_bits == 0) {
1503
memset(dst, 0, sizeof(HUF_dtable));
1504
return;
1505
}
1506
1507
const size_t size = (size_t)1 << src->max_bits;
1508
dst->max_bits = src->max_bits;
1509
1510
dst->symbols = malloc(size);
1511
dst->num_bits = malloc(size);
1512
if (!dst->symbols || !dst->num_bits) {
1513
BAD_ALLOC();
1514
}
1515
1516
memcpy(dst->symbols, src->symbols, size);
1517
memcpy(dst->num_bits, src->num_bits, size);
1518
}
1519
1520
static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
1521
if (src->accuracy_log == 0) {
1522
memset(dst, 0, sizeof(FSE_dtable));
1523
return;
1524
}
1525
1526
size_t size = (size_t)1 << src->accuracy_log;
1527
dst->accuracy_log = src->accuracy_log;
1528
1529
dst->symbols = malloc(size);
1530
dst->num_bits = malloc(size);
1531
dst->new_state_base = malloc(size * sizeof(u16));
1532
if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
1533
BAD_ALLOC();
1534
}
1535
1536
memcpy(dst->symbols, src->symbols, size);
1537
memcpy(dst->num_bits, src->num_bits, size);
1538
memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
1539
}
1540
1541
/// A dictionary acts as initializing values for the frame context before
1542
/// decompression, so we implement it by applying it's predetermined
1543
/// tables and content to the context before beginning decompression
1544
static void frame_context_apply_dict(frame_context_t *const ctx,
1545
const dictionary_t *const dict) {
1546
// If the content pointer is NULL then it must be an empty dict
1547
if (!dict || !dict->content)
1548
return;
1549
1550
// If the requested dictionary_id is non-zero, the correct dictionary must
1551
// be present
1552
if (ctx->header.dictionary_id != 0 &&
1553
ctx->header.dictionary_id != dict->dictionary_id) {
1554
ERROR("Wrong dictionary provided");
1555
}
1556
1557
// Copy the dict content to the context for references during sequence
1558
// execution
1559
ctx->dict_content = dict->content;
1560
ctx->dict_content_len = dict->content_size;
1561
1562
// If it's a formatted dict copy the precomputed tables in so they can
1563
// be used in the table repeat modes
1564
if (dict->dictionary_id != 0) {
1565
// Deep copy the entropy tables so they can be freed independently of
1566
// the dictionary struct
1567
HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
1568
FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
1569
FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
1570
FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
1571
1572
// Copy the repeated offsets
1573
memcpy(ctx->previous_offsets, dict->previous_offsets,
1574
sizeof(ctx->previous_offsets));
1575
}
1576
}
1577
1578
#else // ZDEC_NO_DICTIONARY is defined
1579
1580
static void frame_context_apply_dict(frame_context_t *const ctx,
1581
const dictionary_t *const dict) {
1582
(void)ctx;
1583
if (dict && dict->content) ERROR("dictionary not supported");
1584
}
1585
1586
#endif
1587
/******* END DICTIONARY PARSING ***********************************************/
1588
1589
/******* IO STREAM OPERATIONS *************************************************/
1590
1591
/// Reads `num` bits from a bitstream, and updates the internal offset
1592
static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1593
if (num_bits > 64 || num_bits <= 0) {
1594
ERROR("Attempt to read an invalid number of bits");
1595
}
1596
1597
const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1598
const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1599
if (bytes > in->len) {
1600
INP_SIZE();
1601
}
1602
1603
const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1604
1605
in->bit_offset = (num_bits + in->bit_offset) % 8;
1606
in->ptr += full_bytes;
1607
in->len -= full_bytes;
1608
1609
return result;
1610
}
1611
1612
/// If a non-zero number of bits have been read from the current byte, advance
1613
/// the offset to the next byte
1614
static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1615
if (num_bits < 0) {
1616
ERROR("Attempting to rewind stream by a negative number of bits");
1617
}
1618
1619
// move the offset back by `num_bits` bits
1620
const int new_offset = in->bit_offset - num_bits;
1621
// determine the number of whole bytes we have to rewind, rounding up to an
1622
// integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1623
const i64 bytes = -(new_offset - 7) / 8;
1624
1625
in->ptr -= bytes;
1626
in->len += bytes;
1627
// make sure the resulting `bit_offset` is positive, as mod in C does not
1628
// convert numbers from negative to positive (e.g. -22 % 8 == -6)
1629
in->bit_offset = ((new_offset % 8) + 8) % 8;
1630
}
1631
1632
/// If the remaining bits in a byte will be unused, advance to the end of the
1633
/// byte
1634
static inline void IO_align_stream(istream_t *const in) {
1635
if (in->bit_offset != 0) {
1636
if (in->len == 0) {
1637
INP_SIZE();
1638
}
1639
in->ptr++;
1640
in->len--;
1641
in->bit_offset = 0;
1642
}
1643
}
1644
1645
/// Write the given byte into the output stream
1646
static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1647
if (out->len == 0) {
1648
OUT_SIZE();
1649
}
1650
1651
out->ptr[0] = symb;
1652
out->ptr++;
1653
out->len--;
1654
}
1655
1656
/// Returns the number of bytes left to be read in this stream. The stream must
1657
/// be byte aligned.
1658
static inline size_t IO_istream_len(const istream_t *const in) {
1659
return in->len;
1660
}
1661
1662
/// Returns a pointer where `len` bytes can be read, and advances the internal
1663
/// state. The stream must be byte aligned.
1664
static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1665
if (len > in->len) {
1666
INP_SIZE();
1667
}
1668
if (in->bit_offset != 0) {
1669
ERROR("Attempting to operate on a non-byte aligned stream");
1670
}
1671
const u8 *const ptr = in->ptr;
1672
in->ptr += len;
1673
in->len -= len;
1674
1675
return ptr;
1676
}
1677
/// Returns a pointer to write `len` bytes to, and advances the internal state
1678
static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1679
if (len > out->len) {
1680
OUT_SIZE();
1681
}
1682
u8 *const ptr = out->ptr;
1683
out->ptr += len;
1684
out->len -= len;
1685
1686
return ptr;
1687
}
1688
1689
/// Advance the inner state by `len` bytes
1690
static inline void IO_advance_input(istream_t *const in, size_t len) {
1691
if (len > in->len) {
1692
INP_SIZE();
1693
}
1694
if (in->bit_offset != 0) {
1695
ERROR("Attempting to operate on a non-byte aligned stream");
1696
}
1697
1698
in->ptr += len;
1699
in->len -= len;
1700
}
1701
1702
/// Returns an `ostream_t` constructed from the given pointer and length
1703
static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1704
return (ostream_t) { out, len };
1705
}
1706
1707
/// Returns an `istream_t` constructed from the given pointer and length
1708
static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1709
return (istream_t) { in, len, 0 };
1710
}
1711
1712
/// Returns an `istream_t` with the same base as `in`, and length `len`
1713
/// Then, advance `in` to account for the consumed bytes
1714
/// `in` must be byte aligned
1715
static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1716
// Consume `len` bytes of the parent stream
1717
const u8 *const ptr = IO_get_read_ptr(in, len);
1718
1719
// Make a substream using the pointer to those `len` bytes
1720
return IO_make_istream(ptr, len);
1721
}
1722
/******* END IO STREAM OPERATIONS *********************************************/
1723
1724
/******* BITSTREAM OPERATIONS *************************************************/
1725
/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1726
static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1727
const size_t offset) {
1728
if (num_bits > 64) {
1729
ERROR("Attempt to read an invalid number of bits");
1730
}
1731
1732
// Skip over bytes that aren't in range
1733
src += offset / 8;
1734
size_t bit_offset = offset % 8;
1735
u64 res = 0;
1736
1737
int shift = 0;
1738
int left = num_bits;
1739
while (left > 0) {
1740
u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1741
// Read the next byte, shift it to account for the offset, and then mask
1742
// out the top part if we don't need all the bits
1743
res += (((u64)*src++ >> bit_offset) & mask) << shift;
1744
shift += 8 - bit_offset;
1745
left -= 8 - bit_offset;
1746
bit_offset = 0;
1747
}
1748
1749
return res;
1750
}
1751
1752
/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
1753
/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1754
/// `src + offset`. If the offset becomes negative, the extra bits at the
1755
/// bottom are filled in with `0` bits instead of reading from before `src`.
1756
static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1757
i64 *const offset) {
1758
*offset = *offset - bits;
1759
size_t actual_off = *offset;
1760
size_t actual_bits = bits;
1761
// Don't actually read bits from before the start of src, so if `*offset <
1762
// 0` fix actual_off and actual_bits to reflect the quantity to read
1763
if (*offset < 0) {
1764
actual_bits += *offset;
1765
actual_off = 0;
1766
}
1767
u64 res = read_bits_LE(src, actual_bits, actual_off);
1768
1769
if (*offset < 0) {
1770
// Fill in the bottom "overflowed" bits with 0's
1771
res = -*offset >= 64 ? 0 : (res << -*offset);
1772
}
1773
return res;
1774
}
1775
/******* END BITSTREAM OPERATIONS *********************************************/
1776
1777
/******* BIT COUNTING OPERATIONS **********************************************/
1778
/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1779
/// `num`, or `-1` if `num == 0`.
1780
static inline int highest_set_bit(const u64 num) {
1781
for (int i = 63; i >= 0; i--) {
1782
if (((u64)1 << i) <= num) {
1783
return i;
1784
}
1785
}
1786
return -1;
1787
}
1788
/******* END BIT COUNTING OPERATIONS ******************************************/
1789
1790
/******* HUFFMAN PRIMITIVES ***************************************************/
1791
static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1792
u16 *const state, const u8 *const src,
1793
i64 *const offset) {
1794
// Look up the symbol and number of bits to read
1795
const u8 symb = dtable->symbols[*state];
1796
const u8 bits = dtable->num_bits[*state];
1797
const u16 rest = STREAM_read_bits(src, bits, offset);
1798
// Shift `bits` bits out of the state, keeping the low order bits that
1799
// weren't necessary to determine this symbol. Then add in the new bits
1800
// read from the stream.
1801
*state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1802
1803
return symb;
1804
}
1805
1806
static inline void HUF_init_state(const HUF_dtable *const dtable,
1807
u16 *const state, const u8 *const src,
1808
i64 *const offset) {
1809
// Read in a full `dtable->max_bits` bits to initialize the state
1810
const u8 bits = dtable->max_bits;
1811
*state = STREAM_read_bits(src, bits, offset);
1812
}
1813
1814
static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1815
ostream_t *const out,
1816
istream_t *const in) {
1817
const size_t len = IO_istream_len(in);
1818
if (len == 0) {
1819
INP_SIZE();
1820
}
1821
const u8 *const src = IO_get_read_ptr(in, len);
1822
1823
// "Each bitstream must be read backward, that is starting from the end down
1824
// to the beginning. Therefore it's necessary to know the size of each
1825
// bitstream.
1826
//
1827
// It's also necessary to know exactly which bit is the latest. This is
1828
// detected by a final bit flag : the highest bit of latest byte is a
1829
// final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1830
// final-bit-flag itself is not part of the useful bitstream. Hence, the
1831
// last byte contains between 0 and 7 useful bits."
1832
const int padding = 8 - highest_set_bit(src[len - 1]);
1833
1834
// Offset starts at the end because HUF streams are read backwards
1835
i64 bit_offset = len * 8 - padding;
1836
u16 state;
1837
1838
HUF_init_state(dtable, &state, src, &bit_offset);
1839
1840
size_t symbols_written = 0;
1841
while (bit_offset > -dtable->max_bits) {
1842
// Iterate over the stream, decoding one symbol at a time
1843
IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1844
symbols_written++;
1845
}
1846
// "The process continues up to reading the required number of symbols per
1847
// stream. If a bitstream is not entirely and exactly consumed, hence
1848
// reaching exactly its beginning position with all bits consumed, the
1849
// decoding process is considered faulty."
1850
1851
// When all symbols have been decoded, the final state value shouldn't have
1852
// any data from the stream, so it should have "read" dtable->max_bits from
1853
// before the start of `src`
1854
// Therefore `offset`, the edge to start reading new bits at, should be
1855
// dtable->max_bits before the start of the stream
1856
if (bit_offset != -dtable->max_bits) {
1857
CORRUPTION();
1858
}
1859
1860
return symbols_written;
1861
}
1862
1863
static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1864
ostream_t *const out, istream_t *const in) {
1865
// "Compressed size is provided explicitly : in the 4-streams variant,
1866
// bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1867
// value represents the compressed size of one stream, in order. The last
1868
// stream size is deducted from total compressed size and from previously
1869
// decoded stream sizes"
1870
const size_t csize1 = IO_read_bits(in, 16);
1871
const size_t csize2 = IO_read_bits(in, 16);
1872
const size_t csize3 = IO_read_bits(in, 16);
1873
1874
istream_t in1 = IO_make_sub_istream(in, csize1);
1875
istream_t in2 = IO_make_sub_istream(in, csize2);
1876
istream_t in3 = IO_make_sub_istream(in, csize3);
1877
istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1878
1879
size_t total_output = 0;
1880
// Decode each stream independently for simplicity
1881
// If we wanted to we could decode all 4 at the same time for speed,
1882
// utilizing more execution units
1883
total_output += HUF_decompress_1stream(dtable, out, &in1);
1884
total_output += HUF_decompress_1stream(dtable, out, &in2);
1885
total_output += HUF_decompress_1stream(dtable, out, &in3);
1886
total_output += HUF_decompress_1stream(dtable, out, &in4);
1887
1888
return total_output;
1889
}
1890
1891
/// Initializes a Huffman table using canonical Huffman codes
1892
/// For more explanation on canonical Huffman codes see
1893
/// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1894
/// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1895
/// earlier codes)
1896
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1897
const int num_symbs) {
1898
memset(table, 0, sizeof(HUF_dtable));
1899
if (num_symbs > HUF_MAX_SYMBS) {
1900
ERROR("Too many symbols for Huffman");
1901
}
1902
1903
u8 max_bits = 0;
1904
u16 rank_count[HUF_MAX_BITS + 1];
1905
memset(rank_count, 0, sizeof(rank_count));
1906
1907
// Count the number of symbols for each number of bits, and determine the
1908
// depth of the tree
1909
for (int i = 0; i < num_symbs; i++) {
1910
if (bits[i] > HUF_MAX_BITS) {
1911
ERROR("Huffman table depth too large");
1912
}
1913
max_bits = MAX(max_bits, bits[i]);
1914
rank_count[bits[i]]++;
1915
}
1916
1917
const size_t table_size = 1 << max_bits;
1918
table->max_bits = max_bits;
1919
table->symbols = malloc(table_size);
1920
table->num_bits = malloc(table_size);
1921
1922
if (!table->symbols || !table->num_bits) {
1923
free(table->symbols);
1924
free(table->num_bits);
1925
BAD_ALLOC();
1926
}
1927
1928
// "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1929
// order. Symbols with a Weight of zero are removed. Then, starting from
1930
// lowest weight, prefix codes are distributed in order."
1931
1932
u32 rank_idx[HUF_MAX_BITS + 1];
1933
// Initialize the starting codes for each rank (number of bits)
1934
rank_idx[max_bits] = 0;
1935
for (int i = max_bits; i >= 1; i--) {
1936
rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1937
// The entire range takes the same number of bits so we can memset it
1938
memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1939
}
1940
1941
if (rank_idx[0] != table_size) {
1942
CORRUPTION();
1943
}
1944
1945
// Allocate codes and fill in the table
1946
for (int i = 0; i < num_symbs; i++) {
1947
if (bits[i] != 0) {
1948
// Allocate a code for this symbol and set its range in the table
1949
const u16 code = rank_idx[bits[i]];
1950
// Since the code doesn't care about the bottom `max_bits - bits[i]`
1951
// bits of state, it gets a range that spans all possible values of
1952
// the lower bits
1953
const u16 len = 1 << (max_bits - bits[i]);
1954
memset(&table->symbols[code], i, len);
1955
rank_idx[bits[i]] += len;
1956
}
1957
}
1958
}
1959
1960
static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1961
const u8 *const weights,
1962
const int num_symbs) {
1963
// +1 because the last weight is not transmitted in the header
1964
if (num_symbs + 1 > HUF_MAX_SYMBS) {
1965
ERROR("Too many symbols for Huffman");
1966
}
1967
1968
u8 bits[HUF_MAX_SYMBS];
1969
1970
u64 weight_sum = 0;
1971
for (int i = 0; i < num_symbs; i++) {
1972
// Weights are in the same range as bit count
1973
if (weights[i] > HUF_MAX_BITS) {
1974
CORRUPTION();
1975
}
1976
weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1977
}
1978
1979
// Find the first power of 2 larger than the sum
1980
const int max_bits = highest_set_bit(weight_sum) + 1;
1981
const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1982
// If the left over isn't a power of 2, the weights are invalid
1983
if (left_over & (left_over - 1)) {
1984
CORRUPTION();
1985
}
1986
1987
// left_over is used to find the last weight as it's not transmitted
1988
// by inverting 2^(weight - 1) we can determine the value of last_weight
1989
const int last_weight = highest_set_bit(left_over) + 1;
1990
1991
for (int i = 0; i < num_symbs; i++) {
1992
// "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1993
bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1994
}
1995
bits[num_symbs] =
1996
max_bits + 1 - last_weight; // Last weight is always non-zero
1997
1998
HUF_init_dtable(table, bits, num_symbs + 1);
1999
}
2000
2001
static void HUF_free_dtable(HUF_dtable *const dtable) {
2002
free(dtable->symbols);
2003
free(dtable->num_bits);
2004
memset(dtable, 0, sizeof(HUF_dtable));
2005
}
2006
/******* END HUFFMAN PRIMITIVES ***********************************************/
2007
2008
/******* FSE PRIMITIVES *******************************************************/
2009
/// For more description of FSE see
2010
/// https://github.com/Cyan4973/FiniteStateEntropy/
2011
2012
/// Allow a symbol to be decoded without updating state
2013
static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
2014
const u16 state) {
2015
return dtable->symbols[state];
2016
}
2017
2018
/// Consumes bits from the input and uses the current state to determine the
2019
/// next state
2020
static inline void FSE_update_state(const FSE_dtable *const dtable,
2021
u16 *const state, const u8 *const src,
2022
i64 *const offset) {
2023
const u8 bits = dtable->num_bits[*state];
2024
const u16 rest = STREAM_read_bits(src, bits, offset);
2025
*state = dtable->new_state_base[*state] + rest;
2026
}
2027
2028
/// Decodes a single FSE symbol and updates the offset
2029
static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
2030
u16 *const state, const u8 *const src,
2031
i64 *const offset) {
2032
const u8 symb = FSE_peek_symbol(dtable, *state);
2033
FSE_update_state(dtable, state, src, offset);
2034
return symb;
2035
}
2036
2037
static inline void FSE_init_state(const FSE_dtable *const dtable,
2038
u16 *const state, const u8 *const src,
2039
i64 *const offset) {
2040
// Read in a full `accuracy_log` bits to initialize the state
2041
const u8 bits = dtable->accuracy_log;
2042
*state = STREAM_read_bits(src, bits, offset);
2043
}
2044
2045
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2046
ostream_t *const out,
2047
istream_t *const in) {
2048
const size_t len = IO_istream_len(in);
2049
if (len == 0) {
2050
INP_SIZE();
2051
}
2052
const u8 *const src = IO_get_read_ptr(in, len);
2053
2054
// "Each bitstream must be read backward, that is starting from the end down
2055
// to the beginning. Therefore it's necessary to know the size of each
2056
// bitstream.
2057
//
2058
// It's also necessary to know exactly which bit is the latest. This is
2059
// detected by a final bit flag : the highest bit of latest byte is a
2060
// final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2061
// final-bit-flag itself is not part of the useful bitstream. Hence, the
2062
// last byte contains between 0 and 7 useful bits."
2063
const int padding = 8 - highest_set_bit(src[len - 1]);
2064
i64 offset = len * 8 - padding;
2065
2066
u16 state1, state2;
2067
// "The first state (State1) encodes the even indexed symbols, and the
2068
// second (State2) encodes the odd indexes. State1 is initialized first, and
2069
// then State2, and they take turns decoding a single symbol and updating
2070
// their state."
2071
FSE_init_state(dtable, &state1, src, &offset);
2072
FSE_init_state(dtable, &state2, src, &offset);
2073
2074
// Decode until we overflow the stream
2075
// Since we decode in reverse order, overflowing the stream is offset going
2076
// negative
2077
size_t symbols_written = 0;
2078
while (1) {
2079
// "The number of symbols to decode is determined by tracking bitStream
2080
// overflow condition: If updating state after decoding a symbol would
2081
// require more bits than remain in the stream, it is assumed the extra
2082
// bits are 0. Then, the symbols for each of the final states are
2083
// decoded and the process is complete."
2084
IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2085
symbols_written++;
2086
if (offset < 0) {
2087
// There's still a symbol to decode in state2
2088
IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2089
symbols_written++;
2090
break;
2091
}
2092
2093
IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2094
symbols_written++;
2095
if (offset < 0) {
2096
// There's still a symbol to decode in state1
2097
IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2098
symbols_written++;
2099
break;
2100
}
2101
}
2102
2103
return symbols_written;
2104
}
2105
2106
static void FSE_init_dtable(FSE_dtable *const dtable,
2107
const i16 *const norm_freqs, const int num_symbs,
2108
const int accuracy_log) {
2109
if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2110
ERROR("FSE accuracy too large");
2111
}
2112
if (num_symbs > FSE_MAX_SYMBS) {
2113
ERROR("Too many symbols for FSE");
2114
}
2115
2116
dtable->accuracy_log = accuracy_log;
2117
2118
const size_t size = (size_t)1 << accuracy_log;
2119
dtable->symbols = malloc(size * sizeof(u8));
2120
dtable->num_bits = malloc(size * sizeof(u8));
2121
dtable->new_state_base = malloc(size * sizeof(u16));
2122
2123
if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2124
BAD_ALLOC();
2125
}
2126
2127
// Used to determine how many bits need to be read for each state,
2128
// and where the destination range should start
2129
// Needs to be u16 because max value is 2 * max number of symbols,
2130
// which can be larger than a byte can store
2131
u16 state_desc[FSE_MAX_SYMBS];
2132
2133
// "Symbols are scanned in their natural order for "less than 1"
2134
// probabilities. Symbols with this probability are being attributed a
2135
// single cell, starting from the end of the table. These symbols define a
2136
// full state reset, reading Accuracy_Log bits."
2137
int high_threshold = size;
2138
for (int s = 0; s < num_symbs; s++) {
2139
// Scan for low probability symbols to put at the top
2140
if (norm_freqs[s] == -1) {
2141
dtable->symbols[--high_threshold] = s;
2142
state_desc[s] = 1;
2143
}
2144
}
2145
2146
// "All remaining symbols are sorted in their natural order. Starting from
2147
// symbol 0 and table position 0, each symbol gets attributed as many cells
2148
// as its probability. Cell allocation is spread, not linear."
2149
// Place the rest in the table
2150
const u16 step = (size >> 1) + (size >> 3) + 3;
2151
const u16 mask = size - 1;
2152
u16 pos = 0;
2153
for (int s = 0; s < num_symbs; s++) {
2154
if (norm_freqs[s] <= 0) {
2155
continue;
2156
}
2157
2158
state_desc[s] = norm_freqs[s];
2159
2160
for (int i = 0; i < norm_freqs[s]; i++) {
2161
// Give `norm_freqs[s]` states to symbol s
2162
dtable->symbols[pos] = s;
2163
// "A position is skipped if already occupied, typically by a "less
2164
// than 1" probability symbol."
2165
do {
2166
pos = (pos + step) & mask;
2167
} while (pos >=
2168
high_threshold);
2169
// Note: no other collision checking is necessary as `step` is
2170
// coprime to `size`, so the cycle will visit each position exactly
2171
// once
2172
}
2173
}
2174
if (pos != 0) {
2175
CORRUPTION();
2176
}
2177
2178
// Now we can fill baseline and num bits
2179
for (size_t i = 0; i < size; i++) {
2180
u8 symbol = dtable->symbols[i];
2181
u16 next_state_desc = state_desc[symbol]++;
2182
// Fills in the table appropriately, next_state_desc increases by symbol
2183
// over time, decreasing number of bits
2184
dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2185
// Baseline increases until the bit threshold is passed, at which point
2186
// it resets to 0
2187
dtable->new_state_base[i] =
2188
((u16)next_state_desc << dtable->num_bits[i]) - size;
2189
}
2190
}
2191
2192
/// Decode an FSE header as defined in the Zstandard format specification and
2193
/// use the decoded frequencies to initialize a decoding table.
2194
static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2195
const int max_accuracy_log) {
2196
// "An FSE distribution table describes the probabilities of all symbols
2197
// from 0 to the last present one (included) on a normalized scale of 1 <<
2198
// Accuracy_Log .
2199
//
2200
// It's a bitstream which is read forward, in little-endian fashion. It's
2201
// not necessary to know its exact size, since it will be discovered and
2202
// reported by the decoding process.
2203
if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2204
ERROR("FSE accuracy too large");
2205
}
2206
2207
// The bitstream starts by reporting on which scale it operates.
2208
// Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2209
// and match lengths is 9, and for offsets is 8. Higher values are
2210
// considered errors."
2211
const int accuracy_log = 5 + IO_read_bits(in, 4);
2212
if (accuracy_log > max_accuracy_log) {
2213
ERROR("FSE accuracy too large");
2214
}
2215
2216
// "Then follows each symbol value, from 0 to last present one. The number
2217
// of bits used by each field is variable. It depends on :
2218
//
2219
// Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2220
// and presuming 100 probabilities points have already been distributed, the
2221
// decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2222
// Therefore, it must read log2sup(156) == 8 bits.
2223
//
2224
// Value decoded : small values use 1 less bit : example : Presuming values
2225
// from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2226
// in an 8-bits field. They are used this way : first 99 values (hence from
2227
// 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2228
2229
i32 remaining = 1 << accuracy_log;
2230
i16 frequencies[FSE_MAX_SYMBS];
2231
2232
int symb = 0;
2233
while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2234
// Log of the number of possible values we could read
2235
int bits = highest_set_bit(remaining + 1) + 1;
2236
2237
u16 val = IO_read_bits(in, bits);
2238
2239
// Try to mask out the lower bits to see if it qualifies for the "small
2240
// value" threshold
2241
const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2242
const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2243
2244
if ((val & lower_mask) < threshold) {
2245
IO_rewind_bits(in, 1);
2246
val = val & lower_mask;
2247
} else if (val > lower_mask) {
2248
val = val - threshold;
2249
}
2250
2251
// "Probability is obtained from Value decoded by following formula :
2252
// Proba = value - 1"
2253
const i16 proba = (i16)val - 1;
2254
2255
// "It means value 0 becomes negative probability -1. -1 is a special
2256
// probability, which means "less than 1". Its effect on distribution
2257
// table is described in next paragraph. For the purpose of calculating
2258
// cumulated distribution, it counts as one."
2259
remaining -= proba < 0 ? -proba : proba;
2260
2261
frequencies[symb] = proba;
2262
symb++;
2263
2264
// "When a symbol has a probability of zero, it is followed by a 2-bits
2265
// repeat flag. This repeat flag tells how many probabilities of zeroes
2266
// follow the current one. It provides a number ranging from 0 to 3. If
2267
// it is a 3, another 2-bits repeat flag follows, and so on."
2268
if (proba == 0) {
2269
// Read the next two bits to see how many more 0s
2270
int repeat = IO_read_bits(in, 2);
2271
2272
while (1) {
2273
for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2274
frequencies[symb++] = 0;
2275
}
2276
if (repeat == 3) {
2277
repeat = IO_read_bits(in, 2);
2278
} else {
2279
break;
2280
}
2281
}
2282
}
2283
}
2284
IO_align_stream(in);
2285
2286
// "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2287
// is complete. If the last symbol makes cumulated total go above 1 <<
2288
// Accuracy_Log, distribution is considered corrupted."
2289
if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2290
CORRUPTION();
2291
}
2292
2293
// Initialize the decoding table using the determined weights
2294
FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2295
}
2296
2297
static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2298
dtable->symbols = malloc(sizeof(u8));
2299
dtable->num_bits = malloc(sizeof(u8));
2300
dtable->new_state_base = malloc(sizeof(u16));
2301
2302
if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2303
BAD_ALLOC();
2304
}
2305
2306
// This setup will always have a state of 0, always return symbol `symb`,
2307
// and never consume any bits
2308
dtable->symbols[0] = symb;
2309
dtable->num_bits[0] = 0;
2310
dtable->new_state_base[0] = 0;
2311
dtable->accuracy_log = 0;
2312
}
2313
2314
static void FSE_free_dtable(FSE_dtable *const dtable) {
2315
free(dtable->symbols);
2316
free(dtable->num_bits);
2317
free(dtable->new_state_base);
2318
memset(dtable, 0, sizeof(FSE_dtable));
2319
}
2320
/******* END FSE PRIMITIVES ***************************************************/
2321
2322