Path: blob/master/examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/src/Model.cpp
1558 views
//1// Created by sergio on 12/05/19.2//34#include "../include/Model.h"56Model::Model(const std::string& model_filename, const std::vector<uint8_t>& config_options) {7this->status = TF_NewStatus();8this->graph = TF_NewGraph();910// Create the session.11TF_SessionOptions* sess_opts = TF_NewSessionOptions();1213if (!config_options.empty())14{15TF_SetConfig(sess_opts, static_cast<const void*>(config_options.data()), config_options.size(), this->status);16this->status_check(true);17}1819TF_Buffer* RunOpts = NULL;2021const char* tags = "serve";22int ntags = 1;2324this->session = TF_LoadSessionFromSavedModel(sess_opts, RunOpts, model_filename.c_str(), &tags, ntags, this->graph, NULL, this->status);25if (TF_GetCode(this->status) == TF_OK)26{27printf("TF_LoadSessionFromSavedModel OK\n");28}29else30{31printf("%s", TF_Message(this->status));32}33TF_DeleteSessionOptions(sess_opts);3435// Check the status36this->status_check(true);3738// Create the graph39TF_Graph* g = this->graph;404142this->status_check(true);43}4445Model::~Model() {46TF_DeleteSession(this->session, this->status);47TF_DeleteGraph(this->graph);48this->status_check(true);49TF_DeleteStatus(this->status);50}515253void Model::init() {54TF_Operation* init_op[1] = {TF_GraphOperationByName(this->graph, "init")};5556this->error_check(init_op[0]!= nullptr, "Error: No operation named \"init\" exists");5758TF_SessionRun(this->session, nullptr, nullptr, nullptr, 0, nullptr, nullptr, 0, init_op, 1, nullptr, this->status);59this->status_check(true);60}6162void Model::save(const std::string &ckpt) {63// Encode file_name to tensor64size_t size = 8 + TF_StringEncodedSize(ckpt.length());65TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);66char* data = static_cast<char *>(TF_TensorData(t));67for (int i=0; i<8; i++) {data[i]=0;}68TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);6970memset(data, 0, 8); // 8-byte offset of first string.71TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, status);7273// Check errors74if (!this->status_check(false)) {75TF_DeleteTensor(t);76std::cerr << "Error during filename " << ckpt << " encoding" << std::endl;77this->status_check(true);78}7980TF_Output output_file;81output_file.oper = TF_GraphOperationByName(this->graph, "save/Const");82output_file.index = 0;83TF_Output inputs[1] = {output_file};8485TF_Tensor* input_values[1] = {t};86const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph, "save/control_dependency")};87if (!restore_op[0]) {88TF_DeleteTensor(t);89this->error_check(false, "Error: No operation named \"save/control_dependencyl\" exists");90}919293TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);94TF_DeleteTensor(t);9596this->status_check(true);97}9899void Model::restore_savedmodel(const std::string & savedmdl)100{101102103104}105106void Model::restore(const std::string& ckpt) {107108// Encode file_name to tensor109size_t size = 8 + TF_StringEncodedSize(ckpt.size());110TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);111char* data = static_cast<char *>(TF_TensorData(t));112for (int i=0; i<8; i++) {data[i]=0;}113TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);114115// Check errors116if (!this->status_check(false)) {117TF_DeleteTensor(t);118std::cerr << "Error during filename " << ckpt << " encoding" << std::endl;119this->status_check(true);120}121122TF_Output output_file;123output_file.oper = TF_GraphOperationByName(this->graph, "save/Const");124output_file.index = 0;125TF_Output inputs[1] = {output_file};126127TF_Tensor* input_values[1] = {t};128const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph, "save/restore_all")};129if (!restore_op[0]) {130TF_DeleteTensor(t);131this->error_check(false, "Error: No operation named \"save/restore_all\" exists");132}133134135136TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);137TF_DeleteTensor(t);138139this->status_check(true);140}141142TF_Buffer *Model::read(const std::string& filename) {143std::ifstream file (filename, std::ios::binary | std::ios::ate);144145// Error opening the file146if (!file.is_open()) {147std::cerr << "Unable to open file: " << filename << std::endl;148return nullptr;149}150151152// Cursor is at the end to get size153auto size = file.tellg();154// Move cursor to the beginning155file.seekg (0, std::ios::beg);156157// Read158auto data = new char [size];159file.seekg (0, std::ios::beg);160file.read (data, size);161162// Error reading the file163if (!file) {164std::cerr << "Unable to read the full file: " << filename << std::endl;165return nullptr;166}167168169// Create tensorflow buffer from read data170TF_Buffer* buffer = TF_NewBufferFromString(data, size);171172// Close file and remove data173file.close();174delete[] data;175176return buffer;177}178179std::vector<std::string> Model::get_operations() const {180std::vector<std::string> result;181size_t pos = 0;182TF_Operation* oper;183184// Iterate through the operations of a graph185while ((oper = TF_GraphNextOperation(this->graph, &pos)) != nullptr) {186result.emplace_back(TF_OperationName(oper));187}188189return result;190}191192void Model::run(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {193194this->error_check(std::all_of(inputs.begin(), inputs.end(), [](const Tensor* i){return i->flag == 1;}),195"Error: Not all elements from the inputs are full");196197this->error_check(std::all_of(outputs.begin(), outputs.end(), [](const Tensor* o){return o->flag != -1;}),198"Error: Not all outputs Tensors are valid");199200201// Clean previous stored outputs202std::for_each(outputs.begin(), outputs.end(), [](Tensor* o){o->clean();});203204// Get input operations205std::vector<TF_Output> io(inputs.size());206std::transform(inputs.begin(), inputs.end(), io.begin(), [](const Tensor* i) {return i->op;});207208// Get input values209std::vector<TF_Tensor*> iv(inputs.size());210std::transform(inputs.begin(), inputs.end(), iv.begin(), [](const Tensor* i) {return i->val;});211212// Get output operations213std::vector<TF_Output> oo(outputs.size());214std::transform(outputs.begin(), outputs.end(), oo.begin(), [](const Tensor* o) {return o->op;});215216// Prepare output recipients217auto ov = new TF_Tensor*[outputs.size()];218219TF_SessionRun(this->session, nullptr, io.data(), iv.data(), inputs.size(), oo.data(), ov, outputs.size(), nullptr, 0, nullptr, this->status);220this->status_check(true);221222// Save results on outputs and mark as full223for (std::size_t i=0; i<outputs.size(); i++) {224outputs[i]->val = ov[i];225outputs[i]->flag = 1;226outputs[i]->deduce_shape();227}228229// Mark input as empty230std::for_each(inputs.begin(), inputs.end(), [] (Tensor* i) {i->clean();});231232delete[] ov;233}234235void Model::run(Tensor &input, Tensor &output) {236this->run(&input, &output);237}238239void Model::run(const std::vector<Tensor*> &inputs, Tensor &output) {240this->run(inputs, &output);241}242243void Model::run(Tensor &input, const std::vector<Tensor*> &outputs) {244this->run(&input, outputs);245}246247void Model::run(Tensor *input, Tensor *output) {248this->run(std::vector<Tensor*>({input}), std::vector<Tensor*>({output}));249}250251void Model::run(const std::vector<Tensor*> &inputs, Tensor *output) {252this->run(inputs, std::vector<Tensor*>({output}));253}254255void Model::run(Tensor *input, const std::vector<Tensor*> &outputs) {256this->run(std::vector<Tensor*>({input}), outputs);257}258259bool Model::status_check(bool throw_exc) const {260261if (TF_GetCode(this->status) != TF_OK) {262if (throw_exc) {263const char* errmsg = TF_Message(status);264printf(errmsg);265throw std::runtime_error(errmsg);266} else {267return false;268}269}270return true;271}272273void Model::error_check(bool condition, const std::string &error) const {274if (!condition) {275throw std::runtime_error(error);276}277}278279280