torch.meshgrid()函數解析
torch.meshgrid()的功能是生成網格,可以用於生成座標。函數輸入兩個數據類型相同的一維張量,兩個輸出張量的行數爲第一個輸入張量的元素個數,列數爲第二個輸入張量的元素個數,當兩個輸入張量數據類型不同或維度不是一維時會報錯。
其中第一個輸出張量填充第一個輸入張量中的元素,各行元素相同;第二個輸出張量填充第二個輸入張量中的元素各列元素相同。
# 【1】
import torch
a = torch.tensor([1, 2, 3, 4])
print(a)
b = torch.tensor([4, 5, 6])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
結果顯示:
tensor([1, 2, 3, 4])
tensor([4, 5, 6])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]])
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
# 【2】
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6])
print(a)
b = torch.tensor([7, 8, 9, 10])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
結果顯示:
tensor([1, 2, 3, 4, 5, 6])
tensor([ 7, 8, 9, 10])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5],
[6, 6, 6, 6]])
tensor([[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10]])
在YOLO V3將圖像劃分爲單元網格的部分就用到了torch.meshgrid()函數,如下所示。
yv, xv = torch.meshgrid([torch.arange(self.ny, device=device),torch.arange(self.nx, device=device)])
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2)).float()