📔DiGress模型分子生成
type
status
date
slug
summary
tags
category
icon
password
论文
Abstract
DiGress模型,离散去噪扩散模型,用于生成带有分类节点和边属性的图,区别于依赖高斯噪声的连续扩散模型(使得图稀疏)。DiGress模型通过逐步修改图的结构来保持其主要结构,之后使用图变换器网络来逆转扩散过程。
引入了边缘分布保持、辅助结构特征和条件生成机制等创新
Introduction
Background
传统的连续扩散模型在保留图的核心性质方面有所不足
Motivation
现有模型难以有效处理
Challenges
- 图的特性的保持
- 分类属性(节点与边属性)
- 可扩展性(GuacaMol等复杂大规模数据集)
Method
- 离散扩散过程: 马尔可夫过程添加噪声,增删边、修改边点属性 定制噪声转移矩阵实现噪声转移,确保与原始图的边缘分布一致
- 图变换器网络: 置换等变得网络用于逆转扩散,从噪声图预测原始图 图生成—转化为—节点和边分类
- 算法改进: 边缘分布保持:调整噪声匹配图数据分布 结构特征增强:图论特征和光谱特征,弥补标准图神经网络表示能力的不足。 条件生成:实现基于目标特性的图生成
Results
- 通用图生成: 在SBM和平面图数据集上,DiGress 的性能优于现有方法。
- 分子生成: QM9、MOSES 和 GuacaMol,DiGress 的表现与当前最先进的自回归和基于片段的模型相当或更好 第一个无需依赖分子特定表示就能扩展到 GuacaMol 数据集的离散扩散模型
- 条件生成: 生成了具有特定属性(如分子偶极矩和轨道能级)的图,与未加条件的方法相比,平均绝对误差显著降低
- 消融实验: 边缘分布保持和结构特征增强对模型性能提升显著
离散扩散过程
- 扩散从一个干净的图 G0 开始,通过引入噪声逐步生成 G1,G2,…,GT。
- 噪声添加:
随机修改图的节点类型、边类型或图的连接方式:
节点类别的转换:随机改变节点的分类标签。
边的编辑:增加或删除边,或改变边的属性。
- 使用 马尔可夫链模型:
每一步的噪声过程是条件独立的,仅依赖于前一状态 Gt−1。
转移概率用转移矩阵 Qt 定义,其中 Qt[i,j] 表示从状态 i 到状态 j 的概率。
噪声只对图的节点和边进行离散修改,不会生成完全连接的噪声图。(why?)(相对应的是高斯噪声的加噪会导致可能出现全连接的噪声图,因为其在不同边缘分布上噪声概率相同,而真实的分布是绝对稀疏的,会导致模型难以学习)
离散加噪相较于均匀加噪更加“自然”且符合图数据的结构特点,避免了过多无关的边连接,保持了图的基本拓扑结构。
模型训练与分子生成过程
训练很简单,就是模型根据分子的边、点特征矩阵来进行学习,学习边、点的分布。DiGress默认的是没有任何额外特征的加入。
在生成的时候,普通的非条件生成中,模型会根据测试集中给出的分子骨架作为基准,在其余部分进行加噪后进行去噪生成,不会破坏原有的骨架结构;条件生成中,则是训练了一个regressor的回归器,该回归器会从噪声图中预测性质。
DiGress代码总结
DiscreteDenoisingDiffusion类
init
设置模型参数,维度信息
初始化训练,验证,测试指标
创建组件:
- GraphTransformer:核心模型
- noise_schedule:噪声调度器
- transition_model:转移模型
1. 训练流程相关函数
training_step()
: 训练步骤- 调用
apply_noise()
添加噪声 - 调用
compute_extra_data()
计算额外特征 - 调用
forward()
进行前向传播 - 调用
train_loss
计算损失
apply_noise()
: 向数据添加噪声- 使用
noise_schedule
计算噪声参数 - 使用
transition_model
计算转移概率 - 调用
sample_discrete_features()
采样噪声特征
compute_extra_data()
: 计算额外特征- 调用
extra_features
和domain_features
计算额外特征
forward()
: 前向传播- 拼接噪声数据和额外特征
- 调用
model
(GraphTransformer)进行预测
2. 验证/测试相关函数
validation_step()
/test_step()
: 验证/测试步骤- 类似训练步骤,但额外调用
compute_val_loss()
compute_val_loss()
: 计算验证损失- 调用
kl_prior()
计算KL散度 - 调用
compute_Lt()
计算扩散损失 - 调用
reconstruction_logp()
计算重建损失
3. 采样生成相关函数
sample_batch()
: 生成样本批次- 调用
sample_p_zs_given_zt()
逐步去噪 - 处理生成的链和最终样本
sample_p_zs_given_zt()
: 单步去噪- 调用
compute_extra_data()
和forward()
预测去噪分布 - 计算后验分布并采样
4. 辅助函数
kl_prior()
: 计算先验KL散度
compute_Lt()
: 计算扩散损失
reconstruction_logp()
: 计算重建对数概率
上一篇
卷积神经网络CNN
下一篇
千山有雪,明月泠绛
Loading...