Skip to content

Commit f29da2f

Browse files
authored
Add files via upload
0 parents  commit f29da2f

File tree

3 files changed

+230
-0
lines changed

3 files changed

+230
-0
lines changed

‎input.png

261 KB
Loading

‎main.cpp

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#define _CRT_SECURE_NO_WARNINGS
2+
#include <iostream>
3+
#include <fstream>
4+
#include <numeric>
5+
#include <opencv2/imgproc.hpp>
6+
#include <opencv2/highgui.hpp>
7+
//#include <cuda_provider_factory.h> ///nvidia-cuda加速
8+
#include <onnxruntime_cxx_api.h>
9+
10+
using namespace cv;
11+
using namespace std;
12+
using namespace Ort;
13+
14+
class CodeFormer
15+
{
16+
public:
17+
CodeFormer(string modelpath);
18+
Mat detect(Mat cv_image);
19+
private:
20+
void preprocess(Mat srcimg);
21+
vector<float> input_image_;
22+
vector<double> input2_tensor;
23+
int inpWidth;
24+
int inpHeight;
25+
int outWidth;
26+
int outHeight;
27+
28+
float min_max[2] = { -1,1 };
29+
30+
//存储初始化获得的可执行网络
31+
Env env = Env(ORT_LOGGING_LEVEL_ERROR, "CodeFormer");
32+
Ort::Session *ort_session = nullptr;
33+
SessionOptions sessionOptions = SessionOptions();
34+
vector<char*> input_names;
35+
vector<char*> output_names;
36+
vector<vector<int64_t>> input_node_dims; // >=1 outputs
37+
vector<vector<int64_t>> output_node_dims; // >=1 outputs
38+
};
39+
40+
CodeFormer::CodeFormer(string model_path)
41+
{
42+
//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///nvidia-cuda加速
43+
sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
44+
std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); ///如果在windows系统就这么写
45+
ort_session = new Session(env, widestr.c_str(), sessionOptions); ///如果在windows系统就这么写
46+
///ort_session = new Session(env, model_path.c_str(), sessionOptions); ///如果在linux系统,就这么写
47+
48+
size_t numInputNodes = ort_session->GetInputCount();
49+
size_t numOutputNodes = ort_session->GetOutputCount();
50+
AllocatorWithDefaultOptions allocator;
51+
for (int i = 0; i < numInputNodes; i++)
52+
{
53+
input_names.push_back(ort_session->GetInputName(i, allocator));
54+
Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
55+
auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
56+
auto input_dims = input_tensor_info.GetShape();
57+
input_node_dims.push_back(input_dims);
58+
}
59+
for (int i = 0; i < numOutputNodes; i++)
60+
{
61+
output_names.push_back(ort_session->GetOutputName(i, allocator));
62+
Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
63+
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
64+
auto output_dims = output_tensor_info.GetShape();
65+
output_node_dims.push_back(output_dims);
66+
}
67+
68+
this->inpHeight = input_node_dims[0][2];
69+
this->inpWidth = input_node_dims[0][3];
70+
this->outHeight = output_node_dims[0][2];
71+
this->outWidth = output_node_dims[0][3];
72+
input2_tensor.push_back(0.5);
73+
}
74+
75+
void CodeFormer::preprocess(Mat srcimg)
76+
{
77+
Mat dstimg;
78+
cvtColor(srcimg, dstimg, COLOR_BGR2RGB);
79+
resize(dstimg, dstimg, Size(this->inpWidth, this->inpHeight), INTER_LINEAR);
80+
this->input_image_.resize(this->inpWidth * this->inpHeight * dstimg.channels());
81+
int k = 0;
82+
for (int c = 0; c < 3; c++)
83+
{
84+
for (int i = 0; i < this->inpHeight; i++)
85+
{
86+
for (int j = 0; j < this->inpWidth; j++)
87+
{
88+
float pix = dstimg.ptr<uchar>(i)[j * 3 + c];
89+
this->input_image_[k] = (pix / 255.0 - 0.5) / 0.5;
90+
k++;
91+
}
92+
}
93+
}
94+
}
95+
96+
Mat CodeFormer::detect(Mat srcimg)
97+
{
98+
int im_h = srcimg.rows;
99+
int im_w = srcimg.cols;
100+
this->preprocess(srcimg);
101+
array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };
102+
vector<int64_t> input2_shape_ = { 1 };
103+
104+
auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
105+
vector<Value> ort_inputs;
106+
ort_inputs.push_back(Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()));
107+
ort_inputs.push_back(Value::CreateTensor<double>(allocator_info, input2_tensor.data(), input2_tensor.size(), input2_shape_.data(), input2_shape_.size()));
108+
vector<Value> ort_outputs = ort_session->Run(RunOptions{ nullptr }, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names.data(), output_names.size());
109+
110+
////post_process
111+
float* pred = ort_outputs[0].GetTensorMutableData<float>();
112+
//////Mat mask(outHeight, outWidth, CV_32FC3, pred); /////经过试验,直接这样赋值,是不行的
113+
const unsigned int channel_step = outHeight * outWidth;
114+
vector<Mat> channel_mats;
115+
Mat rmat(outHeight, outWidth, CV_32FC1, pred); // R
116+
Mat gmat(outHeight, outWidth, CV_32FC1, pred + channel_step); // G
117+
Mat bmat(outHeight, outWidth, CV_32FC1, pred + 2 * channel_step); // B
118+
channel_mats.push_back(rmat);
119+
channel_mats.push_back(gmat);
120+
channel_mats.push_back(bmat);
121+
Mat mask;
122+
merge(channel_mats, mask); // CV_32FC3 allocated
123+
124+
///不用for循环遍历Mat里的每个像素值,实现numpy.clip函数
125+
mask.setTo(this->min_max[0], mask < this->min_max[0]);
126+
mask.setTo(this->min_max[1], mask > this->min_max[1]); ////也可以用threshold函数,阈值类型THRESH_TOZERO_INV
127+
128+
mask = (mask - this->min_max[0]) / (this->min_max[1] - this->min_max[0]);
129+
mask *= 255.0;
130+
mask.convertTo(mask, CV_8UC3);
131+
cvtColor(mask, mask, COLOR_BGR2RGB);
132+
return mask;
133+
}
134+
135+
int main()
136+
{
137+
CodeFormer mynet("codeformer.onnx");
138+
string imgpath = "input.png";
139+
Mat srcimg = imread(imgpath);
140+
Mat dstimg = mynet.detect(srcimg);
141+
resize(dstimg, dstimg, Size(srcimg.cols, srcimg.rows), INTER_LINEAR);
142+
143+
//imwrite("result.jpg", dstimg)
144+
namedWindow("srcimg", WINDOW_NORMAL);
145+
imshow("srcimg", srcimg);
146+
namedWindow("dstimg", WINDOW_NORMAL);
147+
imshow("dstimg", dstimg);
148+
waitKey(0);
149+
destroyAllWindows();
150+
}

‎main.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import cv2
3+
import numpy as np
4+
import onnxruntime as ort
5+
6+
7+
class CodeFormer():
8+
def __init__(self, modelpath):
9+
# net = cv2.dnn.readNet(modelpath)
10+
so = ort.SessionOptions()
11+
so.log_severity_level = 3
12+
self.session = ort.InferenceSession(modelpath, so)
13+
model_inputs = self.session.get_inputs()
14+
self.input_name0 = model_inputs[0].name
15+
self.input_name1 = model_inputs[1].name
16+
self.inpheight = model_inputs[0].shape[2]
17+
self.inpwidth = model_inputs[0].shape[3]
18+
19+
def post_processing(self, tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
20+
# tensor 3ch
21+
_tensor = tensor[0]
22+
23+
_tensor = _tensor.clip(min_max[0], min_max[1])
24+
25+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
26+
27+
n_dim = _tensor.ndim
28+
29+
if n_dim == 3:
30+
img_np = _tensor
31+
img_np = img_np.transpose(1, 2, 0)
32+
if img_np.shape[2] == 1: # gray image
33+
img_np = np.squeeze(img_np, axis=2)
34+
else:
35+
if rgb2bgr:
36+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
37+
elif n_dim == 2:
38+
img_np = _tensor
39+
else:
40+
raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
41+
if out_type == np.uint8:
42+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
43+
img_np = (img_np * 255.0).round()
44+
img_np = img_np.astype(out_type)
45+
return img_np
46+
47+
def detect(self, srcimg):
48+
dstimg = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)
49+
dstimg = cv2.resize(dstimg, (self.inpwidth, self.inpheight), interpolation=cv2.INTER_LINEAR)
50+
dstimg = (dstimg.astype(np.float32)/255.0 - 0.5) / 0.5
51+
input_image = np.expand_dims(dstimg.transpose(2, 0, 1), axis=0).astype(np.float32)
52+
53+
# Inference
54+
output = self.session.run(None, {self.input_name0: input_image, self.input_name1:np.array([0.5])})[0]
55+
restored_img = self.post_processing(output, rgb2bgr=True, min_max=(-1, 1))
56+
return restored_img.astype('uint8')
57+
58+
if __name__ == '__main__':
59+
parser = argparse.ArgumentParser()
60+
parser.add_argument("--imgpath", type=str, default='input.png', help="image path")
61+
parser.add_argument("--modelpath", type=str, default='codeformer.onnx', help="onnxmodel path")
62+
args = parser.parse_args()
63+
64+
mynet = CodeFormer(args.modelpath)
65+
srcimg = cv2.imread(args.imgpath)
66+
restored_img = mynet.detect(srcimg)
67+
restored_img = cv2.resize(restored_img, (srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)
68+
69+
# if srcimg.shape[0]>=srcimg.shape[1]:
70+
# result = np.vstack((srcimg, restored_img))
71+
# else:
72+
# result = np.hstack((srcimg, restored_img))
73+
74+
# cv2.imwrite('result.jpg', restored_img)
75+
cv2.namedWindow("srcimg", cv2.WINDOW_NORMAL)
76+
cv2.imshow("srcimg", srcimg)
77+
cv2.namedWindow("restored_img", cv2.WINDOW_NORMAL)
78+
cv2.imshow("restored_img", restored_img)
79+
cv2.waitKey(0)
80+
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)