Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/30/newsgroups_visualize.ipynb
1193 views
Kernel: Python 3

Open In Colab

import numpy as np import scipy.io as scio import matplotlib.pyplot as plt from matplotlib.lines import Line2D import pylab %matplotlib inline pylab.rcParams["figure.figsize"] = (15, 10)
import requests from io import BytesIO from scipy.io import loadmat url = "https://raw.githubusercontent.com/probml/probml-data/main/data/20news_w100.mat" response = requests.get(url) # rawdata = response.text rawdata = BytesIO(response.content) data = loadmat(rawdata) print(data)
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNX86, Created on: Sat Jun 29 15:07:10 2002', '__version__': '1.0', '__globals__': [], 'documents': <100x16242 sparse matrix of type '<class 'numpy.uint8'>' with 65451 stored elements in Compressed Sparse Column format>, 'wordlist': array([[array(['aids'], dtype='<U4'), array(['baseball'], dtype='<U8'), array(['bible'], dtype='<U5'), array(['bmw'], dtype='<U3'), array(['cancer'], dtype='<U6'), array(['car'], dtype='<U3'), array(['card'], dtype='<U4'), array(['case'], dtype='<U4'), array(['children'], dtype='<U8'), array(['christian'], dtype='<U9'), array(['computer'], dtype='<U8'), array(['course'], dtype='<U6'), array(['data'], dtype='<U4'), array(['dealer'], dtype='<U6'), array(['disease'], dtype='<U7'), array(['disk'], dtype='<U4'), array(['display'], dtype='<U7'), array(['doctor'], dtype='<U6'), array(['dos'], dtype='<U3'), array(['drive'], dtype='<U5'), array(['driver'], dtype='<U6'), array(['earth'], dtype='<U5'), array(['email'], dtype='<U5'), array(['engine'], dtype='<U6'), array(['evidence'], dtype='<U8'), array(['fact'], dtype='<U4'), array(['fans'], dtype='<U4'), array(['files'], dtype='<U5'), array(['food'], dtype='<U4'), array(['format'], dtype='<U6'), array(['ftp'], dtype='<U3'), array(['games'], dtype='<U5'), array(['god'], dtype='<U3'), array(['government'], dtype='<U10'), array(['graphics'], dtype='<U8'), array(['gun'], dtype='<U3'), array(['health'], dtype='<U6'), array(['help'], dtype='<U4'), array(['hit'], dtype='<U3'), array(['hockey'], dtype='<U6'), array(['honda'], dtype='<U5'), array(['human'], dtype='<U5'), array(['image'], dtype='<U5'), array(['insurance'], dtype='<U9'), array(['israel'], dtype='<U6'), array(['jesus'], dtype='<U5'), array(['jews'], dtype='<U4'), array(['launch'], dtype='<U6'), array(['law'], dtype='<U3'), array(['league'], dtype='<U6'), array(['lunar'], dtype='<U5'), array(['mac'], dtype='<U3'), array(['mars'], dtype='<U4'), array(['medicine'], dtype='<U8'), array(['memory'], dtype='<U6'), array(['mission'], dtype='<U7'), array(['moon'], dtype='<U4'), array(['msg'], dtype='<U3'), array(['nasa'], dtype='<U4'), array(['nhl'], dtype='<U3'), array(['number'], dtype='<U6'), array(['oil'], dtype='<U3'), array(['orbit'], dtype='<U5'), array(['patients'], dtype='<U8'), array(['pc'], dtype='<U2'), array(['phone'], dtype='<U5'), array(['players'], dtype='<U7'), array(['power'], dtype='<U5'), array(['president'], dtype='<U9'), array(['problem'], dtype='<U7'), array(['program'], dtype='<U7'), array(['puck'], dtype='<U4'), array(['question'], dtype='<U8'), array(['religion'], dtype='<U8'), array(['research'], dtype='<U8'), array(['rights'], dtype='<U6'), array(['satellite'], dtype='<U9'), array(['science'], dtype='<U7'), array(['scsi'], dtype='<U4'), array(['season'], dtype='<U6'), array(['server'], dtype='<U6'), array(['shuttle'], dtype='<U7'), array(['software'], dtype='<U8'), array(['solar'], dtype='<U5'), array(['space'], dtype='<U5'), array(['state'], dtype='<U5'), array(['studies'], dtype='<U7'), array(['system'], dtype='<U6'), array(['team'], dtype='<U4'), array(['technology'], dtype='<U10'), array(['university'], dtype='<U10'), array(['version'], dtype='<U7'), array(['video'], dtype='<U5'), array(['vitamin'], dtype='<U7'), array(['war'], dtype='<U3'), array(['water'], dtype='<U5'), array(['win'], dtype='<U3'), array(['windows'], dtype='<U7'), array(['won'], dtype='<U3'), array(['world'], dtype='<U5')]], dtype=object), 'newsgroups': array([[1, 1, 1, ..., 4, 4, 4]], dtype=uint8), 'groupnames': array([[array(['comp.*'], dtype='<U6'), array(['rec.*'], dtype='<U5'), array(['sci.*'], dtype='<U5'), array(['talk.*'], dtype='<U6')]], dtype=object)}
X = data["documents"] print(type(X)) print(X.shape)
<class 'scipy.sparse.csc.csc_matrix'> (100, 16242)
X = X.T # 对X进行转置 print(X.shape, type(X))
(16242, 100) <class 'scipy.sparse.csr.csr_matrix'>
y = data["newsgroups"] classlabels = data["groupnames"] print(type(classlabels), classlabels.shape)
<class 'numpy.ndarray'> (1, 4)
nwords = np.sum(X, 1) print(nwords.shape, "\n", nwords[:5], type(nwords))
(16242, 1) [[5] [1] [3] [3] [4]] <class 'numpy.matrix'>
word_num_index = np.argsort(-nwords, axis=0) print(word_num_index.shape, type(word_num_index)) index_1000 = np.array(word_num_index[:1000]) print(index_1000.shape, type(index_1000))
(16242, 1) <class 'numpy.matrix'> (1000, 1) <class 'numpy.ndarray'>
XX = X[index_1000.flatten()].toarray() yy = y.T[index_1000.flatten()] print(type(XX), XX.shape) print(type(yy), yy.shape) new_yy = np.sort(yy, axis=0) index_of_yy = np.argsort(yy, axis=0) XX = XX[index_of_yy.flatten()] print(XX.shape)
<class 'numpy.ndarray'> (1000, 100) <class 'numpy.ndarray'> (1000, 1) (1000, 100)
yy_unique = np.unique(new_yy) print(yy_unique)
[1 2 3 4]
ax = plt.gca() ax.imshow(XX, cmap=plt.cm.gray_r, aspect="auto") for label in yy_unique[:-1]: label_index = np.where(new_yy.flatten() == label)[-1][-1] line1 = [(0, label_index), (XX.shape[1], label_index)] (line1_xs, line1_ys) = zip(*line1) ax.add_line(Line2D(line1_xs, line1_ys, linewidth=5, color="red"))
Image in a Jupyter notebook