Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/cppwin/TensorflowTTSCppInference/tfg2p.cpp
1559 views
1
#include "tfg2p.h"
2
#include <stdexcept>
3
TFG2P::TFG2P()
4
{
5
G2P = nullptr;
6
7
}
8
9
TFG2P::TFG2P(const std::string &SavedModelFolder)
10
{
11
G2P = nullptr;
12
13
Initialize(SavedModelFolder);
14
}
15
16
bool TFG2P::Initialize(const std::string &SavedModelFolder)
17
{
18
try {
19
20
G2P = new Model(SavedModelFolder);
21
22
}
23
catch (...) {
24
G2P = nullptr;
25
return false;
26
27
}
28
return true;
29
}
30
31
TFTensor<int32_t> TFG2P::DoInference(const std::vector<int32_t> &InputIDs, float Temperature)
32
{
33
if (!G2P)
34
throw std::invalid_argument("Tried to do inference on unloaded or invalid model!");
35
36
// Convenience reference so that we don't have to constantly derefer pointers.
37
Model& Mdl = *G2P;
38
39
40
// Convenience reference so that we don't have to constantly derefer pointers.
41
42
Tensor input_ids{ Mdl,"serving_default_input_ids" };
43
Tensor input_len{Mdl,"serving_default_input_len"};
44
Tensor input_temp{Mdl,"serving_default_input_temperature"};
45
46
input_ids.set_data(InputIDs, std::vector<int64_t>{(int64_t)InputIDs.size()});
47
input_len.set_data(std::vector<int32_t>{(int32_t)InputIDs.size()});
48
input_temp.set_data(std::vector<float>{Temperature});
49
50
51
52
std::vector<Tensor*> Inputs {&input_ids,&input_len,&input_temp};
53
Tensor out_ids{ Mdl,"StatefulPartitionedCall" };
54
55
Mdl.run(Inputs, out_ids);
56
57
TFTensor<int32_t> RetTensor = VoxUtil::CopyTensor<int32_t>(out_ids);
58
59
return RetTensor;
60
61
62
}
63
64
TFG2P::~TFG2P()
65
{
66
if (G2P)
67
delete G2P;
68
69
}
70
71