blob: 1fed444d79c3b3043a1417d7597f1fcf2f2eec33 [file] [log] [blame]
Yabin Cuice2385d2024-07-10 15:58:38 -07001/*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ZstdUtil.h"
18
19#include <android-base/logging.h>
20#include <zstd.h>
21
22namespace simpleperf {
23
24namespace {
25
26class CompressionOutBuffer {
27 public:
28 CompressionOutBuffer(size_t min_free_size)
29 : min_free_size_(min_free_size), buffer_(min_free_size) {}
30
31 const char* DataStart() const { return buffer_.data() + data_pos_; }
32 size_t DataSize() const { return data_size_; }
33 char* FreeStart() { return buffer_.data() + data_pos_ + data_size_; }
34 size_t FreeSize() const { return buffer_.size() - data_pos_ - data_size_; }
35
36 void PrepareForInput() {
37 if (data_pos_ > 0) {
38 if (data_size_ == 0) {
39 data_pos_ = 0;
40 } else {
41 memmove(buffer_.data(), buffer_.data() + data_pos_, data_size_);
42 data_pos_ = 0;
43 }
44 }
45 if (FreeSize() < min_free_size_) {
46 buffer_.resize(buffer_.size() * 2);
47 }
48 }
49
50 void ProduceData(size_t size) {
51 data_size_ += size;
52 CHECK_LE(data_pos_ + data_size_, buffer_.size());
53 }
54
55 void ConsumeData(size_t size) {
56 CHECK_LE(size, data_size_);
57 data_pos_ += size;
58 data_size_ -= size;
59 }
60
61 private:
62 const size_t min_free_size_;
63 std::vector<char> buffer_;
64 size_t data_pos_ = 0;
65 size_t data_size_ = 0;
66};
67
68using ZSTD_CCtx_pointer = std::unique_ptr<ZSTD_CCtx, decltype(&ZSTD_freeCCtx)>;
69
70class ZstdCompressor : public Compressor {
71 public:
72 ZstdCompressor(ZSTD_CCtx_pointer cctx)
73 : cctx_(std::move(cctx)), out_buffer_(ZSTD_CStreamOutSize()) {}
74
75 bool AddInputData(const char* data, size_t size) override {
76 ZSTD_inBuffer input = {data, size, 0};
77 while (input.pos < input.size) {
78 out_buffer_.PrepareForInput();
79 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
80 size_t remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_continue);
81 if (ZSTD_isError(remaining)) {
82 LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
83 return false;
84 }
85 out_buffer_.ProduceData(output.pos);
Yabin Cui09136222024-07-16 10:30:21 -070086 total_output_size_ += output.pos;
Yabin Cuice2385d2024-07-10 15:58:38 -070087 }
88 total_input_size_ += size;
89 return true;
90 }
91
92 bool FlushOutputData() override {
Yabin Cui99f599a2024-07-16 15:21:46 -070093 if (flushed_input_size_ == total_input_size_) {
94 return true;
95 }
96 flushed_input_size_ = total_input_size_;
Yabin Cuice2385d2024-07-10 15:58:38 -070097 ZSTD_inBuffer input = {nullptr, 0, 0};
98 size_t remaining = 0;
99 do {
100 out_buffer_.PrepareForInput();
101 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
102 remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_end);
103 if (ZSTD_isError(remaining)) {
104 LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
105 return false;
106 }
107 out_buffer_.ProduceData(output.pos);
108 total_output_size_ += output.pos;
109 } while (remaining != 0);
110 return true;
111 }
112
113 std::string_view GetOutputData() override {
114 return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
115 }
116
117 void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
118
119 private:
120 ZSTD_CCtx_pointer cctx_;
121 CompressionOutBuffer out_buffer_;
Yabin Cui99f599a2024-07-16 15:21:46 -0700122 uint64_t flushed_input_size_ = 0;
Yabin Cuice2385d2024-07-10 15:58:38 -0700123};
124
125using ZSTD_DCtx_pointer = std::unique_ptr<ZSTD_DCtx, decltype(&ZSTD_freeDCtx)>;
126
127class ZstdDecompressor : public Decompressor {
128 public:
129 ZstdDecompressor(ZSTD_DCtx_pointer dctx)
130 : dctx_(std::move(dctx)), out_buffer_(ZSTD_DStreamOutSize()) {}
131
132 bool AddInputData(const char* data, size_t size) override {
133 ZSTD_inBuffer input = {data, size, 0};
134 while (input.pos < input.size) {
135 out_buffer_.PrepareForInput();
136 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
137 size_t remaining = ZSTD_decompressStream(dctx_.get(), &output, &input);
138 if (ZSTD_isError(remaining)) {
139 LOG(ERROR) << "ZSTD_decompressStream() failed: " << ZSTD_getErrorName(remaining);
140 return false;
141 }
142 out_buffer_.ProduceData(output.pos);
143 }
144 return true;
145 }
146
147 std::string_view GetOutputData() override {
148 return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
149 }
150
151 void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
152
153 private:
154 ZSTD_DCtx_pointer dctx_;
155 CompressionOutBuffer out_buffer_;
156};
157
158} // namespace
159
160Compressor::~Compressor() {}
161
162Decompressor::~Decompressor() {}
163
164std::unique_ptr<Compressor> CreateZstdCompressor(size_t compression_level) {
165 ZSTD_CCtx_pointer cctx(ZSTD_createCCtx(), ZSTD_freeCCtx);
166 if (!cctx) {
167 LOG(ERROR) << "ZSTD_createCCtx() failed";
168 return nullptr;
169 }
170 size_t err = ZSTD_CCtx_setParameter(cctx.get(), ZSTD_c_compressionLevel, compression_level);
171 if (ZSTD_isError(err)) {
172 LOG(ERROR) << "failed to set compression level: " << ZSTD_getErrorName(err);
173 return nullptr;
174 }
175 return std::unique_ptr<Compressor>(new ZstdCompressor(std::move(cctx)));
176}
177
178std::unique_ptr<Decompressor> CreateZstdDecompressor() {
179 ZSTD_DCtx_pointer dctx(ZSTD_createDCtx(), ZSTD_freeDCtx);
180 if (!dctx) {
181 LOG(ERROR) << "ZSTD_createDCtx() failed";
182 return nullptr;
183 }
184 return std::unique_ptr<Decompressor>(new ZstdDecompressor(std::move(dctx)));
185}
186
187} // namespace simpleperf