Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bigram_hinton_diagram.py
1192 views
1
import superimport
2
3
import collections
4
import re
5
import os
6
import numpy as np
7
import matplotlib.pyplot as plt
8
import pyprobml_utils as pml
9
import requests
10
11
12
url = 'https://raw.githubusercontent.com/probml/probml-data/main/data/timemachine.txt'
13
response = requests.get(url)
14
data = response.text
15
lines = [s+'\n' for s in response.text.split("\n")]
16
raw_dataset = [re.sub('[^A-Za-z]+', ' ', st).lower().split() for st in lines]
17
18
# Print first few lines
19
for sentence in raw_dataset[:10]:
20
print(sentence)
21
22
# Concat sentences into single string of chars
23
# skip blank lines
24
sentences = [' '.join(s) for s in raw_dataset if s]
25
26
# concat into single long string
27
charseq = ''.join(sentences)
28
29
30
# Unigrams
31
wseq = charseq
32
print('First 10 unigrams\n', wseq[:10])
33
34
# Bigrams
35
word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
36
print('First 10 bigrams\n', word_pairs[:10])
37
38
# Trigrams
39
word_triples = [triple for triple in zip(wseq[:-2], wseq[1:-1], wseq[2:])]
40
print('First 10 trigrams\n', word_triples[:10])
41
42
# ngram statistics
43
counter = collections.Counter(wseq)
44
counter_pairs = collections.Counter(word_pairs)
45
counter_triples = collections.Counter(word_triples)
46
47
print('Most common unigrams\n', counter.most_common(10))
48
print('Most common bigrams\n', counter_pairs.most_common(10))
49
print('Most common trigrams\n', counter_triples.most_common(10))
50
51
52
# convert [(('t', 'h', 'e'), 3126), ...] to {'the': 3126, ...}
53
def make_dict(lst, min_count=1):
54
d = dict()
55
for s, c in lst:
56
if c <= min_count:
57
continue
58
key = ''.join(s)
59
d[key] = c
60
return d
61
62
unigram_dict = make_dict(counter.most_common())
63
alphabet = list(unigram_dict.keys())
64
alpha_size = len(alphabet)
65
66
bigram_dict = make_dict(counter_pairs.most_common())
67
68
bigram_count = np.zeros((alpha_size, alpha_size))
69
for k, v in bigram_dict.items():
70
code0 = alphabet.index(k[0])
71
code1 = alphabet.index(k[1])
72
#print('code0 {}, code1 {}, k {}, v {}'.format(code0, code1, k, v))
73
bigram_count[code0, code1] += v
74
75
bigram_prob = bigram_count / (1e-10+np.sum(bigram_count,axis=1))
76
77
#https://matplotlib.org/3.1.1/gallery/specialty_plots/hinton_demo.html
78
def hinton_diagram(matrix, max_weight=None, ax=None):
79
"""Draw Hinton diagram for visualizing a weight matrix."""
80
ax = ax if ax is not None else plt.gca()
81
82
if not max_weight:
83
max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))
84
85
ax.patch.set_facecolor('gray')
86
ax.set_aspect('equal', 'box')
87
ax.xaxis.set_major_locator(plt.NullLocator())
88
ax.yaxis.set_major_locator(plt.NullLocator())
89
90
for (x, y), w in np.ndenumerate(matrix):
91
color = 'white' if w > 0 else 'black'
92
size = np.sqrt(np.abs(w) / max_weight)
93
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
94
facecolor=color, edgecolor=color)
95
ax.add_patch(rect)
96
97
ax.autoscale_view()
98
ax.invert_yaxis()
99
ax.axis('on')
100
ax.set_xticks(range(alpha_size))
101
ax.set_xticklabels(alphabet)
102
ax.set_yticks(range(alpha_size))
103
ax.set_yticklabels(alphabet)
104
105
106
plt.figure(figsize=(8,8))
107
hinton_diagram(bigram_count.T)
108
pml.savefig('bigram-count.pdf')
109
plt.show()
110
111
plt.figure(figsize=(8,8))
112
hinton_diagram(bigram_prob.T)
113
pml.savefig('bigram-prob.pdf')
114
plt.show()
115
116
117