PyTorch中部分方法介紹

1.torchvision.transforms.Normalize(mean, std)

mean參數:給定序列的均值,形式爲(R,G,B)(並不一定是三維)

std參數:給定序列的平均標準偏差(標準差),形式爲(R,G,B)(並不一定是三維)

功能:將給定的Tensor正則化,也就是按照如下公式計算:

\frac{Tensor-mean}{std}


2.torch.utils.data.DataLoader()


3.torch.nn.ReLu(flag)

   torch.nn.Tanh()


4.a.data.cpu()和a.cpu().data一樣,都是將數據取出放到cpu,準備放到cpu計算


5.torch.clamp(input, min, max, out=None),將input的tensor限定在min~max之間,如果小於這個範圍,取min;大於這個範圍,去max,亦可input.clamp(min,max)這樣使用


6.nn.MSELoss(reduce),使用是傳入兩個量x,y,計算(x-y)^2,如果reduce=False,直接返回向量形式的 loss;如果reduce=True,返回標量

7.nn.MaxPool2d(2)中,默認步長是2

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章