Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/ins_tf13.py
809 views
1
''' Tensorflow inception score code
2
Derived from https://github.com/openai/improved-gan
3
Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
4
THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE
5
To use this code, run sample.py on your model with --sample_npz, and then
6
pass the experiment name in the --experiment_name.
7
This code also saves pool3 stats to an npz file for FID calculation
8
'''
9
10
from __future__ import absolute_import
11
from __future__ import division
12
from __future__ import print_function
13
14
import os.path
15
import sys
16
import tarfile
17
import math
18
from argparse import ArgumentParser
19
20
from six.moves import urllib
21
from tqdm import tqdm, trange
22
import tensorflow as tf
23
import numpy as np
24
25
MODEL_DIR = './inception_model'
26
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
27
softmax = None
28
29
30
def prepare_parser():
31
usage = 'Parser for TF1.3- Inception Score scripts.'
32
parser = ArgumentParser(description=usage)
33
parser.add_argument('--run_name',
34
type=str,
35
default='',
36
help='Which experiment'
37
's samples.npz file to pull and evaluate')
38
parser.add_argument('--type', type=str, default='', help='[real, fake]')
39
parser.add_argument('--batch_size', type=int, default=500, help='Default overall batchsize (default: %(default)s)')
40
return parser
41
42
43
def run(config):
44
# Inception with TF1.3 or earlier.
45
# Call this function with list of images. Each of elements should be a
46
# numpy array with values ranging from 0 to 255.
47
def get_inception_score(images, splits=10):
48
assert (type(images) == list)
49
assert (type(images[0]) == np.ndarray)
50
assert (len(images[0].shape) == 3)
51
assert (np.max(images[0]) > 10)
52
assert (np.min(images[0]) >= 0.0)
53
inps = []
54
for img in images:
55
img = img.astype(np.float32)
56
inps.append(np.expand_dims(img, 0))
57
bs = config['batch_size']
58
with tf.Session() as sess:
59
preds, pools = [], []
60
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
61
for i in trange(n_batches):
62
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
63
inp = np.concatenate(inp, 0)
64
pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp})
65
preds.append(pred)
66
pools.append(pool)
67
preds = np.concatenate(preds, 0)
68
scores = []
69
for i in range(splits):
70
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
71
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
72
kl = np.mean(np.sum(kl, 1))
73
scores.append(np.exp(kl))
74
return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0))
75
76
# Init inception
77
def _init_inception():
78
global softmax, pool3
79
if not os.path.exists(MODEL_DIR):
80
os.makedirs(MODEL_DIR)
81
filename = DATA_URL.split('/')[-1]
82
filepath = os.path.join(MODEL_DIR, filename)
83
if not os.path.exists(filepath):
84
85
def _progress(count, block_size, total_size):
86
sys.stdout.write('\r>> Downloading %s %.1f%%' %
87
(filename, float(count * block_size) / float(total_size) * 100.0))
88
sys.stdout.flush()
89
90
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
91
print()
92
statinfo = os.stat(filepath)
93
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
94
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
95
with tf.gfile.FastGFile(os.path.join(MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
96
graph_def = tf.GraphDef()
97
graph_def.ParseFromString(f.read())
98
_ = tf.import_graph_def(graph_def, name='')
99
# Works with an arbitrary minibatch size.
100
with tf.Session() as sess:
101
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
102
ops = pool3.graph.get_operations()
103
for op_idx, op in enumerate(ops):
104
for o in op.outputs:
105
shape = o.get_shape()
106
shape = [s.value for s in shape]
107
new_shape = []
108
for j, s in enumerate(shape):
109
if s == 1 and j == 0:
110
new_shape.append(None)
111
else:
112
new_shape.append(s)
113
o._shape = tf.TensorShape(new_shape)
114
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
115
logits = tf.matmul(tf.squeeze(pool3), w)
116
softmax = tf.nn.softmax(logits)
117
118
# if softmax is None: # No need to functionalize like this.
119
_init_inception()
120
121
fname = '%s/%s/%s/%s/samples.npz' % ("samples", config['run_name'], config['type'], "npz")
122
print('loading %s ...' % fname)
123
ims = np.load(fname)['x']
124
import time
125
t0 = time.time()
126
inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1, 2).swapaxes(2, 3)), splits=1)
127
t1 = time.time()
128
print('Inception took %3f seconds, score of %3f +/- %3f.' % (t1 - t0, inc_mean, inc_std))
129
130
131
def main():
132
# parse command line and run
133
parser = prepare_parser()
134
config = vars(parser.parse_args())
135
print(config)
136
run(config)
137
138
139
if __name__ == '__main__':
140
main()
141
142