📔DiGress模型分子生成

type
status
date
slug
summary
tags
category
icon
password

论文

Abstract

DiGress模型,离散去噪扩散模型,用于生成带有分类节点和边属性的图,区别于依赖高斯噪声的连续扩散模型(使得图稀疏)。DiGress模型通过逐步修改图的结构来保持其主要结构,之后使用图变换器网络来逆转扩散过程。
引入了边缘分布保持、辅助结构特征和条件生成机制等创新

Introduction

Background

传统的连续扩散模型在保留图的核心性质方面有所不足

Motivation

现有模型难以有效处理

Challenges

  1. 图的特性的保持
  1. 分类属性(节点与边属性)
  1. 可扩展性(GuacaMol等复杂大规模数据集)
 

Method

  1. 离散扩散过程: 马尔可夫过程添加噪声,增删边、修改边点属性 定制噪声转移矩阵实现噪声转移,确保与原始图的边缘分布一致
  1. 图变换器网络: 置换等变得网络用于逆转扩散,从噪声图预测原始图 图生成—转化为—节点和边分类
  1. 算法改进: 边缘分布保持:调整噪声匹配图数据分布 结构特征增强:图论特征和光谱特征,弥补标准图神经网络表示能力的不足。 条件生成:实现基于目标特性的图生成

Results

  1. 通用图生成: 在SBM和平面图数据集上,DiGress 的性能优于现有方法。
  1. 分子生成: QM9、MOSES 和 GuacaMol,DiGress 的表现与当前最先进的自回归和基于片段的模型相当或更好 第一个无需依赖分子特定表示就能扩展到 GuacaMol 数据集的离散扩散模型
  1. 条件生成: 生成了具有特定属性(如分子偶极矩和轨道能级)的图,与未加条件的方法相比,平均绝对误差显著降低
  1. 消融实验: 边缘分布保持和结构特征增强对模型性能提升显著

离散扩散过程

  • 扩散从一个干净的图 G0 开始,通过引入噪声逐步生成 G1,G2,…,GT。
  • 噪声添加
    • 随机修改图的节点类型、边类型或图的连接方式:
      节点类别的转换:随机改变节点的分类标签。
      边的编辑:增加或删除边,或改变边的属性。
  • 使用 马尔可夫链模型
    • 每一步的噪声过程是条件独立的,仅依赖于前一状态 Gt−1。
      转移概率用转移矩阵 Qt 定义,其中 Qt[i,j] 表示从状态 i 到状态 j 的概率。
 
噪声只对图的节点和边进行离散修改,不会生成完全连接的噪声图。(why?)(相对应的是高斯噪声的加噪会导致可能出现全连接的噪声图,因为其在不同边缘分布上噪声概率相同,而真实的分布是绝对稀疏的,会导致模型难以学习)
离散加噪相较于均匀加噪更加“自然”且符合图数据的结构特点,避免了过多无关的边连接,保持了图的基本拓扑结构。
 

模型训练与分子生成过程

训练很简单,就是模型根据分子的边、点特征矩阵来进行学习,学习边、点的分布。DiGress默认的是没有任何额外特征的加入。
在生成的时候,普通的非条件生成中,模型会根据测试集中给出的分子骨架作为基准,在其余部分进行加噪后进行去噪生成,不会破坏原有的骨架结构;条件生成中,则是训练了一个regressor的回归器,该回归器会从噪声图中预测性质。
 

DiGress代码总结

DiscreteDenoisingDiffusion类

init

设置模型参数,维度信息
初始化训练,验证,测试指标
创建组件:
  • GraphTransformer:核心模型
  • noise_schedule:噪声调度器
  • transition_model:转移模型

1. 训练流程相关函数

  • training_step(): 训练步骤
    • 调用apply_noise()添加噪声
    • 调用compute_extra_data()计算额外特征
    • 调用forward()进行前向传播
    • 调用train_loss计算损失
  • apply_noise(): 向数据添加噪声
    • 使用noise_schedule计算噪声参数
    • 使用transition_model计算转移概率
    • 调用sample_discrete_features()采样噪声特征
  • compute_extra_data(): 计算额外特征
    • 调用extra_featuresdomain_features计算额外特征
  • forward(): 前向传播
    • 拼接噪声数据和额外特征
    • 调用model(GraphTransformer)进行预测

2. 验证/测试相关函数

  • validation_step()/test_step(): 验证/测试步骤
    • 类似训练步骤,但额外调用compute_val_loss()
  • compute_val_loss(): 计算验证损失
    • 调用kl_prior()计算KL散度
    • 调用compute_Lt()计算扩散损失
    • 调用reconstruction_logp()计算重建损失

3. 采样生成相关函数

  • sample_batch(): 生成样本批次
    • 调用sample_p_zs_given_zt()逐步去噪
    • 处理生成的链和最终样本
  • sample_p_zs_given_zt(): 单步去噪
    • 调用compute_extra_data()forward()预测去噪分布
    • 计算后验分布并采样

4. 辅助函数

  • kl_prior(): 计算先验KL散度
  • compute_Lt(): 计算扩散损失
  • reconstruction_logp(): 计算重建对数概率
上一篇
卷积神经网络CNN
下一篇
千山有雪,明月泠绛
Loading...