測試用數據
輸入:
卷積核爲:
padding爲SAME
使用opengles的imageArray存儲輸入輸出數據,紋理格式爲rgba16f,爲將紋理座標與輸入矩陣座標一一對應,所以需要先將輸入進行補零操作。先將輸入矩陣拉平爲一個一維向量,再對這個一維向量每個數字後補3個零,然後傳入到一個5x5的gl紋理上,這樣紋理座標就與輸入座標一一對應了。
對卷積核先做預計算,然後將卷積預計算得到的 GgGt矩陣用紋理存儲
private void transferKernelToGgGt(float[][][][] mKennels) {
int kennel_amount = mKennels.length;
int kennel_channel = mKennels[0].length;
float[][][][] GgGt = new float[kennel_amount][kennel_channel][4][4];
float[][] G = new float[][]{{1, 0, 0}, {0.5f, 0.5f, 0.5f}, {0.5f, -0.5f, 0.5f}, {0, 0, 1}};
float[][] Gt = Numpy.transpose(G);
for (int a = 0; a < kennel_amount; a++) {
for (int c = 0; c < kennel_channel; c++) {
GgGt[a][c] = Numpy.dot(Numpy.dot(G, mKennels[a][c]), Gt);
}
}
int align_c = Utils.alignBy4(kennel_channel);
float[][][] GgGt_align = new float[kennel_amount][align_c / 4][16 * 4 + 4]; // depth == 0 處 最後一列第一個爲 bias
for (int a = 0; a < kennel_amount; a++) {
for (int c = 0; c < kennel_channel; c++) {
for (int h = 0; h < 4; h++) {
for (int w = 0; w < 4; w++) {
GgGt_align[a][c / 4][(w + h * 4) * 4 + c % 4] = GgGt[a][c][h][w];
}
}
}
}
// 添加bias項
for (int a = 0; a < kennel_amount; a++) {
GgGt_align[a][0][16 * 4] = 0.01f * a;
}
// 傳輸到紋理
for (int a = 0; a < kennel_amount; a++) {
float[][] kennel = GgGt_align[a];
int depth = kennel.length;
for (int c = 0; c < depth; c++) {
Render.transferToTextureArrayFloat(FloatBuffer.wrap(kennel[c]), mKennelTex, 0, a, c, 17, 1, 1);
}
}
}
每個卷積核GgGt矩陣都只用紋理的一行存儲,存儲shape爲16x1,如果需要在卷積時加入bias,可再加一列存儲bias,並將存儲shape改爲17x1。
所有卷積核的GgGt均存儲再一張紋理上,按卷積核順序,在紋理上由上到下依次排列。
輸出爲5x5,可以將其轉化爲3x3個2x2的輸出(將5x5補零爲6x6),以便應用Winograd算法。每個shader只處理一個2x2的輸出,以便進行並行計算。
shader的主要邏輯
// 卷積數據數據座標
int conv_x0 = pad_w + start_x;
int conv_y0 = pad_h + start_y;
int conv_x1 = conv_x0 + 1;
int conv_y1 = conv_y0 + 1;
int conv_x2 = conv_x1 + 1;
int conv_y2 = conv_y1 + 1;
int conv_x3 = conv_x2 + 1;
int conv_y3 = conv_y2 + 1;
for (int c = 0; c < in_depth; c++) {
// 從輸入紋理讀取數據
vec4 data00 = imageLoad(input_image, ivec3(conv_x0, conv_y0, c));
vec4 data01 = imageLoad(input_image, ivec3(conv_x1, conv_y0, c));
vec4 data02 = imageLoad(input_image, ivec3(conv_x2, conv_y0, c));
vec4 data03 = imageLoad(input_image, ivec3(conv_x3, conv_y0, c));
vec4 data10 = imageLoad(input_image, ivec3(conv_x0, conv_y1, c));
vec4 data11 = imageLoad(input_image, ivec3(conv_x1, conv_y1, c));
vec4 data12 = imageLoad(input_image, ivec3(conv_x2, conv_y1, c));
vec4 data13 = imageLoad(input_image, ivec3(conv_x3, conv_y1, c));
vec4 data20 = imageLoad(input_image, ivec3(conv_x0, conv_y2, c));
vec4 data21 = imageLoad(input_image, ivec3(conv_x1, conv_y2, c));
vec4 data22 = imageLoad(input_image, ivec3(conv_x2, conv_y2, c));
vec4 data23 = imageLoad(input_image, ivec3(conv_x3, conv_y2, c));
vec4 data30 = imageLoad(input_image, ivec3(conv_x0, conv_y3, c));
vec4 data31 = imageLoad(input_image, ivec3(conv_x1, conv_y3, c));
vec4 data32 = imageLoad(input_image, ivec3(conv_x2, conv_y3, c));
vec4 data33 = imageLoad(input_image, ivec3(conv_x3, conv_y3, c));
// 提取公共計算
vec4 d00_20_sub = data00 - data20;
vec4 d01_21_sub = data01 - data21;
vec4 d02_22_sub = data02 - data22;
vec4 d03_23_sub = data03 - data23;
vec4 d10_20_add = data10 + data20;
vec4 d11_21_add = data11 + data21;
vec4 d12_22_add = data12 + data22;
vec4 d13_23_add = data13 + data23;
vec4 d20_10_sub = data20 - data10;
vec4 d21_11_sub = data21 - data11;
vec4 d22_12_sub = data22 - data12;
vec4 d23_13_sub = data23 - data13;
vec4 d10_30_sub = data10 - data30;
vec4 d11_31_sub = data11 - data31;
vec4 d12_32_sub = data12 - data32;
vec4 d13_33_sub = data13 - data33;
// 計算BtdB矩陣
vec4 BtdB00 = d00_20_sub - d02_22_sub;
vec4 BtdB01 = d01_21_sub + d02_22_sub;
vec4 BtdB02 = d02_22_sub - d01_21_sub;
vec4 BtdB03 = d01_21_sub - d03_23_sub;
vec4 BtdB10 = d10_20_add - d12_22_add;
vec4 BtdB11 = d11_21_add + d12_22_add;
vec4 BtdB12 = d12_22_add - d11_21_add;
vec4 BtdB13 = d11_21_add - d13_23_add;
vec4 BtdB20 = d20_10_sub - d22_12_sub;
vec4 BtdB21 = d21_11_sub + d22_12_sub;
vec4 BtdB22 = d22_12_sub - d21_11_sub;
vec4 BtdB23 = d21_11_sub - d23_13_sub;
vec4 BtdB30 = d10_30_sub - d12_32_sub;
vec4 BtdB31 = d11_31_sub + d12_32_sub;
vec4 BtdB32 = d12_32_sub - d11_31_sub;
vec4 BtdB33 = d11_31_sub - d13_33_sub;
for (int i = 0; i < 4; i++) {
int z_i = z_0 + i;
// 從卷積核紋理讀取GgGt矩陣
vec4 GgGt00_0 = imageLoad(kernel_image, ivec3(0, z_i, c));
vec4 GgGt01_0 = imageLoad(kernel_image, ivec3(1, z_i, c));
vec4 GgGt02_0 = imageLoad(kernel_image, ivec3(2, z_i, c));
vec4 GgGt03_0 = imageLoad(kernel_image, ivec3(3, z_i, c));
vec4 GgGt10_0 = imageLoad(kernel_image, ivec3(4, z_i, c));
vec4 GgGt11_0 = imageLoad(kernel_image, ivec3(5, z_i, c));
vec4 GgGt12_0 = imageLoad(kernel_image, ivec3(6, z_i, c));
vec4 GgGt13_0 = imageLoad(kernel_image, ivec3(7, z_i, c));
vec4 GgGt20_0 = imageLoad(kernel_image, ivec3(8, z_i, c));
vec4 GgGt21_0 = imageLoad(kernel_image, ivec3(9, z_i, c));
vec4 GgGt22_0 = imageLoad(kernel_image, ivec3(10, z_i, c));
vec4 GgGt23_0 = imageLoad(kernel_image, ivec3(11, z_i, c));
vec4 GgGt30_0 = imageLoad(kernel_image, ivec3(12, z_i, c));
vec4 GgGt31_0 = imageLoad(kernel_image, ivec3(13, z_i, c));
vec4 GgGt32_0 = imageLoad(kernel_image, ivec3(14, z_i, c));
vec4 GgGt33_0 = imageLoad(kernel_image, ivec3(15, z_i, c));
vec4 m00_0 = BtdB00 * GgGt00_0;
vec4 m01_0 = BtdB01 * GgGt01_0;
vec4 m02_0 = BtdB02 * GgGt02_0;
vec4 m03_0 = BtdB03 * GgGt03_0;
vec4 m10_0 = BtdB10 * GgGt10_0;
vec4 m11_0 = BtdB11 * GgGt11_0;
vec4 m12_0 = BtdB12 * GgGt12_0;
vec4 m13_0 = BtdB13 * GgGt13_0;
vec4 m20_0 = BtdB20 * GgGt20_0;
vec4 m21_0 = BtdB21 * GgGt21_0;
vec4 m22_0 = BtdB22 * GgGt22_0;
vec4 m23_0 = BtdB23 * GgGt23_0;
vec4 m30_0 = BtdB30 * GgGt30_0;
vec4 m31_0 = BtdB31 * GgGt31_0;
vec4 m32_0 = BtdB32 * GgGt32_0;
vec4 m33_0 = BtdB33 * GgGt33_0;
//提取輸出的公共計算
vec4 m01_11_21_0 = m01_0 + m11_0 + m21_0;
vec4 m02_12_22_0 = m02_0 + m12_0 + m22_0;
vec4 m11_21_31_0 = m11_0 - m21_0 - m31_0;
vec4 m12_22_32_0 = m12_0 - m22_0 - m32_0;
//合併爲輸出
result00[i] += sum(m00_0 + m10_0 + m20_0 + m01_11_21_0 + m02_12_22_0);
result01[i] += sum(m01_11_21_0 - m02_12_22_0 - m03_0 - m13_0 - m23_0);
result10[i] += sum(m10_0 - m20_0 - m30_0 + m11_21_31_0 + m12_22_32_0);
result11[i] += sum(m11_21_31_0 - m12_22_32_0 - m13_0 + m23_0 + m33_0);
}
}
測試結果:
輸出結果爲: