张一极
如果是全连接的nerual network,输入是img(100*100),其参数,100 x 100 x 3,每个通道含有100x100的参数量,三个通道,第一层的nerual network假设为1000个nerual,第一层的参数量为30000 x 1000,所以引入了卷积,有目的地放弃一些权重,卷积其实可以看作,放弃了一部分权重的dnn,那么它对应的反向传播,只需要在最终的upload的结果上,不更新某一些参数即可,具体的后续实现会进行记录。
卷积作为一种拟合,把图像矩阵每一部分都变换成一个新的值,其核心就是权重的封装:
filter的四个数据,通过不断刷新去获取最合适的四个值:
运算的具体过程是这样的:(stride = 1)
每一次卷积都是使用等同于卷积核大小的尺寸数据进行计算,也就是黄色区域每一个像素乘以卷积核对应位置的每一个参数:
output:
cnn中的卷积方式不止有这一种,另外两种方式分别是:
1.(mode = full)从卷积核的第一个卷积接触到的element进行直接卷积:
2.(mode = same)
第三种就是最普遍的情况:
x
1Matrix conv_element(Matrix mid1,Matrix kernel,int kernel_size = 2,int stride = 1)
2{
3 Matrix conv_result = CreateMatrix(((mid1.row-kernel_size)/stride)+1,((mid1.col-kernel_size)/stride)+1);
4 for(int x_ = 0;x_<=(mid1.row-kernel_size)/stride;x_+=stride)
5 {
6 for(int y_ = 0;y_<=(mid1.col-kernel_size)/stride;y_+=stride)
7 {
8 Matrix crop_pic = iloc(mid1,x_,x_+kernel.col,y_,y_+kernel.row);
9 change_va(conv_result,x_,y_,matrix_sum(mul_simple(crop_pic,kernel)));
10 }
11 }
12 // cout<<"row: "<<conv_result.row<<" , "<<"col: "<<conv_result.col<<endl;
13 // cout_mat(conv_result);
14 return conv_result;
15 }
16 /*
17 parameter:
18 Matrix mid1,
19 int input_dim = 3
20 int output_channels = 3
21 int stride = 1
22 int kernel_size = 2
23 int mode = 0
24 int padding = 0
25 */
26double conv_test(Matrix mid1,int input_dim = 3,int output_channels = 3,int stride = 1,int kernel_size = 2,int mode = 0,int padding = 0)
27{
28 // cout_mat(mid1);
29 Matrix mid_rgb[input_dim];
30 for(int rgb_idx = 0;rgb_idx<input_dim;rgb_idx++)
31 {
32 mid_rgb[rgb_idx] = CreateRandMat(mid1.row,mid1.col);
33 cout<<"---------rgb: "<<rgb_idx<<"---------"<<endl;
34 cout_mat(mid_rgb[rgb_idx]);
35 }
36 Matrix filters[output_channels][input_dim];
37 for(int channel_index = 0;channel_index<input_dim;channel_index++)
38 {
39 for(int filter_index = 0;filter_index<output_channels;filter_index++)
40 {
41 Matrix kernel = ones(kernel_size,kernel_size);
42 filters[channel_index][filter_index] = kernel;
43 // cout<<"---------"<<endl;
44 // cout<<"channel: "<<channel_index<<", index: "<<filter_index<<endl;
45 // cout_mat(filters[channel_index][filter_index]);
46 }
47 }
48 if(mode == 0)
49 {
50 cout<<"input_img:"<<endl;
51 cout_mat(mid1);
52 Matrix conv_result = CreateMatrix(((mid1.row-kernel_size)/stride)+1,((mid1.col-kernel_size)/stride)+1);
53 Matrix kernel = ones(kernel_size,kernel_size);
54 cout<<"--------- kernels: 3x3--------"<<endl;
55 cout_mat(kernel);
56 cout<<"--------- mid1 ---------"<<endl;
57 cout<<"row: "<<mid1.row<<" , "<<"col: "<<mid1.col<<endl;
58 cout_mat(mid1);
59 cout<<"--------- output: ---------"<<endl;
60 Matrix feature_maps[output_channels];
61 for(int filter_idx = 0;filter_idx<output_channels;filter_idx++)
62 {
63 Matrix sum_rgb = CreateMatrix(((mid1.row-kernel_size)/stride)+1,((mid1.col-kernel_size)/stride)+1);
64 for(int channel_idx=0;channel_idx<input_dim;channel_idx++)
65 {
66 sum_rgb = add(sum_rgb,conv_element(mid_rgb[channel_idx],filters[filter_idx][channel_idx],kernel_size,stride),0);
67 cout<<"sum_rgb"<<"filters_index: "<<filter_idx<<" "<<endl;
68 cout_mat(sum_rgb);
69 }
70 feature_maps[filter_idx]=sum_rgb;
71 }
72 for(int i = 0;i < output_channels;i++)
73 {
74 cout<<"==========filter: "<<i<<"========="<<endl;
75 cout_mat(feature_maps[i]);
76
77 }
78 return 0.0;
79}
80}
假设rgb:
xxxxxxxxxx
1211———rgb: 0———
20.6868,1.358,-0.0419,0.3864,-0.87,0.7874,0.1393,
30.7593,0.1114,-0.7965,-0.7005,1.5068,-0.3997,1.1945,
4-0.6814,-0.6164,1.7819,0.933,1.0805,0.2337,-0.7265,
51.6941,0.5209,-0.9397,-0.3,0.9019,-0.5348,-0.9147,
60.257,1.2136,1.9702,0.996,-0.394,-0.0892,-0.6112,
70.9649,0.2556,-0.6445,-0.3106,1.3732,-0.4526,1.5519,
8———rgb: 1———
90.5726,-0.2106,-0.6804,-0.6533,-0.7418,1.4237,-0.0423,
100.3982,1.0687,-0.2758,-0.7519,0.9336,-0.4004,0.7577,
111.0035,1.8065,0.7624,1.0708,0.0301,-0.0468,-0.3465,
121.5645,-0.397,1.6417,1.148,-0.1405,-0.3185,-0.7553,
131.5715,1.2818,-0.9162,0.0648,1.2894,0.8675,1.2892,
141.9186,1.0597,-0.7541,1.4489,-0.5107,1.0148,0.0982,
15———rgb: 2———
161.3683,0.3772,1.4713,1.8638,1.019,-0.9855,0.2794,
17-0.1709,0.7684,1.7359,1.7779,1.4403,0.0131,1.4975,
180.2598,1.6541,1.6398,0.9777,-0.0517,0.0652,-0.0643,
191.0337,0.4624,1.9443,1.9263,0.4125,-0.9113,0.2605,
20-0.656,-0.7297,-0.398,0.2574,-0.157,1.5341,0.8177,
21-0.7898,1.1526,0.2964,0.1901,1.4331,-0.9055,-0.8993,
卷积核是1x1的3x3,输出如下:
xxxxxxxxxx
1611--------- output: ---------
2sum_rgbfilters_index: 0
32.5612,2.4154,3.2797,2.9576,2.946,
41.8336,-0.00589997,3.4674,2.7209,2.3417,
55.2002,5.5595,6.0298,2.8271,-1.0543,
65.2921,2.7615,2.6525,1.1899,0.8305,
7sum_rgbfilters_index: 0
87.0063,4.5518,2.9734,3.8216,4.5133,
99.4063,6.0675,7.8858,4.2453,2.0551,
1013.5189,12.0223,10.9803,6.7919,0.8143,
1112.2626,7.3391,5.9238,6.0536,3.6646,
12sum_rgbfilters_index: 0
1316.1102,16.8179,14.8474,9.9414,7.7263,
1418.7338,18.9543,19.6888,9.8953,4.7169,
1518.7293,19.7566,17.5316,10.8451,2.72,
1614.5785,12.4409,11.8289,9.8333,5.2494,
17sum_rgbfilters_index: 1
182.5612,2.4154,3.2797,2.9576,2.946,
191.8336,-0.00589997,3.4674,2.7209,2.3417,
205.2002,5.5595,6.0298,2.8271,-1.0543,
215.2921,2.7615,2.6525,1.1899,0.8305,
22sum_rgbfilters_index: 1
237.0063,4.5518,2.9734,3.8216,4.5133,
249.4063,6.0675,7.8858,4.2453,2.0551,
2513.5189,12.0223,10.9803,6.7919,0.8143,
2612.2626,7.3391,5.9238,6.0536,3.6646,
27sum_rgbfilters_index: 1
2816.1102,16.8179,14.8474,9.9414,7.7263,
2918.7338,18.9543,19.6888,9.8953,4.7169,
3018.7293,19.7566,17.5316,10.8451,2.72,
3114.5785,12.4409,11.8289,9.8333,5.2494,
32sum_rgbfilters_index: 2
332.5612,2.4154,3.2797,2.9576,2.946,
341.8336,-0.00589997,3.4674,2.7209,2.3417,
355.2002,5.5595,6.0298,2.8271,-1.0543,
365.2921,2.7615,2.6525,1.1899,0.8305,
37sum_rgbfilters_index: 2
387.0063,4.5518,2.9734,3.8216,4.5133,
399.4063,6.0675,7.8858,4.2453,2.0551,
4013.5189,12.0223,10.9803,6.7919,0.8143,
4112.2626,7.3391,5.9238,6.0536,3.6646,
42sum_rgbfilters_index: 2
4316.1102,16.8179,14.8474,9.9414,7.7263,
4418.7338,18.9543,19.6888,9.8953,4.7169,
4518.7293,19.7566,17.5316,10.8451,2.72,
4614.5785,12.4409,11.8289,9.8333,5.2494,
47==========filter: 0=========
4816.1102,16.8179,14.8474,9.9414,7.7263,
4918.7338,18.9543,19.6888,9.8953,4.7169,
5018.7293,19.7566,17.5316,10.8451,2.72,
5114.5785,12.4409,11.8289,9.8333,5.2494,
52==========filter: 1=========
5316.1102,16.8179,14.8474,9.9414,7.7263,
5418.7338,18.9543,19.6888,9.8953,4.7169,
5518.7293,19.7566,17.5316,10.8451,2.72,
5614.5785,12.4409,11.8289,9.8333,5.2494,
57==========filter: 2=========
5816.1102,16.8179,14.8474,9.9414,7.7263,
5918.7338,18.9543,19.6888,9.8953,4.7169,
6018.7293,19.7566,17.5316,10.8451,2.72,
6114.5785,12.4409,11.8289,9.8333,5.2494,