图神经网络

越来越多的新生理论出现,进而出现更先进的算法,AI也是如此。最近兴起的GNN(Graph Neural Network) 进入了视野。这个东西背后的数学原理是啥,和CNN有啥不同?在讲解图神经网络之前,我想引入一个大家或多或少都比较熟悉的问题:旅行商 问题。

这个问题的描述是这样的:

要求商人从某个城市出发,经过且只经过所有城市一次,求路程最短的路径。

这个问题很经典,有时候也经常会遇到。我们很显然的会想到一个方法,将它用矩阵来描述:
a0b7e98d-4188-4f5b-85a3-b8a69e8763bb-image.png

这个问题是可以通过动态规划的方式来解决的。假设你从s出发,此时你之前做过的路径都是最优的,也就是so far so good, 那么你只要求出剩下的城市到s的最短路径即可。

当然图神经网络与这些都没有太大的联系。

GNN能解决的问题

简单来说,CNN可以解决图像问题,RNN可以解决序列与序列之间的依赖问题,那么GNN能解决啥?比如说这些东西:

0a192d2f-fe6e-4871-8170-59b62ec83b91-image.png

你是否相信未来可能用一个神经网络来解决旅行商问题. 那么这就对了,GNN正是用于处理结构性问题的。
这么讲很多人还是无法理解的,说到底,到底是什么问题需要图神经网络来解决的呢?简单的举两个例子:

  • 上面的旅行者问题,我用GNN求解可能是这样的,直接输入城市与与城市的网络图,GNN负责预测出制定出发点开始的最短路径,这是很可行的;
  • 上述用途还有点抽象,再举一个很直白的例子,人物百科,每个人物与另外一个人物都会有关系,GNN可以用来输入一个人物(节点),直接输出与他相关的人物(其他节点)。

对于第二个应用,有人会问了,我通过图数据库查询也能做到,要GAN干啥。这种问题其实很早之前就讨论过,神经网络求最优值也可以通过暴力枚举来求解,对话机器人又可以通过数据库查询来求解,但是问题是二者存在本质的区别:神经网络是通过学习进行预测的。

总结之,本质上GNN解决的问题是:图中的每一个节点都和一个label相关,GNN就是我们希望的工具,它能预测出每个节点的label。
有了label,不就有了训练神经网络的可能?

说道这里,我们就需要对GNN从入门到高级进行定位了,首先很直接的,它能解决的问题是:

  • 节点分类;

负责预测图上的没一个节点的类别。说了这么多,GNN的应用并不像图片分类,目标检测,序列预测,文本生成那么直观。但是在点云领域却很有用途!!,比如寻找到点云中每个点之间的关系,找到meshs表面之间的关系,对于3D的应用可能会比2D图像更有效。这也是未来自动驾驶,机器人领域以及AI技术落地必不可少的一环。

12fc0ef8-3071-4c9b-8c4c-e777fc99a564-image.png

Graph Nerual Networks 实例

下面来上手万一玩:

0ce9f438-a0ad-4f9d-862e-c11108a02eba-image.png

首先我们安装一下 torch-geometric 这个包:

$ pip install --upgrade torch-scatter
$ pip install --upgrade torch-sparse
$ pip install --upgrade torch-cluster
$ pip install --upgrade torch-spline-conv (optional)
$ pip install torch-geometric

非常简单的一段代码:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
print(data)