diff --git a/internal/util/compressor/compressor.go b/internal/util/compressor/compressor.go index dc424bd8c5..91004e845d 100644 --- a/internal/util/compressor/compressor.go +++ b/internal/util/compressor/compressor.go @@ -6,25 +6,29 @@ import ( "github.com/klauspost/compress/zstd" ) -type CompressType int16 +type CompressType string const ( - Zstd CompressType = iota + 1 + CompressTypeZstd CompressType = "zstd" - DefaultCompressAlgorithm CompressType = Zstd + DefaultCompressAlgorithm CompressType = CompressTypeZstd ) type Compressor interface { Compress(in io.Reader) error + CompressBytes(src, dst []byte) []byte ResetWriter(out io.Writer) // Flush() error Close() error + GetType() CompressType } type Decompressor interface { Decompress(out io.Writer) error + DecompressBytes(src, dst []byte) ([]byte, error) ResetReader(in io.Reader) Close() + GetType() CompressType } var ( @@ -36,6 +40,7 @@ type ZstdCompressor struct { encoder *zstd.Encoder } +// For compressing small blocks, pass nil to the `out` parameter func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, error) { encoder, err := zstd.NewWriter(out, opts...) if err != nil { @@ -45,6 +50,7 @@ func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, er return &ZstdCompressor{encoder}, nil } +// Use case: compress stream // Call Close() to make sure the data is flushed to the underlying writer // after the last Compress() call func (c *ZstdCompressor) Compress(in io.Reader) error { @@ -57,6 +63,14 @@ func (c *ZstdCompressor) Compress(in io.Reader) error { return nil } +// Use case: compress small blocks +// This compresses the src bytes and appends it to the dst bytes, then return the result +// This can be called concurrently +func (c *ZstdCompressor) CompressBytes(src []byte, dst []byte) []byte { + return c.encoder.EncodeAll(src, dst) +} + +// Reset the writer to reuse the compressor func (c *ZstdCompressor) ResetWriter(out io.Writer) { c.encoder.Reset(out) } @@ -76,10 +90,15 @@ func (c *ZstdCompressor) Close() error { return c.encoder.Close() } +func (c *ZstdCompressor) GetType() CompressType { + return CompressTypeZstd +} + type ZstdDecompressor struct { decoder *zstd.Decoder } +// For compressing small blocks, pass nil to the `in` parameter func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, error) { decoder, err := zstd.NewReader(in, opts...) if err != nil { @@ -89,6 +108,8 @@ func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, return &ZstdDecompressor{decoder}, nil } +// Usa case: decompress stream +// Write the decompressed data into `out` func (dec *ZstdDecompressor) Decompress(out io.Writer) error { _, err := io.Copy(out, dec.decoder) if err != nil { @@ -99,6 +120,14 @@ func (dec *ZstdDecompressor) Decompress(out io.Writer) error { return nil } +// Use case: decompress small blocks +// This decompresses the src bytes and appends it to the dst bytes, then return the result +// This can be called concurrently +func (dec *ZstdDecompressor) DecompressBytes(src []byte, dst []byte) ([]byte, error) { + return dec.decoder.DecodeAll(src, dst) +} + +// Reset the reader to reuse the decompressor func (dec *ZstdDecompressor) ResetReader(in io.Reader) { dec.decoder.Reset(in) } @@ -108,7 +137,15 @@ func (dec *ZstdDecompressor) Close() { dec.decoder.Close() } +func (dec *ZstdDecompressor) GetType() CompressType { + return CompressTypeZstd +} + // Global methods + +// Usa case: compress stream, large object only once +// This can be called concurrently +// Try ZstdCompressor for better efficiency if you need compress mutiple streams one by one func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error { enc, err := NewZstdCompressor(out, opts...) if err != nil { @@ -123,6 +160,9 @@ func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error { return enc.Close() } +// Use case: decompress stream, large object only once +// This can be called concurrently +// Try ZstdDecompressor for better efficiency if you need decompress mutiple streams one by one func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error { dec, err := NewZstdDecompressor(in, opts...) if err != nil { @@ -136,3 +176,20 @@ func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error { return nil } + +var ( + globalZstdCompressor, _ = zstd.NewWriter(nil) + globalZstdDecompressor, _ = zstd.NewReader(nil) +) + +// Use case: compress small blocks +// This can be called concurrently +func ZstdCompressBytes(src, dst []byte) []byte { + return globalZstdCompressor.EncodeAll(src, dst) +} + +// Use case: decompress small blocks +// This can be called concurrently +func ZstdDecompressBytes(src, dst []byte) ([]byte, error) { + return globalZstdDecompressor.DecodeAll(src, dst) +} diff --git a/internal/util/compressor/compressor_test.go b/internal/util/compressor/compressor_test.go index edc66ab404..cd4386e629 100644 --- a/internal/util/compressor/compressor_test.go +++ b/internal/util/compressor/compressor_test.go @@ -30,13 +30,24 @@ func TestZstdCompress(t *testing.T) { enc.ResetWriter(compressed) testCompress(t, data+": reuse", enc, compressed, origin) + + // Test type + dec, err := NewZstdDecompressor(nil) + assert.NoError(t, err) + assert.Equal(t, enc.GetType(), CompressTypeZstd) + assert.Equal(t, dec.GetType(), CompressTypeZstd) } func testCompress(t *testing.T, data string, enc Compressor, compressed, origin *bytes.Buffer) { + compressedBytes := make([]byte, 0) + originBytes := make([]byte, 0) + err := enc.Compress(strings.NewReader(data)) assert.NoError(t, err) err = enc.Close() assert.NoError(t, err) + compressedBytes = enc.CompressBytes([]byte(data), compressedBytes) + assert.Equal(t, compressed.Bytes(), compressedBytes) // Close() method should satisfy idempotence err = enc.Close() @@ -46,6 +57,9 @@ func testCompress(t *testing.T, data string, enc Compressor, compressed, origin assert.NoError(t, err) err = dec.Decompress(origin) assert.NoError(t, err) + originBytes, err = dec.DecompressBytes(compressedBytes, originBytes) + assert.NoError(t, err) + assert.Equal(t, origin.Bytes(), originBytes) assert.Equal(t, data, origin.String()) @@ -70,21 +84,30 @@ func testCompress(t *testing.T, data string, enc Compressor, compressed, origin func TestGlobalMethods(t *testing.T) { data := "hello zstd algorithm!" compressed := new(bytes.Buffer) + compressedBytes := make([]byte, 0) origin := new(bytes.Buffer) + originBytes := make([]byte, 0) err := ZstdCompress(strings.NewReader(data), compressed) assert.NoError(t, err) + compressedBytes = ZstdCompressBytes([]byte(data), compressedBytes) + assert.Equal(t, compressed.Bytes(), compressedBytes) + err = ZstdDecompress(compressed, origin) assert.NoError(t, err) + originBytes, err = ZstdDecompressBytes(compressedBytes, originBytes) + assert.NoError(t, err) + assert.Equal(t, origin.Bytes(), originBytes) + assert.Equal(t, data, origin.String()) // Mock error reader/writer errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF} errWriter := &mock.ErrWriter{Err: io.ErrShortWrite} - compressedBytes := compressed.Bytes() + compressedBytes = compressed.Bytes() compressed = bytes.NewBuffer(compressedBytes) // The old compressed buffer is closed err = ZstdCompress(errReader, compressed) assert.ErrorIs(t, err, errReader.Err) @@ -116,16 +139,23 @@ func TestCurrencyGlobalMethods(t *testing.T) { go func(idx int) { defer wg.Done() - buf := new(bytes.Buffer) + compressed := new(bytes.Buffer) + compressedBytes := make([]byte, 0) origin := new(bytes.Buffer) + originBytes := make([]byte, 0) data := prefix + fmt.Sprintf(": %d-th goroutine", idx) - err := ZstdCompress(strings.NewReader(data), buf, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx))) + err := ZstdCompress(strings.NewReader(data), compressed, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx))) assert.NoError(t, err) + compressedBytes = ZstdCompressBytes([]byte(data), compressedBytes) + assert.Equal(t, compressed.Bytes(), compressedBytes) - err = ZstdDecompress(buf, origin) + err = ZstdDecompress(compressed, origin) assert.NoError(t, err) + originBytes, err = ZstdDecompressBytes(compressedBytes, originBytes) + assert.NoError(t, err) + assert.Equal(t, origin.Bytes(), originBytes) assert.Equal(t, data, origin.String()) }(i)