Explainer for black box models that predict molecule properties

Related tags

Deep Learningexmol
Overview

Explaining why that molecule

GitHub tests paper docs PyPI version MIT license

exmol is a package to explain black-box predictions of molecules. The package uses model agnostic explanations to help users understand why a molecule is predicted to have a property.

Install

pip install exmol

Counterfactual Generation

Our package implements the Model Agnostic Counterfactual Compounds with STONED (MACCS) to generate counterfactuals. A counterfactual can explain a prediction by showing what would have to change in the molecule to change its predicted class. Here is an eample of a counterfactual:

This package is not popular. If the package had a logo, it would be popular.

In addition to having a changed prediction, a molecular counterfactual must be similar to its base molecule as much as possible. Here is an example of a molecular counterfactual:

counterfactual demo

The counterfactual shows that if the carboxylic acid were an ester, the molecule would be active. It is up to the user to translate this set of structures into a meaningful sentence.

Usage

Let's assume you have a deep learning model my_model(s) that takes in one SMILES string and outputs a predicted binary class. To generate counterfactuals, we need to wrap our function so that it can take both SMILES and SELFIES, but it only needs to use one.

We first expand chemical space around the prediction of interest

import exmol

# mol of interest
base = 'CCCO'

samples = exmol.sample_space(base, lambda smi, sel: my_model(smi), batched=False)

Here we use a lambda to wrap our function and indicate our function can only take one SMILES string, not a list of them with batched=False. Now we select counterfactuals from that space and plot them.

cfs = exmol.cf_explain(samples)
exmol.plot_cf(cfs)

set of counterfactuals

We can also plot the space around the counterfactual. This is computed via PCA of the affinity matrix -- the similarity with the base molecule. Due to how similarity is calculated, the base is going to be the farthest from all other molecules. Thus your base should fall on the left (or right) extreme of your plot.

cfs = exmol.cf_explain(samples)
exmol.plot_space(samples, cfs)

chemical space

Each counterfactual is a Python dataclass with information allowing it to be used in your own analysis:

print(cfs[0])
Examples(
  smiles='CCOC(=O)c1ccc(N=CN(Cl)c2ccccc2)cc1',
  selfies='[C][C][O][C][Branch1_2][C][=O][C][=C][C][=C][Branch1_1][#C][N][=C][N][Branch1_1][C][Cl][C][=C][C][=C][C][=C][Ring1][Branch1_2][C][=C][Ring1][S]',
  similarity=0.8181818181818182,
  yhat=-5.459493637084961,
  index=1807,
  position=array([-6.11371691,  1.24629293]),
  is_origin=False,
  cluster=26,
  label='Counterfactual')

Chemical Space

When calling exmol.sample_space you can pass preset=<preset>, which can be one of the following:

  • 'narrow': Only one change to molecular structure, reduced set of possible bonds/elements
  • 'medium': Default. One or two changes to molecular structure, reduced set of possible bonds/elements
  • 'wide': One through five changes to molecular structure, large set of possible bonds/elements
  • 'chemed': A restrictive set where only pubchem molecules are considered. Experimental

You can also pass num_samples as a "request" for number of samples. You will typically end up with less due to degenerate molecules. See API for complete description.

SVG

Molecules are by default drawn as PNGs. If you would like to have them drawn as SVGs, call insert_svg after calling plot_space or plot_cf

import skunk
exmol.plot_cf(exps)
svg = exmol.insert_svg(exps, mol_fontsize=16)

# for Jupyter Notebook
skunk.display(svg)

# To save to file
with open('myplot.svg', 'w') as f:
    f.write(svg)

This is done with the skunk 🦨 library.

API and Docs

Read API here. You should also read the paper (see below) for a more exact description of the methods and implementation.

Citation

Please cite Wellawatte et al.

 @article{wellawatte_seshadri_white_2021,
 place={Cambridge},
 title={Model agnostic generation of counterfactual explanations for molecules},
 DOI={10.33774/chemrxiv-2021-4qkg8},
 journal={ChemRxiv},
 publisher={Cambridge Open Engage},
 author={Wellawatte, Geemi P and Seshadri, Aditi and White, Andrew D},
 year={2021}}

This content is a preprint and has not been peer-reviewed.

Comments
  • Add LIME explanations

    Add LIME explanations

    This is a big PR!

    • [x] Document LIME function
    • [x] Compute t-stats using examples that have non-zero weights
    • [x] Add plotting code for descriptors - needs SMARTS annotations for MACCS keys (166 files)
    • [x] Add plotting code for chemical space and fit
    • [x] Description in readme
    • [x] Clean up notebooks and add documentation
    • [x] Remove extra files
    • [x] Add LIME notebooks to CI?
    opened by hgandhi2411 11
  • Error while plotting counterfactuals using plot_cf()

    Error while plotting counterfactuals using plot_cf()

    plot_cf() function errors out with the following error. This behavior is also consistent across all notebooks in paper/.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-10-b6c8ed26216e> in <module>
          1 fkw = {"figsize": (8, 6)}
          2 mpl.rc("axes", titlesize=12)
    ----> 3 exmol.plot_cf(exps, figure_kwargs=fkw, mol_size=(450, 400), nrows=1)
          4 
          5 plt.savefig("rf-simple.png", dpi=180)
    
    /gpfs/fs2/scratch/hgandhi/exmol/exmol/exmol.py in plot_cf(exps, fig, figure_kwargs, mol_size, mol_fontsize, nrows, ncols)
        682         title += f"\nf(x) = {e.yhat:.3f}"
        683         axs[i].set_title(title)
    --> 684         axs[i].imshow(np.asarray(img), gid=f"rdkit-img-{i}")
        685         axs[i].axis("off")
        686     for j in range(i, C * R):
    
    ~/.local/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
       1359     def inner(ax, *args, data=None, **kwargs):
       1360         if data is None:
    -> 1361             return func(ax, *map(sanitize_sequence, args), **kwargs)
       1362 
       1363         bound = new_sig.bind(ax, *args, **kwargs)
    
    ~/.local/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, filternorm, filterrad, resample, url, **kwargs)
       5607                               resample=resample, **kwargs)
       5608 
    -> 5609         im.set_data(X)
       5610         im.set_alpha(alpha)
       5611         if im.get_clip_path() is None:
    
    ~/.local/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
        699                 not np.can_cast(self._A.dtype, float, "same_kind")):
        700             raise TypeError("Image data of dtype {} cannot be converted to "
    --> 701                             "float".format(self._A.dtype))
        702 
        703         if self._A.ndim == 3 and self._A.shape[-1] == 1:
    
    TypeError: Image data of dtype <U14622 cannot be converted to float
    
    opened by hgandhi2411 6
  • Error after installation

    Error after installation

    Hi,

    First at all, thank you for your work!. I am obtaining a problem installing your library, o better say when I do "import exmol", I obtaing one error:"No module named 'dataclasses'".

    I have installed as: pip install exmol...

    Thanks!

    opened by PARODBE 6
  • CODEX Example

    CODEX Example

    While messing around with CODEX, I noticed it wants to compute ECFP4 fingerprints using a different method and this gives slightly different similarities. @geemi725 could you double-check the ECFP4 implementation we have is correct, or is the CODEX one correct?

    image

    opened by whitead 6
  • Object has no attribute '__code__'

    Object has no attribute '__code__'

    Hi there, I noticed that sample_space does not seem to work with class instances, because they do not have a __code__ attribute:

    import exmol
    class A:
        pass
    exmol.sample_space('C', A(), batched=True)
    
    AttributeError: 'A' object has no attribute '__code__'
    

    Is there any way around this other than forcing the call to a separate function?

    opened by oiao 5
  • The module 'exmol' has no attribute 'lime_explain'

    The module 'exmol' has no attribute 'lime_explain'

    In the notebook RF-lime.ipynb, the command

    exmol.lime_explain(space, descriptor_type=descriptor_type)

    gives a error module 'exmol' has no attribute 'lime_explain'

    Please, let me know how to fix this error. Thanks.

    opened by andresilvapimentel 5
  • Easier usage of explain

    Easier usage of explain

    Working through some examples, I've noted the following things:

    1. Descriptor type should have a default - maybe MACCS since the plots will show-up
    2. Maybe we should only save SVGs, rather than return unless prompted
    3. We should do string comparison for descriptor types using lowercase strings, so that classic and Classic and ecfp are valid.
    4. We probably shouldn't save without a filename - it is unexpected
    opened by whitead 4
  • Allow using custom list of molecules

    Allow using custom list of molecules

    Hello @whitead, this is very nice package !

    I found the new chemed option very useful and thought extending it to any list of molecule would make sense.

    Here is the main change to the API:

    explanation = exmol.sample_space(
          "CCCC",
          model,
          preset="custom", #use custom preset
          batched=False,
          data=data, # provide list of smiles or molecules
    )
    

    Let me know if this PR make sense.

    opened by maclandrol 4
  • Target molecule frequently on the edge of sample space visualization

    Target molecule frequently on the edge of sample space visualization

    In your example provided in the code, the target molecule is on the edge of the sampled distribution (in the PCA plot). I also find this happens very frequently with my experiments on my model. I think this suggests that the sampling produces molecules that are not evenly distributed around the target. I just want to verify that this is a property of the STONED sampling algorithm, and not an artifact of the visualization code (which it does not seem to be). I've attached an example of my own, for both "narrow" and "medium" presets.

    preset="narrow", nmols=10

    explain_narrow_0 05_10

    preset="medium", nmols=10

    explain_medium_0 05_10

    opened by adamoyoung 3
  • Sanitizing SMILES removes chirality information

    Sanitizing SMILES removes chirality information

    On this line of sample_space(), chirality information of origin_smiles is removed. The output is then unsuitable as input to a chirality-aware ML model, e.g. to distinguish L vs. D amino acids which are important in models of binding affinity. Could the option to skip this sanitization step be provided to the user?

    PS: Great code base and beautiful visualizations! We're finding it very useful in explaining our Gaussian Process models. The future of SAR ←→ ML looks exciting.

    opened by tianyu-lu 2
  • Release 0.5.0 on pypi

    Release 0.5.0 on pypi

    Are you planning to release 0.5.0 on pypi? I am maintaining the conda package of exmol and I would like to bump it to 0.5.0. See https://github.com/conda-forge/exmol-feedstock

    Thanks!

    opened by hadim 2
  • run_STONED couldn't generate SMILES after 30 minutes

    run_STONED couldn't generate SMILES after 30 minutes

    For certain SMILES, run_STONED() failed to generate after running for so long. So far, one SMILES known to cause such issue is

    [Na+].[Na+].[Na+].[Na+].[Na+].[O-][S](=O)(=O)OCC[S](=O)(=O)c1cccc(Nc2nc(Cl)nc(Nc3cc(cc4C=C(\C(=N/Nc5ccc6c(cccc6[S]([O-])(=O)=O)c5[S]([O-])(=O)=O)C(=O)c34)[S]([O-])(=O)=O)[S]([O-])(=O)=O)n2)c1

    Here is how I use the function: exmol.run_stoned(smiles, num_samples=10, max_mutations=1).

    opened by qcampbel 2
Releases(v2.2.1)
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
Official PyTorch Implementation of "AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting".

AgentFormer This repo contains the official implementation of our paper: AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecast

Ye Yuan 161 Dec 23, 2022
EM-POSE 3D Human Pose Estimation from Sparse Electromagnetic Trackers.

EM-POSE: 3D Human Pose Estimation from Sparse Electromagnetic Trackers This repository contains the code to our paper published at ICCV 2021. For ques

Facebook Research 62 Dec 14, 2022
Python utility to generate filesystem content for Obsidian.

Security Vault Generator Quickly parse, format, and output common frameworks/content for Obsidian.md. There is a strong focus on MITRE ATT&CK because

Justin Angel 73 Dec 02, 2022
Implementation of "Bidirectional Projection Network for Cross Dimension Scene Understanding" CVPR 2021 (Oral)

Bidirectional Projection Network for Cross Dimension Scene Understanding CVPR 2021 (Oral) [ Project Webpage ] [ arXiv ] [ Video ] Existing segmentatio

Hu Wenbo 135 Dec 26, 2022
Causal Imitative Model for Autonomous Driving

Causal Imitative Model for Autonomous Driving Mohammad Reza Samsami, Mohammadhossein Bahari, Saber Salehkaleybar, Alexandre Alahi. arXiv 2021. [Projec

VITA lab at EPFL 8 Oct 04, 2022
PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis

Daft-Exprt - PyTorch Implementation PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis The

Keon Lee 47 Dec 18, 2022
SAT: 2D Semantics Assisted Training for 3D Visual Grounding, ICCV 2021 (Oral)

SAT: 2D Semantics Assisted Training for 3D Visual Grounding SAT: 2D Semantics Assisted Training for 3D Visual Grounding by Zhengyuan Yang, Songyang Zh

Zhengyuan Yang 22 Nov 30, 2022
PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot

Progressive Growing of GANs inference in PyTorch with CelebA training snapshot Description This is an inference sample written in PyTorch of the origi

320 Nov 21, 2022
Hardware-accelerated DNN model inference ROS2 packages using NVIDIA Triton/TensorRT for both Jetson and x86_64 with CUDA-capable GPU

Isaac ROS DNN Inference Overview This repository provides two NVIDIA GPU-accelerated ROS2 nodes that perform deep learning inference using custom mode

NVIDIA Isaac ROS 62 Dec 14, 2022
用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本和PARL(paddle)版本

用强化学习玩合成大西瓜 代码地址:https://github.com/Sharpiless/play-daxigua-using-Reinforcement-Learning 用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本、PARL(paddle)版本和pytorch版本

72 Dec 17, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022
Randomizes the warps in a stock pokeemerald repo.

pokeemerald warp randomizer Randomizes the warps in a stock pokeemerald repo. Usage Instructions Install networkx and matplotlib via pip3 or similar.

Max Thomas 6 Mar 17, 2022
Think Big, Teach Small: Do Language Models Distil Occam’s Razor?

Think Big, Teach Small: Do Language Models Distil Occam’s Razor? Software related to the paper "Think Big, Teach Small: Do Language Models Distil Occa

0 Dec 07, 2021
Constrained Logistic Regression - How to apply specific constraints to logistic regression's coefficients

Constrained Logistic Regression Sample implementation of constructing a logistic regression with given ranges on each of the feature's coefficients (v

1 Dec 29, 2021
Affine / perspective transformation in Pose Estimation with Tensorflow 2

Pose Transformation Affine / Perspective transformation in Pose Estimation with Tensorflow 2 Introduction 이 repo는 pose estimation을 연구하고 개발하는 데 도움이 되기

Kim Junho 1 Dec 22, 2021
Video2x - A lossless video/GIF/image upscaler achieved with waifu2x, Anime4K, SRMD and RealSR.

Official Discussion Group (Telegram): https://t.me/video2x A Discord server is also available. Please note that most developers are only on Telegram.

K4YT3X 5.9k Dec 31, 2022
PyTorch implementation of EfficientNetV2

[NEW!] Check out our latest work involution accepted to CVPR'21 that introduces a new neural operator, other than convolution and self-attention. PyTo

Duo Li 375 Jan 03, 2023
DvD-TD3: Diversity via Determinants for TD3 version

DvD-TD3: Diversity via Determinants for TD3 version The implementation of paper Effective Diversity in Population Based Reinforcement Learning. Instal

3 Feb 11, 2022
An implementation of chunked, compressed, N-dimensional arrays for Python.

Zarr Latest Release Package Status License Build Status Coverage Downloads Gitter Citation What is it? Zarr is a Python package providing an implement

Zarr Developers 1.1k Dec 30, 2022