解決softmax後列和不爲1的bug記錄 :問題原因爲 s爲1維的,來除torch.exp(x)(64x10)時候,維數不對應,需要將s也要轉換爲2維的即維數爲(64x1),纔可以廣播按行對應相除

def softmax(x):
    ## TODO: Implement the softmax function here
    #print("torch.exp(x)=",torch.exp(x))
    s = torch.sum(torch.exp(x),dim=1)
    print("s.size=",s.size())
    print("s.view(-1,1)=",s.view(-1,1).size())

   #有問題的return (torch.exp(x)/s).view(-1,1)
    return torch.exp(x)/(s.view(-1,1))
    """
def softmax(x):
    print("x.shape=",x.shape)
    return torch.exp(x)/torch.sum(torch.exp(x), dim=1).view(-1, 1)
"""
# Here, out should be the output of the network in the previous excercise with shape (64,10)
#print("out=",out)
print("out.shape=",out.shape)
probabilities = softmax(out)
print("probabilities=",probabilities)
# Does it have the right shape? Should be (64, 10)
#print(probabilities.shape)
# Does it sum to 1?
#print(probabilities)
print(probabilities.sum(dim=1))

#result

out.shape= torch.Size([64, 10])
s.size= torch.Size([64])
s.view(-1,1)= torch.Size([64, 1])
probabilities= tensor([[ 3.0988e-16,  1.0000e+00,  1.8007e-11,  1.4650e-11,  8.8916e-09,
          5.3950e-17,  4.7416e-09,  8.0687e-19,  3.1054e-09,  2.2082e-10],
        [ 7.2884e-18,  9.9896e-01,  4.5583e-08,  1.2075e-05,  2.1988e-09,
          1.5377e-13,  1.0279e-03,  2.0932e-10,  2.9357e-07,  4.3231e-11],
        [ 5.2138e-16,  4.6368e-02,  8.9821e-14,  3.0514e-09,  3.1598e-09,
          1.3639e-17,  9.5363e-01,  6.5663e-15,  1.2519e-09,  1.5230e-12],
        [ 1.5445e-11,  9.9793e-01,  2.8143e-08,  1.2875e-04,  1.2082e-04,
          9.7656e-14,  3.9710e-05,  7.6168e-09,  1.7803e-03,  1.4417e-08],
        [ 1.0958e-13,  9.6759e-01,  2.2564e-12,  1.9013e-11,  1.9647e-11,
          1.8859e-12,  3.2378e-02,  5.4169e-08,  2.7937e-05,  3.2464e-13],
        [ 8.1973e-11,  6.9379e-01,  5.9974e-09,  8.5656e-06,  5.1688e-03,
          2.4149e-17,  2.3638e-02,  4.3271e-12,  4.1039e-05,  2.7735e-01],
        [ 3.7563e-10,  9.9727e-01,  2.4940e-06,  8.8124e-04,  3.5414e-05,
          1.8617e-11,  2.0491e-04,  1.7631e-11,  1.6052e-03,  1.5836e-10],
        [ 9.4069e-17,  9.9962e-01,  5.0893e-13,  2.1499e-11,  2.3957e-06,
          3.5678e-20,  3.8035e-04,  1.4381e-14,  7.2088e-07,  2.4388e-08],
        [ 3.5989e-12,  9.9864e-01,  1.6160e-06,  2.3764e-04,  1.2289e-06,
          6.1952e-12,  2.6893e-07,  1.1190e-06,  1.1140e-03,  1.5123e-07],
        [ 7.6218e-18,  9.5789e-01,  2.4285e-11,  1.9727e-09,  4.4734e-13,
          1.1504e-13,  4.2112e-02,  5.4719e-12,  1.1318e-07,  4.2207e-14],
        [ 6.9200e-18,  9.9999e-01,  3.3765e-11,  3.0676e-09,  6.6521e-10,
          1.9209e-17,  9.4535e-06,  1.0995e-13,  1.4980e-07,  1.1143e-08],
        [ 1.1968e-10,  9.9062e-01,  8.3010e-10,  1.9934e-03,  2.6987e-08,
          6.7082e-12,  5.9127e-05,  5.9065e-06,  7.3257e-03,  3.9619e-08],
        [ 2.6697e-20,  1.0000e+00,  6.1491e-14,  5.5949e-13,  3.8020e-12,
          4.6593e-19,  1.3082e-07,  1.8616e-14,  8.4002e-11,  8.4461e-15],
        [ 2.4753e-17,  1.0000e+00,  5.8076e-12,  1.6527e-07,  1.4076e-10,
          3.4725e-13,  8.1565e-07,  2.1852e-11,  2.3215e-07,  4.4056e-12],
        [ 4.7344e-17,  1.0000e+00,  3.3414e-11,  3.1781e-10,  8.6068e-12,
          2.3071e-15,  2.8367e-10,  2.3826e-12,  1.7079e-07,  3.1149e-11],
        [ 4.4352e-11,  9.9746e-01,  2.8773e-09,  7.2385e-11,  3.6884e-08,
          4.4588e-15,  9.4779e-08,  1.2648e-13,  2.5424e-03,  7.2419e-13],
        [ 1.1315e-14,  4.7640e-01,  4.4058e-08,  5.7084e-13,  6.0529e-09,
          7.8344e-15,  5.2360e-01,  6.2813e-09,  5.5826e-14,  3.6299e-14],
        [ 7.4028e-14,  9.7872e-01,  3.0862e-10,  1.8337e-07,  1.6125e-06,
          4.8351e-14,  2.1259e-02,  7.2829e-11,  1.4334e-05,  1.3491e-09],
        [ 1.3201e-13,  9.9998e-01,  1.2696e-10,  1.9295e-08,  4.1318e-11,
          7.8015e-12,  1.6925e-05,  3.8119e-15,  6.8226e-07,  1.3916e-11],
        [ 1.4418e-18,  8.1829e-01,  1.2666e-11,  5.9709e-13,  1.2443e-06,
          2.1665e-16,  1.8171e-01,  6.2544e-13,  3.6078e-13,  9.4016e-15],
        [ 2.1567e-10,  9.6855e-01,  3.3647e-10,  5.3065e-07,  2.0691e-06,
          2.1863e-20,  2.8312e-02,  1.0761e-07,  2.8217e-03,  3.0837e-04],
        [ 7.6143e-12,  9.8924e-01,  2.0398e-06,  8.3411e-06,  1.1166e-03,
          6.4356e-08,  9.6115e-03,  3.7792e-08,  1.4629e-05,  2.7286e-06],
        [ 1.0694e-19,  1.0000e+00,  1.6696e-11,  9.9358e-11,  3.0807e-15,
          1.4382e-16,  4.7081e-09,  6.2134e-14,  8.9735e-11,  4.8432e-13],
        [ 1.6541e-16,  9.0249e-01,  9.2298e-12,  8.2350e-06,  1.0797e-10,
          2.7196e-13,  9.7506e-02,  2.3810e-09,  1.0474e-07,  1.2539e-13],
        [ 3.7895e-15,  9.9998e-01,  3.1908e-08,  4.6845e-15,  6.6953e-07,
          1.8835e-17,  2.2061e-05,  3.7345e-11,  5.3391e-10,  9.8172e-13],
        [ 2.2465e-13,  9.9520e-01,  3.3508e-10,  2.0350e-07,  1.6494e-05,
          1.1619e-11,  1.7151e-05,  2.1420e-09,  4.7628e-03,  4.6928e-13],
        [ 6.8133e-13,  9.9846e-01,  1.0881e-06,  6.1678e-10,  7.9566e-07,
          1.2640e-13,  1.0096e-05,  2.6806e-10,  5.2880e-09,  1.5275e-03],
        [ 4.0346e-06,  9.9922e-01,  5.6760e-14,  5.1681e-04,  2.7113e-10,
          1.9770e-15,  3.4958e-08,  4.6642e-09,  2.5622e-04,  3.5541e-06],
        [ 2.3558e-15,  9.9996e-01,  1.5141e-10,  9.6498e-08,  3.9098e-05,
          1.0603e-15,  1.3403e-08,  4.2760e-15,  1.1000e-09,  4.5716e-11],
        [ 7.2170e-17,  9.9850e-01,  1.0450e-11,  4.0275e-07,  1.7123e-09,
          5.7820e-12,  1.5023e-03,  3.4995e-12,  9.2998e-09,  1.2087e-11],
        [ 7.0592e-17,  9.9999e-01,  3.2517e-09,  2.0088e-13,  2.8803e-12,
          7.7315e-18,  3.9520e-06,  3.0273e-11,  1.4446e-06,  7.1642e-14],
        [ 1.9605e-14,  9.9196e-01,  2.2681e-07,  1.8784e-05,  4.0640e-07,
          2.9341e-11,  1.3631e-03,  1.0106e-06,  6.6461e-03,  1.3264e-05],
        [ 2.0659e-11,  9.9777e-01,  1.1538e-11,  9.0129e-06,  3.1849e-05,
          2.0640e-11,  2.0553e-03,  1.6242e-07,  1.3015e-04,  4.3597e-10],
        [ 4.0858e-13,  9.9590e-01,  7.0598e-13,  4.0908e-03,  1.4616e-06,
          2.7791e-13,  1.0953e-05,  2.4548e-09,  2.6244e-07,  8.3998e-11],
        [ 2.4086e-11,  9.9991e-01,  1.3949e-10,  5.7709e-05,  3.4177e-07,
          9.4874e-16,  7.1716e-06,  7.3639e-13,  2.9284e-05,  6.0737e-13],
        [ 6.1061e-14,  9.3933e-01,  2.6500e-10,  3.2538e-07,  4.3397e-07,
          9.8440e-16,  4.2918e-05,  6.4921e-13,  6.0628e-02,  8.5904e-13],
        [ 4.5025e-17,  1.0000e+00,  1.1380e-15,  1.2371e-17,  3.1707e-15,
          1.0971e-22,  5.5156e-11,  3.0394e-16,  2.0227e-13,  2.7100e-13],
        [ 8.6031e-14,  9.9975e-01,  1.5011e-06,  5.5593e-05,  2.7639e-08,
          4.2199e-15,  3.8121e-07,  1.6479e-04,  2.6434e-05,  7.3974e-10],
        [ 7.8388e-10,  9.9906e-01,  3.6027e-09,  1.8897e-04,  5.1720e-10,
          5.0279e-11,  7.4521e-04,  1.6138e-06,  5.0832e-12,  1.4910e-07],
        [ 8.1989e-15,  9.9968e-01,  3.4793e-14,  2.9727e-04,  9.1857e-11,
          5.3726e-14,  2.2063e-05,  3.3956e-08,  5.2318e-08,  1.4547e-08],
        [ 2.9552e-11,  9.9901e-01,  1.8481e-04,  5.5410e-05,  2.9475e-07,
          3.4597e-12,  7.4382e-04,  4.2870e-09,  7.1338e-06,  1.2287e-08],
        [ 1.4819e-11,  2.0494e-01,  2.6659e-12,  4.1191e-09,  1.4334e-06,
          1.6648e-16,  7.9506e-01,  4.5270e-11,  1.4956e-07,  1.4043e-08],
        [ 4.1190e-11,  9.7229e-01,  3.6178e-04,  1.7179e-02,  8.1722e-08,
          5.6448e-15,  2.3135e-05,  1.6037e-06,  1.0142e-02,  1.6100e-08],
        [ 7.3725e-11,  7.0134e-01,  1.7235e-07,  1.5593e-04,  2.1137e-07,
          4.9391e-15,  2.9849e-01,  1.3844e-06,  8.7998e-06,  3.4835e-06],
        [ 1.4538e-16,  9.9962e-01,  1.0699e-14,  2.9652e-06,  3.3357e-08,
          5.2820e-13,  3.7365e-04,  1.6186e-13,  1.0933e-07,  2.0106e-10],
        [ 2.5675e-18,  1.0000e+00,  6.6527e-11,  3.2654e-10,  6.1968e-11,
          1.0783e-16,  4.2726e-06,  5.4559e-14,  5.1654e-08,  4.5689e-09],
        [ 3.8212e-13,  9.9232e-01,  2.4962e-07,  1.1517e-08,  1.8197e-05,
          5.6938e-10,  6.1286e-08,  8.9154e-11,  7.6590e-03,  8.1098e-07],
        [ 1.0655e-15,  9.9801e-01,  8.3258e-07,  2.4878e-09,  7.7644e-07,
          1.6196e-12,  1.9839e-03,  3.1300e-10,  7.4616e-09,  5.4257e-14],
        [ 4.8779e-16,  9.9999e-01,  8.2279e-12,  7.2976e-11,  4.4960e-13,
          6.2256e-13,  1.3462e-05,  8.8224e-12,  3.6317e-11,  1.3092e-11],
        [ 7.8841e-19,  9.9997e-01,  5.5566e-11,  4.0056e-10,  9.2204e-11,
          1.1939e-11,  2.6435e-05,  2.1659e-09,  9.7906e-08,  2.7225e-17],
        [ 2.9075e-11,  9.9067e-01,  4.3553e-09,  3.9526e-09,  5.5092e-08,
          1.6082e-14,  9.3325e-03,  1.3082e-08,  3.2009e-07,  1.7811e-09],
        [ 6.3138e-09,  9.9999e-01,  4.1571e-07,  2.6991e-07,  3.8994e-06,
          4.2260e-14,  3.8578e-06,  7.5499e-12,  1.2526e-07,  1.3982e-10],
        [ 5.6722e-16,  8.2703e-01,  4.3730e-14,  2.1261e-07,  1.2086e-10,
          9.9402e-16,  1.7297e-01,  3.5612e-12,  2.2576e-08,  4.4894e-11],
        [ 6.6159e-14,  1.0000e+00,  2.1189e-07,  2.3230e-06,  7.5754e-09,
          8.0700e-11,  1.8903e-06,  4.0510e-10,  3.2011e-09,  6.0312e-08],
        [ 5.6496e-13,  9.9997e-01,  7.0953e-10,  7.8982e-06,  7.9838e-07,
          5.1995e-17,  1.3731e-05,  1.6974e-13,  3.9789e-06,  3.2110e-07],
        [ 3.3208e-07,  9.5613e-01,  2.8434e-09,  4.1989e-07,  7.8602e-07,
          6.8962e-14,  2.7200e-04,  1.1510e-09,  4.3544e-02,  5.3730e-05],
        [ 1.8417e-13,  9.9728e-01,  3.9895e-07,  3.4604e-05,  6.7489e-09,
          2.1293e-11,  2.1081e-03,  8.0648e-09,  3.0386e-07,  5.8137e-04],
        [ 3.7823e-12,  1.0000e+00,  3.2145e-10,  2.3663e-10,  4.1574e-07,
          1.8841e-12,  5.5345e-07,  6.2556e-13,  1.5828e-10,  3.6763e-14],
        [ 2.4093e-17,  1.0000e+00,  7.4625e-15,  2.2570e-13,  5.3833e-09,
          7.0629e-20,  3.6674e-10,  2.6192e-16,  1.0119e-09,  1.5468e-21],
        [ 8.6024e-15,  2.1488e-01,  3.2251e-10,  1.0990e-06,  1.4385e-06,
          6.2122e-13,  7.8512e-01,  1.5255e-08,  1.1232e-06,  1.2307e-08],
        [ 3.4641e-09,  9.9518e-01,  5.8583e-10,  1.6316e-04,  1.2406e-03,
          4.2527e-15,  3.3174e-03,  1.1678e-06,  9.8092e-05,  8.7550e-07],
        [ 9.6422e-15,  1.0000e+00,  2.1577e-12,  4.3397e-13,  2.8023e-09,
          1.4402e-18,  2.1583e-07,  1.4629e-14,  1.1131e-10,  5.5922e-12],
        [ 1.4889e-13,  9.9998e-01,  1.3683e-08,  2.6725e-11,  7.6360e-12,
          2.5601e-14,  1.5277e-05,  1.4099e-10,  2.0855e-09,  1.5405e-15],
        [ 2.0768e-16,  9.9753e-01,  5.2618e-07,  2.0046e-06,  6.0347e-07,
          2.9689e-13,  2.4702e-03,  3.1447e-09,  5.1434e-09,  4.9742e-10]])
tensor([ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000])

#問題原因爲 s爲1維的,來除torch.exp(x)(64x10)時候,維數不對應,需要將s也要轉換爲2維的即維數爲(64x1),纔可以廣播按行對應相除

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章