Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/thinkbayes2
Path: blob/master/scripts/species.py
1901 views
1
"""This file contains code used in "Think Bayes",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2012 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
from __future__ import print_function
9
10
import matplotlib.pyplot as pyplot
11
import thinkplot
12
import numpy
13
14
import csv
15
import random
16
import shelve
17
import sys
18
import time
19
20
import thinkbayes2
21
22
import warnings
23
24
warnings.simplefilter('error', RuntimeWarning)
25
26
27
FORMATS = ['pdf', 'eps', 'png']
28
29
30
class Locker(object):
31
"""Encapsulates a shelf for storing key-value pairs."""
32
33
def __init__(self, shelf_file):
34
self.shelf = shelve.open(shelf_file)
35
36
def Close(self):
37
"""Closes the shelf.
38
"""
39
self.shelf.close()
40
41
def Add(self, key, value):
42
"""Adds a key-value pair."""
43
self.shelf[str(key)] = value
44
45
def Lookup(self, key):
46
"""Looks up a key."""
47
return self.shelf.get(str(key))
48
49
def Keys(self):
50
"""Returns an iterator of keys."""
51
return self.shelf.iterkeys()
52
53
def Read(self):
54
"""Returns the contents of the shelf as a map."""
55
return dict(self.shelf)
56
57
58
class Subject(object):
59
"""Represents a subject from the belly button study."""
60
61
def __init__(self, code):
62
"""
63
code: string ID
64
species: sequence of (int count, string species) pairs
65
"""
66
self.code = code
67
self.species = []
68
self.suite = None
69
self.num_reads = None
70
self.num_species = None
71
self.total_reads = None
72
self.total_species = None
73
self.prev_unseen = None
74
self.pmf_n = None
75
self.pmf_q = None
76
self.pmf_l = None
77
78
def Add(self, species, count):
79
"""Add a species-count pair.
80
81
It is up to the caller to ensure that species names are unique.
82
83
species: string species/genus name
84
count: int number of individuals
85
"""
86
self.species.append((count, species))
87
88
def Done(self, reverse=False, clean_param=0):
89
"""Called when we are done adding species counts.
90
91
reverse: which order to sort in
92
"""
93
if clean_param:
94
self.Clean(clean_param)
95
96
self.species.sort(reverse=reverse)
97
counts = self.GetCounts()
98
self.num_species = len(counts)
99
self.num_reads = sum(counts)
100
101
def Clean(self, clean_param=50):
102
"""Identifies and removes bogus data.
103
104
clean_param: parameter that controls the number of legit species
105
"""
106
def prob_bogus(k, r):
107
"""Compute the probability that a species is bogus."""
108
q = clean_param / r
109
p = (1-q) ** k
110
return p
111
112
print(self.code, clean_param)
113
114
counts = self.GetCounts()
115
r = 1.0 * sum(counts)
116
117
species_seq = []
118
for k, species in sorted(self.species):
119
120
if random.random() < prob_bogus(k, r):
121
continue
122
species_seq.append((k, species))
123
self.species = species_seq
124
125
def GetM(self):
126
"""Gets number of observed species."""
127
return len(self.species)
128
129
def GetCounts(self):
130
"""Gets the list of species counts
131
132
Should be in increasing order, if Sort() has been invoked.
133
"""
134
return [count for count, _ in self.species]
135
136
def MakeCdf(self):
137
"""Makes a CDF of total prevalence vs rank."""
138
counts = self.GetCounts()
139
counts.sort(reverse=True)
140
cdf = thinkbayes2.Cdf(dict(enumerate(counts)))
141
return cdf
142
143
def GetNames(self):
144
"""Gets the names of the seen species."""
145
return [name for _, name in self.species]
146
147
def PrintCounts(self):
148
"""Prints the counts and species names."""
149
for count, name in reversed(self.species):
150
print(count, name)
151
152
def GetSpecies(self, index):
153
"""Gets the count and name of the indicated species.
154
155
Returns: count-species pair
156
"""
157
return self.species[index]
158
159
def GetCdf(self):
160
"""Returns cumulative prevalence vs number of species.
161
"""
162
counts = self.GetCounts()
163
items = enumerate(counts)
164
cdf = thinkbayes2.Cdf(items)
165
return cdf
166
167
def GetPrevalences(self):
168
"""Returns a sequence of prevalences (normalized counts).
169
"""
170
counts = self.GetCounts()
171
total = sum(counts)
172
prevalences = numpy.array(counts, dtype=numpy.float) / total
173
return prevalences
174
175
def Process(self, low=None, high=500, conc=1, iters=100):
176
"""Computes the posterior distribution of n and the prevalences.
177
178
Sets attribute: self.suite
179
180
low: minimum number of species
181
high: maximum number of species
182
conc: concentration parameter
183
iters: number of iterations to use in the estimator
184
"""
185
counts = self.GetCounts()
186
m = len(counts)
187
if low is None:
188
low = max(m, 2)
189
ns = range(low, high+1)
190
191
#start = time.time()
192
self.suite = Species5(ns, conc=conc, iters=iters)
193
self.suite.Update(counts)
194
#end = time.time()
195
196
#print 'Processing time' end-start
197
198
def MakePrediction(self, num_sims=100):
199
"""Make predictions for the given subject.
200
201
Precondition: Process has run
202
203
num_sims: how many simulations to run for predictions
204
205
Adds attributes
206
pmf_l: predictive distribution of additional species
207
"""
208
add_reads = self.total_reads - self.num_reads
209
curves = self.RunSimulations(num_sims, add_reads)
210
self.pmf_l = self.MakePredictive(curves)
211
212
def MakeQuickPrediction(self, num_sims=100):
213
"""Make predictions for the given subject.
214
215
Precondition: Process has run
216
217
num_sims: how many simulations to run for predictions
218
219
Adds attribute:
220
pmf_l: predictive distribution of additional species
221
"""
222
add_reads = self.total_reads - self.num_reads
223
pmf = thinkbayes2.Pmf()
224
_, seen = self.GetSeenSpecies()
225
226
for _ in range(num_sims):
227
_, observations = self.GenerateObservations(add_reads)
228
all_seen = seen.union(observations)
229
l = len(all_seen) - len(seen)
230
pmf.Incr(l)
231
232
pmf.Normalize()
233
self.pmf_l = pmf
234
235
def DistL(self):
236
"""Returns the distribution of additional species, l.
237
"""
238
return self.pmf_l
239
240
def MakeFigures(self):
241
"""Makes figures showing distribution of n and the prevalences."""
242
self.PlotDistN()
243
self.PlotPrevalences()
244
245
def PlotDistN(self):
246
"""Plots distribution of n."""
247
pmf = self.suite.DistN()
248
print('90% CI for N:', pmf.CredibleInterval(90))
249
pmf.label = self.code
250
251
thinkplot.Clf()
252
thinkplot.PrePlot(num=1)
253
254
thinkplot.Pmf(pmf)
255
256
root = 'species-ndist-%s' % self.code
257
thinkplot.Save(root=root,
258
xlabel='Number of species',
259
ylabel='Prob',
260
formats=FORMATS,
261
)
262
263
def PlotPrevalences(self, num=5):
264
"""Plots dist of prevalence for several species.
265
266
num: how many species (starting with the highest prevalence)
267
"""
268
thinkplot.Clf()
269
thinkplot.PrePlot(num=5)
270
271
for rank in range(1, num+1):
272
self.PlotPrevalence(rank)
273
274
root = 'species-prev-%s' % self.code
275
thinkplot.Save(root=root,
276
xlabel='Prevalence',
277
ylabel='Prob',
278
formats=FORMATS,
279
axis=[0, 0.3, 0, 1],
280
)
281
282
def PlotPrevalence(self, rank=1, cdf_flag=True):
283
"""Plots dist of prevalence for one species.
284
285
rank: rank order of the species to plot.
286
cdf_flag: whether to plot the CDF
287
"""
288
# convert rank to index
289
index = self.GetM() - rank
290
291
_, mix = self.suite.DistOfPrevalence(index)
292
count, _ = self.GetSpecies(index)
293
mix.label = '%d (%d)' % (rank, count)
294
295
print('90%% CI for prevalence of species %d:' % rank, end=' ')
296
print(mix.CredibleInterval(90))
297
298
if cdf_flag:
299
cdf = mix.MakeCdf()
300
thinkplot.Cdf(cdf)
301
else:
302
thinkplot.Pmf(mix)
303
304
def PlotMixture(self, rank=1):
305
"""Plots dist of prevalence for all n, and the mix.
306
307
rank: rank order of the species to plot
308
"""
309
# convert rank to index
310
index = self.GetM() - rank
311
312
print(self.GetSpecies(index))
313
print(self.GetCounts()[index])
314
315
metapmf, mix = self.suite.DistOfPrevalence(index)
316
317
thinkplot.Clf()
318
for pmf in metapmf.Values():
319
thinkplot.Pmf(pmf, color='blue', alpha=0.2, linewidth=0.5)
320
321
thinkplot.Pmf(mix, color='blue', alpha=0.9, linewidth=2)
322
323
root = 'species-mix-%s' % self.code
324
thinkplot.Save(root=root,
325
xlabel='Prevalence',
326
ylabel='Prob',
327
formats=FORMATS,
328
axis=[0, 0.3, 0, 0.3],
329
legend=False)
330
331
def GetSeenSpecies(self):
332
"""Makes a set of the names of seen species.
333
334
Returns: number of species, set of string species names
335
"""
336
names = self.GetNames()
337
m = len(names)
338
seen = set(SpeciesGenerator(names, m))
339
return m, seen
340
341
def GenerateObservations(self, num_reads):
342
"""Generates a series of random observations.
343
344
num_reads: number of reads to generate
345
346
Returns: number of species, sequence of string species names
347
"""
348
n, prevalences = self.suite.SamplePosterior()
349
350
names = self.GetNames()
351
name_iter = SpeciesGenerator(names, n)
352
353
items = zip(name_iter, prevalences)
354
355
cdf = thinkbayes2.Cdf(dict(items))
356
observations = cdf.Sample(num_reads)
357
358
#for ob in observations:
359
# print ob
360
361
return n, observations
362
363
def Resample(self, num_reads):
364
"""Choose a random subset of the data (without replacement).
365
366
num_reads: number of reads in the subset
367
"""
368
t = []
369
for count, species in self.species:
370
t.extend([species]*count)
371
372
random.shuffle(t)
373
reads = t[:num_reads]
374
375
subject = Subject(self.code)
376
hist = thinkbayes2.Hist(reads)
377
for species, count in hist.Items():
378
subject.Add(species, count)
379
380
subject.Done()
381
return subject
382
383
def Match(self, match):
384
"""Match up a rarefied subject with a complete subject.
385
386
match: complete Subject
387
388
Assigns attributes:
389
total_reads:
390
total_species:
391
prev_unseen:
392
"""
393
self.total_reads = match.num_reads
394
self.total_species = match.num_species
395
396
# compute the prevalence of unseen species (at least approximately,
397
# based on all species counts in match
398
_, seen = self.GetSeenSpecies()
399
400
seen_total = 0.0
401
unseen_total = 0.0
402
for count, species in match.species:
403
if species in seen:
404
seen_total += count
405
else:
406
unseen_total += count
407
408
self.prev_unseen = unseen_total / (seen_total + unseen_total)
409
410
def RunSimulation(self, num_reads, frac_flag=False, jitter=0.01):
411
"""Simulates additional observations and returns a rarefaction curve.
412
413
k is the number of additional observations
414
num_new is the number of new species seen
415
416
num_reads: how many new reads to simulate
417
frac_flag: whether to convert to fraction of species seen
418
jitter: size of jitter added if frac_flag is true
419
420
Returns: list of (k, num_new) pairs
421
"""
422
m, seen = self.GetSeenSpecies()
423
n, observations = self.GenerateObservations(num_reads)
424
425
curve = []
426
for i, obs in enumerate(observations):
427
seen.add(obs)
428
429
if frac_flag:
430
frac_seen = len(seen) / float(n)
431
frac_seen += random.uniform(-jitter, jitter)
432
curve.append((i+1, frac_seen))
433
else:
434
num_new = len(seen) - m
435
curve.append((i+1, num_new))
436
437
return curve
438
439
def RunSimulations(self, num_sims, num_reads, frac_flag=False):
440
"""Runs simulations and returns a list of curves.
441
442
Each curve is a sequence of (k, num_new) pairs.
443
444
num_sims: how many simulations to run
445
num_reads: how many samples to generate in each simulation
446
frac_flag: whether to convert num_new to fraction of total
447
"""
448
curves = [self.RunSimulation(num_reads, frac_flag)
449
for _ in range(num_sims)]
450
return curves
451
452
def MakePredictive(self, curves):
453
"""Makes a predictive distribution of additional species.
454
455
curves: list of (k, num_new) curves
456
457
Returns: Pmf of num_new
458
"""
459
pred = thinkbayes2.Pmf(label=self.code)
460
for curve in curves:
461
_, last_num_new = curve[-1]
462
pred.Incr(last_num_new)
463
pred.Normalize()
464
return pred
465
466
467
def MakeConditionals(curves, ks):
468
"""Makes Cdfs of the distribution of num_new conditioned on k.
469
470
curves: list of (k, num_new) curves
471
ks: list of values of k
472
473
Returns: list of Cdfs
474
"""
475
joint = MakeJointPredictive(curves)
476
477
cdfs = []
478
for k in ks:
479
pmf = joint.Conditional(1, 0, k)
480
pmf.label = 'k=%d' % k
481
cdf = pmf.MakeCdf()
482
cdfs.append(cdf)
483
print('90%% credible interval for %d' % k, end=' ')
484
print(cdf.CredibleInterval(90))
485
return cdfs
486
487
488
def MakeJointPredictive(curves):
489
"""Makes a joint distribution of k and num_new.
490
491
curves: list of (k, num_new) curves
492
493
Returns: joint Pmf of (k, num_new)
494
"""
495
joint = thinkbayes2.Joint()
496
for curve in curves:
497
for k, num_new in curve:
498
joint.Incr((k, num_new))
499
joint.Normalize()
500
return joint
501
502
503
def MakeFracCdfs(curves, ks):
504
"""Makes Cdfs of the fraction of species seen.
505
506
curves: list of (k, num_new) curves
507
508
Returns: list of Cdfs
509
"""
510
d = {}
511
for curve in curves:
512
for k, frac in curve:
513
if k in ks:
514
d.setdefault(k, []).append(frac)
515
516
cdfs = {}
517
for k, fracs in d.items():
518
cdf = thinkbayes2.Cdf(fracs)
519
cdfs[k] = cdf
520
521
return cdfs
522
523
def SpeciesGenerator(names, num):
524
"""Generates a series of names, starting with the given names.
525
526
Additional names are 'unseen' plus a serial number.
527
528
names: list of strings
529
num: total number of species names to generate
530
531
Returns: string iterator
532
"""
533
i = 0
534
for name in names:
535
yield name
536
i += 1
537
538
while i < num:
539
yield 'unseen-%d' % i
540
i += 1
541
542
543
def ReadRarefactedData(filename='journal.pone.0047712.s001.csv',
544
clean_param=0):
545
"""Reads a data file and returns a list of Subjects.
546
547
Data from http://www.plosone.org/article/
548
info%3Adoi%2F10.1371%2Fjournal.pone.0047712#s4
549
550
filename: string filename to read
551
clean_param: parameter passed to Clean
552
553
Returns: map from code to Subject
554
"""
555
fp = open(filename)
556
reader = csv.reader(fp)
557
#_ = reader.next()
558
_ = next(reader)
559
560
subject = Subject('')
561
subject_map = {}
562
563
i = 0
564
for t in reader:
565
code = t[0]
566
if code != subject.code:
567
# start a new subject
568
subject = Subject(code)
569
subject_map[code] = subject
570
571
# append a number to the species names so they're unique
572
species = t[1]
573
species = '%s-%d' % (species, i)
574
i += 1
575
576
count = int(t[2])
577
subject.Add(species, count)
578
579
for code, subject in subject_map.items():
580
subject.Done(clean_param=clean_param)
581
582
return subject_map
583
584
585
def ReadCompleteDataset(filename='BBB_data_from_Rob.csv', clean_param=0):
586
"""Reads a data file and returns a list of Subjects.
587
588
Data from personal correspondence with Rob Dunn, received 2-7-13.
589
Converted from xlsx to csv.
590
591
filename: string filename to read
592
clean_param: parameter passed to Clean
593
594
Returns: map from code to Subject
595
"""
596
fp = open(filename)
597
reader = csv.reader(fp)
598
header = next(reader)
599
header = next(reader)
600
601
subject_codes = header[1:-1]
602
subject_codes = ['B'+code for code in subject_codes]
603
604
# create the subject map
605
uber_subject = Subject('uber')
606
subject_map = {}
607
for code in subject_codes:
608
subject_map[code] = Subject(code)
609
610
# read lines
611
i = 0
612
for t in reader:
613
otu_code = t[0]
614
if otu_code == '':
615
continue
616
617
# pull out a species name and give it a number
618
otu_names = t[-1]
619
taxons = otu_names.split(';')
620
species = taxons[-1]
621
species = '%s-%d' % (species, i)
622
i += 1
623
624
counts = [int(x) for x in t[1:-1]]
625
626
# print otu_code, species
627
628
for code, count in zip(subject_codes, counts):
629
if count > 0:
630
subject_map[code].Add(species, count)
631
uber_subject.Add(species, count)
632
633
uber_subject.Done(clean_param=clean_param)
634
for code, subject in subject_map.items():
635
subject.Done(clean_param=clean_param)
636
637
return subject_map, uber_subject
638
639
640
def JoinSubjects():
641
"""Reads both datasets and computes their inner join.
642
643
Finds all subjects that appear in both datasets.
644
645
For subjects in the rarefacted dataset, looks up the total
646
number of reads and stores it as total_reads. num_reads
647
is normally 400.
648
649
Returns: map from code to Subject
650
"""
651
652
# read the rarefacted dataset
653
sampled_subjects = ReadRarefactedData()
654
655
# read the complete dataset
656
all_subjects, _ = ReadCompleteDataset()
657
658
for code, subject in sampled_subjects.items():
659
if code in all_subjects:
660
match = all_subjects[code]
661
subject.Match(match)
662
663
return sampled_subjects
664
665
666
def JitterCurve(curve, dx=0.2, dy=0.3):
667
"""Adds random noise to the pairs in a curve.
668
669
dx and dy control the amplitude of the noise in each dimension.
670
"""
671
curve = [(x+random.uniform(-dx, dx),
672
y+random.uniform(-dy, dy)) for x, y in curve]
673
return curve
674
675
676
def OffsetCurve(curve, i, n, dx=0.3, dy=0.3):
677
"""Adds random noise to the pairs in a curve.
678
679
i is the index of the curve
680
n is the number of curves
681
682
dx and dy control the amplitude of the noise in each dimension.
683
"""
684
xoff = -dx + 2 * dx * i / (n-1)
685
yoff = -dy + 2 * dy * i / (n-1)
686
curve = [(x+xoff, y+yoff) for x, y in curve]
687
return curve
688
689
690
def PlotCurves(curves, root='species-rare'):
691
"""Plots a set of curves.
692
693
curves is a list of curves; each curve is a list of (x, y) pairs.
694
"""
695
thinkplot.Clf()
696
color = '#225EA8'
697
698
n = len(curves)
699
for i, curve in enumerate(curves):
700
curve = OffsetCurve(curve, i, n)
701
xs, ys = zip(*curve)
702
thinkplot.Plot(xs, ys, color=color, alpha=0.3, linewidth=0.5)
703
704
thinkplot.Save(root=root,
705
xlabel='# samples',
706
ylabel='# species',
707
formats=FORMATS,
708
legend=False)
709
710
711
def PlotConditionals(cdfs, root='species-cond'):
712
"""Plots cdfs of num_new conditioned on k.
713
714
cdfs: list of Cdf
715
root: string filename root
716
"""
717
thinkplot.Clf()
718
thinkplot.PrePlot(num=len(cdfs))
719
720
thinkplot.Cdfs(cdfs)
721
722
thinkplot.Save(root=root,
723
xlabel='# new species',
724
ylabel='Prob',
725
formats=FORMATS)
726
727
728
def PlotFracCdfs(cdfs, root='species-frac'):
729
"""Plots CDFs of the fraction of species seen.
730
731
cdfs: map from k to CDF of fraction of species seen after k samples
732
"""
733
thinkplot.Clf()
734
color = '#225EA8'
735
736
for k, cdf in cdfs.items():
737
xs, ys = cdf.Render()
738
ys = [1-y for y in ys]
739
thinkplot.Plot(xs, ys, color=color, linewidth=1)
740
741
x = 0.9
742
y = 1 - cdf.Prob(x)
743
pyplot.text(x, y, str(k), fontsize=9, color=color,
744
horizontalalignment='center',
745
verticalalignment='center',
746
bbox=dict(facecolor='white', edgecolor='none'))
747
748
thinkplot.Save(root=root,
749
xlabel='Fraction of species seen',
750
ylabel='Probability',
751
formats=FORMATS,
752
legend=False)
753
754
755
class Species(thinkbayes2.Suite):
756
"""Represents hypotheses about the number of species."""
757
758
def __init__(self, ns, conc=1, iters=1000):
759
hypos = [thinkbayes2.Dirichlet(n, conc) for n in ns]
760
thinkbayes2.Suite.__init__(self, hypos)
761
self.iters = iters
762
763
def Update(self, data):
764
"""Updates the suite based on the data.
765
766
data: list of observed frequencies
767
"""
768
# call Update in the parent class, which calls Likelihood
769
thinkbayes2.Suite.Update(self, data)
770
771
# update the next level of the hierarchy
772
for hypo in self.Values():
773
hypo.Update(data)
774
775
def Likelihood(self, data, hypo):
776
"""Computes the likelihood of the data under this hypothesis.
777
778
hypo: Dirichlet object
779
data: list of observed frequencies
780
"""
781
dirichlet = hypo
782
783
# draw sample Likelihoods from the hypothetical Dirichlet dist
784
# and add them up
785
like = 0
786
for _ in range(self.iters):
787
like += dirichlet.Likelihood(data)
788
789
# correct for the number of ways the observed species
790
# might have been chosen from all species
791
m = len(data)
792
like *= thinkbayes2.BinomialCoef(dirichlet.n, m)
793
794
return like
795
796
def DistN(self):
797
"""Computes the distribution of n."""
798
pmf = thinkbayes2.Pmf()
799
for hypo, prob in self.Items():
800
pmf.Set(hypo.n, prob)
801
return pmf
802
803
804
class Species2(object):
805
"""Represents hypotheses about the number of species.
806
807
Combines two layers of the hierarchy into one object.
808
809
ns and probs represent the distribution of N
810
811
params represents the parameters of the Dirichlet distributions
812
"""
813
814
def __init__(self, ns, conc=1, iters=1000):
815
self.ns = ns
816
self.conc = conc
817
self.probs = numpy.ones(len(ns), dtype=numpy.float)
818
self.params = numpy.ones(self.ns[-1], dtype=numpy.float) * conc
819
self.iters = iters
820
self.num_reads = 0
821
self.m = 0
822
823
def Preload(self, data):
824
"""Change the initial parameters to fit the data better.
825
826
Just an experiment. Doesn't work.
827
"""
828
m = len(data)
829
singletons = data.count(1)
830
num = m - singletons
831
print(m, singletons, num)
832
addend = numpy.ones(num, dtype=numpy.float) * 1
833
print(len(addend))
834
print(len(self.params[singletons:m]))
835
self.params[singletons:m] += addend
836
print('Preload', num)
837
838
def Update(self, data):
839
"""Updates the distribution based on data.
840
841
data: numpy array of counts
842
"""
843
self.num_reads += sum(data)
844
845
like = numpy.zeros(len(self.ns), dtype=numpy.float)
846
for _ in range(self.iters):
847
like += self.SampleLikelihood(data)
848
849
self.probs *= like
850
self.probs /= self.probs.sum()
851
852
self.m = len(data)
853
#self.params[:self.m] += data * self.conc
854
self.params[:self.m] += data
855
856
def SampleLikelihood(self, data):
857
"""Computes the likelihood of the data for all values of n.
858
859
Draws one sample from the distribution of prevalences.
860
861
data: sequence of observed counts
862
863
Returns: numpy array of m likelihoods
864
"""
865
gammas = numpy.random.gamma(self.params)
866
867
m = len(data)
868
row = gammas[:m]
869
col = numpy.cumsum(gammas)
870
871
log_likes = []
872
for n in self.ns:
873
ps = row / col[n-1]
874
terms = numpy.log(ps) * data
875
log_like = terms.sum()
876
log_likes.append(log_like)
877
878
log_likes -= numpy.max(log_likes)
879
likes = numpy.exp(log_likes)
880
881
coefs = [thinkbayes2.BinomialCoef(n, m) for n in self.ns]
882
likes *= coefs
883
884
return likes
885
886
def DistN(self):
887
"""Computes the distribution of n.
888
889
Returns: new Pmf object
890
"""
891
pmf = thinkbayes2.Pmf(dict(zip(self.ns, self.probs)))
892
return pmf
893
894
def RandomN(self):
895
"""Returns a random value of n."""
896
return self.DistN().Random()
897
898
def DistQ(self, iters=100):
899
"""Computes the distribution of q based on distribution of n.
900
901
Returns: pmf of q
902
"""
903
cdf_n = self.DistN().MakeCdf()
904
sample_n = cdf_n.Sample(iters)
905
906
pmf = thinkbayes2.Pmf()
907
for n in sample_n:
908
q = self.RandomQ(n)
909
pmf.Incr(q)
910
911
pmf.Normalize()
912
return pmf
913
914
def RandomQ(self, n):
915
"""Returns a random value of q.
916
917
Based on n, self.num_reads and self.conc.
918
919
n: number of species
920
921
Returns: q
922
"""
923
# generate random prevalences
924
dirichlet = thinkbayes2.Dirichlet(n, conc=self.conc)
925
prevalences = dirichlet.Random()
926
927
# generate a simulated sample
928
pmf = thinkbayes2.Pmf(dict(enumerate(prevalences)))
929
cdf = pmf.MakeCdf()
930
sample = cdf.Sample(self.num_reads)
931
seen = set(sample)
932
933
# add up the prevalence of unseen species
934
q = 0
935
for species, prev in enumerate(prevalences):
936
if species not in seen:
937
q += prev
938
939
return q
940
941
def MarginalBeta(self, n, index):
942
"""Computes the conditional distribution of the indicated species.
943
944
n: conditional number of species
945
index: which species
946
947
Returns: Beta object representing a distribution of prevalence.
948
"""
949
alpha0 = self.params[:n].sum()
950
alpha = self.params[index]
951
return thinkbayes2.Beta(alpha, alpha0-alpha)
952
953
def DistOfPrevalence(self, index):
954
"""Computes the distribution of prevalence for the indicated species.
955
956
index: which species
957
958
Returns: (metapmf, mix) where metapmf is a MetaPmf and mix is a Pmf
959
"""
960
metapmf = thinkbayes2.Pmf()
961
962
for n, prob in zip(self.ns, self.probs):
963
beta = self.MarginalBeta(n, index)
964
pmf = beta.MakePmf()
965
metapmf.Set(pmf, prob)
966
967
mix = thinkbayes2.MakeMixture(metapmf)
968
return metapmf, mix
969
970
def SamplePosterior(self):
971
"""Draws random n and prevalences.
972
973
Returns: (n, prevalences)
974
"""
975
n = self.RandomN()
976
prevalences = self.SamplePrevalences(n)
977
978
#print 'Peeking at n_cheat'
979
#n = n_cheat
980
981
return n, prevalences
982
983
def SamplePrevalences(self, n):
984
"""Draws a sample of prevalences given n.
985
986
n: the number of species assumed in the conditional
987
988
Returns: numpy array of n prevalences
989
"""
990
if n == 1:
991
return [1.0]
992
993
q_desired = self.RandomQ(n)
994
q_desired = max(q_desired, 1e-6)
995
996
params = self.Unbias(n, self.m, q_desired)
997
998
gammas = numpy.random.gamma(params)
999
gammas /= gammas.sum()
1000
return gammas
1001
1002
def Unbias(self, n, m, q_desired):
1003
"""Adjusts the parameters to achieve desired prev_unseen (q).
1004
1005
n: number of species
1006
m: seen species
1007
q_desired: prevalence of unseen species
1008
"""
1009
params = self.params[:n].copy()
1010
1011
if n == m:
1012
return params
1013
1014
x = sum(params[:m])
1015
y = sum(params[m:])
1016
a = x + y
1017
#print x, y, a, x/a, y/a
1018
1019
g = q_desired * a / y
1020
f = (a - g * y) / x
1021
params[:m] *= f
1022
params[m:] *= g
1023
1024
return params
1025
1026
1027
class Species3(Species2):
1028
"""Represents hypotheses about the number of species."""
1029
1030
def Update(self, data):
1031
"""Updates the suite based on the data.
1032
1033
data: list of observations
1034
"""
1035
# sample the likelihoods and add them up
1036
like = numpy.zeros(len(self.ns), dtype=numpy.float)
1037
for _ in range(self.iters):
1038
like += self.SampleLikelihood(data)
1039
1040
self.probs *= like
1041
self.probs /= self.probs.sum()
1042
1043
m = len(data)
1044
self.params[:m] += data
1045
1046
def SampleLikelihood(self, data):
1047
"""Computes the likelihood of the data under all hypotheses.
1048
1049
data: list of observations
1050
"""
1051
# get a random sample
1052
gammas = numpy.random.gamma(self.params)
1053
1054
# row is just the first m elements of gammas
1055
m = len(data)
1056
row = gammas[:m]
1057
1058
# col is the cumulative sum of gammas
1059
col = numpy.cumsum(gammas)[self.ns[0]-1:]
1060
1061
# each row of the array is a set of ps, normalized
1062
# for each hypothetical value of n
1063
array = row / col[:, numpy.newaxis]
1064
1065
# computing the multinomial PDF under a log transform
1066
# take the log of the ps and multiply by the data
1067
terms = numpy.log(array) * data
1068
1069
# add up the rows
1070
log_likes = terms.sum(axis=1)
1071
1072
# before exponentiating, scale into a reasonable range
1073
log_likes -= numpy.max(log_likes)
1074
likes = numpy.exp(log_likes)
1075
1076
# correct for the number of ways we could see m species
1077
# out of a possible n
1078
coefs = [thinkbayes2.BinomialCoef(n, m) for n in self.ns]
1079
likes *= coefs
1080
1081
return likes
1082
1083
1084
class Species4(Species):
1085
"""Represents hypotheses about the number of species."""
1086
1087
def Update(self, data):
1088
"""Updates the suite based on the data.
1089
1090
data: list of observed frequencies
1091
"""
1092
m = len(data)
1093
1094
# loop through the species and update one at a time
1095
for i in range(m):
1096
one = numpy.zeros(i+1)
1097
one[i] = data[i]
1098
1099
# call the parent class
1100
Species.Update(self, one)
1101
1102
def Likelihood(self, data, hypo):
1103
"""Computes the likelihood of the data under this hypothesis.
1104
1105
Note: this only works correctly if we update one species at a time.
1106
1107
hypo: Dirichlet object
1108
data: list of observed frequencies
1109
"""
1110
dirichlet = hypo
1111
like = 0
1112
for _ in range(self.iters):
1113
like += dirichlet.Likelihood(data)
1114
1115
# correct for the number of unseen species the new one
1116
# could have been
1117
m = len(data)
1118
num_unseen = dirichlet.n - m + 1
1119
like *= num_unseen
1120
1121
return like
1122
1123
1124
class Species5(Species2):
1125
"""Represents hypotheses about the number of species.
1126
1127
Combines two laters of the hierarchy into one object.
1128
1129
ns and probs represent the distribution of N
1130
1131
params represents the parameters of the Dirichlet distributions
1132
"""
1133
1134
def Update(self, data):
1135
"""Updates the suite based on the data.
1136
1137
data: list of observed frequencies in increasing order
1138
"""
1139
# loop through the species and update one at a time
1140
m = len(data)
1141
for i in range(m):
1142
self.UpdateOne(i+1, data[i])
1143
self.params[i] += data[i]
1144
1145
def UpdateOne(self, i, count):
1146
"""Updates the suite based on the data.
1147
1148
Evaluates the likelihood for all values of n.
1149
1150
i: which species was observed (1..n)
1151
count: how many were observed
1152
"""
1153
# how many species have we seen so far
1154
self.m = i
1155
1156
# how many reads have we seen
1157
self.num_reads += count
1158
1159
if self.iters == 0:
1160
return
1161
1162
# sample the likelihoods and add them up
1163
likes = numpy.zeros(len(self.ns), dtype=numpy.float)
1164
for _ in range(self.iters):
1165
likes += self.SampleLikelihood(i, count)
1166
1167
# correct for the number of unseen species the new one
1168
# could have been
1169
unseen_species = [n-i+1 for n in self.ns]
1170
likes *= unseen_species
1171
1172
# multiply the priors by the likelihoods and renormalize
1173
self.probs *= likes
1174
self.probs /= self.probs.sum()
1175
1176
def SampleLikelihood(self, i, count):
1177
"""Computes the likelihood of the data under all hypotheses.
1178
1179
i: which species was observed
1180
count: how many were observed
1181
"""
1182
# get a random sample of p
1183
gammas = numpy.random.gamma(self.params)
1184
1185
# sums is the cumulative sum of p, for each value of n
1186
sums = numpy.cumsum(gammas)[self.ns[0]-1:]
1187
1188
# get p for the mth species, for each value of n
1189
ps = gammas[i-1] / sums
1190
log_likes = numpy.log(ps) * count
1191
1192
# before exponentiating, scale into a reasonable range
1193
log_likes -= numpy.max(log_likes)
1194
likes = numpy.exp(log_likes)
1195
1196
return likes
1197
1198
1199
def MakePosterior(constructor, data, ns, conc=1, iters=1000):
1200
"""Makes a suite, updates it and returns the posterior suite.
1201
1202
Prints the elapsed time.
1203
1204
data: observed species and their counts
1205
ns: sequence of hypothetical ns
1206
conc: concentration parameter
1207
iters: how many samples to draw
1208
1209
Returns: posterior suite of the given type
1210
"""
1211
suite = constructor(ns, conc=conc, iters=iters)
1212
1213
# print constructor.__name__
1214
start = time.time()
1215
suite.Update(data)
1216
end = time.time()
1217
print('Processing time', end-start)
1218
1219
return suite
1220
1221
1222
def PlotAllVersions():
1223
"""Makes a graph of posterior distributions of N."""
1224
data = [1, 2, 3]
1225
m = len(data)
1226
n = 20
1227
ns = range(m, n)
1228
1229
for constructor in [Species, Species2, Species3, Species4, Species5]:
1230
suite = MakePosterior(constructor, data, ns)
1231
pmf = suite.DistN()
1232
pmf.label = '%s' % (constructor.__name__)
1233
thinkplot.Pmf(pmf)
1234
1235
thinkplot.Save(root='species3',
1236
xlabel='Number of species',
1237
ylabel='Prob')
1238
1239
1240
def PlotMedium():
1241
"""Makes a graph of posterior distributions of N."""
1242
data = [1, 1, 1, 1, 2, 3, 5, 9]
1243
m = len(data)
1244
n = 20
1245
ns = range(m, n)
1246
1247
for constructor in [Species, Species2, Species3, Species4, Species5]:
1248
suite = MakePosterior(constructor, data, ns)
1249
pmf = suite.DistN()
1250
pmf.label = '%s' % (constructor.__name__)
1251
thinkplot.Pmf(pmf)
1252
1253
thinkplot.Show()
1254
1255
1256
def SimpleDirichletExample():
1257
"""Makes a plot showing posterior distributions for three species.
1258
1259
This is the case where we know there are exactly three species.
1260
"""
1261
thinkplot.Clf()
1262
thinkplot.PrePlot(3)
1263
1264
names = ['lions', 'tigers', 'bears']
1265
data = [3, 2, 1]
1266
1267
dirichlet = thinkbayes2.Dirichlet(3)
1268
for i in range(3):
1269
beta = dirichlet.MarginalBeta(i)
1270
print('mean', names[i], beta.Mean())
1271
1272
dirichlet.Update(data)
1273
for i in range(3):
1274
beta = dirichlet.MarginalBeta(i)
1275
print('mean', names[i], beta.Mean())
1276
1277
pmf = beta.MakePmf(label=names[i])
1278
thinkplot.Pmf(pmf)
1279
1280
thinkplot.Save(root='species1',
1281
xlabel='Prevalence',
1282
ylabel='Prob',
1283
formats=FORMATS,
1284
)
1285
1286
1287
def HierarchicalExample():
1288
"""Shows the posterior distribution of n for lions, tigers and bears.
1289
"""
1290
ns = range(3, 30)
1291
suite = Species(ns, iters=8000)
1292
1293
data = [3, 2, 1]
1294
suite.Update(data)
1295
1296
thinkplot.Clf()
1297
thinkplot.PrePlot(num=1)
1298
1299
pmf = suite.DistN()
1300
thinkplot.Pdf(pmf)
1301
thinkplot.Save(root='species2',
1302
xlabel='Number of species',
1303
ylabel='Prob',
1304
formats=FORMATS,
1305
)
1306
1307
1308
def CompareHierarchicalExample():
1309
"""Makes a graph of posterior distributions of N."""
1310
data = [3, 2, 1]
1311
m = len(data)
1312
n = 30
1313
ns = range(m, n)
1314
1315
constructors = [Species, Species5]
1316
iters = [1000, 100]
1317
1318
for constructor, iters in zip(constructors, iters):
1319
suite = MakePosterior(constructor, data, ns, iters)
1320
pmf = suite.DistN()
1321
pmf.label = '%s' % (constructor.__name__)
1322
thinkplot.Pmf(pmf)
1323
1324
thinkplot.Show()
1325
1326
1327
def ProcessSubjects(codes):
1328
"""Process subjects with the given codes and plot their posteriors.
1329
1330
code: sequence of string codes
1331
"""
1332
thinkplot.Clf()
1333
thinkplot.PrePlot(len(codes))
1334
1335
subjects = ReadRarefactedData()
1336
pmfs = []
1337
for code in codes:
1338
subject = subjects[code]
1339
1340
subject.Process()
1341
pmf = subject.suite.DistN()
1342
pmf.label = subject.code
1343
thinkplot.Pmf(pmf)
1344
1345
pmfs.append(pmf)
1346
1347
print('ProbGreater', thinkbayes2.PmfProbGreater(pmfs[0], pmfs[1]))
1348
print('ProbLess', thinkbayes2.PmfProbLess(pmfs[0], pmfs[1]))
1349
1350
thinkplot.Save(root='species4',
1351
xlabel='Number of species',
1352
ylabel='Prob',
1353
formats=FORMATS,
1354
)
1355
1356
1357
def RunSubject(code, conc=1, high=500):
1358
"""Run the analysis for the subject with the given code.
1359
1360
code: string code
1361
"""
1362
subjects = JoinSubjects()
1363
subject = subjects[code]
1364
1365
subject.Process(conc=conc, high=high, iters=300)
1366
subject.MakeQuickPrediction()
1367
1368
PrintSummary(subject)
1369
actual_l = subject.total_species - subject.num_species
1370
cdf_l = subject.DistL().MakeCdf()
1371
PrintPrediction(cdf_l, actual_l)
1372
1373
subject.MakeFigures()
1374
1375
num_reads = 400
1376
curves = subject.RunSimulations(100, num_reads)
1377
root = 'species-rare-%s' % subject.code
1378
PlotCurves(curves, root=root)
1379
1380
num_reads = 800
1381
curves = subject.RunSimulations(500, num_reads)
1382
ks = [100, 200, 400, 800]
1383
cdfs = MakeConditionals(curves, ks)
1384
root = 'species-cond-%s' % subject.code
1385
PlotConditionals(cdfs, root=root)
1386
1387
num_reads = 1000
1388
curves = subject.RunSimulations(500, num_reads, frac_flag=True)
1389
ks = [10, 100, 200, 400, 600, 800, 1000]
1390
cdfs = MakeFracCdfs(curves, ks)
1391
root = 'species-frac-%s' % subject.code
1392
PlotFracCdfs(cdfs, root=root)
1393
1394
1395
def PrintSummary(subject):
1396
"""Print a summary of a subject.
1397
1398
subject: Subject
1399
"""
1400
print(subject.code)
1401
print('found %d species in %d reads' % (subject.num_species,
1402
subject.num_reads))
1403
1404
print('total %d species in %d reads' % (subject.total_species,
1405
subject.total_reads))
1406
1407
cdf = subject.suite.DistN().MakeCdf()
1408
print('n')
1409
PrintPrediction(cdf, 'unknown')
1410
1411
1412
def PrintPrediction(cdf, actual):
1413
"""Print a summary of a prediction.
1414
1415
cdf: predictive distribution
1416
actual: actual value
1417
"""
1418
median = cdf.Percentile(50)
1419
low, high = cdf.CredibleInterval(75)
1420
1421
print('predicted %0.2f (%0.2f %0.2f)' % (median, low, high))
1422
print('actual', actual)
1423
1424
1425
def RandomSeed(x):
1426
"""Initialize random.random and numpy.random.
1427
1428
x: int seed
1429
"""
1430
random.seed(x)
1431
numpy.random.seed(x)
1432
1433
1434
def GenerateFakeSample(n, r, tr, conc=1):
1435
"""Generates fake data with the given parameters.
1436
1437
n: number of species
1438
r: number of reads in subsample
1439
tr: total number of reads
1440
conc: concentration parameter
1441
1442
Returns: hist of all reads, hist of subsample, prev_unseen
1443
"""
1444
# generate random prevalences
1445
dirichlet = thinkbayes2.Dirichlet(n, conc=conc)
1446
prevalences = dirichlet.Random()
1447
prevalences.sort()
1448
1449
# generate a simulated sample
1450
pmf = thinkbayes2.Pmf(dict(enumerate(prevalences)))
1451
cdf = pmf.MakeCdf()
1452
sample = cdf.Sample(tr)
1453
1454
# collect the species counts
1455
hist = thinkbayes2.Hist(sample)
1456
1457
# extract a subset of the data
1458
if tr > r:
1459
random.shuffle(sample)
1460
subsample = sample[:r]
1461
subhist = thinkbayes2.Hist(subsample)
1462
else:
1463
subhist = hist
1464
1465
# add up the prevalence of unseen species
1466
prev_unseen = 0
1467
for species, prev in enumerate(prevalences):
1468
if species not in subhist:
1469
prev_unseen += prev
1470
1471
return hist, subhist, prev_unseen
1472
1473
1474
def PlotActualPrevalences():
1475
"""Makes a plot comparing actual prevalences with a model.
1476
"""
1477
# read data
1478
subject_map, _ = ReadCompleteDataset()
1479
1480
# for subjects with more than 50 species,
1481
# PMF of max prevalence, and PMF of max prevalence
1482
# generated by a simulation
1483
pmf_actual = thinkbayes2.Pmf()
1484
pmf_sim = thinkbayes2.Pmf()
1485
1486
# concentration parameter used in the simulation
1487
conc = 0.06
1488
1489
for code, subject in subject_map.items():
1490
prevalences = subject.GetPrevalences()
1491
m = len(prevalences)
1492
if m < 2:
1493
continue
1494
1495
actual_max = max(prevalences)
1496
print(code, m, actual_max)
1497
1498
# incr the PMFs
1499
if m > 50:
1500
pmf_actual.Incr(actual_max)
1501
pmf_sim.Incr(SimulateMaxPrev(m, conc))
1502
1503
# plot CDFs for the actual and simulated max prevalence
1504
cdf_actual = pmf_actual.MakeCdf(label='actual')
1505
cdf_sim = pmf_sim.MakeCdf(label='sim')
1506
1507
thinkplot.Cdfs([cdf_actual, cdf_sim])
1508
thinkplot.Show()
1509
1510
1511
def ScatterPrevalences(ms, actual):
1512
"""Make a scatter plot of actual prevalences and expected values.
1513
1514
ms: sorted sequence of in m (number of species)
1515
actual: sequence of actual max prevalence
1516
"""
1517
for conc in [1, 0.5, 0.2, 0.1]:
1518
expected = [ExpectedMaxPrev(m, conc) for m in ms]
1519
thinkplot.Plot(ms, expected)
1520
1521
thinkplot.Scatter(ms, actual)
1522
thinkplot.Show(xscale='log')
1523
1524
1525
def SimulateMaxPrev(m, conc=1):
1526
"""Returns random max prevalence from a Dirichlet distribution.
1527
1528
m: int number of species
1529
conc: concentration parameter of the Dirichlet distribution
1530
1531
Returns: float max of m prevalences
1532
"""
1533
dirichlet = thinkbayes2.Dirichlet(m, conc)
1534
prevalences = dirichlet.Random()
1535
return max(prevalences)
1536
1537
1538
def ExpectedMaxPrev(m, conc=1, iters=100):
1539
"""Estimate expected max prevalence.
1540
1541
m: number of species
1542
conc: concentration parameter
1543
iters: how many iterations to run
1544
1545
Returns: expected max prevalence
1546
"""
1547
dirichlet = thinkbayes2.Dirichlet(m, conc)
1548
1549
t = []
1550
for _ in range(iters):
1551
prevalences = dirichlet.Random()
1552
t.append(max(prevalences))
1553
1554
return numpy.mean(t)
1555
1556
1557
class Calibrator(object):
1558
"""Encapsulates the calibration process."""
1559
1560
def __init__(self, conc=0.1):
1561
"""
1562
"""
1563
self.conc = conc
1564
1565
self.ps = range(10, 100, 10)
1566
self.total_n = numpy.zeros(len(self.ps))
1567
self.total_q = numpy.zeros(len(self.ps))
1568
self.total_l = numpy.zeros(len(self.ps))
1569
1570
self.n_seq = []
1571
self.q_seq = []
1572
self.l_seq = []
1573
1574
def Calibrate(self, num_runs=100, n_low=30, n_high=400, r=400, tr=1200):
1575
"""Runs calibrations.
1576
1577
num_runs: how many runs
1578
"""
1579
for seed in range(num_runs):
1580
self.RunCalibration(seed, n_low, n_high, r, tr)
1581
1582
self.total_n *= 100.0 / num_runs
1583
self.total_q *= 100.0 / num_runs
1584
self.total_l *= 100.0 / num_runs
1585
1586
def Validate(self, num_runs=100, clean_param=0):
1587
"""Runs validations.
1588
1589
num_runs: how many runs
1590
"""
1591
subject_map, _ = ReadCompleteDataset(clean_param=clean_param)
1592
1593
i = 0
1594
for match in subject_map.itervalues():
1595
if match.num_reads < 400:
1596
continue
1597
num_reads = 100
1598
1599
print('Validate', match.code)
1600
subject = match.Resample(num_reads)
1601
subject.Match(match)
1602
1603
n_actual = None
1604
q_actual = subject.prev_unseen
1605
l_actual = subject.total_species - subject.num_species
1606
self.RunSubject(subject, n_actual, q_actual, l_actual)
1607
1608
i += 1
1609
if i == num_runs:
1610
break
1611
1612
self.total_n *= 100.0 / num_runs
1613
self.total_q *= 100.0 / num_runs
1614
self.total_l *= 100.0 / num_runs
1615
1616
def PlotN(self, root='species-n'):
1617
"""Makes a scatter plot of simulated vs actual prev_unseen (q).
1618
"""
1619
xs, ys = zip(*self.n_seq)
1620
if None in xs:
1621
return
1622
1623
high = max(xs+ys)
1624
1625
thinkplot.Plot([0, high], [0, high], color='gray')
1626
thinkplot.Scatter(xs, ys)
1627
thinkplot.Save(root=root,
1628
xlabel='Actual n',
1629
ylabel='Predicted')
1630
1631
def PlotQ(self, root='species-q'):
1632
"""Makes a scatter plot of simulated vs actual prev_unseen (q).
1633
"""
1634
thinkplot.Plot([0, 0.2], [0, 0.2], color='gray')
1635
xs, ys = zip(*self.q_seq)
1636
thinkplot.Scatter(xs, ys)
1637
thinkplot.Save(root=root,
1638
xlabel='Actual q',
1639
ylabel='Predicted')
1640
1641
def PlotL(self, root='species-n'):
1642
"""Makes a scatter plot of simulated vs actual l.
1643
"""
1644
thinkplot.Plot([0, 20], [0, 20], color='gray')
1645
xs, ys = zip(*self.l_seq)
1646
thinkplot.Scatter(xs, ys)
1647
thinkplot.Save(root=root,
1648
xlabel='Actual l',
1649
ylabel='Predicted')
1650
1651
def PlotCalibrationCurves(self, root='species5'):
1652
"""Plots calibration curves"""
1653
print(self.total_n)
1654
print(self.total_q)
1655
print(self.total_l)
1656
1657
thinkplot.Plot([0, 100], [0, 100], color='gray', alpha=0.2)
1658
1659
if self.total_n[0] >= 0:
1660
thinkplot.Plot(self.ps, self.total_n, label='n')
1661
1662
thinkplot.Plot(self.ps, self.total_q, label='q')
1663
thinkplot.Plot(self.ps, self.total_l, label='l')
1664
1665
thinkplot.Save(root=root,
1666
axis=[0, 100, 0, 100],
1667
xlabel='Ideal percentages',
1668
ylabel='Predictive distributions',
1669
formats=FORMATS,
1670
)
1671
1672
def RunCalibration(self, seed, n_low, n_high, r, tr):
1673
"""Runs a single calibration run.
1674
1675
Generates N and prevalences from a Dirichlet distribution,
1676
then generates simulated data.
1677
1678
Runs analysis to get the posterior distributions.
1679
Generates calibration curves for each posterior distribution.
1680
1681
seed: int random seed
1682
"""
1683
# generate a random number of species and their prevalences
1684
# (from a Dirichlet distribution with alpha_i = conc for all i)
1685
RandomSeed(seed)
1686
n_actual = random.randrange(n_low, n_high+1)
1687
1688
hist, subhist, q_actual = GenerateFakeSample(
1689
n_actual,
1690
r,
1691
tr,
1692
self.conc)
1693
1694
l_actual = len(hist) - len(subhist)
1695
print('Run low, high, conc', n_low, n_high, self.conc)
1696
print('Run r, tr', r, tr)
1697
print('Run n, q, l', n_actual, q_actual, l_actual)
1698
1699
# extract the data
1700
data = [count for species, count in subhist.Items()]
1701
data.sort()
1702
print('data', data)
1703
1704
# make a Subject and process
1705
subject = Subject('simulated')
1706
subject.num_reads = r
1707
subject.total_reads = tr
1708
1709
for species, count in subhist.Items():
1710
subject.Add(species, count)
1711
subject.Done()
1712
1713
self.RunSubject(subject, n_actual, q_actual, l_actual)
1714
1715
def RunSubject(self, subject, n_actual, q_actual, l_actual):
1716
"""Runs the analysis for a subject.
1717
1718
subject: Subject
1719
n_actual: number of species
1720
q_actual: prevalence of unseen species
1721
l_actual: number of new species
1722
"""
1723
# process and make prediction
1724
subject.Process(conc=self.conc, iters=100)
1725
subject.MakeQuickPrediction()
1726
1727
# extract the posterior suite
1728
suite = subject.suite
1729
1730
# check the distribution of n
1731
pmf_n = suite.DistN()
1732
print('n')
1733
self.total_n += self.CheckDistribution(pmf_n, n_actual, self.n_seq)
1734
1735
# check the distribution of q
1736
pmf_q = suite.DistQ()
1737
print('q')
1738
self.total_q += self.CheckDistribution(pmf_q, q_actual, self.q_seq)
1739
1740
# check the distribution of additional species
1741
pmf_l = subject.DistL()
1742
print('l')
1743
self.total_l += self.CheckDistribution(pmf_l, l_actual, self.l_seq)
1744
1745
def CheckDistribution(self, pmf, actual, seq):
1746
"""Checks a predictive distribution and returns a score vector.
1747
1748
pmf: predictive distribution
1749
actual: actual value
1750
seq: which sequence to append (actual, mean) onto
1751
"""
1752
mean = pmf.Mean()
1753
seq.append((actual, mean))
1754
1755
cdf = pmf.MakeCdf()
1756
PrintPrediction(cdf, actual)
1757
1758
sv = ScoreVector(cdf, self.ps, actual)
1759
return sv
1760
1761
1762
def ScoreVector(cdf, ps, actual):
1763
"""Checks whether the actual value falls in each credible interval.
1764
1765
cdf: predictive distribution
1766
ps: percentages to check (0-100)
1767
actual: actual value
1768
1769
Returns: numpy array of 0, 0.5, or 1
1770
"""
1771
scores = []
1772
for p in ps:
1773
low, high = cdf.CredibleInterval(p)
1774
score = Score(low, high, actual)
1775
scores.append(score)
1776
1777
return numpy.array(scores)
1778
1779
1780
def Score(low, high, n):
1781
"""Score whether the actual value falls in the range.
1782
1783
Hitting the posts counts as 0.5, -1 is invalid.
1784
1785
low: low end of range
1786
high: high end of range
1787
n: actual value
1788
1789
Returns: -1, 0, 0.5 or 1
1790
"""
1791
if n is None:
1792
return -1
1793
if low < n < high:
1794
return 1
1795
if n == low or n == high:
1796
return 0.5
1797
else:
1798
return 0
1799
1800
1801
def FakeSubject(n=300, conc=0.1, num_reads=400, prevalences=None):
1802
"""Makes a fake Subject.
1803
1804
If prevalences is provided, n and conc are ignored.
1805
1806
n: number of species
1807
conc: concentration parameter
1808
num_reads: number of reads
1809
prevalences: numpy array of prevalences (overrides n and conc)
1810
"""
1811
# generate random prevalences
1812
if prevalences is None:
1813
dirichlet = thinkbayes2.Dirichlet(n, conc=conc)
1814
prevalences = dirichlet.Random()
1815
prevalences.sort()
1816
1817
# generate a simulated sample
1818
pmf = thinkbayes2.Pmf(dict(enumerate(prevalences)))
1819
cdf = pmf.MakeCdf()
1820
sample = cdf.Sample(num_reads)
1821
1822
# collect the species counts
1823
hist = thinkbayes2.Hist(sample)
1824
1825
# extract the data
1826
data = [count for species, count in hist.Items()]
1827
data.sort()
1828
1829
# make a Subject and process
1830
subject = Subject('simulated')
1831
1832
for species, count in hist.Items():
1833
subject.Add(species, count)
1834
subject.Done()
1835
1836
return subject
1837
1838
1839
def PlotSubjectCdf(code=None, clean_param=0):
1840
"""Checks whether the Dirichlet model can replicate the data.
1841
"""
1842
subject_map, uber_subject = ReadCompleteDataset(clean_param=clean_param)
1843
1844
if code is None:
1845
subjects = subject_map.values()
1846
subject = random.choice(subjects)
1847
code = subject.code
1848
elif code == 'uber':
1849
subject = uber_subject
1850
else:
1851
subject = subject_map[code]
1852
1853
print(subject.code)
1854
1855
m = subject.GetM()
1856
1857
subject.Process(high=m, conc=0.1, iters=0)
1858
print(subject.suite.params[:m])
1859
1860
# plot the cdf
1861
options = dict(linewidth=3, color='blue', alpha=0.5)
1862
cdf = subject.MakeCdf()
1863
thinkplot.Cdf(cdf, **options)
1864
1865
options = dict(linewidth=1, color='green', alpha=0.5)
1866
1867
# generate fake subjects and plot their CDFs
1868
for _ in range(10):
1869
prevalences = subject.suite.SamplePrevalences(m)
1870
fake = FakeSubject(prevalences=prevalences)
1871
cdf = fake.MakeCdf()
1872
thinkplot.Cdf(cdf, **options)
1873
1874
root = 'species-cdf-%s' % code
1875
thinkplot.Save(root=root,
1876
xlabel='rank',
1877
ylabel='CDF',
1878
xscale='log',
1879
formats=FORMATS,
1880
)
1881
1882
1883
def RunCalibration(flag='cal', num_runs=100, clean_param=50):
1884
"""Runs either the calibration or validation process.
1885
1886
flag: string 'cal' or 'val'
1887
num_runs: how many runs
1888
clean_param: parameter used for data cleaning
1889
"""
1890
cal = Calibrator(conc=0.1)
1891
1892
if flag == 'val':
1893
cal.Validate(num_runs=num_runs, clean_param=clean_param)
1894
else:
1895
cal.Calibrate(num_runs=num_runs)
1896
1897
cal.PlotN(root='species-n-%s' % flag)
1898
cal.PlotQ(root='species-q-%s' % flag)
1899
cal.PlotL(root='species-l-%s' % flag)
1900
cal.PlotCalibrationCurves(root='species5-%s' % flag)
1901
1902
1903
def RunTests():
1904
"""Runs calibration code and generates some figures."""
1905
RunCalibration(flag='val')
1906
RunCalibration(flag='cal')
1907
1908
PlotSubjectCdf('B1558.G', clean_param=50)
1909
PlotSubjectCdf(None)
1910
1911
1912
def main(script):
1913
RandomSeed(17)
1914
RunSubject('B1242', conc=1, high=100)
1915
1916
RandomSeed(17)
1917
SimpleDirichletExample()
1918
1919
RandomSeed(17)
1920
HierarchicalExample()
1921
1922
1923
if __name__ == '__main__':
1924
main(*sys.argv)
1925
1926