# Loading Corpus

In [74]:
import nltk
from nltk.corpus import gutenberg

nltk.download('gutenberg')
austen_emma = gutenberg.raw('austen-emma.txt')


[nltk_data] Downloading package gutenberg to /root/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!


In [75]:
print(austen_emma[:200])

[Emma by Jane Austen 1816]

VOLUME I

CHAPTER I


Emma Woodhouse, handsome, clever, and rich, with a comfortable home
and happy disposition, seemed to unite some of the best blessings
of existence; an


# Creating Frequency Tables

In [76]:
import string
from collections import Counter

austen_emma_cleaned = austen_emma.upper()
austen_emma_cleaned = ''.join(c for c in austen_emma_cleaned if c in string.ascii_uppercase + ' ')

n = 6

ngrams = [''.join(ngram) for ngram in zip(*(austen_emma_cleaned[i:-(n-i)] for i in range(n)))]
freqs = Counter(ngrams)
freqs

english_freqs = Counter(austen_emma_cleaned)

# Computing the marginal frequencies

In [77]:
def marginal(prefix, counts):
  marginal_freqs = {ngram[-1]: count for ngram, count in freqs.items() if ngram[:-1] == prefix}
  if not marginal_freqs:
    marginal_freqs = english_freqs

  characters, counts = zip(*marginal_freqs.items())
  sum_counts = sum(counts)
  probabilities = [count / sum_counts for count in counts]
  return characters, probabilities

# Sampling from our markov model

In [78]:
import random

s = 'EMMA '

for i in range(100):
  characters, probabilities = marginal(s[-(n-1):], freqs)
  s += random.choices(characters, weights=probabilities, k=1)[0]

s

'EMMA ASSISTERSTHE BUSINESSTO BE TO LEAVE THE LADY SCEPTIONS HINTSHIS INIMITABLE CIRCULARLY PERSONS YOU HA'

# Shannon Coding

In [79]:
import math
from tqdm.notebook import tqdm
from numba import jit

@jit(nopython=True)
def float_to_binary_numba(number, length):
    output = ""
    for i in range(1, length + 1):
        digit = int(number >= 0.5)
        output += str(digit)
        number = 2 * (number - 0.5 * digit)
    return output

def float_to_binary(number, length):
  output = ""
  for i in range(1, length + 1):
    digit = int(number >= 0.5)
    output += str(digit)
    number = 2 * (number - 0.5 * digit)

  return output

def shannon_encoding(freq_table):
  items = sorted(freq_table.items(), key=lambda x: x[1], reverse=True)
  encoding = {}
  Pi = 0.

  for char, prob in items:
    encoding[char] = float_to_binary_numba(Pi, math.ceil(math.log2(1/prob)))
    Pi += prob

  return encoding

def encode(string, encoding):
  return ''.join(encoding[char] for char in string)

def expected_length(freq_table, encoding):
  return sum(freq_table[char] * len(encoding[char]) for char in freq_table)

def shannon_entropy(freq_table):
    return sum(-freq_table[char] * math.log2(freq_table[char]) for char in freq_table)


In [80]:
freq_table = {'a': 0.5, 'b': 0.25, 'c': 0.25}
encoding = shannon_encoding(freq_table)
encoding

{'a': '0', 'b': '10', 'c': '11'}

In [81]:
print(f'expected length: {expected_length(freq_table, encoding)}')
print(f'shannon entropy: {shannon_entropy(freq_table)}')

expected length: 1.5
shannon entropy: 1.5


In [82]:
def decode(text, encoding):
  output = ""
  while text:
    token, i = decode_token(text, encoding)
    output += token
    text = text[i:]
  return output

def decode_token(text, encoding):
  inverted_encoding = {value: key for key, value in encoding.items()}
  for i in range(len(text)):
    if text[:i+1] in inverted_encoding:
      return inverted_encoding[text[:i+1]], i+1

In [83]:
decode(encode('aabc', encoding), encoding)

'aabc'

# Combining with NanoGPT

In [84]:
%pip install -q torch numpy transformers datasets tiktoken wandb tqdm

In [85]:
!git clone https://github.com/karpathy/nanoGPT
%cd nanoGPT

Cloning into 'nanoGPT'...
remote: Enumerating objects: 671, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 671 (delta 5), reused 12 (delta 2), pack-reused 649[K
Receiving objects: 100% (671/671), 949.28 KiB | 767.00 KiB/s, done.
Resolving deltas: 100% (376/376), done.
/content/nanoGPT/nanoGPT


In [86]:
from model import GPT
import tiktoken
import torch

gpt2 = GPT.from_pretrained('gpt2', dict(dropout=0.0)).to(dtype=torch.float64)
enc = tiktoken.get_encoding("gpt2")
tokenize = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
detokenize = lambda l: enc.decode(l)

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M


In [87]:
import sys

max_context = 128

def compress(text):
  tokens = tokenize('\n' + text)
  tokens = torch.tensor(tokens, dtype=torch.long)
  output = ''

  for i in tqdm(range(1, len(tokens))):
    # get marginal from model
    cond_tokens = tokens[None,:i] if i <= max_context else tokens[None,i-max_context:i]
    probs = gpt2(cond_tokens)[0].flatten().softmax(0)
    prob_next = probs[tokens[i]] # probability of next token
    # sort probabilties
    probs, indices = torch.sort(probs, descending=True)
    idx = torch.argwhere(indices == tokens[i]).item()
    # Pi is the cumulative sum up until idx (not including idx)
    zero_tensor = torch.tensor([0])
    probs = torch.cat((zero_tensor, probs), dim=0)
    Pi = probs.cumsum(0)[idx].item()
    # encode as a binary string using Shannon's method
    encoded_token = float_to_binary_numba(Pi, math.ceil(math.log2(1/prob_next)))
    output += encoded_token

  return output

def decompress(text):
  tokens = tokenize('\n')
  tokens = torch.tensor(tokens, dtype=torch.long)
  output = []
  processed = 0
  text_length = len(text)

  while text:
    # feed in tokens and truncate if number of tokens exceeds max_context
    cond_tokens = tokens[None,:] if len(tokens) <= max_context else tokens[None,-max_context:]
    probs = gpt2(cond_tokens)[0].flatten().softmax(0)
    # construct frequency table
    freq_table = {i:prob.item() for i, prob in enumerate(probs)}
    encoding = shannon_encoding(freq_table)
    token, i = decode_token(text, encoding)
    # decompress the rest of the text
    text = text[i:]
    output.append(token)
    tokens = torch.cat((tokens, torch.tensor([token], dtype=torch.long)), dim=0)
    processed += i
    print(f'\r Decompressed {processed/text_length*100:.2f}% of text...', end='')

  return detokenize(output)

# Benchmarking our code

In [88]:
text = austen_emma[:10000]
print(text)

[Emma by Jane Austen 1816]

VOLUME I

CHAPTER I


Emma Woodhouse, handsome, clever, and rich, with a comfortable home
and happy disposition, seemed to unite some of the best blessings
of existence; and had lived nearly twenty-one years in the world
with very little to distress or vex her.

She was the youngest of the two daughters of a most affectionate,
indulgent father; and had, in consequence of her sister's marriage,
been mistress of his house from a very early period.  Her mother
had died too long ago for her to have more than an indistinct
remembrance of her caresses; and her place had been supplied
by an excellent woman as governess, who had fallen little short
of a mother in affection.

Sixteen years had Miss Taylor been in Mr. Woodhouse's family,
less as a governess than a friend, very fond of both daughters,
but particularly of Emma.  Between _them_ it was more the intimacy
of sisters.  Even before Miss Taylor had ceased to hold the nominal
office of governess, the mildness o

In [89]:
%%time
compressed_string = compress(text)
compressed_string

  0%|          | 0/2468 [00:00<?, ?it/s]

CPU times: user 34min 28s, sys: 32.6 s, total: 35min
Wall time: 5min 49s


'011100111100111011100101101100100110100010101011001011111011101010001101101011000011110111000110110010000000110000011011010101011111101100010000110100100111011000111001100001101111100000010110010000011100101000110010100010111111111111111111111001110111000111111001010000110000011001001101000011001101010110000100010111010010110011110000111111111111111001010011011011101100000101010001010101000111010011101000101110000000100100010011111111111111111111001110010101000111000000000111110111101110111100100100101000000000000010101011100000010110011000001010001000101000111001010011101011111111111111111111101111110000000001100001010000010100110000101010000001011000100001010001111111111111111111111101110101101111110000110110010110101100101111001111010000000011101000111110010010000111101011000000000101001111111111111110100010110101011110000011110110101101110000000010101011010000101101011101100010010011111111010110110000000111110101011000100100100000110101000111100001100101011010101110011100111001110

In [90]:
%%time
decompressed_string = decompress(compressed_string)
assert decompressed_string == text, 'Decompress(Compress(X)) != X'

 Decompressed 100.00% of text...CPU times: user 49min 15s, sys: 25.9 s, total: 49min 41s
Wall time: 21min 44s


In [94]:
print(f'Size of uncompressed text: {len(text) * 8}')

Size of uncompressed text: 80000


In [95]:
import zlib
print(f'Size of compressed text using zlib: {len(zlib.compress(text.encode())) * 8}')

Size of compressed text using zlib: 35784


In [96]:
print(f'Size of compressed text using our compression algorithm: {len(compressed_string)}')

Size of compressed text using our compression algorithm: 15179
