Path: blob/master/examples/cppwin/TensorflowTTSCppInference/tfg2p.cpp
1559 views
#include "tfg2p.h"1#include <stdexcept>2TFG2P::TFG2P()3{4G2P = nullptr;56}78TFG2P::TFG2P(const std::string &SavedModelFolder)9{10G2P = nullptr;1112Initialize(SavedModelFolder);13}1415bool TFG2P::Initialize(const std::string &SavedModelFolder)16{17try {1819G2P = new Model(SavedModelFolder);2021}22catch (...) {23G2P = nullptr;24return false;2526}27return true;28}2930TFTensor<int32_t> TFG2P::DoInference(const std::vector<int32_t> &InputIDs, float Temperature)31{32if (!G2P)33throw std::invalid_argument("Tried to do inference on unloaded or invalid model!");3435// Convenience reference so that we don't have to constantly derefer pointers.36Model& Mdl = *G2P;373839// Convenience reference so that we don't have to constantly derefer pointers.4041Tensor input_ids{ Mdl,"serving_default_input_ids" };42Tensor input_len{Mdl,"serving_default_input_len"};43Tensor input_temp{Mdl,"serving_default_input_temperature"};4445input_ids.set_data(InputIDs, std::vector<int64_t>{(int64_t)InputIDs.size()});46input_len.set_data(std::vector<int32_t>{(int32_t)InputIDs.size()});47input_temp.set_data(std::vector<float>{Temperature});48495051std::vector<Tensor*> Inputs {&input_ids,&input_len,&input_temp};52Tensor out_ids{ Mdl,"StatefulPartitionedCall" };5354Mdl.run(Inputs, out_ids);5556TFTensor<int32_t> RetTensor = VoxUtil::CopyTensor<int32_t>(out_ids);5758return RetTensor;596061}6263TFG2P::~TFG2P()64{65if (G2P)66delete G2P;6768}697071