main.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #include "onnxruntime_cxx_api.h"
  2. #include <iostream>
  3. #include <assert.h>
  4. using namespace std;
  5. using namespace Ort;
  6. #include <fstream>
  7. #include <vector>
  8. #include <cstdint>
  9. #include <list>
  10. #include <string>
  11. struct WavHeader {
  12. char riff[4]; // "RIFF"
  13. uint32_t chunkSize;
  14. char wave[4]; // "WAVE"
  15. char fmt[4]; // "fmt "
  16. uint32_t subchunk1Size;
  17. uint16_t audioFormat;
  18. uint16_t numChannels;
  19. uint32_t sampleRate;
  20. uint32_t byteRate;
  21. uint16_t blockAlign;
  22. uint16_t bitsPerSample;
  23. char data[4]; // "data"
  24. uint32_t subchunk2Size;
  25. };
  26. void floatToWav(const std::vector<float>& data, const std::string& filename, uint32_t sampleRate = 22050, uint16_t numChannels = 1) {
  27. std::ofstream file(filename, std::ios::binary);
  28. if (!file) {
  29. std::cerr << "Failed to open file for writing." << std::endl;
  30. return;
  31. }
  32. WavHeader header;
  33. std::copy(std::begin("RIFF"), std::end("RIFF"), header.riff);
  34. header.chunkSize = sizeof(WavHeader) - 8 + data.size() * sizeof(float);
  35. std::copy(std::begin("WAVE"), std::end("WAVE"), header.wave);
  36. std::copy(std::begin("fmt "), std::end("fmt "), header.fmt);
  37. header.subchunk1Size = 16;
  38. header.audioFormat = 3; // IEEE float
  39. header.numChannels = numChannels;
  40. header.sampleRate = sampleRate;
  41. header.bitsPerSample = 32; // 32-bit float
  42. header.byteRate = header.sampleRate * header.numChannels * header.bitsPerSample / 8;
  43. header.blockAlign = header.numChannels * header.bitsPerSample / 8;
  44. std::copy(std::begin("data"), std::end("data"), header.data);
  45. header.subchunk2Size = data.size() * sizeof(float);
  46. file.write(reinterpret_cast<char*>(&header), sizeof(header));
  47. file.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(float));
  48. if (!file) {
  49. std::cerr << "Error writing to file." << std::endl;
  50. }
  51. }
  52. void writeVectorToFile(const std::vector<float>& data, const std::string& filename) {
  53. std::ofstream outFile(filename);
  54. if (!outFile) {
  55. std::cerr << "Error opening file for writing: " << filename << std::endl;
  56. return;
  57. }
  58. for (const auto& item : data) {
  59. outFile << item << std::endl;
  60. }
  61. outFile.close();
  62. }
  63. int main() {
  64. // Create a new environment
  65. Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "test");
  66. // Create a new session and load the model
  67. Ort::SessionOptions session_options;
  68. session_options.SetIntraOpNumThreads(1);
  69. const char* model_path = "./model_184000_audio_len.onnx";
  70. Ort::Session session(env, model_path, session_options);
  71. Ort::AllocatorWithDefaultOptions allocator;
  72. //model info
  73. // 获得模型又多少个输入和输出,一般是指对应网络层的数目
  74. // 一般输入只有图像的话input_nodes为1
  75. size_t num_input_nodes = session.GetInputCount();
  76. // 如果是多输出网络,就会是对应输出的数目
  77. size_t num_output_nodes = session.GetOutputCount();
  78. printf("Number of inputs = %zu\n", num_input_nodes);
  79. printf("Number of output = %zu\n", num_output_nodes);
  80. // 自动获取维度数量
  81. auto input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
  82. auto output_dims = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
  83. std::cout << "input_dims:" << input_dims[0] << std::endl;
  84. std::cout << "output_dims:" << output_dims[0] << std::endl;
  85. std::vector<const char*> input_node_names = {"input", "input_lengths", "scales", "sid"};
  86. std::vector<const char*> output_node_names = {"output", "output_lengths"};
  87. // printf("inputs init\n");
  88. // Input text
  89. string text = "一号哨,发生犯人爆狱!";
  90. // python预处理之后的 输入数据----------------------------------------------------------------------------------------------------------
  91. int64_t input_data[] = {0, 51, 0, 198, 0, 66, 0, 96, 0, 162, 0, 196, 0, 16,
  92. 0, 61, 0, 96, 0, 162, 0, 196, 0, 3, 0, 16, 0, 48,
  93. 0, 43, 0, 198, 0, 61, 0, 110, 0, 139, 0, 198, 0, 16,
  94. 0, 48, 0, 43, 0, 56, 0, 196, 0, 150, 0, 110, 0, 56,
  95. 0, 197, 0, 16, 0, 58, 0, 96, 0, 162, 0, 196, 0, 126,
  96. 0, 196, 0, 5, 0};
  97. std::vector<int64_t> input_node_dims = {1, sizeof(input_data)/sizeof(input_data[0])};
  98. size_t input_tensor_size = sizeof(input_data)/sizeof(input_data[0]);
  99. std::vector<int32_t> input_tensor_values(input_tensor_size);
  100. // Prepare input data
  101. auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  102. // Create input tensor object from data
  103. Ort::Value input_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, input_data, input_tensor_size, input_node_dims.data(), input_node_dims.size());
  104. assert(input_tensor.IsTensor());
  105. std::vector<int64_t> input_lengh_dims = {1};
  106. int64_t input_lengths[] = {sizeof(input_data)/sizeof(input_data[0])};
  107. auto lengths_memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  108. Ort::Value input_lengths_tensor = Ort::Value::CreateTensor<int64_t>(lengths_memory_info, input_lengths, 1, input_lengh_dims.data(), input_lengh_dims.size());
  109. assert(input_lengths_tensor.IsTensor());
  110. std::vector<int64_t> scales_dims = {3};
  111. std::vector<float> scales_data = {0.667, 0.8, 1.0};
  112. auto scales_memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  113. Ort::Value scales_tensor = Ort::Value::CreateTensor<float>(scales_memory_info, scales_data.data(), scales_data.size(), scales_dims.data(), scales_dims.size());
  114. assert(scales_tensor.IsTensor());
  115. std::vector<int64_t> sid_dims = {1};
  116. int64_t sid[] = {25};
  117. auto sid_memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  118. Ort::Value sid_tensor = Ort::Value::CreateTensor<int64_t>(sid_memory_info, sid, 1, sid_dims.data(), sid_dims.size());
  119. assert(sid_tensor.IsTensor());
  120. std::vector<Ort::Value> ort_inputs;
  121. ort_inputs.push_back(std::move(input_tensor));
  122. ort_inputs.push_back(std::move(input_lengths_tensor));
  123. ort_inputs.push_back(std::move(scales_tensor));
  124. ort_inputs.push_back(std::move(sid_tensor));
  125. // Run model
  126. auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), output_node_names.size());
  127. // Get pointer to output tensor
  128. float* audio = output_tensors[0].GetTensorMutableData<float>();
  129. int* audio_lengths = output_tensors[1].GetTensorMutableData<int32_t>();
  130. // Print output
  131. int len = audio_lengths[0]*256;
  132. printf("audio_length: %d\n", len);
  133. std::vector<float> audioData(len);
  134. for(int i=0; i<len; i++){
  135. audioData[i] = audio[i];
  136. }
  137. //writeVectorToFile(audioData, "datas_cpp.txt");
  138. floatToWav(audioData, "output.wav");
  139. return 0;
  140. }