Path: blob/master/deprecated/scripts/bigram_hinton_diagram.py
1192 views
import superimport12import collections3import re4import os5import numpy as np6import matplotlib.pyplot as plt7import pyprobml_utils as pml8import requests91011url = 'https://raw.githubusercontent.com/probml/probml-data/main/data/timemachine.txt'12response = requests.get(url)13data = response.text14lines = [s+'\n' for s in response.text.split("\n")]15raw_dataset = [re.sub('[^A-Za-z]+', ' ', st).lower().split() for st in lines]1617# Print first few lines18for sentence in raw_dataset[:10]:19print(sentence)2021# Concat sentences into single string of chars22# skip blank lines23sentences = [' '.join(s) for s in raw_dataset if s]2425# concat into single long string26charseq = ''.join(sentences)272829# Unigrams30wseq = charseq31print('First 10 unigrams\n', wseq[:10])3233# Bigrams34word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]35print('First 10 bigrams\n', word_pairs[:10])3637# Trigrams38word_triples = [triple for triple in zip(wseq[:-2], wseq[1:-1], wseq[2:])]39print('First 10 trigrams\n', word_triples[:10])4041# ngram statistics42counter = collections.Counter(wseq)43counter_pairs = collections.Counter(word_pairs)44counter_triples = collections.Counter(word_triples)4546print('Most common unigrams\n', counter.most_common(10))47print('Most common bigrams\n', counter_pairs.most_common(10))48print('Most common trigrams\n', counter_triples.most_common(10))495051# convert [(('t', 'h', 'e'), 3126), ...] to {'the': 3126, ...}52def make_dict(lst, min_count=1):53d = dict()54for s, c in lst:55if c <= min_count:56continue57key = ''.join(s)58d[key] = c59return d6061unigram_dict = make_dict(counter.most_common())62alphabet = list(unigram_dict.keys())63alpha_size = len(alphabet)6465bigram_dict = make_dict(counter_pairs.most_common())6667bigram_count = np.zeros((alpha_size, alpha_size))68for k, v in bigram_dict.items():69code0 = alphabet.index(k[0])70code1 = alphabet.index(k[1])71#print('code0 {}, code1 {}, k {}, v {}'.format(code0, code1, k, v))72bigram_count[code0, code1] += v7374bigram_prob = bigram_count / (1e-10+np.sum(bigram_count,axis=1))7576#https://matplotlib.org/3.1.1/gallery/specialty_plots/hinton_demo.html77def hinton_diagram(matrix, max_weight=None, ax=None):78"""Draw Hinton diagram for visualizing a weight matrix."""79ax = ax if ax is not None else plt.gca()8081if not max_weight:82max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))8384ax.patch.set_facecolor('gray')85ax.set_aspect('equal', 'box')86ax.xaxis.set_major_locator(plt.NullLocator())87ax.yaxis.set_major_locator(plt.NullLocator())8889for (x, y), w in np.ndenumerate(matrix):90color = 'white' if w > 0 else 'black'91size = np.sqrt(np.abs(w) / max_weight)92rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,93facecolor=color, edgecolor=color)94ax.add_patch(rect)9596ax.autoscale_view()97ax.invert_yaxis()98ax.axis('on')99ax.set_xticks(range(alpha_size))100ax.set_xticklabels(alphabet)101ax.set_yticks(range(alpha_size))102ax.set_yticklabels(alphabet)103104105plt.figure(figsize=(8,8))106hinton_diagram(bigram_count.T)107pml.savefig('bigram-count.pdf')108plt.show()109110plt.figure(figsize=(8,8))111hinton_diagram(bigram_prob.T)112pml.savefig('bigram-prob.pdf')113plt.show()114115116117