#pragma once
#include <cassert>
#include <cstdint>
#include <fstream>
#include <sstream>
#include <zlib.h>
#include <memory>
#include <iostream>
#include "strict_fstream.hpp"
#if defined(__GNUC__) && !defined(__clang__)
#if (__GNUC__ > 5) || (__GNUC__ == 5 && __GNUC_MINOR__>0)
#define CAN_MOVE_IOSTREAM
#endif
#else
#define CAN_MOVE_IOSTREAM
#endif
namespace zstr
{
static const std::size_t default_buff_size = static_cast<std::size_t>(1 << 20);
class Exception
: public std::ios_base::failure
{
public:
static std::string error_to_message(z_stream * zstrm_p, int ret)
{
std::string msg = "zlib: ";
switch (ret)
{
case Z_STREAM_ERROR:
msg += "Z_STREAM_ERROR: ";
break;
case Z_DATA_ERROR:
msg += "Z_DATA_ERROR: ";
break;
case Z_MEM_ERROR:
msg += "Z_MEM_ERROR: ";
break;
case Z_VERSION_ERROR:
msg += "Z_VERSION_ERROR: ";
break;
case Z_BUF_ERROR:
msg += "Z_BUF_ERROR: ";
break;
default:
std::ostringstream oss;
oss << ret;
msg += "[" + oss.str() + "]: ";
break;
}
if (zstrm_p->msg) {
msg += zstrm_p->msg;
}
msg += " ("
"next_in: " +
std::to_string(uintptr_t(zstrm_p->next_in)) +
", avail_in: " +
std::to_string(uintptr_t(zstrm_p->avail_in)) +
", next_out: " +
std::to_string(uintptr_t(zstrm_p->next_out)) +
", avail_out: " +
std::to_string(uintptr_t(zstrm_p->avail_out)) +
")";
return msg;
}
Exception(z_stream * zstrm_p, int ret)
: std::ios_base::failure(error_to_message(zstrm_p, ret))
{
}
};
namespace detail
{
class z_stream_wrapper
: public z_stream
{
public:
z_stream_wrapper(bool _is_input, int _level, int _window_bits)
: is_input(_is_input)
{
this->zalloc = nullptr;
this->zfree = nullptr;
this->opaque = nullptr;
int ret;
if (is_input)
{
this->avail_in = 0;
this->next_in = nullptr;
ret = inflateInit2(this, _window_bits ? _window_bits : 15+32);
}
else
{
ret = deflateInit2(this, _level, Z_DEFLATED, _window_bits ? _window_bits : 15+16, 8, Z_DEFAULT_STRATEGY);
}
if (ret != Z_OK) throw Exception(this, ret);
}
~z_stream_wrapper()
{
if (is_input)
{
inflateEnd(this);
}
else
{
deflateEnd(this);
}
}
private:
bool is_input;
};
}
class istreambuf
: public std::streambuf
{
public:
istreambuf(std::streambuf * _sbuf_p,
std::size_t _buff_size = default_buff_size, bool _auto_detect = true, int _window_bits = 0)
: sbuf_p(_sbuf_p),
in_buff(),
in_buff_start(nullptr),
in_buff_end(nullptr),
out_buff(),
zstrm_p(nullptr),
buff_size(_buff_size),
auto_detect(_auto_detect),
auto_detect_run(false),
is_text(false),
window_bits(_window_bits)
{
assert(sbuf_p);
in_buff = std::unique_ptr<char[]>(new char[buff_size]);
in_buff_start = in_buff.get();
in_buff_end = in_buff.get();
out_buff = std::unique_ptr<char[]>(new char[buff_size]);
setg(out_buff.get(), out_buff.get(), out_buff.get());
}
istreambuf(const istreambuf &) = delete;
istreambuf & operator = (const istreambuf &) = delete;
pos_type seekoff(off_type off, std::ios_base::seekdir dir,
std::ios_base::openmode which) override
{
if (off != 0 || dir != std::ios_base::cur) {
return std::streambuf::seekoff(off, dir, which);
}
if (!zstrm_p) {
return 0;
}
return static_cast<long int>(zstrm_p->total_out - static_cast<uLong>(in_avail()));
}
std::streambuf::int_type underflow() override
{
if (this->gptr() == this->egptr())
{
char * out_buff_free_start = out_buff.get();
int tries = 0;
do
{
if (++tries > 1000) {
throw std::ios_base::failure("Failed to fill buffer after 1000 tries");
}
if (in_buff_start == in_buff_end)
{
in_buff_start = in_buff.get();
std::streamsize sz = sbuf_p->sgetn(in_buff.get(), static_cast<std::streamsize>(buff_size));
in_buff_end = in_buff_start + sz;
if (in_buff_end == in_buff_start) break;
}
if (auto_detect && ! auto_detect_run)
{
auto_detect_run = true;
unsigned char b0 = *reinterpret_cast< unsigned char * >(in_buff_start);
unsigned char b1 = *reinterpret_cast< unsigned char * >(in_buff_start + 1);
is_text = ! (in_buff_start + 2 <= in_buff_end
&& ((b0 == 0x1F && b1 == 0x8B)
|| (b0 == 0x78 && (b1 == 0x01
|| b1 == 0x9C
|| b1 == 0xDA))));
}
if (is_text)
{
assert(in_buff_start == in_buff.get());
std::swap(in_buff, out_buff);
out_buff_free_start = in_buff_end;
in_buff_start = in_buff.get();
in_buff_end = in_buff.get();
}
else
{
if (! zstrm_p) zstrm_p = std::unique_ptr<detail::z_stream_wrapper>(new detail::z_stream_wrapper(true, Z_DEFAULT_COMPRESSION, window_bits));
zstrm_p->next_in = reinterpret_cast< decltype(zstrm_p->next_in) >(in_buff_start);
zstrm_p->avail_in = uint32_t(in_buff_end - in_buff_start);
zstrm_p->next_out = reinterpret_cast< decltype(zstrm_p->next_out) >(out_buff_free_start);
zstrm_p->avail_out = uint32_t((out_buff.get() + buff_size) - out_buff_free_start);
int ret = inflate(zstrm_p.get(), Z_NO_FLUSH);
if (ret != Z_OK && ret != Z_STREAM_END) throw Exception(zstrm_p.get(), ret);
in_buff_start = reinterpret_cast< decltype(in_buff_start) >(zstrm_p->next_in);
in_buff_end = in_buff_start + zstrm_p->avail_in;
out_buff_free_start = reinterpret_cast< decltype(out_buff_free_start) >(zstrm_p->next_out);
assert(out_buff_free_start + zstrm_p->avail_out == out_buff.get() + buff_size);
if (ret == Z_STREAM_END) {
zstrm_p.reset();
}
}
} while (out_buff_free_start == out_buff.get());
this->setg(out_buff.get(), out_buff.get(), out_buff_free_start);
}
return this->gptr() == this->egptr()
? traits_type::eof()
: traits_type::to_int_type(*this->gptr());
}
private:
std::streambuf * sbuf_p;
std::unique_ptr<char[]> in_buff;
char * in_buff_start;
char * in_buff_end;
std::unique_ptr<char[]> out_buff;
std::unique_ptr<detail::z_stream_wrapper> zstrm_p;
std::size_t buff_size;
bool auto_detect;
bool auto_detect_run;
bool is_text;
int window_bits;
};
class ostreambuf
: public std::streambuf
{
public:
ostreambuf(std::streambuf * _sbuf_p,
std::size_t _buff_size = default_buff_size, int _level = Z_DEFAULT_COMPRESSION, int _window_bits = 0)
: sbuf_p(_sbuf_p),
in_buff(),
out_buff(),
zstrm_p(new detail::z_stream_wrapper(false, _level, _window_bits)),
buff_size(_buff_size)
{
assert(sbuf_p);
in_buff = std::unique_ptr<char[]>(new char[buff_size]);
out_buff = std::unique_ptr<char[]>(new char[buff_size]);
setp(in_buff.get(), in_buff.get() + buff_size);
}
ostreambuf(const ostreambuf &) = delete;
ostreambuf & operator = (const ostreambuf &) = delete;
int deflate_loop(int flush)
{
while (true)
{
zstrm_p->next_out = reinterpret_cast< decltype(zstrm_p->next_out) >(out_buff.get());
zstrm_p->avail_out = uint32_t(buff_size);
int ret = deflate(zstrm_p.get(), flush);
if (ret != Z_OK && ret != Z_STREAM_END && ret != Z_BUF_ERROR) {
failed = true;
throw Exception(zstrm_p.get(), ret);
}
std::streamsize sz = sbuf_p->sputn(out_buff.get(), reinterpret_cast< decltype(out_buff.get()) >(zstrm_p->next_out) - out_buff.get());
if (sz != reinterpret_cast< decltype(out_buff.get()) >(zstrm_p->next_out) - out_buff.get())
{
return -1;
}
if (ret == Z_STREAM_END || ret == Z_BUF_ERROR || sz == 0)
{
break;
}
}
return 0;
}
virtual ~ostreambuf()
{
if (!failed) try {
sync();
} catch (...) {}
}
std::streambuf::int_type overflow(std::streambuf::int_type c = traits_type::eof()) override
{
zstrm_p->next_in = reinterpret_cast< decltype(zstrm_p->next_in) >(pbase());
zstrm_p->avail_in = uint32_t(pptr() - pbase());
while (zstrm_p->avail_in > 0)
{
int r = deflate_loop(Z_NO_FLUSH);
if (r != 0)
{
setp(nullptr, nullptr);
return traits_type::eof();
}
}
setp(in_buff.get(), in_buff.get() + buff_size);
return traits_type::eq_int_type(c, traits_type::eof()) ? traits_type::eof() : sputc(char_type(c));
}
int sync() override
{
overflow();
if (! pptr()) return -1;
zstrm_p->next_in = nullptr;
zstrm_p->avail_in = 0;
if (deflate_loop(Z_FINISH) != 0) return -1;
deflateReset(zstrm_p.get());
return 0;
}
private:
std::streambuf * sbuf_p = nullptr;
std::unique_ptr<char[]> in_buff;
std::unique_ptr<char[]> out_buff;
std::unique_ptr<detail::z_stream_wrapper> zstrm_p;
std::size_t buff_size;
bool failed = false;
};
class istream
: public std::istream
{
public:
istream(std::istream & is,
std::size_t _buff_size = default_buff_size, bool _auto_detect = true, int _window_bits = 0)
: std::istream(new istreambuf(is.rdbuf(), _buff_size, _auto_detect, _window_bits))
{
exceptions(std::ios_base::badbit);
}
explicit istream(std::streambuf * sbuf_p)
: std::istream(new istreambuf(sbuf_p))
{
exceptions(std::ios_base::badbit);
}
virtual ~istream()
{
delete rdbuf();
}
};
class ostream
: public std::ostream
{
public:
ostream(std::ostream & os,
std::size_t _buff_size = default_buff_size, int _level = Z_DEFAULT_COMPRESSION, int _window_bits = 0)
: std::ostream(new ostreambuf(os.rdbuf(), _buff_size, _level, _window_bits))
{
exceptions(std::ios_base::badbit);
}
explicit ostream(std::streambuf * sbuf_p)
: std::ostream(new ostreambuf(sbuf_p))
{
exceptions(std::ios_base::badbit);
}
virtual ~ostream()
{
delete rdbuf();
}
};
namespace detail
{
template < typename FStream_Type >
struct strict_fstream_holder
{
strict_fstream_holder(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in)
: _fs(filename, mode)
{}
strict_fstream_holder() = default;
FStream_Type _fs {};
};
}
class ifstream
: private detail::strict_fstream_holder< strict_fstream::ifstream >,
public std::istream
{
public:
explicit ifstream(const std::string filename, std::ios_base::openmode mode = std::ios_base::in, size_t buff_size = default_buff_size)
: detail::strict_fstream_holder< strict_fstream::ifstream >(filename, mode),
std::istream(new istreambuf(_fs.rdbuf(), buff_size))
{
exceptions(std::ios_base::badbit);
}
explicit ifstream(): detail::strict_fstream_holder< strict_fstream::ifstream >(), std::istream(new istreambuf(_fs.rdbuf())){}
void close() {
_fs.close();
}
#ifdef CAN_MOVE_IOSTREAM
void open(const std::string filename, std::ios_base::openmode mode = std::ios_base::in) {
_fs.open(filename, mode);
std::istream::operator=(std::istream(new istreambuf(_fs.rdbuf())));
}
#endif
bool is_open() const {
return _fs.is_open();
}
virtual ~ifstream()
{
if (_fs.is_open()) close();
if (rdbuf()) delete rdbuf();
}
std::streampos compressed_tellg()
{
return _fs.tellg();
}
};
class ofstream
: private detail::strict_fstream_holder< strict_fstream::ofstream >,
public std::ostream
{
public:
explicit ofstream(const std::string filename, std::ios_base::openmode mode = std::ios_base::out,
int level = Z_DEFAULT_COMPRESSION, size_t buff_size = default_buff_size)
: detail::strict_fstream_holder< strict_fstream::ofstream >(filename, mode | std::ios_base::binary),
std::ostream(new ostreambuf(_fs.rdbuf(), buff_size, level))
{
exceptions(std::ios_base::badbit);
}
explicit ofstream(): detail::strict_fstream_holder< strict_fstream::ofstream >(), std::ostream(new ostreambuf(_fs.rdbuf())){}
void close() {
std::ostream::flush();
_fs.close();
}
#ifdef CAN_MOVE_IOSTREAM
void open(const std::string filename, std::ios_base::openmode mode = std::ios_base::out, int level = Z_DEFAULT_COMPRESSION) {
flush();
_fs.open(filename, mode | std::ios_base::binary);
std::ostream::operator=(std::ostream(new ostreambuf(_fs.rdbuf(), default_buff_size, level)));
}
#endif
bool is_open() const {
return _fs.is_open();
}
ofstream& flush() {
std::ostream::flush();
_fs.flush();
return *this;
}
virtual ~ofstream()
{
if (_fs.is_open()) close();
if (rdbuf()) delete rdbuf();
}
std::streampos compressed_tellp()
{
return _fs.tellp();
}
};
}