class SR_net { public: SR_net(string path, vector<int> input_size, bool fp32, bool cuda = true); private: vector<int64_t> Gdims; int Gfp32; Env env = Env(ORT_LOGGING_LEVEL_ERROR, "RRDB"); SessionOptions session_options = SessionOptions(); Session* Gsession = nullptr; vector<const char*> Ginput_names; vector<const char*> Goutput_names; vector<int> Ginput_size = {}; }; SR_net::SR_net(string path, vector<int> input_size, bool fp32, bool cuda) { this->Ginput_size = input_size; this->Gfp32 = fp32; clock_t startTime_, endTime_; startTime_ = clock(); session_options.SetIntraOpNumThreads(6); if (cuda) { OrtCUDAProviderOptions cuda_option; cuda_option.device_id = 0; cuda_option.arena_extend_strategy = 0; cuda_option.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; cuda_option.gpu_mem_limit = SIZE_MAX; cuda_option.do_copy_in_default_stream = 1; session_options.AppendExecutionProvider_CUDA(cuda_option); } wstring widestr = wstring(path.begin(), path.end()); this->Gsession = new Session(env, widestr.c_str(), this->session_options); this->session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); AllocatorWithDefaultOptions allocator; this->Ginput_names = { "input" }; this->Goutput_names = { "output" }; endTime_ = clock(); cout << " The model loading time is:" << (double)(endTime_ - startTime_) / CLOCKS_PER_SEC << "s" << endl; } int main() { vector<int> input_shape = {}; SR_net net("E:/prj/SR_C/onnx_file/rrdb_full.onnx", input_shape, true, true); vector<String> files; glob("E:/prj/超分样本/1", files, true); size_t num = files.size(); bool Moos = true; cout << "共读取了" << num << "张图片" << endl; cout << "--------------------------------" << endl; for (int i = 0; i < num; i++) { Mat srcimg = imread(files[i]); Mat SR_image = net.Detect(srcimg, Moos); imshow("input", srcimg); imshow("result", SR_image); imwrite("./output/" + to_string(i + 1) + ".png", SR_image); waitKey(0); } },在这段代码中,我如何把SR_net net("E:/prj/SR_C/onnx_file/rrdb_full.onnx", input_shape, true, true);这一行写到主函数的外面?
时间: 2024-01-31 18:04:11 浏览: 152
可以在 SR_net 类的定义中增加一个默认构造函数,并在这个构造函数中调用 SR_net(string path, vector<int> input_size, bool fp32, bool cuda) 构造函数,如下所示:
```c++
class SR_net {
public:
SR_net() : SR_net("E:/prj/SR_C/onnx_file/rrdb_full.onnx", {}, true, true) {}
SR_net(string path, vector<int> input_size, bool fp32, bool cuda = true);
// 其他成员函数和成员变量
};
int main() {
SR_net net;
// 其他代码
return 0;
}
```
这样,在主函数中就可以直接使用 SR_net net; 创建对象了。默认构造函数中的参数可以根据实际情况进行修改。
阅读全文