前回は、networkxを使って、グラフデータを扱う方法を紹介しました。今回は、近年話題のGNN(Graph Neural Network)を扱えるPytorch Geomericというライブラリーを紹介したいと思います。具体的には、Google Golaboratoryでのインストール方法と実際にグラフの作成することで正常に動作するか確認するところまでを紹介したいと思います。

インストール方法

# pytorchのバージョンをチェック
torch.__version__
'1.5.0+cu101'

公式リファレンスより参照

!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

ライブラリーの読み込み

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data
%matplotlib inline

データの作成

以下のような隣接リストを"graph"に格納しています。隣接リストでは、数字が頂点番号を示していて、"[0, 1]"ならば頂点0から頂点1へ枝があることを意味しています。また、無向グラフを作成したいため、"[1, 0]"という枝も追加しています。
"x"は、頂点それぞれに付随している特徴ベクトルです。頂点の特徴ベクトルとは例えるならば、twitterのようなSNSで頂点をユーザとしたときに、そのユーザのフォロー数やフォロワー数などの情報をまとめたベクトルです。"x"でいうと、頂点0で表されたユーザは10人をフォローしていて、15人にフォローされているといった具合です。

graph = [[0,1],
         [1,0],
         [2,0],
         [0,2],
         [2,5],
         [5,2],
         [2,4],
         [4,2],
         [0,3],
         [3,0],
         [4,5],
         [5,4]]
x = [[10, 15],[11, 12],[13, 15],[18, 20],[10, 20],[17, 21]]

実際にどのようなグラフになるか前回紹介したnetworkxを使って描画してみた結果が以下です。

# グラフの描画
G = nx.DiGraph()
G.add_edges_from(graph)
nx.draw_networkx(G)

次に、graphのような隣接リストの形式ではGeometricでは受け取れないため、COO形式に隣接リストを変換する必要があります。COO形式とは、枝の出発点を集めたリストと枝の到着点を集めたリストに分ける形式です。

# COO形式に変換
edge_index = np.array(graph).T.tolist()
edge_index
[[0, 1, 2, 0, 2, 5, 2, 4, 0, 3, 4, 5], [1, 0, 0, 2, 5, 2, 4, 2, 3, 0, 5, 4]]
# tensor型に変換
x = torch.tensor(x, dtype=torch.float)
edge_index = torch.tensor(edge_index, dtype=torch.long)

tensor型に変換した"x"と"edge_index"をfrom torch_geometric.data import DataでimportしておいたDataに入力することで、Geometricで扱えるデータ型に変換することができます。

data = Data(x=x, edge_index=edge_index)
print(data)
Data(edge_index=[2, 12], x=[6, 2])

上記の出力結果が示す数値の意味は以下の通りです。

  • edge_index = [始点リストと終点リスト(2), 枝数(12=6*2)]
  • x = [頂点数(6), 特徴量次元(2)]

まとめ

今回は、Pytorch GeometricのGoogle Colaboratory上でのインストール方法と簡単な動作チェックまでを行いました。
GNNのライブラリーでは、他にDeep Graph Libraryなどもありますので、気になる方はチェックしてみてください。
次回は、実際にモデルを構築してみたいと思いますので、よろしければご覧ください。