Dynamic Graph CNN for Learning on Point Clouds 代碼註解

論文解讀參閱:https://blog.csdn.net/weixin_39373480/article/details/88724518 

                       https://blog.csdn.net/qq_39426225/article/details/101980690

                       https://blog.csdn.net/hongbin_xu/article/details/85258278

下面我們來看論文主要核心部分的代碼:

1、爲點雲變換準備3X3變換矩陣的代碼,knn部分代碼,鄰接矩陣代碼,如下:

def input_transform_net(edge_feature, is_training, bn_decay=None, K=3, is_dist=False):
  """ Input (XYZ) Transform Net, input is BxNx3 gray image
    Return:
      Transformation matrix of size 3xK """
  batch_size = edge_feature.get_shape()[0].value
  num_point = edge_feature.get_shape()[1].value

  # input_image = tf.expand_dims(point_cloud, -1)

  net = tf_util.conv2d(edge_feature, 64, [1, 1],
             padding='VALID', stride=[1, 1],
             bn=True, is_training=is_training,
             scope='tconv1', bn_decay=bn_decay, is_dist=is_dist)
  net = tf_util.conv2d(net, 128, [1, 1],
             padding='VALID', stride=[1, 1],
             bn=True, is_training=is_training,
             scope='tconv2', bn_decay=bn_decay, is_dist=is_dist)
  
  net = tf.reduce_max(net, axis=-2, keep_dims=True)
  
  net = tf_util.conv2d(net, 1024, [1, 1],
             padding='VALID', stride=[1, 1],
             bn=True, is_training=is_training,
             scope='tconv3', bn_decay=bn_decay, is_dist=is_dist)
  net = tf_util.max_pool2d(net, [num_point, 1],
               padding='VALID', scope='tmaxpool')

  net = tf.reshape(net, [batch_size, -1])
  net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
                  scope='tfc1', bn_decay=bn_decay, is_dist=is_dist)
  net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
                  scope='tfc2', bn_decay=bn_decay, is_dist=is_dist)

  with tf.variable_scope('transform_XYZ') as sc:
    # assert(K==3)
    with tf.device('/cpu:0'):
      weights = tf.get_variable('weights', [256, K*K],
                    initializer=tf.constant_initializer(0.0),
                    dtype=tf.float32)
      biases = tf.get_variable('biases', [K*K],
                   initializer=tf.constant_initializer(0.0),
                   dtype=tf.float32)
    biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
    transform = tf.matmul(net, weights)
    transform = tf.nn.bias_add(transform, biases)

  transform = tf.reshape(transform, [batch_size, K, K])
  return transform

 得到k鄰居(包括自己)

def knn(adj_matrix, k=20):
  """Get KNN based on the pairwise distance.
  Args:
    pairwise distance: (batch_size, num_points, num_points)
    k: int

  Returns:
    nearest neighbors: (batch_size, num_points, k)
  """

  # 鄰接矩陣取負值,如果原始點雲中兩個點距離越遠,這裏值越小
  # 原始鄰接矩陣中自己指向自己時,距離爲0,所以取負值,然後選k個最大的值,自己會被選出來
  # 選出來的結果是按從大到小排序的,索引對應值從大到小排序
  neg_adj = -adj_matrix  
  value, nn_idx = tf.nn.top_k(neg_adj, k=k)
  return nn_idx

tf.nn.top_k(input, k, name=None)

這個函數的作用是返回 input 中每行最大的 k 個數,並且返回它們所在位置的索引。

# 得到鄰接矩陣
def pairwise_distance(point_cloud):
  """Compute pairwise distance of a point cloud.

  Args:
    point_cloud: tensor (batch_size, num_points, num_dims)

  Returns:
    pairwise distance: (batch_size, num_points, num_points)
  """
  og_batch_size = point_cloud.get_shape().as_list()[0]
  point_cloud = tf.squeeze(point_cloud)
  if og_batch_size == 1:
    point_cloud = tf.expand_dims(point_cloud, 0)
  
  # 將點雲的後兩個維度置換  
  point_cloud_transpose = tf.transpose(point_cloud, perm=[0, 2, 1])
   
  # 類似二維矩陣A和A的轉置相乘
  point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose) 

  # 這裏爲什麼乘以-2?
  point_cloud_inner = -2*point_cloud_inner

  # 點雲自身的每個特徵平方求和,這個和等於轉置乘以本身的對角線部分
  point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keep_dims=True)

  # 轉置的意義?
  point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1])

  # 這一步操作其實就是(x-y)^2,其中x和y代表向量
  return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose

 2、邊特徵提取的原理部分

 3、邊特徵提取的代碼部分

def get_edge_feature(point_cloud, nn_idx, k=20):
  """Construct edge feature for each point
  Args:
    point_cloud: (batch_size, num_points, 1, num_dims)
    nn_idx: (batch_size, num_points, k)
    k: int

  Returns:
    edge features: (batch_size, num_points, k, num_dims)
  """
  # 得到點雲的 batch_size, num_points, 1, num_dims,放入一個列表[]中,然後*[0]取出第一個元素

  # 刪除point_cloud爲1的維度索引,變爲 (batch_size, num_points, num_dims)
  og_batch_size = point_cloud.get_shape().as_list()[0]  
  point_cloud = tf.squeeze(point_cloud)                 
  if og_batch_size == 1:
    point_cloud = tf.expand_dims(point_cloud, 0)

  # 整個點雲複製給 point_cloud_central (batch_size, num_points, num_dims)
  point_cloud_central = point_cloud                     

  # 得到點雲的 batch_size, num_points, 1, num_dims,放入一個元組()中
  point_cloud_shape = point_cloud.get_shape()           
  batch_size = point_cloud_shape[0].value
  num_points = point_cloud_shape[1].value
  num_dims = point_cloud_shape[2].value

  idx_ = tf.range(batch_size) * num_points
  # batch_size = 32 --> idx_.shape = (batch_size, 1, 1)
  idx_ = tf.reshape(idx_, [batch_size, 1, 1])          

  # 點雲變爲batch_size X num_point行,num_dims列的矩陣
  # 第一個 num_points 行代表第一個點雲,第二個 num_points 行代表第二個點雲
  # ......第batch_size個 num_points 行代表第batch_size個點雲
  point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims]) 

  # 將point_cloud變成point_cloud_flat後(batch_size*num_points行),每隔num_points行取出k行,
  # 要取的行號根據爲nn_idx裏面的索引元素(k個)+batch_size*num_points
  # 所以每個點雲(batch_size個)都會得到自己的k個鄰居
  # 最後得到 point_cloud_neighbors:(batch_size, num_points, k, num_dims)  
  # nn_idx.shape= (batch_size, num_points, k) k爲點的維度    
  point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  
  
  # 在維度倒數第二個位置增加一個維度 (batch_size, num_points,1, num_dims)
  point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)
  
  # point_cloud_central 規模變爲 (batch_size, num_points,k, num_dims)
  point_cloud_central = tf.tile(point_cloud_central, [1, 1, k, 1])  

  # 把兩個點雲按最後一個維度連接起來
  edge_feature = tf.concat([point_cloud_central, point_cloud_neighbors-point_cloud_central], axis=-1)
  return edge_feature

我們來圖解看一下,point_cloud_neighbors-point_cloud_central 到底幹了什麼?

# nn_idx.shape= (batch_size, num_points, k) k爲點的維度
point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  

# 在維度倒數第二個位置增加一個維度 (batch_size, num_points,1, num_dims)
point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)

# point_cloud_central 規模變爲 (batch_size, num_points,k, num_dims)
point_cloud_central = tf.tile(point_cloud_central, [1, 1, k, 1])

提取到的 edge_feature 作爲 tf_util.py 文件中函數conv2d的輸入   個人理解方程(8)中的θ m和ϕm體現在這段代碼裏

def conv2d(inputs,
           num_output_channels,
           kernel_size,
           scope,
           stride=[1, 1],
           padding='SAME',
           use_xavier=True,
           stddev=1e-3,
           weight_decay=0.0,
           activation_fn=tf.nn.relu,
           bn=False,
           bn_decay=None,
           is_training=None,
           is_dist=False):
  """ 2D convolution with non-linear operation.

  Args:
    inputs: 4-D tensor variable BxHxWxC
    num_output_channels: int
    kernel_size: a list of 2 ints
    scope: string
    stride: a list of 2 ints
    padding: 'SAME' or 'VALID'
    use_xavier: bool, use xavier_initializer if true
    stddev: float, stddev for truncated_normal init
    weight_decay: float
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable

  Returns:
    Variable tensor
  """
    with tf.variable_scope(scope) as sc:
      kernel_h, kernel_w = kernel_size
      num_in_channels = inputs.get_shape()[-1].value
      kernel_shape = [kernel_h, kernel_w,
                      num_in_channels, num_output_channels]
      kernel = _variable_with_weight_decay('weights',
                                           shape=kernel_shape,
                                           use_xavier=use_xavier,
                                           stddev=stddev,
                                           wd=weight_decay)
      stride_h, stride_w = stride
      outputs = tf.nn.conv2d(inputs, kernel,
                             [1, stride_h, stride_w, 1],
                             padding=padding)
      biases = _variable_on_cpu('biases', [num_output_channels],
                                tf.constant_initializer(0.0))
      outputs = tf.nn.bias_add(outputs, biases)

      if bn:
        outputs = batch_norm_for_conv2d(outputs, is_training,
                                        bn_decay=bn_decay, scope='bn', is_dist=is_dist)

      if activation_fn is not None:
        outputs = activation_fn(outputs)
      return outputs

對於get_edge_feature提取的邊特徵,作者貌似沒有按照(8)ReLU(θm · (xj − xi) + ϕm · xi),那樣進行"+",而代碼裏變成contact操作了,然後進入conv2D進行卷積,這裏沒想明白爲什麼?理論上是一樣的麼?

4、分類整體架構部分的代碼:

def get_model(point_cloud, is_training, bn_decay=None):
  """ Classification PointNet, input is BxNx3, output Bx40 """
  batch_size = point_cloud.get_shape()[0].value
  num_point = point_cloud.get_shape()[1].value
  end_points = {}
  k = 20
  
  # 對最原始的點雲提取鄰接矩陣
  adj_matrix = tf_util.pairwise_distance(point_cloud)

  # knn操作,得到每個點的k近鄰
  nn_idx = tf_util.knn(adj_matrix, k=k)

  # 先提取一個原始點雲的邊特徵,用於估計變換矩陣
  edge_feature = tf_util.get_edge_feature(point_cloud, nn_idx=nn_idx, k=k)

  # 利用邊特徵估計變換矩陣
  with tf.variable_scope('transform_net1') as sc:
    transform = input_transform_net(edge_feature, is_training, bn_decay, K=3)

  # 變換點雲
  point_cloud_transformed = tf.matmul(point_cloud, transform)

  # 變換後點雲的鄰接矩陣
  adj_matrix = tf_util.pairwise_distance(point_cloud_transformed)

  # knn操作,得到每個點的k近鄰
  nn_idx = tf_util.knn(adj_matrix, k=k)

  # 得到變換後點雲的邊的信息 xi||(xj − xi) 對所有的 i
  edge_feature = tf_util.get_edge_feature(point_cloud_transformed, nn_idx=nn_idx, k=k)

  # 通過2D卷積得到邊的特徵
  net = tf_util.conv2d(edge_feature, 64, [1,1],
                       padding='VALID', stride=[1,1],
                       bn=True, is_training=is_training,
                       scope='dgcnn1', bn_decay=bn_decay)
  
  #最大池化操作
  net = tf.reduce_max(net, axis=-2, keep_dims=True)
  net1 = net

  adj_matrix = tf_util.pairwise_distance(net)
  nn_idx = tf_util.knn(adj_matrix, k=k)
  edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)

  net = tf_util.conv2d(edge_feature, 64, [1,1],
                       padding='VALID', stride=[1,1],
                       bn=True, is_training=is_training,
                       scope='dgcnn2', bn_decay=bn_decay)
  net = tf.reduce_max(net, axis=-2, keep_dims=True)
  net2 = net
 
  adj_matrix = tf_util.pairwise_distance(net)
  nn_idx = tf_util.knn(adj_matrix, k=k)
  edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)  

  net = tf_util.conv2d(edge_feature, 64, [1,1],
                       padding='VALID', stride=[1,1],
                       bn=True, is_training=is_training,
                       scope='dgcnn3', bn_decay=bn_decay)
  net = tf.reduce_max(net, axis=-2, keep_dims=True)
  net3 = net

  adj_matrix = tf_util.pairwise_distance(net)
  nn_idx = tf_util.knn(adj_matrix, k=k)
  edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)  
  
  net = tf_util.conv2d(edge_feature, 128, [1,1],
                       padding='VALID', stride=[1,1],
                       bn=True, is_training=is_training,
                       scope='dgcnn4', bn_decay=bn_decay)
  net = tf.reduce_max(net, axis=-2, keep_dims=True)
  net4 = net
  
  # 每個邊特徵提取後的特徵連接起來,然後進行 MLP 處理
  net_out_concat = tf.concat([net1, net2, net3, net4], axis=-1)
  net_out = tf_util.conv2d(net_out_concat, 1024, [1, 1],
                       padding='VALID', stride=[1, 1],
                       bn=True, is_training=is_training,
                       scope='agg', bn_decay=bn_decay)
 
  net_out = tf.reduce_max(net_out, axis=1, keep_dims=True)

  # MLP on global point cloud vector
  net_out = tf.reshape(net_out, [batch_size, -1])
  net_out = tf_util.fully_connected(net_out, 512, bn=True, is_training=is_training,
                                scope='fc1', bn_decay=bn_decay)
  net_out = tf_util.dropout(net_out, keep_prob=0.5, is_training=is_training,
                         scope='dp1')
  net_out = tf_util.fully_connected(net_out, 256, bn=True, is_training=is_training,
                                scope='fc2', bn_decay=bn_decay)
  net_out = tf_util.dropout(net_out, keep_prob=0.5, is_training=is_training,
                        scope='dp2')
  net_out = tf_util.fully_connected(net_out, 40, activation_fn=None, scope='fc3')

  return net_out, end_points

 

下面爲tensorflow的一些相關的操作 

tensorflow筆記:tf.reshape的詳細講解

原文鏈接:https://blog.csdn.net/abc13526222160/article/details/85867777

函數原型:目的是爲了功能改變張量(tensor)的形狀。

tf.reshape(
    tensor,
    shape,
    name=None
)

tensor形參傳入一個tensor。shape傳入一個向量,代表新tensor的維度數和每個維度的長度。如果傳入[3,4,5],就會返回一個內含各分量數值和原傳入張量一模一樣的3*4*5尺寸的張量。

如果shape傳入的向量某一個分量設置爲-1,比如[-1,4,5],那麼這個分量代表的維度尺寸會被自動計算出來。 

用法一,一個尺寸爲1*9的張量轉化爲3*3的張量:

#coding:utf-8
import tensorflow as tf
t=[1,2,3,4,5,6,7,8,9]
with tf.Session() as sess:
    print (sess.run(tf.reshape(t,[3,3])))

輸出結果:
[[1 2 3]
 [4 5 6]
 [7 8 9]]

用法二,一個尺寸爲3 * 2 * 3的張量,轉換爲第二個維度尺寸爲9的張量,即n*9的張量:

#coding:utf-8
import tensorflow as tf
t=[[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]]]
with tf.Session() as sess:
    print (sess.run(tf.reshape(t,[-1,9])))

輸出結果:顯然,n被計算爲2。
[[1 2 1 2 1 2 1 2 1]
 [2 1 2 1 2 1 2 1 2]]

用法三,僅含有單個元素的張量轉化爲標量:
t爲張量[7]

#coding:utf-8
import tensorflow as tf
t=[7]
with tf.Session() as sess:
    print (sess.run(tf.reshape(t,[])))

輸出結果:
7

tensorflow筆記:tf.shape()和(tensor)x.get_shape().as_list()

原文鏈接:https://blog.csdn.net/abc13526222160/article/details/85135517

1、Tensorflow中的tf.shape()
先說tf.shape()很顯然這個是獲取張量的大小的,用法無需多說,直接上例子吧!

import tensorflow as tf
import numpy as np
 
a=np.array([[1,2,3],[4,5,6]])
b=[[1,2,3],[4,5,6]]
c=tf.constant([[1,2,3],[4,5,6]])
 
with tf.Session() as sess:
    print(sess.run(tf.shape(a)))
    print(sess.run(tf.shape(b)))
    print(sess.run(tf.shape(c)))

輸出結果:
[2 3]
[2 3]
[2 3]

2、(Tensor)x.get_shape().as_list()
這個簡單說明一下,x.get_shape(),只有tensor(張量)纔可以使用這種方法,返回的是一個元組。

import tensorflow as tf
import numpy as np
 
a=np.array([[1,2,3],[4,5,6]])
b=[[1,2,3],[4,5,6]]
c=tf.constant([[1,2,3],[4,5,6]])
 
print(c.get_shape())                   #這裏返回的是一個元組
print(c.get_shape().as_list())         #返回的元組重新變回一個列表
 
with tf.Session() as sess:
    print(sess.run(tf.shape(a)))
    print(sess.run(tf.shape(b)))
    print(sess.run(tf.shape(c)))

輸出結果:
(2, 3)
[2, 3]
[2 3]
[2 3]
[2 3]
不是張量進行操作的話,會報錯!

下面強調一些注意點:

第一點:tensor.get_shape()返回的是元組不能放到sess.run()裏面,這個裏面只能放operation和tensor;

第二點:tf.shape()返回的是一個tensor。要想知道是多少,必須通過sess.run()

 

tf.range()函數

python中的range()的用法基本一樣,只不過這裏返回的是一個1-D的tensor。有以下兩種形式:

range(limit, delta=1, dtype=None, name='range')
range(start, limit, delta=1, dtype=None, name='range')

該數字序列開始於 start 並且將以 delta 爲增量擴展到不包括 limit 時的最大值結束,類似python的range函數。

#-*-coding:utf-8-*-

import tensorflow as tf

x=tf.range(8.0, 13.0, 2.0)
y=tf.range(10, 15)
z=tf.range(3, 1, -0.5)
w=tf.range(3)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print (sess.run(x))#輸出[  8.  10.  12.]
print (sess.run(y))#輸出[10 11 12 13 14]
print (sess.run(z))#輸出[ 3.   2.5  2.   1.5]
print (sess.run(w))#輸出[0 1 2]

 

直觀的理解tensorflow中的tf.tile()函數

原文鏈接:https://blog.csdn.net/tsyccnh/article/details/82459859

tensorflow中的tile()函數是用來對張量(Tensor)進行擴展的,其特點是對當前張量內的數據進行一定規則的複製。最終的輸出張量維度不變。

函數定義:

tf.tile(
    input,
    multiples,
    name=None
)

input是待擴展的張量,multiples是擴展方法。
假如input是一個3維的張量。那麼mutiples就必須是一個1x3的1維張量。這個張量的三個值依次表示input的第1,第2,第3維數據擴展幾倍。
具體舉一個例子:

import tensorflow as tf

a = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32)
a1 = tf.tile(a, [2, 3])
with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(tf.shape(a)))
    print(sess.run(a1))
    print(sess.run(tf.shape(a1)))


[[1. 2.]
 [3. 4.]
 [5. 6.]]
[3 2]
[[1. 2. 1. 2. 1. 2.]
 [3. 4. 3. 4. 3. 4.]
 [5. 6. 5. 6. 5. 6.]
 [1. 2. 1. 2. 1. 2.]
 [3. 4. 3. 4. 3. 4.]
 [5. 6. 5. 6. 5. 6.]]
[6 6]

tf.tile()具體的操作過程如下:
(原始矩陣應爲3*2)


請注意:上面繪圖中第一次擴展後第一維由三個數據變成兩行六個數據,多一行並不是多了一維,數據扔爲順序排列,只是爲了方便繪製而已。

每一維數據的擴展都是將前面的數據進行復制然後直接接在原數據後面。

如果multiples的某一個數據爲1,則表示該維數據保持不變。

就這樣。

 

TensorFlow的tf.concat實例詳細介紹

原文鏈接:https://blog.csdn.net/sinat_29957455/article/details/86100641

tf.concat函數:函數功能比較簡單,主要用於連接兩個數組
參數:

values:需要連接的數組
axis:從哪個維度來連接數組
例子:

一維張量

import tensorflow as tf
a = [1,2,3]
b = [4,5,6]
c = tf.concat([a,b],0)
sess = tf.InteractiveSession()
print(sess.run(c))

輸出: [1 2 3 4 5 6]

注意:axis參數不能超過數組的維度。如果超過數組的維度,如下:

    c = tf.concat([a,b],1)
1
則會報,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat',意思是數組至少是二維,axis才能爲1。

二維張量

    a = [[1,1],[2,2],[3,3]]
    b = [[4,4],[5,5],[6,6]]
    c = tf.concat([a,b],0)
    print(sess.run(c))
    
輸出:
    [[1 1]
     [2 2]
     [3 3]
     [4 4]
     [5 5]
     [6 6]]
    

    c = tf.concat([a,b],1) #等價於tf.concat([a,b],-1)
    print(sess.run(c))
    
輸出:
    [[1 1 4 4]
     [2 2 5 5]
     [3 3 6 6]]

三維張量

    a = [[[1,1],[2,2]],[[3,3],[4,4]]]
    b = [[[5,5]],[[6,6]]]

    c = tf.concat([a,b],1)
    print(sess.run(c))
    """
    [[[1 1]
      [2 2]
      [5 5]]

     [[3 3]
      [4 4]
      [6 6]]]
    """

注意:在使用tf.concat函數連接兩個數組的時候,數組該維度必須是一致的,否則會報錯,如下:

    c = tf.concat([a,b],0)
錯誤提示ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1,意思是a在第1個維度上shape是2,而b在第一個維度上shape是1。
總結:如何來判斷數組是否在該個維度上的shape是相同的呢?其實很簡單,我們根據tf.concat的axis參數來去數組的[],0表示去掉最外面的一層,1去掉兩層,以此類推,下面舉例說明一下。
如:最後一個例子中的c = tf.concat([a,b],1),我們先將a去掉最外面兩層[],變成了[1,1],[2,2]和[3,3],[4,4]],然後再將b去掉最外面兩層[],變成了[5,5]和[6,6],此時再進行concat,可以發現此時的shape是相等的。

tf.gather()用法

原文鏈接:https://blog.csdn.net/Eric_LH/article/details/83794038

原文鏈接:https://blog.csdn.net/LoseInVain/article/details/85339985

tf.gather(
    params,  # 需要被索引的張量
    indices,  # 索引
    validate_indices=None,
    name=None,
    axis=0
)

其作用很簡單,就是根據提供的indices在axis這個軸上對params進行索引,拼接成一個新的張量,其示意圖如下所示:

其中的indices可以是標量張量,向量張量或者是更高階的張量,但是其元素都必須是整數類型,比如int32,int64等,而且注意檢查不要越界了,因爲如果越界了,如果使用的CPU,則會報錯,如果在GPU上進行操作的,那麼相應的輸出值將會被置爲0,而不會報錯,因此認真檢查是否越界。

簡單的說: tf.gather(等待被取元素的張量,索引)

tf.gather根據索引,從輸入張量中依次取元素,構成一個新的張量。

索引的維度可以小於張量的維度。這時,取張量元素時,會把相應的低維當作一個整體取出來。

例如

假設輸入張量 [[1,2,3],[4,5,6],[7,8,9]] 是個二維的

如果只給一個一維索引0. 它就把[1,2,3]整體取出:

如果給兩個一維索引,0和1,它就形成[[1,2,3],[4,5,6]]

發佈了16 篇原創文章 · 獲贊 4 · 訪問量 4496
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章