smb: client: compress: LZ77 code improvements cleanup
authorEnzo Matsumiya <ematsumiya@suse.de>
Fri, 6 Sep 2024 17:41:50 +0000 (14:41 -0300)
committerSteve French <stfrench@microsoft.com>
Sun, 15 Sep 2024 15:42:45 +0000 (10:42 -0500)
- Check data compressibility with some heuristics (copied from
  btrfs):
  - should_compress() final decision is is_compressible(data)

- Cleanup compress/lz77.h leaving only lz77_compress() exposed:
  - Move parts to compress/lz77.c, while removing the rest of it
    because they were either unused, used only once, were
    implemented wrong (thanks to David Howells for the help)

- Updated the compression parameters (still compatible with
  Windows implementation) trading off ~20% compression ratio
  for ~40% performance:
  - min match len: 3 -> 4
  - max distance: 8KiB -> 1KiB
  - hash table type: u32 * -> u64 *

Known bugs:
This implementation currently works fine in general, but breaks with
some payloads used during testing.  Investigation ongoing, to be
fixed in a next commit.

Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
Co-developed-by: David Howells <dhowells@redhat.com>
Signed-off-by: David Howells <dhowells@redhat.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/smb/client/compress.c
fs/smb/client/compress.h
fs/smb/client/compress/lz77.c
fs/smb/client/compress/lz77.h

index 4efbccbd40bfb352c557d58716c852848b0af69e..65d3d219e8bc54b672750caacee4e475b9c03ab5 100644 (file)
@@ -15,6 +15,7 @@
 #include <linux/slab.h>
 #include <linux/kernel.h>
 #include <linux/uio.h>
+#include <linux/sort.h>
 
 #include "cifsglob.h"
 #include "../common/smb2pdu.h"
 #include "compress/lz77.h"
 #include "compress.h"
 
-int smb_compress(void *buf, const void *data, size_t *len)
+/*
+ * The heuristic_*() functions below try to determine data compressibility.
+ *
+ * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
+ * unused parts.
+ *
+ * Read that file for better and more detailed explanation of the calculations.
+ *
+ * The algorithms are ran in a collected sample of the input (uncompressed) data.
+ * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
+ *
+ * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
+ * to "need more analysis" (likely uncompressible).
+ */
+
+struct bucket {
+       unsigned int count;
+};
+
+/**
+ * calc_shannon_entropy() - Compute Shannon entropy of the sampled data.
+ * @bkt:       Bytes counts of the sample.
+ * @slen:      Size of the sample.
+ *
+ * Return: true if the level (percentage of number of bits that would be required to
+ *        compress the data) is below the minimum threshold.
+ *
+ * Note:
+ * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a
+ * possibility of compression, but compressing, or even further analysing, it would waste so much
+ * resources that it's simply not worth it.
+ *
+ * Also Shannon entropy is the last computed heuristic; if we got this far and ended up
+ * with uncertainty, just stay on the safe side and call it uncompressible.
+ */
+static bool calc_shannon_entropy(struct bucket *bkt, size_t slen)
+{
+       const size_t threshold = 65, max_entropy = 8 * ilog2(16);
+       size_t i, p, p2, len, sum = 0;
+
+#define pow4(n) (n * n * n * n)
+       len = ilog2(pow4(slen));
+
+       for (i = 0; i < 256 && bkt[i].count > 0; i++) {
+               p = bkt[i].count;
+               p2 = ilog2(pow4(p));
+               sum += p * (len - p2);
+       }
+
+       sum /= slen;
+
+       return ((sum * 100 / max_entropy) <= threshold);
+}
+
+/**
+ * calc_byte_distribution() - Compute byte distribution on the sampled data.
+ * @bkt:       Byte counts of the sample.
+ * @slen:      Size of the sample.
+ *
+ * Return:
+ * 1:  High probability (normal (Gaussian) distribution) of the data being compressible.
+ * 0:  A "hard no" for compression -- either a computed uniform distribution of the bytes (e.g.
+ *     random or encrypted data), or calc_shannon_entropy() returned false (see above).
+ * 2:  When computed byte distribution resulted in "low > n < high" grounds.
+ *     calc_shannon_entropy() should be used for a final decision.
+ */
+static int calc_byte_distribution(struct bucket *bkt, size_t slen)
 {
-       struct smb2_compression_hdr *hdr;
-       size_t buf_len, data_len;
+       const size_t low = 64, high = 200, threshold = slen * 90 / 100;
+       size_t sum = 0;
+       int i;
+
+       for (i = 0; i < low; i++)
+               sum += bkt[i].count;
+
+       if (sum > threshold)
+               return i;
+
+       for (; i < high && bkt[i].count > 0; i++) {
+               sum += bkt[i].count;
+               if (sum > threshold)
+                       break;
+       }
+
+       if (i <= low)
+               return 1;
+
+       if (i >= high)
+               return 0;
+
+       return 2;
+}
+
+static bool check_ascii_bytes(const struct bucket *bkt)
+{
+       const size_t threshold = 64;
+       size_t count = 0;
+       int i;
+
+       for (i = 0; i < threshold; i++)
+               if (bkt[i].count > 0)
+                       count++;
+
+       for (; i < 256; i++) {
+               if (bkt[i].count > 0) {
+                       count++;
+                       if (count > threshold)
+                               break;
+               }
+       }
+
+       return (count < threshold);
+}
+
+static bool check_repeated_data(const u8 *sample, size_t len)
+{
+       size_t s = len / 2;
+
+       return (!memcmp(&sample[0], &sample[s], s));
+}
+
+static int cmp_bkt(const void *_a, const void *_b)
+{
+       const struct bucket *a = _a, *b = _b;
+
+       /* Reverse sort. */
+       if (a->count > b->count)
+               return -1;
+
+       return 1;
+}
+
+/*
+ * TODO:
+ * Support other iter types, if required.
+ * Only ITER_XARRAY is supported for now.
+ */
+static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample)
+{
+       struct folio *folios[16], *folio;
+       unsigned int nr, i, j, npages;
+       loff_t start = iter->xarray_start + iter->iov_offset;
+       pgoff_t last, index = start / PAGE_SIZE;
+       size_t len, off, foff;
+       ssize_t ret = 0;
+       void *p;
+       int s = 0;
+
+       last = (start + max - 1) / PAGE_SIZE;
+       do {
+               nr = xa_extract(iter->xarray, (void **)folios, index, last, ARRAY_SIZE(folios),
+                               XA_PRESENT);
+               if (nr == 0)
+                       return -EIO;
+
+               for (i = 0; i < nr; i++) {
+                       folio = folios[i];
+                       npages = folio_nr_pages(folio);
+                       foff = start - folio_pos(folio);
+                       off = foff % PAGE_SIZE;
+
+                       for (j = foff / PAGE_SIZE; j < npages; j++) {
+                               size_t len2;
+
+                               len = min_t(size_t, max, PAGE_SIZE - off);
+                               len2 = min_t(size_t, len, SZ_2K);
+
+                               p = kmap_local_page(folio_page(folio, j));
+                               memcpy(&sample[s], p, len2);
+                               kunmap_local(p);
+
+                               if (ret < 0)
+                                       return ret;
+
+                               s += len2;
+
+                               if (len2 < SZ_2K || s >= max - SZ_2K)
+                                       return s;
+
+                               max -= len;
+                               if (max <= 0)
+                                       return s;
+
+                               start += len;
+                               off = 0;
+                               index++;
+                       }
+               }
+       } while (nr == ARRAY_SIZE(folios));
+
+       return s;
+}
+
+/**
+ * is_compressible() - Determines if a chunk of data is compressible.
+ * @data: Iterator containing uncompressed data.
+ *
+ * Return:
+ * 0:          @data is not compressible
+ * 1:          @data is compressible
+ * -ENOMEM:    failed to allocate memory for sample buffer
+ *
+ * Tests shows that this function is quite reliable in predicting data compressibility,
+ * matching close to 1:1 with the behaviour of LZ77 compression success and failures.
+ */
+static int is_compressible(const struct iov_iter *data)
+{
+       const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M;
+       struct bucket *bkt;
+       int i = 0, ret = 0;
+       size_t len;
+       u8 *sample;
+
+       len = iov_iter_count(data);
+       if (len < read_size)
+               return 0;
+
+       if (len - read_size > max)
+               len = max;
+
+       sample = kvzalloc(len, GFP_KERNEL);
+       if (!sample)
+               return -ENOMEM;
+
+       /* Sample 2K bytes per page of the uncompressed data. */
+       ret = collect_sample(data, len, sample);
+       if (ret < 0)
+               goto out;
+
+       len = ret;
+       ret = 1;
+
+       if (check_repeated_data(sample, len))
+               goto out;
+
+       bkt = kcalloc(bkt_size, sizeof(*bkt), GFP_KERNEL);
+       if (!bkt) {
+               kvfree(sample);
+               return -ENOMEM;
+       }
+
+       for (i = 0; i < len; i++)
+               bkt[sample[i]].count++;
+
+       if (check_ascii_bytes(bkt))
+               goto out;
+
+       /* Sort in descending order */
+       sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL);
+
+       ret = calc_byte_distribution(bkt, len);
+       if (ret != 2)
+               goto out;
+
+       ret = calc_shannon_entropy(bkt, len);
+out:
+       kvfree(sample);
+       kfree(bkt);
+
+       WARN(ret < 0, "%s: ret=%d\n", __func__, ret);
+
+       return !!ret;
+}
+
+bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
+{
+       const struct smb2_hdr *shdr = rq->rq_iov->iov_base;
+
+       if (unlikely(!tcon || !tcon->ses || !tcon->ses->server))
+               return false;
+
+       if (!tcon->ses->server->compression.enabled)
+               return false;
+
+       if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
+               return false;
+
+       if (shdr->Command == SMB2_WRITE) {
+               const struct smb2_write_req *wreq = rq->rq_iov->iov_base;
+
+               if (wreq->Length < SMB_COMPRESS_MIN_LEN)
+                       return false;
+
+               return is_compressible(&rq->rq_iter);
+       }
+
+       return (shdr->Command == SMB2_READ);
+}
+
+int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn)
+{
+       struct iov_iter iter;
+       u32 slen, dlen;
+       void *src, *dst;
        int ret;
 
-       buf_len = sizeof(struct smb2_write_req);
-       data_len = *len;
-       *len = 0;
-
-       hdr = buf;
-       hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
-       hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len);
-       hdr->Offset = cpu_to_le32(buf_len);
-       hdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
-       hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77;
-
-       /* XXX: add other algs here as they're implemented */
-       ret = lz77_compress(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len);
-       if (!ret)
-               *len = SMB_COMPRESS_HDR_LEN + buf_len + data_len;
+       if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base)
+               return -EINVAL;
+
+       if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req))
+               return -EINVAL;
+
+       slen = iov_iter_count(&rq->rq_iter);
+       src = kvzalloc(slen, GFP_KERNEL);
+       if (!src) {
+               ret = -ENOMEM;
+               goto err_free;
+       }
+
+       /* Keep the original iter intact. */
+       iter = rq->rq_iter;
+
+       if (!copy_from_iter_full(src, slen, &iter)) {
+               ret = -EIO;
+               goto err_free;
+       }
+
+       /*
+        * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8
+        * of @slen.
+        */
+       dlen = slen;
+       dst = kvzalloc(dlen, GFP_KERNEL);
+       if (!dst) {
+               ret = -ENOMEM;
+               goto err_free;
+       }
+
+       ret = lz77_compress(src, slen, dst, &dlen);
+       if (!ret) {
+               struct smb2_compression_hdr hdr = { 0 };
+               struct smb_rqst comp_rq = { .rq_nvec = 3, };
+               struct kvec iov[3];
+
+               hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
+               hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen);
+               hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77;
+               hdr.Flags = SMB2_COMPRESSION_FLAG_NONE;
+               hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len);
+
+               iov[0].iov_base = &hdr;
+               iov[0].iov_len = sizeof(hdr);
+               iov[1] = rq->rq_iov[0];
+               iov[2].iov_base = dst;
+               iov[2].iov_len = dlen;
+
+               comp_rq.rq_iov = iov;
+
+               ret = send_fn(server, 1, &comp_rq);
+       } else if (ret == -EMSGSIZE || dlen >= slen) {
+               ret = send_fn(server, 1, rq);
+       }
+err_free:
+       kvfree(dst);
+       kvfree(src);
 
        return ret;
 }
index c0dabe0a60d8110af7b71d1876c34ab4c45422a3..f3ed1d3e52fbfd3fc3acda05393a3a1e189da6cb 100644 (file)
 #define SMB_COMPRESS_PAYLOAD_HDR_LEN   8
 #define SMB_COMPRESS_MIN_LEN           PAGE_SIZE
 
-struct smb_compress_ctx {
-       struct TCP_Server_Info *server;
-       struct work_struct work;
-       struct mid_q_entry *mid;
+#ifdef CONFIG_CIFS_COMPRESSION
+typedef int (*compress_send_fn)(struct TCP_Server_Info *, int, struct smb_rqst *);
 
-       void *buf; /* compressed data */
-       void *data; /* uncompressed data */
-       size_t len;
-};
+int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn);
 
-#ifdef CONFIG_CIFS_COMPRESSION
-int smb_compress(void *buf, const void *data, size_t *len);
+/**
+ * should_compress() - Determines if a request (write) or the response to a
+ *                    request (read) should be compressed.
+ * @tcon: tcon of the request is being sent to
+ * @rqst: request to evaluate
+ *
+ * Return: true iff:
+ * - compression was successfully negotiated with server
+ * - server has enabled compression for the share
+ * - it's a read or write request
+ * - (write only) request length is >= SMB_COMPRESS_MIN_LEN
+ * - (write only) is_compressible() returns 1
+ *
+ * Return false otherwise.
+ */
+bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq);
 
 /**
  * smb_compress_alg_valid() - Validate a compression algorithm.
@@ -62,48 +71,20 @@ static __always_inline int smb_compress_alg_valid(__le16 alg, bool valid_none)
 
        return false;
 }
-
-/**
- * should_compress() - Determines if a request (write) or the response to a
- *                    request (read) should be compressed.
- * @tcon: tcon of the request is being sent to
- * @buf: buffer with an SMB2 READ/WRITE request
- *
- * Return: true iff:
- * - compression was successfully negotiated with server
- * - server has enabled compression for the share
- * - it's a read or write request
- * - if write, request length is >= SMB_COMPRESS_MIN_LEN
- *
- * Return false otherwise.
- */
-static __always_inline bool should_compress(const struct cifs_tcon *tcon, const void *buf)
+#else /* !CONFIG_CIFS_COMPRESSION */
+static inline int smb_compress(void *unused1, void *unused2, void *unused3)
 {
-       const struct smb2_hdr *shdr = buf;
-
-       if (!tcon || !tcon->ses || !tcon->ses->server)
-               return false;
-
-       if (!tcon->ses->server->compression.enabled)
-               return false;
-
-       if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
-               return false;
-
-       if (shdr->Command == SMB2_WRITE) {
-               const struct smb2_write_req *req = buf;
+       return -EOPNOTSUPP;
+}
 
-               return (req->Length >= SMB_COMPRESS_MIN_LEN);
-       }
+static inline bool should_compress(void *unused1, void *unused2)
+{
+       return false;
+}
 
-       return (shdr->Command == SMB2_READ);
+static inline int smb_compress_alg_valid(__le16 unused1, bool unused2)
+{
+       return -EOPNOTSUPP;
 }
-/*
- * #else !CONFIG_CIFS_COMPRESSION ...
- * These routines should not be called when CONFIG_CIFS_COMPRESSION disabled
- * #define smb_compress(arg1, arg2, arg3)              (-EOPNOTSUPP)
- * #define smb_compress_alg_valid(arg1, arg2)  (-EOPNOTSUPP)
- * #define should_compress(arg1, arg2)         (false)
- */
 #endif /* !CONFIG_CIFS_COMPRESSION */
 #endif /* _SMB_COMPRESS_H */
index 2b8d548f94928beeaee5aff85f2ddd06b57954f2..553e253ada29d7457482396d3926b8f9f6af8d9b 100644 (file)
@@ -7,14 +7,75 @@
  * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec.
  */
 #include <linux/slab.h>
+#include <linux/sizes.h>
+#include <linux/count_zeros.h>
+#include <asm/unaligned.h>
+
 #include "lz77.h"
 
-static __always_inline u32 hash3(const u8 *ptr)
+/*
+ * Compression parameters.
+ */
+#define LZ77_MATCH_MIN_LEN     4
+#define LZ77_MATCH_MIN_DIST    1
+#define LZ77_MATCH_MAX_DIST    SZ_1K
+#define LZ77_HASH_LOG          15
+#define LZ77_HASH_SIZE         (1 << LZ77_HASH_LOG)
+#define LZ77_STEP_SIZE         sizeof(u64)
+
+static __always_inline u8 lz77_read8(const u8 *ptr)
+{
+       return get_unaligned(ptr);
+}
+
+static __always_inline u64 lz77_read64(const u64 *ptr)
+{
+       return get_unaligned(ptr);
+}
+
+static __always_inline void lz77_write8(u8 *ptr, u8 v)
+{
+       put_unaligned(v, ptr);
+}
+
+static __always_inline void lz77_write16(u16 *ptr, u16 v)
+{
+       put_unaligned_le16(v, ptr);
+}
+
+static __always_inline void lz77_write32(u32 *ptr, u32 v)
+{
+       put_unaligned_le32(v, ptr);
+}
+
+static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end)
 {
-       return lz77_hash32(lz77_read32(ptr) & 0xffffff, LZ77_HASH_LOG);
+       const void *start = cur;
+       u64 diff;
+
+       /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */
+       do {
+               diff = lz77_read64(cur) ^ lz77_read64(wnd);
+               if (!diff) {
+                       cur += LZ77_STEP_SIZE;
+                       wnd += LZ77_STEP_SIZE;
+
+                       continue;
+               }
+
+               /* This computes the number of common bytes in @diff. */
+               cur += count_trailing_zeros(diff) >> 3;
+
+               return (cur - start);
+       } while (likely(cur + LZ77_STEP_SIZE < end));
+
+       while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++))
+               ;
+
+       return (cur - start);
 }
 
-static u8 *write_match(u8 *dst, u8 **nib, u32 dist, u32 len)
+static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u32 len)
 {
        len -= 3;
        dist--;
@@ -22,6 +83,7 @@ static u8 *write_match(u8 *dst, u8 **nib, u32 dist, u32 len)
 
        if (len < 7) {
                lz77_write16(dst, dist + len);
+
                return dst + 2;
        }
 
@@ -31,11 +93,13 @@ static u8 *write_match(u8 *dst, u8 **nib, u32 dist, u32 len)
        len -= 7;
 
        if (!*nib) {
+               lz77_write8(dst, umin(len, 15));
                *nib = dst;
-               lz77_write8(dst, min_t(unsigned int, len, 15));
                dst++;
        } else {
-               **nib |= min_t(unsigned int, len, 15) << 4;
+               u8 *b = *nib;
+
+               lz77_write8(b, *b | umin(len, 15) << 4);
                *nib = NULL;
        }
 
@@ -45,15 +109,16 @@ static u8 *write_match(u8 *dst, u8 **nib, u32 dist, u32 len)
        len -= 15;
        if (len < 255) {
                lz77_write8(dst, len);
+
                return dst + 1;
        }
 
        lz77_write8(dst, 0xff);
        dst++;
-
        len += 7 + 15;
        if (len <= 0xffff) {
                lz77_write16(dst, len);
+
                return dst + 2;
        }
 
@@ -64,148 +129,107 @@ static u8 *write_match(u8 *dst, u8 **nib, u32 dist, u32 len)
        return dst + 4;
 }
 
-static u8 *write_literals(u8 *dst, const u8 *dst_end, const u8 *src, size_t count,
-                         struct lz77_flags *flags)
+noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
 {
-       const u8 *end = src + count;
-
-       while (src < end) {
-               size_t c = lz77_min(count, 32 - flags->count);
-
-               if (dst + c >= dst_end)
-                       return ERR_PTR(-EFAULT);
-
-               if (lz77_copy(dst, src, c))
-                       return ERR_PTR(-EFAULT);
-
-               dst += c;
-               src += c;
-               count -= c;
-
-               flags->val <<= c;
-               flags->count += c;
-               if (flags->count == 32) {
-                       lz77_write32(flags->pos, flags->val);
-                       flags->count = 0;
-                       flags->pos = dst;
-                       dst += 4;
-               }
-       }
-
-       return dst;
-}
-
-static __always_inline bool is_valid_match(const u32 dist, const u32 len)
-{
-       return (dist >= LZ77_MATCH_MIN_DIST && dist < LZ77_MATCH_MAX_DIST) &&
-              (len >= LZ77_MATCH_MIN_LEN && len < LZ77_MATCH_MAX_LEN);
-}
-
-static __always_inline const u8 *find_match(u32 *htable, const u8 *base, const u8 *cur,
-                                           const u8 *end, u32 *best_len)
-{
-       const u8 *match;
-       u32 hash;
-       size_t offset;
-
-       hash = hash3(cur);
-       offset = cur - base;
-
-       if (htable[hash] >= offset)
-               return cur;
-
-       match = base + htable[hash];
-       *best_len = lz77_match(match, cur, end);
-       if (is_valid_match(cur - match, *best_len))
-               return match;
-
-       return cur;
-}
-
-int lz77_compress(const u8 *src, size_t src_len, u8 *dst, size_t *dst_len)
-{
-       const u8 *srcp, *src_end, *anchor;
-       struct lz77_flags flags = { 0 };
-       u8 *dstp, *dst_end, *nib;
-       u32 *htable;
-       int ret;
+       const void *srcp, *end;
+       void *dstp, *nib, *flag_pos;
+       u32 flag_count = 0;
+       long flag = 0;
+       u64 *htable;
 
        srcp = src;
-       anchor = srcp;
-       src_end = src + src_len;
-
+       end = src + slen;
        dstp = dst;
-       dst_end = dst + *dst_len;
-       flags.pos = dstp;
        nib = NULL;
-
-       memset(dstp, 0, *dst_len);
+       flag_pos = dstp;
        dstp += 4;
 
-       htable = kvcalloc(LZ77_HASH_SIZE, sizeof(u32), GFP_KERNEL);
+       htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
        if (!htable)
                return -ENOMEM;
 
-       /* fill hashtable with invalid offsets */
-       memset(htable, 0xff, LZ77_HASH_SIZE * sizeof(u32));
+       /* Main loop. */
+       do {
+               u32 dist, len = 0;
+               const void *wnd;
+               u64 hash;
 
-       /* from here on, any error is because @dst_len reached >= @src_len */
-       ret = -EMSGSIZE;
+               hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
+               wnd = src + htable[hash];
+               htable[hash] = srcp - src;
+               dist = srcp - wnd;
 
-       /* main loop */
-       while (srcp < src_end) {
-               u32 hash, dist, len;
-               const u8 *match;
+               if (dist && dist < LZ77_MATCH_MAX_DIST)
+                       len = lz77_match_len(wnd, srcp, end);
 
-               while (srcp + 3 < src_end) {
-                       len = LZ77_MATCH_MIN_LEN - 1;
-                       match = find_match(htable, src, srcp, src_end, &len);
-                       hash = hash3(srcp);
-                       htable[hash] = srcp - src;
+               if (len < LZ77_MATCH_MIN_LEN) {
+                       lz77_write8(dstp, lz77_read8(srcp));
+
+                       dstp++;
+                       srcp++;
 
-                       if (likely(match < srcp)) {
-                               dist = srcp - match;
-                               break;
+                       flag <<= 1;
+                       flag_count++;
+                       if (flag_count == 32) {
+                               lz77_write32(flag_pos, flag);
+                               flag_count = 0;
+                               flag_pos = dstp;
+                               dstp += 4;
                        }
 
-                       srcp++;
+                       continue;
                }
 
-               dstp = write_literals(dstp, dst_end, anchor, srcp - anchor, &flags);
-               if (IS_ERR(dstp))
-                       goto err_free;
-
-               if (srcp + 3 >= src_end)
-                       goto leftovers;
+               /*
+                * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
+                * going further.
+                */
+               if (unlikely(dstp - dst >= slen - (slen >> 3))) {
+                       *dlen = slen;
+                       goto out;
+               }
 
-               dstp = write_match(dstp, &nib, dist, len);
+               dstp = lz77_write_match(dstp, &nib, dist, len);
                srcp += len;
-               anchor = srcp;
-
-               flags.val = (flags.val << 1) | 1;
-               flags.count++;
-               if (flags.count == 32) {
-                       lz77_write32(flags.pos, flags.val);
-                       flags.count = 0;
-                       flags.pos = dstp;
+
+               flag = (flag << 1) | 1;
+               flag_count++;
+               if (flag_count == 32) {
+                       lz77_write32(flag_pos, flag);
+                       flag_count = 0;
+                       flag_pos = dstp;
+                       dstp += 4;
+               }
+       } while (likely(srcp + LZ77_STEP_SIZE < end));
+
+       while (srcp < end) {
+               u32 c = umin(end - srcp, 32 - flag_count);
+
+               memcpy(dstp, srcp, c);
+
+               dstp += c;
+               srcp += c;
+
+               flag <<= c;
+               flag_count += c;
+               if (flag_count == 32) {
+                       lz77_write32(flag_pos, flag);
+                       flag_count = 0;
+                       flag_pos = dstp;
                        dstp += 4;
                }
-       }
-leftovers:
-       if (srcp < src_end) {
-               dstp = write_literals(dstp, dst_end, srcp, src_end - srcp, &flags);
-               if (IS_ERR(dstp))
-                       goto err_free;
        }
 
-       flags.val <<= (32 - flags.count);
-       flags.val |= (1 << (32 - flags.count)) - 1;
-       lz77_write32(flags.pos, flags.val);
+       flag <<= (32 - flag_count);
+       flag |= (1 << (32 - flag_count)) - 1;
+       lz77_write32(flag_pos, flag);
 
-       *dst_len = dstp - dst;
-       ret = 0;
-err_free:
+       *dlen = dstp - dst;
+out:
        kvfree(htable);
 
-       return ret;
+       if (*dlen < slen)
+               return 0;
+
+       return -EMSGSIZE;
 }
index 3d0d3eaa8ffbf5bacfbe0403025bd09d6437b3dd..cdcb191b48a23b0795dd400daa709ed804d226d6 100644 (file)
  *
  * Authors: Enzo Matsumiya <ematsumiya@suse.de>
  *
- * Definitions and optmized helpers for LZ77 compression.
+ * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec.
  */
 #ifndef _SMB_COMPRESS_LZ77_H
 #define _SMB_COMPRESS_LZ77_H
 
-#include <linux/uaccess.h>
-#ifdef CONFIG_CIFS_COMPRESSION
-#include <asm/ptrace.h>
 #include <linux/kernel.h>
-#include <linux/string.h>
-#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
-#include <asm-generic/unaligned.h>
-#endif
 
-#define LZ77_HASH_LOG          13
-#define LZ77_HASH_SIZE         (1 << LZ77_HASH_LOG)
-#define LZ77_HASH_MASK         lz77_hash_mask(LZ77_HASH_LOG)
-
-/* We can increase this for better compression (but worse performance). */
-#define LZ77_MATCH_MIN_LEN     3
-/* From MS-XCA, but it's arbitrarily chosen. */
-#define LZ77_MATCH_MAX_LEN     S32_MAX
-/*
- * Check this to ensure we don't match the current position, which would
- * end up doing a verbatim copy of the input, and actually overflowing
- * the output buffer because of the encoded metadata.
- */
-#define LZ77_MATCH_MIN_DIST    1
-/* How far back in the buffer can we try to find a match (i.e. window size) */
-#define LZ77_MATCH_MAX_DIST    8192
-
-#define LZ77_STEPSIZE_16       sizeof(u16)
-#define LZ77_STEPSIZE_32       sizeof(u32)
-#define LZ77_STEPSIZE_64       sizeof(u64)
-
-struct lz77_flags {
-       u8 *pos;
-       size_t count;
-       long val;
-};
-
-static __always_inline u32 lz77_hash_mask(const unsigned int log2)
-{
-       return ((1 << log2) - 1);
-}
-
-static __always_inline u32 lz77_hash64(const u64 v, const unsigned int log2)
-{
-       const u64 prime5bytes = 889523592379ULL;
-
-       return (u32)(((v << 24) * prime5bytes) >> (64 - log2));
-}
-
-static __always_inline u32 lz77_hash32(const u32 v, const unsigned int log2)
-{
-       return ((v * 2654435769LL) >> (32 - log2)) & lz77_hash_mask(log2);
-}
-
-static __always_inline u32 lz77_log2(unsigned int x)
-{
-       return x ? ((u32)(31 - __builtin_clz(x))) : 0;
-}
-
-#ifdef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
-static __always_inline u8 lz77_read8(const void *ptr)
-{
-       return *(u8 *)ptr;
-}
-
-static __always_inline u16 lz77_read16(const void *ptr)
-{
-       return *(u16 *)ptr;
-}
-
-static __always_inline u32 lz77_read32(const void *ptr)
-{
-       return *(u32 *)ptr;
-}
-
-static __always_inline u64 lz77_read64(const void *ptr)
-{
-       return *(u64 *)ptr;
-}
-
-static __always_inline void lz77_write8(void *ptr, const u8 v)
-{
-       *(u8 *)ptr = v;
-}
-
-static __always_inline void lz77_write16(void *ptr, const u16 v)
-{
-       *(u16 *)ptr = v;
-}
-
-static __always_inline void lz77_write32(void *ptr, const u32 v)
-{
-       *(u32 *)ptr = v;
-}
-
-static __always_inline void lz77_write64(void *ptr, const u64 v)
-{
-       *(u64 *)ptr = v;
-}
-
-static __always_inline void lz77_write_ptr16(void *ptr, const void *vp)
-{
-       *(u16 *)ptr = *(const u16 *)vp;
-}
-
-static __always_inline void lz77_write_ptr32(void *ptr, const void *vp)
-{
-       *(u32 *)ptr = *(const u32 *)vp;
-}
-
-static __always_inline void lz77_write_ptr64(void *ptr, const void *vp)
-{
-       *(u64 *)ptr = *(const u64 *)vp;
-}
-
-static __always_inline long lz77_copy(u8 *dst, const u8 *src, size_t count)
-{
-       return copy_from_kernel_nofault(dst, src, count);
-}
-#else /* CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS */
-static __always_inline u8 lz77_read8(const void *ptr)
-{
-       return get_unaligned((u8 *)ptr);
-}
-
-static __always_inline u16 lz77_read16(const void *ptr)
-{
-       return lz77_read8(ptr) | (lz77_read8(ptr + 1) << 8);
-}
-
-static __always_inline u32 lz77_read32(const void *ptr)
-{
-       return lz77_read16(ptr) | (lz77_read16(ptr + 2) << 16);
-}
-
-static __always_inline u64 lz77_read64(const void *ptr)
-{
-       return lz77_read32(ptr) | ((u64)lz77_read32(ptr + 4) << 32);
-}
-
-static __always_inline void lz77_write8(void *ptr, const u8 v)
-{
-       put_unaligned(v, (u8 *)ptr);
-}
-
-static __always_inline void lz77_write16(void *ptr, const u16 v)
-{
-       lz77_write8(ptr, v & 0xff);
-       lz77_write8(ptr + 1, (v >> 8) & 0xff);
-}
-
-static __always_inline void lz77_write32(void *ptr, const u32 v)
-{
-       lz77_write16(ptr, v & 0xffff);
-       lz77_write16(ptr + 2, (v >> 16) & 0xffff);
-}
-
-static __always_inline void lz77_write64(void *ptr, const u64 v)
-{
-       lz77_write32(ptr, v & 0xffffffff);
-       lz77_write32(ptr + 4, (v >> 32) & 0xffffffff);
-}
-
-static __always_inline void lz77_write_ptr16(void *ptr, const void *vp)
-{
-       const u16 v = lz77_read16(vp);
-
-       lz77_write16(ptr, v);
-}
-
-static __always_inline void lz77_write_ptr32(void *ptr, const void *vp)
-{
-       const u32 v = lz77_read32(vp);
-
-       lz77_write32(ptr, v);
-}
-
-static __always_inline void lz77_write_ptr64(void *ptr, const void *vp)
-{
-       const u64 v = lz77_read64(vp);
-
-       lz77_write64(ptr, v);
-}
-static __always_inline long lz77_copy(u8 *dst, const u8 *src, size_t count)
-{
-       memcpy(dst, src, count);
-       return 0;
-}
-#endif /* !CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS */
-
-static __always_inline unsigned int __count_common_bytes(const unsigned long diff)
-{
-#ifdef __has_builtin
-#  if __has_builtin(__builtin_ctzll)
-       return (unsigned int)__builtin_ctzll(diff) >> 3;
-#  endif
-#else
-       /* count trailing zeroes */
-       unsigned long bits = 0, i, z = 0;
-
-       bits |= diff;
-       for (i = 0; i < 64; i++) {
-               if (bits[i])
-                       break;
-               z++;
-       }
-
-       return (unsigned int)z >> 3;
-#endif
-}
-
-static __always_inline size_t lz77_match(const u8 *match, const u8 *cur, const u8 *end)
-{
-       const u8 *start = cur;
-
-       if (cur == match)
-               return 0;
-
-       if (likely(cur < end - (LZ77_STEPSIZE_64 - 1))) {
-               u64 const diff = lz77_read64(cur) ^ lz77_read64(match);
-
-               if (!diff) {
-                       cur += LZ77_STEPSIZE_64;
-                       match += LZ77_STEPSIZE_64;
-               } else {
-                       return __count_common_bytes(diff);
-               }
-       }
-
-       while (likely(cur < end - (LZ77_STEPSIZE_64 - 1))) {
-               u64 const diff = lz77_read64(cur) ^ lz77_read64(match);
-
-               if (!diff) {
-                       cur += LZ77_STEPSIZE_64;
-                       match += LZ77_STEPSIZE_64;
-                       continue;
-               }
-
-               cur += __count_common_bytes(diff);
-               return (size_t)(cur - start);
-       }
-
-       if (cur < end - 3 && !(lz77_read32(cur) ^ lz77_read32(match))) {
-               cur += LZ77_STEPSIZE_32;
-               match += LZ77_STEPSIZE_32;
-       }
-
-       if (cur < end - 1 && lz77_read16(cur) == lz77_read16(match)) {
-               cur += LZ77_STEPSIZE_16;
-               match += LZ77_STEPSIZE_16;
-       }
-
-       if (cur < end && *cur == *match)
-               cur++;
-
-       return (size_t)(cur - start);
-}
-
-static __always_inline unsigned long lz77_max(unsigned long a, unsigned long b)
-{
-       int m = (a < b) - 1;
-
-       return (a & m) | (b & ~m);
-}
-
-static __always_inline unsigned long lz77_min(unsigned long a, unsigned long b)
-{
-       int m = (a > b) - 1;
-
-       return (a & m) | (b & ~m);
-}
-
-int lz77_compress(const u8 *src, size_t src_len, u8 *dst, size_t *dst_len);
-/* when CONFIG_CIFS_COMPRESSION not set lz77_compress() is not called */
-#endif /* !CONFIG_CIFS_COMPRESSION */
+int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen);
 #endif /* _SMB_COMPRESS_LZ77_H */