圖神經網絡框架DGL學習 102——圖、節點、邊及其特徵賦值

101(入門)以後就是開始具體逐項學習圖神經網絡的各個細節。下面介紹:
1.如何構建圖
2.將特徵賦給節點或者邊,及查詢方法
這算是圖神經網絡最基礎最基礎的部分了。

一、如何構建圖

DGL中創建的圖的方法有:
1. 通過(u, v),u和v分別爲起始節點和終止節點的列表,可以是numpy矩陣也可以是tensor
2. scipy中的稀疏矩陣,該稀疏矩陣儲存這圖的鄰接矩陣
3. networkx圖對象轉化
4. 逐步添加節點與邊

先導入所有可能需要的模塊

import networkx as nx
import dgl
import torch as th
import numpy as np
import scipy.sparse as spp

1. 通過起始節點u和終止節點v的列表構建圖

u = th.tensor([0,0,0,0,0]) #起始節點
v = th.tensor([1,2,3,4,5]) #終止節點
star1 = dgl.DGLGraph((u,v))
nx.draw(star1.to_networkx(), with_labels=True) #可視化
# plt.show()

結果如下圖:
在這裏插入圖片描述
如果u、v之一是標量,那麼DGL會自動使用boadscat機制,適應數組的長度

plt.clf()
u = th.tensor(0)
v = th.tensor([1,2,3,4,5])
star2= dgl.DGLGraph((u,v))
nx.draw(star2.to_networkx(), with_labels=True)

結果與上圖類似

2. scipy中的稀疏矩陣,該稀疏矩陣儲存這圖的鄰接矩

稀疏矩陣是圖的鄰接矩陣

u = th.tensor([0,0,0,0,0])
v = th.tensor([1,2,3,4,5])
adj = spp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) #稀疏矩陣
star3 = dgl.DGLGraph(adj)
nx.draw(star3.to_networkx(), with_labels=True)
plt.show()

結果如上圖

3.networkx圖對象轉化

DGL中的圖與networkx中的圖是可以相互轉化的

g_nx = nx.petersen_graph() #networkx內置的彼得森圖
g_dgl = dgl.DGLGraph(g_nx) #生成dgl的圖

import matplotlib.pyplot as plt
plt.subplot(121)
nx.draw(g_nx, with_labels=True)
plt.title('Networkx')
plt.subplot(122)
nx.draw(g_dgl.to_networkx(), with_labels=True)
plt.title('DGL')
plt.show()

結果如下圖,可以看出二者等價。
在這裏插入圖片描述

4. 逐步添加節點與邊

plt.clf()
g = dgl.DGLGraph()
g.add_nodes(10)
for i in range(0,5):
    g.add_edge(i, 0) #逐個添加邊
src = list(range(5,8))
dst = th.tensor([0]*3)
g.add_edges(src, dst) #根據邊起點,終點,一次性添加多條邊
src = th.tensor([8,9])
dst = th.tensor([0,0])
g.add_edges(src, dst)
nx.draw(g.to_networkx(), with_labels=True)
plt.show()

結果如下:
在這裏插入圖片描述

二、將特徵賦給節點或者邊,及查詢方法

先構建一個圖

g = dgl.DGLGraph()
g.add_nodes(10)
for i in range(0,10):
    g.add_edge(i, 0) #逐個添加邊

將特徵分配到節點或者邊,使用字典的格式存儲 {名字:特徵張量},稱之爲fields。
ndata是訪問圖中節點數據的語法,類似於數組,通過切片訪問
edata是訪問圖中邊數據的語法,類似於數組,通過切片訪問

1. 特徵賦給節點或者邊,及查詢方法

x = th.randn(10, 3)
print(x)
g.ndata['x'] = x
print('第0個節點:', g.ndata['x'][0])
print('第1, 2個節點:', g.ndata['x'][[1, 2]])
print('第1, 2個節點:', g.ndata['x'][th.tensor([1, 2])])

print('圖中邊的數量:', len(g.edata)) #沒有賦值之前是空的
g.edata['w'] = th.randn(10, 2)
print(g.edata['w'])
print('第0個邊:', g.ndata['x'][0])
print('第1, 2個邊:', g.ndata['x'][[1, 2]])
print('第1, 2個邊:', g.ndata['x'][th.tensor([1, 2])])
#通過邊的起始節點和終止節點的ID進行訪問
print('起始節點爲1,終止節點爲0的邊:', g.edata['w'][g.edge_id(1, 0)])
print('起始節點分別爲1,2,3,終止節點均爲0的邊:', g.edata['w'][g.edge_ids([1,2,3,], 0)])
print('起始節點分別爲1,2,3,終止節點均爲0的邊:', g.edata['w'][g.edge_ids([1,2,3,], [0,0,0])])

g.ndata['feats'] = th.zeros((10, 4))
print(g.node_attr_schemes()) #輸出圖中每一個特徵shape等信息

#刪除節點或者邊的數據
g.ndata.pop('feats')
print(g.node_attr_schemes()) #輸出圖中每一個特徵shape等信息

2.存在重複的邊的情況

有的圖可能存在兩條重合的邊,例如:1->0

lt.clf()
g_mutil = dgl.DGLGraph()
g_mutil.add_nodes(10)
g_mutil.ndata['x'] = th.randn(10, 2)
g_mutil.add_edges(list(range(1,10)),0) #圖中只有9條邊
g_mutil.add_edge(1,0) #再次添加一條1->0的邊,此時有兩條邊了, 分別爲第一條邊和最後一條邊
# 此時圖中共計10條邊
g_mutil.edata['w'] = th.randn(10, 2)
print('第一條邊數據:', g_mutil.edges[0].data['w'])
print('最後一條邊數據', g_mutil.edges[9].data['w'])

g_mutil.edges[0].data['w'] = th.zeros(1, 2) #修改第一條邊的數據
print('第一條邊數據:', g_mutil.edges[0].data['w'])
print('第一條邊數據:', g_mutil.edges[0].data)

一些節點和邊的調用方法:

#所有的邊
print(g_mutil.edges())

#所有的邊
print(g_mutil.nodes())

#根據節點找邊
eid_10 = g_mutil.edge_id(1, 0, return_array=True)
print(eid_10)
g_mutil.edges[eid_10].data['w'] = th.ones(len(eid_10), 2)
print(g_mutil.edata['w'])

nx.draw(g_mutil.to_networkx(), with_labels=True)
plt.show()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章