【论文源码学习】GAT代码运行踩坑 - Pytorch版
@TOC
前言
Hello!
非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~
自我介绍 ଘ(੭ˊᵕˋ)੭
昵称:海轰
标签:程序猿|C++选手|学生
简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金,有幸在竞赛中拿过一些国奖、省奖…已保研。
学习经验:扎实基础 + 多做笔记 + 多敲代码 + 多思考 + 学好英语!
唯有努力💪
知其然 知其所以然!
本文仅记录自己感兴趣的内容
简介
-
原文:GRAPH ATTENTION NETWORKS
-
会议: ICLR 2018 (CCF-A)
笔记
步骤1
下载源码,解压
在Pycharm中打开此项目
步骤2
利用conda配置环境
这里是作者给出的配置文件(environment.yml)
name: pytorch-gat
channels:
- defaults
- pytorch
dependencies:
- python==3.8.5
- pip
- pytorch==1.7.0
- pip:
- matplotlib==3.3.3
- GitPython==3.1.2
- jupyter==1.0.0
- numpy==1.19.2
- scipy==1.5.4
- scikit-learn==0.24.0
- tensorboard==2.2.2
- networkx==2.5
- python-igraph==0.8.3
- pycairo==1.20.0
但是我利用这个yml文件创建环境失败,有些包安装不上
现在回过头来看,可能当时安装方法错了
这里贴出我配置的一个环境
python版本为3.6(其实应该安装3.8的)
name: gat_pytorch
channels:
- defaults
dependencies:
- _pytorch_select=0.1=cpu_0
- absl-py=0.15.0=pyhd3eb1b0_0
- aiohttp=3.7.4.post0=py36h9ed2024_2
- async-timeout=3.0.1=py36hecd8cb5_0
- attrs=21.4.0=pyhd3eb1b0_0
- blas=1.0=mkl
- blinker=1.4=py36hecd8cb5_0
- brotlipy=0.7.0=py36h9ed2024_1003
- bzip2=1.0.8=h1de35cc_0
- c-ares=1.18.1=hca72f7f_0
- ca-certificates=2022.07.19=hecd8cb5_0
- cachetools=4.2.2=pyhd3eb1b0_0
- cairo=1.16.0=h691a603_2
- certifi=2021.5.30=py36hecd8cb5_0
- cffi=1.14.6=py36h2125817_0
- chardet=4.0.0=py36hecd8cb5_1003
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- click=8.0.3=pyhd3eb1b0_0
- coverage=5.5=py36h9ed2024_2
- cryptography=3.4.7=py36h2fd3fbb_0
- curl=7.84.0=hca72f7f_0
- cycler=0.11.0=pyhd3eb1b0_0
- cython=0.29.24=py36h23ab428_0
- dataclasses=0.8=pyh4f3eec9_6
- decorator=4.4.2=pyhd3eb1b0_0
- expat=2.4.4=he9d5cce_0
- fontconfig=2.13.1=ha9ee91d_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- git=2.34.1=pl5262h74264fa_0
- gitdb=4.0.7=pyhd3eb1b0_0
- gitpython=3.1.18=pyhd3eb1b0_1
- glib=2.69.1=h8346a28_1
- google-auth=2.6.0=pyhd3eb1b0_0
- google-auth-oauthlib=0.4.1=py_2
- grpcio=1.36.1=py36h97de6d8_1
- icu=58.2=h0a44026_3
- idna=3.3=pyhd3eb1b0_0
- idna_ssl=1.1.0=py36hecd8cb5_0
- importlib-metadata=4.8.1=py36hecd8cb5_0
- intel-openmp=2022.0.0=hecd8cb5_3615
- jpeg=9e=hca72f7f_0
- kiwisolver=1.3.1=py36h23ab428_0
- krb5=1.19.2=hcd88c3b_0
- lcms2=2.12=hf1fd2bf_0
- lerc=3.0=he9d5cce_0
- libcurl=7.84.0=h6dfd666_0
- libcxx=12.0.0=h2f01273_0
- libdeflate=1.8=h9ed2024_5
- libedit=3.1.20210910=hca72f7f_0
- libev=4.33=h9ed2024_1
- libffi=3.3=hb1e8313_2
- libgfortran=3.0.1=h93005f0_2
- libiconv=1.16=hca72f7f_2
- libnghttp2=1.46.0=ha29bfda_0
- libpng=1.6.37=ha441bb4_0
- libprotobuf=3.17.2=h2842e9f_1
- libssh2=1.10.0=h0a4fc7d_0
- libtiff=4.4.0=h2ef1027_0
- libwebp-base=1.2.2=hca72f7f_0
- libxml2=2.9.14=hbf8cd5e_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- markdown=3.3.4=py36hecd8cb5_0
- matplotlib=3.3.4=py36hecd8cb5_0
- matplotlib-base=3.3.4=py36h8b3ea08_0
- mkl=2019.4=233
- mkl-service=2.3.0=py36h9ed2024_0
- mkl_fft=1.3.0=py36ha059aab_0
- mkl_random=1.1.1=py36h959d312_0
- multidict=5.1.0=py36h9ed2024_2
- ncurses=6.3=hca72f7f_3
- networkx=2.5.1=pyhd3eb1b0_0
- ninja=1.10.2=hecd8cb5_5
- ninja-base=1.10.2=haf03e11_5
- numpy=1.19.2=py36h456fd55_0
- numpy-base=1.19.2=py36hcfb5961_0
- oauthlib=3.2.0=pyhd3eb1b0_1
- olefile=0.46=py36_0
- openjpeg=2.4.0=h66ea3da_0
- openssl=1.1.1q=hca72f7f_0
- pcre=8.45=h23ab428_0
- pcre2=10.37=he7042d7_1
- perl=5.26.2=h4e221da_0
- pillow=8.3.1=py36ha4cf6ea_0
- pip=21.2.2=py36hecd8cb5_0
- pixman=0.40.0=h9ed2024_1
- protobuf=3.17.2=py36h23ab428_0
- pyasn1=0.4.8=pyhd3eb1b0_0
- pyasn1-modules=0.2.8=py_0
- pycairo=1.19.1=py36h06c6e95_0
- pycparser=2.21=pyhd3eb1b0_0
- pyjwt=2.1.0=py36hecd8cb5_0
- pyopenssl=21.0.0=pyhd3eb1b0_1
- pyparsing=3.0.4=pyhd3eb1b0_0
- pysocks=1.7.1=py36hecd8cb5_0
- python=3.6.13=h88f2d9e_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.4.0=cpu_py36hf9bb1df_0
- readline=8.1.2=hca72f7f_1
- requests=2.27.1=pyhd3eb1b0_0
- requests-oauthlib=1.3.0=py_0
- rsa=4.7.2=pyhd3eb1b0_1
- scipy=1.5.2=py36h912ce22_0
- setuptools=58.0.4=py36hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_1
- smmap=4.0.0=pyhd3eb1b0_0
- sqlite=3.39.2=h707629a_0
- tensorboard=2.6.0=py_1
- tensorboard-data-server=0.6.0=py36h5896577_0
- tensorboard-plugin-wit=1.6.0=py_0
- tk=8.6.12=h5d9f67b_0
- tornado=6.1=py36h9ed2024_0
- typing-extensions=4.1.1=hd3eb1b0_0
- typing_extensions=4.1.1=pyh06a4308_0
- urllib3=1.26.8=pyhd3eb1b0_0
- werkzeug=2.0.3=pyhd3eb1b0_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=hca72f7f_1
- yarl=1.6.3=py36h9ed2024_0
- zipp=3.6.0=pyhd3eb1b0_0
- zlib=1.2.12=h4dc903c_2
- zstd=1.5.2=hcb37349_0
- pip:
- igraph==0.9.11
- joblib==1.1.0
- scikit-learn==0.24.2
- sklearn==0.0
- texttable==1.6.4
- threadpoolctl==3.1.0
步骤3
运行py文件即可
一些踩坑
import igraph as ig报错
解决办法:pip install igraph(使用conda install igraph提示失败)
utils中utils.py中下面两句代码可注释(git可以不用)
- import git
- “commit_hash”: git.Repo(search_parent_directories=True).head.object.hexsha,
建议运行两个训练文件
因为python版本自己安装为3.6 pytorch版本太低了
运行playground.py提示版本pytorch版本过低
运行结果
训练cora数据集
训练ppi数据集
结语
文章仅作为个人学习笔记记录,记录从0到1的一个过程
希望对您有一点点帮助,如有错误欢迎小伙伴指正
- 点赞
- 收藏
- 关注作者
评论(0)