Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/src/Model.cpp
1558 views
1
//
2
// Created by sergio on 12/05/19.
3
//
4
5
#include "../include/Model.h"
6
7
Model::Model(const std::string& model_filename, const std::vector<uint8_t>& config_options) {
8
this->status = TF_NewStatus();
9
this->graph = TF_NewGraph();
10
11
// Create the session.
12
TF_SessionOptions* sess_opts = TF_NewSessionOptions();
13
14
if (!config_options.empty())
15
{
16
TF_SetConfig(sess_opts, static_cast<const void*>(config_options.data()), config_options.size(), this->status);
17
this->status_check(true);
18
}
19
20
TF_Buffer* RunOpts = NULL;
21
22
const char* tags = "serve";
23
int ntags = 1;
24
25
this->session = TF_LoadSessionFromSavedModel(sess_opts, RunOpts, model_filename.c_str(), &tags, ntags, this->graph, NULL, this->status);
26
if (TF_GetCode(this->status) == TF_OK)
27
{
28
printf("TF_LoadSessionFromSavedModel OK\n");
29
}
30
else
31
{
32
printf("%s", TF_Message(this->status));
33
}
34
TF_DeleteSessionOptions(sess_opts);
35
36
// Check the status
37
this->status_check(true);
38
39
// Create the graph
40
TF_Graph* g = this->graph;
41
42
43
this->status_check(true);
44
}
45
46
Model::~Model() {
47
TF_DeleteSession(this->session, this->status);
48
TF_DeleteGraph(this->graph);
49
this->status_check(true);
50
TF_DeleteStatus(this->status);
51
}
52
53
54
void Model::init() {
55
TF_Operation* init_op[1] = {TF_GraphOperationByName(this->graph, "init")};
56
57
this->error_check(init_op[0]!= nullptr, "Error: No operation named \"init\" exists");
58
59
TF_SessionRun(this->session, nullptr, nullptr, nullptr, 0, nullptr, nullptr, 0, init_op, 1, nullptr, this->status);
60
this->status_check(true);
61
}
62
63
void Model::save(const std::string &ckpt) {
64
// Encode file_name to tensor
65
size_t size = 8 + TF_StringEncodedSize(ckpt.length());
66
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
67
char* data = static_cast<char *>(TF_TensorData(t));
68
for (int i=0; i<8; i++) {data[i]=0;}
69
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
70
71
memset(data, 0, 8); // 8-byte offset of first string.
72
TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, status);
73
74
// Check errors
75
if (!this->status_check(false)) {
76
TF_DeleteTensor(t);
77
std::cerr << "Error during filename " << ckpt << " encoding" << std::endl;
78
this->status_check(true);
79
}
80
81
TF_Output output_file;
82
output_file.oper = TF_GraphOperationByName(this->graph, "save/Const");
83
output_file.index = 0;
84
TF_Output inputs[1] = {output_file};
85
86
TF_Tensor* input_values[1] = {t};
87
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph, "save/control_dependency")};
88
if (!restore_op[0]) {
89
TF_DeleteTensor(t);
90
this->error_check(false, "Error: No operation named \"save/control_dependencyl\" exists");
91
}
92
93
94
TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);
95
TF_DeleteTensor(t);
96
97
this->status_check(true);
98
}
99
100
void Model::restore_savedmodel(const std::string & savedmdl)
101
{
102
103
104
105
}
106
107
void Model::restore(const std::string& ckpt) {
108
109
// Encode file_name to tensor
110
size_t size = 8 + TF_StringEncodedSize(ckpt.size());
111
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
112
char* data = static_cast<char *>(TF_TensorData(t));
113
for (int i=0; i<8; i++) {data[i]=0;}
114
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
115
116
// Check errors
117
if (!this->status_check(false)) {
118
TF_DeleteTensor(t);
119
std::cerr << "Error during filename " << ckpt << " encoding" << std::endl;
120
this->status_check(true);
121
}
122
123
TF_Output output_file;
124
output_file.oper = TF_GraphOperationByName(this->graph, "save/Const");
125
output_file.index = 0;
126
TF_Output inputs[1] = {output_file};
127
128
TF_Tensor* input_values[1] = {t};
129
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph, "save/restore_all")};
130
if (!restore_op[0]) {
131
TF_DeleteTensor(t);
132
this->error_check(false, "Error: No operation named \"save/restore_all\" exists");
133
}
134
135
136
137
TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);
138
TF_DeleteTensor(t);
139
140
this->status_check(true);
141
}
142
143
TF_Buffer *Model::read(const std::string& filename) {
144
std::ifstream file (filename, std::ios::binary | std::ios::ate);
145
146
// Error opening the file
147
if (!file.is_open()) {
148
std::cerr << "Unable to open file: " << filename << std::endl;
149
return nullptr;
150
}
151
152
153
// Cursor is at the end to get size
154
auto size = file.tellg();
155
// Move cursor to the beginning
156
file.seekg (0, std::ios::beg);
157
158
// Read
159
auto data = new char [size];
160
file.seekg (0, std::ios::beg);
161
file.read (data, size);
162
163
// Error reading the file
164
if (!file) {
165
std::cerr << "Unable to read the full file: " << filename << std::endl;
166
return nullptr;
167
}
168
169
170
// Create tensorflow buffer from read data
171
TF_Buffer* buffer = TF_NewBufferFromString(data, size);
172
173
// Close file and remove data
174
file.close();
175
delete[] data;
176
177
return buffer;
178
}
179
180
std::vector<std::string> Model::get_operations() const {
181
std::vector<std::string> result;
182
size_t pos = 0;
183
TF_Operation* oper;
184
185
// Iterate through the operations of a graph
186
while ((oper = TF_GraphNextOperation(this->graph, &pos)) != nullptr) {
187
result.emplace_back(TF_OperationName(oper));
188
}
189
190
return result;
191
}
192
193
void Model::run(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
194
195
this->error_check(std::all_of(inputs.begin(), inputs.end(), [](const Tensor* i){return i->flag == 1;}),
196
"Error: Not all elements from the inputs are full");
197
198
this->error_check(std::all_of(outputs.begin(), outputs.end(), [](const Tensor* o){return o->flag != -1;}),
199
"Error: Not all outputs Tensors are valid");
200
201
202
// Clean previous stored outputs
203
std::for_each(outputs.begin(), outputs.end(), [](Tensor* o){o->clean();});
204
205
// Get input operations
206
std::vector<TF_Output> io(inputs.size());
207
std::transform(inputs.begin(), inputs.end(), io.begin(), [](const Tensor* i) {return i->op;});
208
209
// Get input values
210
std::vector<TF_Tensor*> iv(inputs.size());
211
std::transform(inputs.begin(), inputs.end(), iv.begin(), [](const Tensor* i) {return i->val;});
212
213
// Get output operations
214
std::vector<TF_Output> oo(outputs.size());
215
std::transform(outputs.begin(), outputs.end(), oo.begin(), [](const Tensor* o) {return o->op;});
216
217
// Prepare output recipients
218
auto ov = new TF_Tensor*[outputs.size()];
219
220
TF_SessionRun(this->session, nullptr, io.data(), iv.data(), inputs.size(), oo.data(), ov, outputs.size(), nullptr, 0, nullptr, this->status);
221
this->status_check(true);
222
223
// Save results on outputs and mark as full
224
for (std::size_t i=0; i<outputs.size(); i++) {
225
outputs[i]->val = ov[i];
226
outputs[i]->flag = 1;
227
outputs[i]->deduce_shape();
228
}
229
230
// Mark input as empty
231
std::for_each(inputs.begin(), inputs.end(), [] (Tensor* i) {i->clean();});
232
233
delete[] ov;
234
}
235
236
void Model::run(Tensor &input, Tensor &output) {
237
this->run(&input, &output);
238
}
239
240
void Model::run(const std::vector<Tensor*> &inputs, Tensor &output) {
241
this->run(inputs, &output);
242
}
243
244
void Model::run(Tensor &input, const std::vector<Tensor*> &outputs) {
245
this->run(&input, outputs);
246
}
247
248
void Model::run(Tensor *input, Tensor *output) {
249
this->run(std::vector<Tensor*>({input}), std::vector<Tensor*>({output}));
250
}
251
252
void Model::run(const std::vector<Tensor*> &inputs, Tensor *output) {
253
this->run(inputs, std::vector<Tensor*>({output}));
254
}
255
256
void Model::run(Tensor *input, const std::vector<Tensor*> &outputs) {
257
this->run(std::vector<Tensor*>({input}), outputs);
258
}
259
260
bool Model::status_check(bool throw_exc) const {
261
262
if (TF_GetCode(this->status) != TF_OK) {
263
if (throw_exc) {
264
const char* errmsg = TF_Message(status);
265
printf(errmsg);
266
throw std::runtime_error(errmsg);
267
} else {
268
return false;
269
}
270
}
271
return true;
272
}
273
274
void Model::error_check(bool condition, const std::string &error) const {
275
if (!condition) {
276
throw std::runtime_error(error);
277
}
278
}
279
280