Path: blob/master/deprecated/scripts/ard_classification_demo.py
1192 views
# Demo of logistic regression with automatic relevancy determination1# to eliminate irrelevant features.23#https://github.com/AmazaspShumik/sklearn-bayes/blob/master/ipython_notebooks_tutorials/rvm_ard/ard_classification_demo.ipynb45import superimport67from ard_linreg_logreg import ClassificationARD8from ard_vb_linreg_logreg import VBClassificationARD910import numpy as np11import matplotlib.pyplot as plt12from pyprobml_utils import save_fig1314from matplotlib import cm15from sklearn.model_selection import train_test_split16from sklearn.linear_model import LogisticRegressionCV171819def generate_dataset(n_samples = 500, n_features = 100,20cov_class_1 = [[0.9,0.1],[1.5,.2]],21cov_class_2 = [[0.9,0.1],[1.5,.2]],22mean_class_1 = (-1,0.4),23mean_class_2 = (-1,-0.4)):24''' Generate binary classification problem with two relevant features'''25X = np.random.randn(n_samples, n_features)26Y = np.ones(n_samples)27sep = int(n_samples/2)28Y[0:sep] = 029X[0:sep,0:2] = np.random.multivariate_normal(mean = mean_class_1,30cov = cov_class_1, size = sep)31X[sep:n_samples,0:2] = np.random.multivariate_normal(mean = mean_class_2,32cov = cov_class_2, size = n_samples - sep)33return X,Y343536373839def run_demo(n_samples, n_features):40np.random.seed(42)41X,Y = generate_dataset(n_samples,n_features)4243plt.figure(figsize = (8,6))44plt.plot(X[Y==0,0],X[Y==0,1],"bo", markersize = 3)45plt.plot(X[Y==1,0],X[Y==1,1],"ro", markersize = 3)46plt.xlabel('feature 1')47plt.ylabel('feature 2')48plt.title("Example of dataset")49plt.show()5051# training & test data52X,x,Y,y = train_test_split(X,Y, test_size = 0.4)5354models = list()55names = list()5657models.append(ClassificationARD())58names.append('logreg-ARD-Laplace')5960models.append(VBClassificationARD())61names.append('logreg-ARD-VB')6263models.append(LogisticRegressionCV(penalty = 'l2', cv=3))64names.append('logreg-CV-L2')6566models.append(LogisticRegressionCV(penalty = 'l1', solver = 'liblinear', cv=3))67names.append('logreg-CV-L1')686970nmodels = len(models)71for i in range(nmodels):72print('\nfitting {}'.format(names[i]))73models[i].fit(X,Y)7475# construct grid76n_grid = 10077max_x = np.max(x[:,0:2],axis = 0)78min_x = np.min(x[:,0:2],axis = 0)79X1 = np.linspace(min_x[0],max_x[0],n_grid)80X2 = np.linspace(min_x[1],max_x[1],n_grid)81x1,x2 = np.meshgrid(X1,X2)82Xgrid = np.zeros([n_grid**2,2])83Xgrid[:,0] = np.reshape(x1,(n_grid**2,))84Xgrid[:,1] = np.reshape(x2,(n_grid**2,))85Xg = np.random.randn(n_grid**2,n_features)86Xg[:,0] = Xgrid[:,0]87Xg[:,1] = Xgrid[:,1]8889# estimate probabilities for grid data points90#preds = [0]*nmodels # iniitialize list91for i in range(nmodels):92pred = models[i].predict_proba(Xg)[:,1]93fig,ax = plt.subplots()94ax.contourf(X1,X2,np.reshape(pred,(n_grid,n_grid)),cmap=cm.coolwarm)95ax.plot(x[y==0,0],x[y==0,1],"bo", markersize = 5)96ax.plot(x[y==1,0],x[y==1,1],"ro", markersize = 5)97nnz = np.sum(models[i].coef_ != 0)98ax.set_title('method {}, N={}, D={}, nnz {}'.format(names[i], n_samples, n_features, nnz))99name = '{}-N{}-D{}.pdf'.format(names[i], n_samples, n_features)100save_fig(name)101plt.show()102103ndims = [100]104ndata = [100, 200, 500]105for n_samples in ndata:106for n_features in ndims:107run_demo(n_samples, n_features)108109