Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupport News AboutSign UpSign In
| Download

All published worksheets from http://sagenb.org

Views: 168703
Image: ubuntu2004
# Back-Propagation Neural Networks # # Written in Python. See http://www.python.org/ # Placed in the public domain. # Neil Schemenauer <[email protected]> import math import random import string random.seed(0) # calculate a random number where: a <= rand < b def rand(a, b): return (b-a)*random.random() + a # Make a matrix (we could use NumPy to speed this up) def makeMatrix(I, J, fill=0.0): m = [] for i in range(I): m.append([fill]*J) return m # our sigmoid function, tanh is a little nicer than the standard 1/(1+e^-x) def sigmoid(x): return math.tanh(x) # derivative of our sigmoid function, in terms of the output (i.e. y) def dsigmoid(y): return 1.0 - y**2 class NN: def __init__(self, ni, nh, no): # number of input, hidden, and output nodes self.ni = ni + 1 # +1 for bias node self.nh = nh self.no = no # activations for nodes self.ai = [1.0]*self.ni self.ah = [1.0]*self.nh self.ao = [1.0]*self.no # create weights self.wi = makeMatrix(self.ni, self.nh) self.wo = makeMatrix(self.nh, self.no) # set them to random vaules for i in range(self.ni): for j in range(self.nh): self.wi[i][j] = rand(-0.2, 0.2) for j in range(self.nh): for k in range(self.no): self.wo[j][k] = rand(-2.0, 2.0) # last change in weights for momentum self.ci = makeMatrix(self.ni, self.nh) self.co = makeMatrix(self.nh, self.no) def update(self, inputs): if len(inputs) != self.ni-1: raise ValueError, 'wrong number of inputs' # input activations for i in range(self.ni-1): #self.ai[i] = sigmoid(inputs[i]) self.ai[i] = inputs[i] # hidden activations for j in range(self.nh): sum = 0.0 for i in range(self.ni): sum = sum + self.ai[i] * self.wi[i][j] self.ah[j] = sigmoid(sum) # output activations for k in range(self.no): sum = 0.0 for j in range(self.nh): sum = sum + self.ah[j] * self.wo[j][k] self.ao[k] = sigmoid(sum) return self.ao[:] def backPropagate(self, targets, N, M): if len(targets) != self.no: raise ValueError, 'wrong number of target values' # calculate error terms for output output_deltas = [0.0] * self.no for k in range(self.no): error = targets[k]-self.ao[k] output_deltas[k] = dsigmoid(self.ao[k]) * error # calculate error terms for hidden hidden_deltas = [0.0] * self.nh for j in range(self.nh): error = 0.0 for k in range(self.no): error = error + output_deltas[k]*self.wo[j][k] hidden_deltas[j] = dsigmoid(self.ah[j]) * error # update output weights for j in range(self.nh): for k in range(self.no): change = output_deltas[k]*self.ah[j] self.wo[j][k] = self.wo[j][k] + N*change + M*self.co[j][k] self.co[j][k] = change #print N*change, M*self.co[j][k] # update input weights for i in range(self.ni): for j in range(self.nh): change = hidden_deltas[j]*self.ai[i] self.wi[i][j] = self.wi[i][j] + N*change + M*self.ci[i][j] self.ci[i][j] = change # calculate error error = 0.0 for k in range(len(targets)): error = error + 0.5*(targets[k]-self.ao[k])**2 return error def test(self, patterns): for p in patterns: print p[0], '->', self.update(p[0]) def weights(self): print 'Input weights:' for i in range(self.ni): print self.wi[i] print print 'Output weights:' for j in range(self.nh): print self.wo[j] def train(self, patterns, iterations=1000, N=0.5, M=0.1): # N: learning rate # M: momentum factor for i in xrange(iterations): error = 0.0 for p in patterns: inputs = p[0] targets = p[1] self.update(inputs) error = error + self.backPropagate(targets, N, M) if i % 100 == 0: print 'error %-14f' % error def demo(): # Teach network XOR function pat = [ [[0,0], [0]], [[0,1], [1]], [[1,0], [1]], [[1,1], [0]] ] # create a network with two input, two hidden, and one output nodes n = NN(2, 2, 1) # train it with some patterns n.train(pat) # test it n.test(pat)
demo()
error 0.942497 error 0.042867 error 0.003480 error 0.001642 error 0.001062 error 0.000782 error 0.000625 error 0.000527 error 0.000442 error 0.000381 [0, 0] -> [0.0042410815506258902] [0, 1] -> [0.98215080294107482] [1, 0] -> [0.98201293886181207] [1, 1] -> [-0.0011469114721422528]
#sqrt test """ pat=[ [[1],[1]], [[2],[1.4142135623730950488016887242097]], [[3],[1.4142135623730950488016887242097]], [[4],[2]], [[5],[2.2360679774997896964091736687313]], [[6],[2.4494897427831780981972840747059]], [[7],[2.6457513110645905905016157536393]], [[8],[2.8284271247461900976033774484194]], [[9],[3]], [[10],[3.1622776601683793319988935444327]], [[11],[3.3166247903553998491149327366707]], ] """ #print pat pat=[[[i], [RR(sqrt(i))]] for i in [1 .. 11]] #print pat #normalize patterns mini=pat[0][0][0] maxi=pat[0][0][0] mino=pat[0][1][0] maxo=pat[0][1][0] for l in pat: for k in l[0]: if mini>k: mini=k if maxi<k: maxi=k for k in l[1]: if mino>k: mino=k if maxo<k: maxo=k i=0 for l in pat: j=0 for k in l[0]: pat[i][0][j]=(k-mini)/(maxi-mini) j=j+1 j=0 for k in l[1]: pat[i][1][j]=(k-mino)/(maxo-mino) j=j+1 i=i+1 n = NN(1, 11, 1) # train it with some patterns n.train(pat) # test it #n.test(pat) print "Testing patterns:" for s in pat: xnorm=s[0] res=n.update(xnorm) xdenorm=[i*(maxi-mini)+mini for i in xnorm] #denormalize result resdenorm=[i*(maxo-mino)+mino for i in res] print "x=",xdenorm,", result=",resdenorm print "Testing some data:" test=[[1],[3],[9],[4]] for x in test: #normalize input x xnorm=[(i-mini)/(maxi-mini) for i in x] res=n.update(xnorm) #denormalize result resdenorm=[i*(maxo-mino)+mino for i in res] print "x=",x,", result=",resdenorm
error 0.712405 error 0.008678 error 0.004657 error 0.004769 error 0.007354 error 0.009314 error 0.009836 error 0.009616 error 0.009165 error 0.008686 Testing patterns: x= [1] , result= [1.16200475699796] x= [2] , result= [1.41294412509937] x= [3] , result= [1.67271513781252] x= [4] , result= [1.93712460638543] x= [5] , result= [2.19922651903816] x= [6] , result= [2.44906945843978] x= [7] , result= [2.67482027156803] x= [8] , result= [2.86558502953640] x= [9] , result= [3.01493195157074] x= [10] , result= [3.12299301434288] x= [11] , result= [3.19568420320399] Testing some data: x= [1] , result= [1.16200475699796] x= [3] , result= [1.67271513781252] x= [9] , result= [3.01493195157074] x= [4] , result= [1.93712460638543]
#summary test low=-20 high=40 pat=[] for i in [1 .. 1000]: a=random.random()*(high-low)+low b=random.random()*(high-low)+low pat.append([[a,b],[a+b]]) #normalize patterns mini=pat[0][0][0] maxi=pat[0][0][0] mino=pat[0][1][0] maxo=pat[0][1][0] for l in pat: for k in l[0]: if mini>k: mini=k if maxi<k: maxi=k for k in l[1]: if mino>k: mino=k if maxo<k: maxo=k i=0 for l in pat: j=0 for k in l[0]: pat[i][0][j]=(k-mini)/(maxi-mini) j=j+1 j=0 for k in l[1]: pat[i][1][j]=(k-mino)/(maxo-mino) j=j+1 i=i+1 #print pat n = NN(2, 12, 1) # train it with some patterns n.train(pat) print "Testing patterns:" for s in pat: xnorm=s[0] res=n.update(xnorm) xdenorm=[i*(maxi-mini)+mini for i in xnorm] #denormalize result resdenorm=[i*(maxo-mino)+mino for i in res] print "x=",xdenorm,", result=",resdenorm print "Testing some data:" test=[[1,10],[-10,10],[0,0],[20,20]] for x in test: #normalize input x xnorm=[(i-mini)/(maxi-mini) for i in x] res=n.update(xnorm) #denormalize result resdenorm=[i*(maxo-mino)+mino for i in res] print "x=",x,", result=",resdenorm
WARNING: Output truncated!
error 147.504931 error 0.027913 error 0.017318 error 0.013152 error 0.010924 error 0.009531 error 0.008537 error 0.007729 error 0.006995 error 0.006293 Testing patterns: x= [1.4184470027051503, 32.912319283195117] , result= [34.031814794433913] x= [24.152696112411984, 22.986828592371307] , result= [47.118436042770256] x= [0.11032775827911934, -12.891350476788695] , result= [-12.727249651045042] x= [37.767428866384307, 31.276638137441093] , result= [68.437143717169391] x= [4.5320794463547784, 31.793091414169293] , result= [36.031130859864888] x= [33.953026901924474, 0.54841740189920074] , result= [34.203432848979958] x= [10.093689546823022, -0.0926095844217798] , result= [10.121879006712099] x= [21.70945084597788, 34.730038811030525] , result= [56.77745681316631] x= [39.072646233349687, 24.626744488840782] , result= [63.891761478143664] x= [-1.6854587638960226, 32.829597405265531] , result= [30.864151029893591] x= [39.557177742674909, 0.79156982463526404] , result= [40.106897828208162] x= [36.922741146954863, 10.692784327041437] , result= [47.616561841406664] x= [37.878126536354955, 39.751359405949088] , result= [73.846474374036461] x= [28.776525749733789, 21.006222951361064] , result= [49.87803897614306] x= [-10.759131842413751, -19.704963006008249] , result= [-30.567140825193018] x= [15.728251025401658, 22.267594328983218] , result= [37.721589283941569] x= [36.132282710740611, 11.027194011277221] , result= [47.141256656387242] x= [21.810796162165229, 18.841358288260107] , result= [40.422297048218738] x= [-7.7047925014260983, 18.658005566804814] , result= [11.068330901668851] x= [38.903272679501214, -13.328902602190105] , result= [25.367675955425739] x= [21.312594591939284, 16.858307049559325] , result= [37.899872448901363] x= [2.5512834274964078, 27.600865231165766] , result= [29.886907798659422] x= [-19.370848484904247, 33.544697327390253] , result= [14.233249005490634] x= [29.041837180762013, 8.8422898881228278] , result= [37.608521417658494] x= [-13.511650707137022, 7.1577133398180557] , result= [-6.201562196183108] x= [15.055173946923215, -4.7669912875254639] , result= [10.402115992425038] x= [9.1918878907357424, 26.543725831232518] , result= [35.442726833029496] x= [35.363907736113845, 13.698701658083888] , result= [49.127291613029897] x= [29.634507102374936, -15.32400722182394] , result= [14.348030002191315] x= [31.382082778807924, 35.248873927865262] , result= [66.472931111570475] x= [-9.9199174229287301, 29.649241705184348] , result= [19.668590684432125] x= [30.973970219559288, 32.719532099545773] , result= [63.87918649626198] x= [11.02837118904354, 16.495254633120606] , result= [27.300643566197444] x= [-7.5150052734380566, 22.487892963142269] , result= [15.016654424684432] x= [4.3010382044885596, -18.729854857576672] , result= [-14.402892571844195] x= [-11.943973189744693, 3.2930818966038444] , result= [-8.5268110606897096] x= [33.110788363648666, 13.89653759613017] , result= [46.983005789475044] x= [34.97542204337482, 35.769030661438578] , result= [69.707276154192144] x= [-14.792300472830011, 15.292924823696136] , result= [0.6952265054922151] x= [0.0716822891364437, 10.407707336582114] , result= [10.598661927120958] x= [7.3314881499585631, 8.7965963572276777] , result= [16.14550933092648] x= [-13.891651510287, 29.989605027809841] , result= [16.118072107319023] x= [9.416797705348813, 18.699253377656998] , result= [27.883020937655346] x= [8.3607251638056113, -9.138897405667322] , result= [-0.6042287326511584] x= [12.460035099531193, -10.427615656351922] , result= [2.2078751987375256] x= [31.130755368855002, 29.896241538288688] , result= [61.363716123060058] x= [-11.381673310583411, -15.869362818114048] , result= [-27.354769893180837] x= [-15.890484983573007, 3.5946414915264775] , result= [-12.216381008609027] ... x= [24.385435206067395, 31.384513942247164] , result= [56.094880802240269] x= [28.755733973876129, -12.211408651143607] , result= [16.539226317092776] x= [38.038365175823571, 8.6393936829393638] , result= [46.639020782467739] x= [37.066271748964425, -9.7512495706082625] , result= [27.079603762851363] x= [36.45482730503366, -8.4221897670940891] , result= [27.786407600803713] x= [32.163445770033661, -3.7591464060803439] , result= [28.157019583964995] x= [30.457038062621759, -10.171077435947304] , result= [20.199629081213011] x= [31.285481982659061, 19.681876220534686] , result= [51.116313697639669] x= [-12.716069155367149, 36.211672238163516] , result= [23.345782292331691] x= [6.0413741925097533, -8.9784562068158316] , result= [-2.7726724808446406] x= [30.62015679726046, -16.736089356593851] , result= [13.928420246321906] x= [-17.706000192608816, 24.234569070311178] , result= [6.7013394443791867] x= [-14.607730452928319, 24.784686888465483] , result= [10.305485778658998] x= [-14.331960879390452, 36.379975642836371] , result= [21.930796231843019] x= [-11.75948002317088, 13.776630853055305] , result= [2.2091351149590395] x= [18.765315723017491, 1.9724278598446112] , result= [20.650986146263129] x= [-14.251866321265718, 9.2975040341677726] , result= [-4.7877735623646913] x= [0.13919900661439399, 33.456173772635509] , result= [33.297720092977215] x= [-7.8100808230486507, -6.5732641085280878] , result= [-14.346070089211064] x= [15.831552700739731, 6.639399351051626] , result= [22.347772130372825] x= [-2.78122212196212, -11.579503677024483] , result= [-14.328405009129398] x= [30.388755049767067, -14.267095737578783] , result= [16.123480250642416] x= [34.237923110345676, 36.426568798992079] , result= [69.648841409209624] x= [-12.184864948124179, 35.478217462233715] , result= [23.148524613958052] x= [-4.0573853512292697, 27.548184250164834] , result= [23.345726421148477] x= [-0.52029736229286883, 6.7356595339215701] , result= [6.3816048619298371] x= [22.629563511624436, 12.774361714738998] , result= [35.112843262440762] x= [25.275217931125919, 3.7784977308741894] , result= [28.802452512989788] x= [-16.261632656181789, 0.60765077257279287] , result= [-15.624064032515449] x= [33.078101090295092, 28.139722942216121] , result= [61.549805218962788] x= [35.570933711896473, 15.351481791097154] , result= [51.0704924399124] x= [-0.30736488155482178, 8.9028618967645272] , result= [8.7390032217535278] x= [9.5914232568164657, -5.2993075324822811] , result= [4.4636852216636598] x= [32.276041678136906, -16.455757910816764] , result= [15.826153325500243] x= [-0.10585270716882178, 38.688600103464438] , result= [38.302437984593993] x= [13.620633606351799, 36.629595812350502] , result= [50.353057179738542] x= [-19.627046529122541, 28.698770979209893] , result= [9.2169585272249748] x= [17.531431984764307, 31.384337458568137] , result= [48.965009034009796] x= [24.076754849169557, 39.991783753346951] , result= [64.213533116197041] x= [0.26475594049011164, 24.639688337223671] , result= [24.730232263712686] x= [30.677141638082457, 21.597015738907597] , result= [52.479485911870199] x= [27.476026440960847, 29.149759996824979] , result= [56.973732853748416] x= [-3.6929092698314676, -3.0060243476293458] , result= [-6.5605153499226425] x= [4.2686029636842129, -18.57353382900062] , result= [-14.277106239513042] x= [34.303003455235626, 28.664183219355177] , result= [63.210687697324147] x= [-12.090493825014921, 29.116642460486485] , result= [17.026343906785122] x= [-0.226451025385078, 16.40876702467305] , result= [16.200625685052934] x= [9.3518655600477949, 6.681141856424901] , result= [16.05156841707101] x= [16.540689762521634, 34.442279259215155] , result= [51.121183150789761] x= [4.8462492857142188, -18.40033845661025] , result= [-13.515226218750168] x= [31.082841982723036, 37.114041258821871] , result= [67.762717990307365] x= [28.430608724789163, -10.599643255059721] , result= [17.799488364176305] x= [5.2707370510428611, 25.412943294015331] , result= [30.413231224001919] x= [31.816662752399576, 32.191761798423258] , result= [64.167816188919431] x= [35.418197155386338, 18.496825894102468] , result= [54.18799506820973] Testing some data: x= [1, 10] , result= [11.11133917733892] x= [-10, 10] , result= [0.18973340082840195] x= [0, 0] , result= [0.18194708816279359] x= [20, 20] , result= [39.756772861907166]