1
+ #define _CRT_SECURE_NO_WARNINGS
2
+ #include < iostream>
3
+ #include < fstream>
4
+ #include < numeric>
5
+ #include < opencv2/imgproc.hpp>
6
+ #include < opencv2/highgui.hpp>
7
+ // #include <cuda_provider_factory.h> ///nvidia-cuda加速
8
+ #include < onnxruntime_cxx_api.h>
9
+
10
+ using namespace cv ;
11
+ using namespace std ;
12
+ using namespace Ort ;
13
+
14
+ class CodeFormer
15
+ {
16
+ public:
17
+ CodeFormer (string modelpath);
18
+ Mat detect (Mat cv_image);
19
+ private:
20
+ void preprocess (Mat srcimg);
21
+ vector<float > input_image_;
22
+ vector<double > input2_tensor;
23
+ int inpWidth;
24
+ int inpHeight;
25
+ int outWidth;
26
+ int outHeight;
27
+
28
+ float min_max[2 ] = { -1 ,1 };
29
+
30
+ // 存储初始化获得的可执行网络
31
+ Env env = Env(ORT_LOGGING_LEVEL_ERROR, " CodeFormer" );
32
+ Ort::Session *ort_session = nullptr ;
33
+ SessionOptions sessionOptions = SessionOptions();
34
+ vector<char *> input_names;
35
+ vector<char *> output_names;
36
+ vector<vector<int64_t >> input_node_dims; // >=1 outputs
37
+ vector<vector<int64_t >> output_node_dims; // >=1 outputs
38
+ };
39
+
40
+ CodeFormer::CodeFormer (string model_path)
41
+ {
42
+ // OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///nvidia-cuda加速
43
+ sessionOptions.SetGraphOptimizationLevel (ORT_ENABLE_BASIC);
44
+ std::wstring widestr = std::wstring (model_path.begin (), model_path.end ()); // /如果在windows系统就这么写
45
+ ort_session = new Session (env, widestr.c_str (), sessionOptions); // /如果在windows系统就这么写
46
+ // /ort_session = new Session(env, model_path.c_str(), sessionOptions); ///如果在linux系统,就这么写
47
+
48
+ size_t numInputNodes = ort_session->GetInputCount ();
49
+ size_t numOutputNodes = ort_session->GetOutputCount ();
50
+ AllocatorWithDefaultOptions allocator;
51
+ for (int i = 0 ; i < numInputNodes; i++)
52
+ {
53
+ input_names.push_back (ort_session->GetInputName (i, allocator));
54
+ Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo (i);
55
+ auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo ();
56
+ auto input_dims = input_tensor_info.GetShape ();
57
+ input_node_dims.push_back (input_dims);
58
+ }
59
+ for (int i = 0 ; i < numOutputNodes; i++)
60
+ {
61
+ output_names.push_back (ort_session->GetOutputName (i, allocator));
62
+ Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo (i);
63
+ auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo ();
64
+ auto output_dims = output_tensor_info.GetShape ();
65
+ output_node_dims.push_back (output_dims);
66
+ }
67
+
68
+ this ->inpHeight = input_node_dims[0 ][2 ];
69
+ this ->inpWidth = input_node_dims[0 ][3 ];
70
+ this ->outHeight = output_node_dims[0 ][2 ];
71
+ this ->outWidth = output_node_dims[0 ][3 ];
72
+ input2_tensor.push_back (0.5 );
73
+ }
74
+
75
+ void CodeFormer::preprocess (Mat srcimg)
76
+ {
77
+ Mat dstimg;
78
+ cvtColor (srcimg, dstimg, COLOR_BGR2RGB);
79
+ resize (dstimg, dstimg, Size (this ->inpWidth , this ->inpHeight ), INTER_LINEAR);
80
+ this ->input_image_ .resize (this ->inpWidth * this ->inpHeight * dstimg.channels ());
81
+ int k = 0 ;
82
+ for (int c = 0 ; c < 3 ; c++)
83
+ {
84
+ for (int i = 0 ; i < this ->inpHeight ; i++)
85
+ {
86
+ for (int j = 0 ; j < this ->inpWidth ; j++)
87
+ {
88
+ float pix = dstimg.ptr <uchar>(i)[j * 3 + c];
89
+ this ->input_image_ [k] = (pix / 255.0 - 0.5 ) / 0.5 ;
90
+ k++;
91
+ }
92
+ }
93
+ }
94
+ }
95
+
96
+ Mat CodeFormer::detect (Mat srcimg)
97
+ {
98
+ int im_h = srcimg.rows ;
99
+ int im_w = srcimg.cols ;
100
+ this ->preprocess (srcimg);
101
+ array<int64_t , 4 > input_shape_{ 1 , 3 , this ->inpHeight , this ->inpWidth };
102
+ vector<int64_t > input2_shape_ = { 1 };
103
+
104
+ auto allocator_info = MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
105
+ vector<Value> ort_inputs;
106
+ ort_inputs.push_back (Value::CreateTensor<float >(allocator_info, input_image_.data (), input_image_.size (), input_shape_.data (), input_shape_.size ()));
107
+ ort_inputs.push_back (Value::CreateTensor<double >(allocator_info, input2_tensor.data (), input2_tensor.size (), input2_shape_.data (), input2_shape_.size ()));
108
+ vector<Value> ort_outputs = ort_session->Run (RunOptions{ nullptr }, input_names.data (), ort_inputs.data (), ort_inputs.size (), output_names.data (), output_names.size ());
109
+
110
+ // //post_process
111
+ float * pred = ort_outputs[0 ].GetTensorMutableData <float >();
112
+ // ////Mat mask(outHeight, outWidth, CV_32FC3, pred); /////经过试验,直接这样赋值,是不行的
113
+ const unsigned int channel_step = outHeight * outWidth;
114
+ vector<Mat> channel_mats;
115
+ Mat rmat (outHeight, outWidth, CV_32FC1, pred); // R
116
+ Mat gmat (outHeight, outWidth, CV_32FC1, pred + channel_step); // G
117
+ Mat bmat (outHeight, outWidth, CV_32FC1, pred + 2 * channel_step); // B
118
+ channel_mats.push_back (rmat);
119
+ channel_mats.push_back (gmat);
120
+ channel_mats.push_back (bmat);
121
+ Mat mask;
122
+ merge (channel_mats, mask); // CV_32FC3 allocated
123
+
124
+ // /不用for循环遍历Mat里的每个像素值,实现numpy.clip函数
125
+ mask.setTo (this ->min_max [0 ], mask < this ->min_max [0 ]);
126
+ mask.setTo (this ->min_max [1 ], mask > this ->min_max [1 ]); // //也可以用threshold函数,阈值类型THRESH_TOZERO_INV
127
+
128
+ mask = (mask - this ->min_max [0 ]) / (this ->min_max [1 ] - this ->min_max [0 ]);
129
+ mask *= 255.0 ;
130
+ mask.convertTo (mask, CV_8UC3);
131
+ cvtColor (mask, mask, COLOR_BGR2RGB);
132
+ return mask;
133
+ }
134
+
135
+ int main ()
136
+ {
137
+ CodeFormer mynet (" codeformer.onnx" );
138
+ string imgpath = " input.png" ;
139
+ Mat srcimg = imread (imgpath);
140
+ Mat dstimg = mynet.detect (srcimg);
141
+ resize (dstimg, dstimg, Size (srcimg.cols , srcimg.rows ), INTER_LINEAR);
142
+
143
+ // imwrite("result.jpg", dstimg)
144
+ namedWindow (" srcimg" , WINDOW_NORMAL);
145
+ imshow (" srcimg" , srcimg);
146
+ namedWindow (" dstimg" , WINDOW_NORMAL);
147
+ imshow (" dstimg" , dstimg);
148
+ waitKey (0 );
149
+ destroyAllWindows ();
150
+ }
0 commit comments