Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/thinkbayes2
Path: blob/master/scripts/kidney.py
1901 views
1
"""This file contains code for use with "Think Bayes",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2012 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
from __future__ import print_function, division
9
10
import math
11
import numpy
12
import random
13
import sys
14
15
import correlation
16
import thinkplot
17
import matplotlib.pyplot as pyplot
18
import thinkbayes2
19
20
21
INTERVAL = 245/365.0
22
FORMATS = ['pdf', 'eps']
23
MINSIZE = 0.2
24
MAXSIZE = 20
25
BUCKET_FACTOR = 10
26
27
28
def log2(x, denom=math.log(2)):
29
"""Computes log base 2."""
30
return math.log(x) / denom
31
32
33
def SimpleModel():
34
"""Runs calculations based on a simple model."""
35
36
# time between discharge and diagnosis, in days
37
interval = 3291.0
38
39
# doubling time in linear measure is doubling time in volume * 3
40
dt = 811.0 * 3
41
42
# number of doublings since discharge
43
doublings = interval / dt
44
45
# how big was the tumor at time of discharge (diameter in cm)
46
d1 = 15.5
47
d0 = d1 / 2.0 ** doublings
48
49
print(('interval (days)', interval))
50
print(('interval (years)', interval / 365))
51
print(('dt', dt))
52
print(('doublings', doublings))
53
print(('d1', d1))
54
print(('d0', d0))
55
56
# assume an initial linear measure of 0.1 cm
57
d0 = 0.1
58
d1 = 15.5
59
60
# how many doublings would it take to get from d0 to d1
61
doublings = log2(d1 / d0)
62
63
# what linear doubling time does that imply?
64
dt = interval / doublings
65
66
print(('doublings', doublings))
67
print(('dt', dt))
68
69
# compute the volumetric doubling time and RDT
70
vdt = dt / 3
71
rdt = 365 / vdt
72
73
print(('vdt', vdt))
74
print(('rdt', rdt))
75
76
cdf = MakeCdf()
77
p = cdf.Prob(rdt)
78
print(('Prob{RDT > 2.4}', 1-p))
79
80
81
def MakeCdf():
82
"""Uses the data from Zhang et al. to construct a CDF."""
83
n = 53.0
84
freqs = [0, 2, 31, 42, 48, 51, 52, 53]
85
ps = [freq/n for freq in freqs]
86
xs = numpy.arange(-1.5, 6.5, 1.0)
87
88
cdf = thinkbayes2.Cdf(xs, ps)
89
return cdf
90
91
92
def PlotCdf(cdf):
93
"""Plots the actual and fitted distributions.
94
95
cdf: CDF object
96
"""
97
xs, ps = cdf.xs, cdf.ps
98
cps = [1-p for p in ps]
99
100
# CCDF on logy scale: shows exponential behavior
101
thinkplot.Clf()
102
thinkplot.Plot(xs, cps, 'bo-')
103
thinkplot.Save(root='kidney1',
104
formats=FORMATS,
105
xlabel='RDT',
106
ylabel='CCDF (log scale)',
107
yscale='log')
108
109
# CDF, model and data
110
111
thinkplot.Clf()
112
thinkplot.PrePlot(num=2)
113
mxs, mys = ModelCdf()
114
thinkplot.Plot(mxs, mys, label='model', linestyle='dashed')
115
116
thinkplot.Plot(xs, ps, 'gs', label='data')
117
thinkplot.Save(root='kidney2',
118
formats=FORMATS,
119
xlabel='RDT (volume doublings per year)',
120
ylabel='CDF',
121
title='Distribution of RDT',
122
axis=[-2, 7, 0, 1],
123
loc=4)
124
125
126
def QQPlot(cdf, fit):
127
"""Makes a QQPlot of the values from actual and fitted distributions.
128
129
cdf: actual Cdf of RDT
130
fit: model
131
"""
132
xs = [-1.5, 5.5]
133
thinkplot.Clf()
134
thinkplot.Plot(xs, xs, 'b-')
135
136
xs, ps = cdf.xs, cdf.ps
137
fs = [fit.Value(p) for p in ps]
138
139
thinkplot.Plot(xs, fs, 'gs')
140
thinkplot.Save(root = 'kidney3',
141
formats=FORMATS,
142
xlabel='Actual',
143
ylabel='Model')
144
145
146
def FitCdf(cdf):
147
"""Fits a line to the log CCDF and returns the slope.
148
149
cdf: Cdf of RDT
150
"""
151
xs, ps = cdf.xs, cdf.ps
152
cps = [1-p for p in ps]
153
154
xs = xs[1:-1]
155
lcps = [math.log(p) for p in cps[1:-1]]
156
157
_inter, slope = correlation.LeastSquares(xs, lcps)
158
return -slope
159
160
161
def CorrelatedGenerator(cdf, rho):
162
"""Generates a sequence of values from cdf with correlation.
163
164
Generates a correlated standard Normal series, then transforms to
165
values from cdf
166
167
cdf: distribution to choose from
168
rho: target coefficient of correlation
169
"""
170
def Transform(x):
171
"""Maps from a Normal variate to a variate with the given CDF."""
172
p = thinkbayes2.NormalCdf(x)
173
y = cdf.Value(p)
174
return y
175
176
# for the first value, choose from a Normal and transform it
177
x = random.gauss(0, 1)
178
yield Transform(x)
179
180
# for subsequent values, choose from the conditional distribution
181
# based on the previous value
182
sigma = math.sqrt(1 - rho**2)
183
while True:
184
x = random.gauss(x * rho, sigma)
185
yield Transform(x)
186
187
188
def UncorrelatedGenerator(cdf, _rho=None):
189
"""Generates a sequence of values from cdf with no correlation.
190
191
Ignores rho, which is accepted as a parameter to provide the
192
same interface as CorrelatedGenerator
193
194
cdf: distribution to choose from
195
rho: ignored
196
"""
197
while True:
198
x = cdf.Random()
199
yield x
200
201
202
def RdtGenerator(cdf, rho):
203
"""Returns an iterator with n values from cdf and the given correlation.
204
205
cdf: Cdf object
206
rho: coefficient of correlation
207
"""
208
if rho == 0.0:
209
return UncorrelatedGenerator(cdf)
210
else:
211
return CorrelatedGenerator(cdf, rho)
212
213
214
def GenerateRdt(pc, lam1, lam2):
215
"""Generate an RDT from a mixture of exponential distributions.
216
217
With prob pc, generate a negative value with param lam2;
218
otherwise generate a positive value with param lam1.
219
"""
220
if random.random() < pc:
221
return -random.expovariate(lam2)
222
else:
223
return random.expovariate(lam1)
224
225
226
def GenerateSample(n, pc, lam1, lam2):
227
"""Generates a sample of RDTs.
228
229
n: sample size
230
pc: probablity of negative growth
231
lam1: exponential parameter of positive growth
232
lam2: exponential parameter of negative growth
233
234
Returns: list of random variates
235
"""
236
xs = [GenerateRdt(pc, lam1, lam2) for _ in range(n)]
237
return xs
238
239
240
def GenerateCdf(n=1000, pc=0.35, lam1=0.79, lam2=5.0):
241
"""Generates a sample of RDTs and returns its CDF.
242
243
n: sample size
244
pc: probablity of negative growth
245
lam1: exponential parameter of positive growth
246
lam2: exponential parameter of negative growth
247
248
Returns: Cdf of generated sample
249
"""
250
xs = GenerateSample(n, pc, lam1, lam2)
251
cdf = thinkbayes2.MakeCdfFromList(xs)
252
return cdf
253
254
255
def ModelCdf(pc=0.35, lam1=0.79, lam2=5.0):
256
"""
257
258
pc: probablity of negative growth
259
lam1: exponential parameter of positive growth
260
lam2: exponential parameter of negative growth
261
262
Returns: list of xs, list of ys
263
"""
264
cdf = thinkbayes2.EvalExponentialCdf
265
x1 = numpy.arange(-2, 0, 0.1)
266
y1 = [pc * (1 - cdf(-x, lam2)) for x in x1]
267
x2 = numpy.arange(0, 7, 0.1)
268
y2 = [pc + (1-pc) * cdf(x, lam1) for x in x2]
269
return list(x1) + list(x2), y1+y2
270
271
272
def BucketToCm(y, factor=BUCKET_FACTOR):
273
"""Computes the linear dimension for a given bucket.
274
275
t: bucket number
276
factor: multiplicitive factor from one bucket to the next
277
278
Returns: linear dimension in cm
279
"""
280
return math.exp(y / factor)
281
282
283
def CmToBucket(x, factor=BUCKET_FACTOR):
284
"""Computes the bucket for a given linear dimension.
285
286
x: linear dimension in cm
287
factor: multiplicitive factor from one bucket to the next
288
289
Returns: float bucket number
290
"""
291
return round(factor * math.log(x))
292
293
294
def Diameter(volume, factor=3/math.pi/4, exp=1/3.0):
295
"""Converts a volume to a diameter.
296
297
d = 2r = 2 * (3/4/pi V)^1/3
298
"""
299
return 2 * (factor * volume) ** exp
300
301
302
def Volume(diameter, factor=4*math.pi/3):
303
"""Converts a diameter to a volume.
304
305
V = 4/3 pi (d/2)^3
306
"""
307
return factor * (diameter/2.0)**3
308
309
310
class Cache(object):
311
"""Records each observation point for each tumor."""
312
313
def __init__(self):
314
"""Initializes the cache.
315
316
joint: map from (age, bucket) to frequency
317
sequences: map from bucket to a list of sequences
318
initial_rdt: sequence of (V0, rdt) pairs
319
"""
320
self.joint = thinkbayes2.Joint()
321
self.sequences = {}
322
self.initial_rdt = []
323
324
def GetBuckets(self):
325
"""Returns an iterator for the keys in the cache."""
326
return self.sequences.iterkeys()
327
328
def GetSequence(self, bucket):
329
"""Looks up a bucket in the cache."""
330
return self.sequences[bucket]
331
332
def ConditionalCdf(self, bucket, name=''):
333
"""Forms the cdf of ages for a given bucket.
334
335
bucket: int bucket number
336
name: string
337
"""
338
pmf = self.joint.Conditional(0, 1, bucket, name=name)
339
cdf = pmf.MakeCdf()
340
return cdf
341
342
def ProbOlder(self, cm, age):
343
"""Computes the probability of exceeding age, given size.
344
345
cm: size in cm
346
age: age in years
347
"""
348
bucket = CmToBucket(cm)
349
cdf = self.ConditionalCdf(bucket)
350
p = cdf.Prob(age)
351
return 1-p
352
353
def GetDistAgeSize(self, size_thresh=MAXSIZE):
354
"""Gets the joint distribution of age and size.
355
356
Map from (age, log size in cm) to log freq
357
358
Returns: new Pmf object
359
"""
360
joint = thinkbayes2.Joint()
361
362
for val, freq in self.joint.Items():
363
age, bucket = val
364
cm = BucketToCm(bucket)
365
if cm > size_thresh:
366
continue
367
log_cm = math.log10(cm)
368
joint.Set((age, log_cm), math.log(freq) * 10)
369
370
return joint
371
372
def Add(self, age, seq, rdt):
373
"""Adds this observation point to the cache.
374
375
age: age of the tumor in years
376
seq: sequence of volumes
377
rdt: RDT during this interval
378
"""
379
final = seq[-1]
380
cm = Diameter(final)
381
bucket = CmToBucket(cm)
382
self.joint.Incr((age, bucket))
383
384
self.sequences.setdefault(bucket, []).append(seq)
385
386
initial = seq[-2]
387
self.initial_rdt.append((initial, rdt))
388
389
def Print(self):
390
"""Prints the size (cm) for each bucket, and the number of sequences."""
391
for bucket in sorted(self.GetBuckets()):
392
ss = self.GetSequence(bucket)
393
diameter = BucketToCm(bucket)
394
print((diameter, len(ss)))
395
396
def Correlation(self):
397
"""Computes the correlation between log volumes and rdts."""
398
vs, rdts = zip(*self.initial_rdt)
399
lvs = [math.log(v) for v in vs]
400
return correlation.Corr(lvs, rdts)
401
402
403
class Calculator(object):
404
"""Encapsulates the state of the computation."""
405
406
def __init__(self):
407
"""Initializes the cache."""
408
self.cache = Cache()
409
410
def MakeSequences(self, n, rho, cdf):
411
"""Returns a list of sequences of volumes.
412
413
n: number of sequences to make
414
rho: serial correlation
415
cdf: Cdf of rdts
416
417
Returns: list of n sequences of volumes
418
"""
419
sequences = []
420
for i in range(n):
421
rdt_seq = RdtGenerator(cdf, rho)
422
seq = self.MakeSequence(rdt_seq)
423
sequences.append(seq)
424
425
if i % 100 == 0:
426
print(i)
427
428
return sequences
429
430
def MakeSequence(self, rdt_seq, v0=0.01, interval=INTERVAL,
431
vmax=Volume(MAXSIZE)):
432
"""Simulate the growth of a tumor.
433
434
rdt_seq: sequence of rdts
435
v0: initial volume in mL (cm^3)
436
interval: timestep in years
437
vmax: volume to stop at
438
439
Returns: sequence of volumes
440
"""
441
seq = v0,
442
age = 0
443
444
for rdt in rdt_seq:
445
age += interval
446
final, seq = self.ExtendSequence(age, seq, rdt, interval)
447
if final > vmax:
448
break
449
450
return seq
451
452
def ExtendSequence(self, age, seq, rdt, interval):
453
"""Generates a new random value and adds it to the end of seq.
454
455
Side-effect: adds sub-sequences to the cache.
456
457
age: age of tumor at the end of this interval
458
seq: sequence of values so far
459
rdt: reciprocal doubling time in doublings per year
460
interval: timestep in years
461
462
Returns: final volume, extended sequence
463
"""
464
initial = seq[-1]
465
doublings = rdt * interval
466
final = initial * 2**doublings
467
new_seq = seq + (final,)
468
self.cache.Add(age, new_seq, rdt)
469
470
return final, new_seq
471
472
def PlotBucket(self, bucket, color='blue'):
473
"""Plots the set of sequences for the given bucket.
474
475
bucket: int bucket number
476
color: string
477
"""
478
sequences = self.cache.GetSequence(bucket)
479
for seq in sequences:
480
n = len(seq)
481
age = n * INTERVAL
482
ts = numpy.linspace(-age, 0, n)
483
PlotSequence(ts, seq, color)
484
485
def PlotBuckets(self):
486
"""Plots the set of sequences that ended in a given bucket."""
487
# 2.01, 4.95 cm, 9.97 cm
488
buckets = [7.0, 16.0, 23.0]
489
buckets = [23.0]
490
colors = ['blue', 'green', 'red', 'cyan']
491
492
thinkplot.Clf()
493
for bucket, color in zip(buckets, colors):
494
self.PlotBucket(bucket, color)
495
496
thinkplot.Save(root='kidney5',
497
formats=FORMATS,
498
title='History of simulated tumors',
499
axis=[-40, 1, MINSIZE, 12],
500
xlabel='years',
501
ylabel='diameter (cm, log scale)',
502
yscale='log')
503
504
def PlotJointDist(self):
505
"""Makes a pcolor plot of the age-size joint distribution."""
506
thinkplot.Clf()
507
508
joint = self.cache.GetDistAgeSize()
509
thinkplot.Contour(joint, contour=False, pcolor=True)
510
511
thinkplot.Save(root='kidney8',
512
formats=FORMATS,
513
axis=[0, 41, -0.7, 1.31],
514
yticks=MakeLogTicks([0.2, 0.5, 1, 2, 5, 10, 20]),
515
xlabel='ages',
516
ylabel='diameter (cm, log scale)')
517
518
def PlotConditionalCdfs(self):
519
"""Plots the cdf of ages for each bucket."""
520
buckets = [7.0, 16.0, 23.0, 27.0]
521
# 2.01, 4.95 cm, 9.97 cm, 14.879 cm
522
names = ['2 cm', '5 cm', '10 cm', '15 cm']
523
cdfs = []
524
525
for bucket, name in zip(buckets, names):
526
cdf = self.cache.ConditionalCdf(bucket, name)
527
cdfs.append(cdf)
528
529
thinkplot.Clf()
530
thinkplot.PrePlot(num=len(cdfs))
531
thinkplot.Cdfs(cdfs)
532
thinkplot.Save(root='kidney6',
533
title='Distribution of age for several diameters',
534
formats=FORMATS,
535
xlabel='tumor age (years)',
536
ylabel='CDF',
537
loc=4)
538
539
def PlotCredibleIntervals(self, xscale='linear'):
540
"""Plots the confidence interval for each bucket."""
541
xs = []
542
ts = []
543
percentiles = [95, 75, 50, 25, 5]
544
min_size = 0.3
545
546
# loop through the buckets, accumulate
547
# xs: sequence of sizes in cm
548
# ts: sequence of percentile tuples
549
for _, bucket in enumerate(sorted(self.cache.GetBuckets())):
550
cm = BucketToCm(bucket)
551
if cm < min_size or cm > 20.0:
552
continue
553
xs.append(cm)
554
cdf = self.cache.ConditionalCdf(bucket)
555
ps = [cdf.Percentile(p) for p in percentiles]
556
ts.append(ps)
557
558
# dump the results into a table
559
fp = open('kidney_table.tex', 'w')
560
PrintTable(fp, xs, ts)
561
fp.close()
562
563
# make the figure
564
linewidths = [1, 2, 3, 2, 1]
565
alphas = [0.3, 0.5, 1, 0.5, 0.3]
566
labels = ['95th', '75th', '50th', '25th', '5th']
567
568
# transpose the ts so we have sequences for each percentile rank
569
thinkplot.Clf()
570
yys = zip(*ts)
571
572
for ys, linewidth, alpha, label in zip(yys, linewidths, alphas, labels):
573
options = dict(color='blue', linewidth=linewidth,
574
alpha=alpha, label=label, markersize=2)
575
576
# plot the data points
577
thinkplot.Plot(xs, ys, 'bo', **options)
578
579
# plot the fit lines
580
fxs = [min_size, 20.0]
581
fys = FitLine(xs, ys, fxs)
582
583
thinkplot.Plot(fxs, fys, **options)
584
585
# put a label at the end of each line
586
x, y = fxs[-1], fys[-1]
587
pyplot.text(x*1.05, y, label, color='blue',
588
horizontalalignment='left',
589
verticalalignment='center')
590
591
# make the figure
592
thinkplot.Save(root='kidney7',
593
formats=FORMATS,
594
title='Credible interval for age vs diameter',
595
xlabel='diameter (cm, log scale)',
596
ylabel='tumor age (years)',
597
xscale=xscale,
598
xticks=MakeTicks([0.5, 1, 2, 5, 10, 20]),
599
axis=[0.25, 35, 0, 45],
600
legend=False,
601
)
602
603
604
def PlotSequences(sequences):
605
"""Plots linear measurement vs time.
606
607
sequences: list of sequences of volumes
608
"""
609
thinkplot.Clf()
610
611
options = dict(color='gray', linewidth=1, linestyle='dashed')
612
thinkplot.Plot([0, 40], [10, 10], **options)
613
614
for seq in sequences:
615
n = len(seq)
616
age = n * INTERVAL
617
ts = numpy.linspace(0, age, n)
618
PlotSequence(ts, seq)
619
620
thinkplot.Save(root='kidney4',
621
formats=FORMATS,
622
axis=[0, 40, MINSIZE, 20],
623
title='Simulations of tumor growth',
624
xlabel='tumor age (years)',
625
yticks=MakeTicks([0.2, 0.5, 1, 2, 5, 10, 20]),
626
ylabel='diameter (cm, log scale)',
627
yscale='log')
628
629
630
def PlotSequence(ts, seq, color='blue'):
631
"""Plots a time series of linear measurements.
632
633
ts: sequence of times in years
634
seq: sequence of columes
635
color: color string
636
"""
637
options = dict(color=color, linewidth=1, alpha=0.2)
638
xs = [Diameter(v) for v in seq]
639
640
thinkplot.Plot(ts, xs, **options)
641
642
643
def PrintCI(fp, cm, ps):
644
"""Writes a line in the LaTeX table.
645
646
fp: file pointer
647
cm: diameter in cm
648
ts: tuples of percentiles
649
"""
650
fp.write('%0.1f' % round(cm, 1))
651
for p in reversed(ps):
652
fp.write(' & %0.1f ' % round(p, 1))
653
fp.write(r'\\' '\n')
654
655
656
def PrintTable(fp, xs, ts):
657
"""Writes the data in a LaTeX table.
658
659
fp: file pointer
660
xs: diameters in cm
661
ts: sequence of tuples of percentiles
662
"""
663
fp.write(r'\begin{tabular}{|r||r|r|r|r|r|}' '\n')
664
fp.write(r'\hline' '\n')
665
fp.write(r'Diameter & \multicolumn{5}{c|}{Percentiles of age} \\' '\n')
666
fp.write(r'(cm) & 5th & 25th & 50th & 75th & 95th \\' '\n')
667
fp.write(r'\hline' '\n')
668
669
for i, (cm, ps) in enumerate(zip(xs, ts)):
670
#print cm, ps
671
if i % 3 == 0:
672
PrintCI(fp, cm, ps)
673
674
fp.write(r'\hline' '\n')
675
fp.write(r'\end{tabular}' '\n')
676
677
678
def FitLine(xs, ys, fxs):
679
"""Fits a line to the xs and ys, and returns fitted values for fxs.
680
681
Applies a log transform to the xs.
682
683
xs: diameter in cm
684
ys: age in years
685
fxs: diameter in cm
686
"""
687
lxs = [math.log(x) for x in xs]
688
inter, slope = correlation.LeastSquares(lxs, ys)
689
# res = correlation.Residuals(lxs, ys, inter, slope)
690
# r2 = correlation.CoefDetermination(ys, res)
691
692
lfxs = [math.log(x) for x in fxs]
693
fys = [inter + slope * x for x in lfxs]
694
return fys
695
696
697
def MakeTicks(xs):
698
"""Makes a pair of sequences for use as pyplot ticks.
699
700
xs: sequence of floats
701
702
Returns (xs, labels), where labels is a sequence of strings.
703
"""
704
labels = [str(x) for x in xs]
705
return xs, labels
706
707
708
def MakeLogTicks(xs):
709
"""Makes a pair of sequences for use as pyplot ticks.
710
711
xs: sequence of floats
712
713
Returns (xs, labels), where labels is a sequence of strings.
714
"""
715
lxs = [math.log10(x) for x in xs]
716
labels = [str(x) for x in xs]
717
return lxs, labels
718
719
720
def TestCorrelation(cdf):
721
"""Tests the correlated generator.
722
723
Makes sure that the sequence has the right distribution and correlation.
724
"""
725
n = 10000
726
rho = 0.4
727
728
rdt_seq = CorrelatedGenerator(cdf, rho)
729
xs = [rdt_seq.next() for _ in range(n)]
730
731
rho2 = correlation.SerialCorr(xs)
732
print((rho, rho2))
733
cdf2 = thinkbayes2.MakeCdfFromList(xs)
734
735
thinkplot.Cdfs([cdf, cdf2])
736
thinkplot.Show()
737
738
739
def main(script):
740
for size in [1, 5, 10]:
741
bucket = CmToBucket(size)
742
print(('Size, bucket', size, bucket))
743
744
SimpleModel()
745
746
random.seed(17)
747
748
cdf = MakeCdf()
749
750
lam1 = FitCdf(cdf)
751
fit = GenerateCdf(lam1=lam1)
752
753
# TestCorrelation(fit)
754
755
PlotCdf(cdf)
756
# QQPlot(cdf, fit)
757
758
calc = Calculator()
759
rho = 0.0
760
sequences = calc.MakeSequences(100, rho, fit)
761
PlotSequences(sequences)
762
763
calc.PlotBuckets()
764
765
_ = calc.MakeSequences(1900, rho, fit)
766
print(('V0-RDT correlation', calc.cache.Correlation()))
767
768
print(('15.5 Probability age > 8 year', calc.cache.ProbOlder(15.5, 8)))
769
print(('6.0 Probability age > 8 year', calc.cache.ProbOlder(6.0, 8)))
770
771
calc.PlotConditionalCdfs()
772
773
calc.PlotCredibleIntervals(xscale='log')
774
775
calc.PlotJointDist()
776
777
778
if __name__ == '__main__':
779
main(*sys.argv)
780
781
782
783