Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

Overview

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Breaking equivariance

    Breaking equivariance

    Hi, Thanks a lot for your work!!

    I was running some equivariance tests on your implementation of the SE3-Transformer and found that it is not always conserved. It does not break every time and unfortunately I do not know where the bug is.

    I have appended an image with an example of equivariance not being conserved.

    image

    To generate this image I used the following code:

        import numpy as np
        import matplotlib.pyplot as plt
        import torch
        
        from se3_transformer_pytorch import SE3Transformer
    

    Make some data

        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:].float()
        feat = torch.randint(0, 20, (1, geom.shape[1],1)).float()
        
        def rot_matrix(x):
            # Rotation matrix
            a ,b ,c = 2*np.pi*x
            return np.array([
                [np.cos(a)*np.cos(b), np.cos(a)*np.sin(b)*np.sin(c)- np.sin(a)*np.cos(c), np.cos(a)*np.sin(b)*np.cos(c)+ np.sin(a)*np.sin(c)],
                [np.sin(a)*np.cos(b), np.sin(a)*np.sin(b)*np.sin(c)+ np.cos(a)*np.cos(c), np.sin(a)*np.sin(b)*np.cos(c)- np.cos(a)*np.sin(c)],
                [-np.sin(b)         , np.cos(b)*np.sin(c)                               , np.cos(b)*np.cos(c)                               ]
            ])
    

    Initialize model

        mdl = SE3Transformer(
            dim = 1,
            depth = 3,
            input_degrees = 1,
            num_degrees = 2,
            output_degrees = 2,
            reduce_dim_out = True,
        )
        
        def model(geom,feat):
            return geom + mdl(feat,geom, return_type = 1)
    

    Check Rotation Invariance:

        with torch.no_grad():
            
            Q = torch.tensor(rot_matrix(np.random.random(3))).float()
            prerotated = model(geom @ Q, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) @ Q).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Rotated", "Post-Rotated"])
            plt.show()
    

    Check Translation Invariance:

        with torch.no_grad():
            
            x0 = 1*torch.rand(3)
            prerotated = model(geom + x0, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) + x0).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Translated", "Post-Translated"])
            plt.show()
    

    Hope this helps!!

    opened by brennanaba 18
  • How to populate input variable length data

    How to populate input variable length data

    Hello, thank you for your work. I am using the implementation of your SE3-Transformer as part of my project, but I have encountered some problems when processing variable length data input into your model. I do not know how to fill in features, coordinates and masks to meet the needs of the model.

    The processing of input data in my DataSet is as follows: image I fill the coordinates and features with zeros to ensure that their input dimensions are (246,3) and (246,20) respectively, and fill the mask with true and false bool types. In this case,there are 49 valid values,.So the shapes with valid features and coordinates are (49,20) and (49,3) and the rest are padded to zero. The first 49 of the mask is true and the rest is false

    Then I check through the torch.utils.data.DataLoader get input dimension is no problem: image My network looks like this, A typical binary classifier,I output the result of se3:

    image

    However, the output of se3 is NaN as shown in the figure below. image The first 48 have results, the rest are NaN, not clear what the problem is

    The next batch may all be NaN due to model weight changes image The output of SE3 causes all my losses to be NaN

    I think these NaN are probably caused by my incorrect filling method. I have also tried to fill in "1", but it was also ineffective. Maybe the mask I typed in can't have false, and I'm confused about that. Could you please tell me how to correctly fill coordinates, features and masks or give me some advice on using your SE3 model to handle variable length data. Hope this helps! ! thank you

    opened by zyk19981118 8
  • CPU/CUDA masking error

    CPU/CUDA masking error

    Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to th_masked_scatter_bool

    When using the nightly Pytorch, the error message is:

    RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for masked_scatter_)

    I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

    opened by denjots 8
  • small bug

    small bug

    https://github.com/lucidrains/se3-transformer-pytorch/blob/7c79998e4d84ec6bd6b6d4b916c6bf30b870b75b/se3_transformer_pytorch/se3_transformer_pytorch.py#L301

    should be if isinstance(m,nn.Linear)

    nbd, but thought you might wanna know.

    opened by MattMcPartlon 5
  • INTERNAL ASSERT FAILED

    INTERNAL ASSERT FAILED

    Hi - I'm not sure if this is actually a bug or if I'm just expecting too much, but the following code snippet bombs with a PyTorch internal assert failure:

    RuntimeError: sub_iter.strides(0)[0] == 0INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1618643394934/work/aten/src/ATen/native/cuda/Reduce.cuh":929, please report a bug to PyTorch.

    That's using the latest PyTorch nightly. Tried with both CUDA 10.2 and 11.1.

    Looking at the PyTorch bug reports, this seems to suggest that there is massive internal memory allocation being triggered when summing over a large tensor. It seems to have been sitting open for a year and I'm not sure if they even see it as an actual bug there. Certainly it doesn't seem to be a priority.

    Problem with this is that this isn't a particularly large data input, or a large model (to say the least!), and my GPU has 40 Gb of RAM, so, scalability of this particular transformer model seems very minimal as it currently stands. Changing the dim from 128 to 64 does at least allow it to run.

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(dim = 128, heads = 1, depth = 1, dim_head = 1, num_degrees = 1, input_degrees=1, output_degrees=1).cuda()
    
    feats = torch.randn(1, 200, 128).cuda()
    coords = torch.randn(1, 200, 3).cuda()
    
    out = model(feats, coords)
    
    
    opened by denjots 4
  • question about non scalar output

    question about non scalar output

    Hello,

    I am interested in using your implementation on my dataset. There are two things I want to check with you

    1. The property I want to predict is a 3x3 symmetric PSD matrix.
    2. There are some edge features (one categorical feature, one continuous feature) besides the coordinate difference

    I was wondering does the current se3-transformer can work with such a scenario? Thanks!

    opened by Chen-Cai-OSU 3
  • Reversible flag odd results

    Reversible flag odd results

    Sorry if it's expected behaviour again, but with a slight tweak of your new example code I get unexpected results when your new reversible option is used to return type 1 data. Here's the code I'm running...

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(
        num_tokens = 20,
        dim = 32,
        dim_head = 32,
        heads = 4,
        depth = 12,             # 12 layers
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True,
        reversible = True       # set reversible to True
    ).cuda()
    
    atoms = torch.randint(0, 4, (1, 50)).cuda()
    coors = torch.randn(1, 50, 3).cuda()
    mask  = torch.ones(1, 50).bool().cuda()
    
    pred = model(atoms, coors, mask = mask, return_type = 1)
    
    loss = pred.mean()
    print(loss)
    loss.backward()
    
    

    Without the reversible flag, the loss is close to zero as might be expected as the input coords are centered on the origin. However, when reversible is set, a very high value is produced which is going to blow up training if this was for real. Is this an expected side effect of reversible nets? Is some kind of normalization essential in that case? Or am I just doing something wrong here?

    opened by denjots 3
  • faster loop

    faster loop

    My small grain on sand for this project ;) : at least don't deal with python appends which are 2x slower than list comprehension.

    If this is of any help, here are some considerations (there might be misunderstandings on my side due to the decorators and so on):

    • i have my reservations on the utility of the line 62 in utils.py
    • would't make more sense to start the for i modified (line 122 of utils.py) in reverse order, then use the cached calculations for the lpmv() ?
    • same case (loop in reverse order) for the for in line 148 of basis.py?
    • if using the scipy.special.poch (which can deal with np arrays) instead of the custom pochhammer implementation, all operations inside get_spherical_harmonics_element are vectorizeable but the lpmv function call.
      • My sense is that the lpmv, get_spherical_harmonics_element and get_spherical_harmonics could be all wrapped in a single function (lower reusability / extension... so maybe doing the inverse loop order and caching is enough).
    opened by hypnopump 1
  • Question about continuous edge features

    Question about continuous edge features

    Thanks for all of the work!

    I am working on a simple proof of concept with your model. Ideally, I would like to perform multidimensional scaling. i.e. given a distance matrix recover corresponding coordinates.

    I was wondering if distance information could be passed as an edge feature (continuous rather than categorical information). Is it possible to do this with the current implementation?

    Thanks again and I appreciate the help!

    opened by MattMcPartlon 1
  • denoise.py bugfix

    denoise.py bugfix

    Fixes issue related to the constructor

    Traceback (most recent call last):
      File "/home/jcastellanos/projects/se3-transformer-pytorch/denoise.py", line 22, in <module>
        transformer = SE3Transformer(
      File "/home/jcastellanos/projects/se3-transformer-pytorch/se3_transformer_pytorch/se3_transformer_pytorch.py", line 1072, in __init__
        self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
    AttributeError: 'NoneType' object has no attribute 'keys'
    
    opened by javierbq 0
  • CUDA out of memory

    CUDA out of memory

    Thanks for your great job!

    The se3-transformer is powerful, but seems to be memory exhaustive.

    I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).

    model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )

    num_points = 512
    feats = torch.randn(1, num_points, 20)
    coors = torch.randn(1, num_points, 3)
    mask = torch.ones(1, num_points).bool()
    

    Does this error relate to the version of pytorch? and how can I fix it?

    opened by PengCheng-NUDT 0
  • SE3Transformer constructor hangs

    SE3Transformer constructor hangs

    I am trying to run an example from the README. The code is:

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    print('Initialising model...')
    model = SE3Transformer(
        dim = 512,
        heads = 8,
        depth = 6,
        dim_head = 64,
        num_degrees = 4,
        valid_radius = 10
    )
    
    print('Running model...')
    feats = torch.randn(1, 1024, 512)
    coors = torch.randn(1, 1024, 3)
    mask  = torch.ones(1, 1024).bool()
    
    out = model(feats, coors, mask) # (1, 1024, 512)
    

    The output hangs on 'Initialising model...' and eventually the kernel dies.

    Any ideas why this would be happening?

    Here is my pip freeze:

    anyio==3.2.1
    argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
    astunparse==1.6.3
    async-generator==1.10
    attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
    axial-positional-embedding==0.2.1
    Babel==2.9.1
    backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
    biopython==1.79
    bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
    cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
    certifi==2021.5.30
    cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
    chardet==4.0.0
    click==8.0.1
    configparser==5.0.2
    decorator==4.4.2
    defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
    dgl-cu101==0.4.3.post2
    dgl-cu110==0.6.1
    docker-pycreds==0.4.0
    egnn-pytorch==0.2.6
    einops==0.3.0
    En-transformer==0.3.8
    entrypoints==0.3
    equivariant-attention @ file:///workspace/projects/se3-transformer-public
    filelock==3.0.12
    gitdb==4.0.7
    GitPython==3.1.18
    graph-transformer-pytorch==0.0.1
    h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
    huggingface-hub==0.0.12
    idna==2.10
    importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
    ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
    ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
    ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
    ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
    jedi==0.17.0
    Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
    joblib==1.0.1
    json5==0.9.6
    jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
    jupyter==1.0.0
    jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
    jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
    jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
    jupyter-server==1.8.0
    jupyter-tensorboard==0.2.0
    jupyterlab==3.0.16
    jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
    jupyterlab-server==2.6.0
    jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
    jupytext==1.11.3
    lie-learn @ git+https://github.com/AMLab-Amsterdam/[email protected]
    llvmlite==0.36.0
    local-attention==1.4.1
    markdown-it-py==1.1.0
    MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
    mdit-py-plugins==0.2.8
    mdtraj==1.9.6
    mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
    mkl-fft==1.3.0
    mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
    mkl-service==2.3.0
    mp-nerf==0.1.11
    nbclassic==0.3.1
    nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
    nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
    nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
    nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
    networkx==2.5.1
    notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
    numba==0.53.1
    numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
    packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
    pandas==1.2.4
    pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
    parso @ file:///tmp/build/80754af9/parso_1617223946239/work
    pathtools==0.1.2
    performer-pytorch==1.0.11
    pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
    pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
    ProDy==2.0
    prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
    promise==2.3
    prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
    protobuf==3.17.3
    psutil==5.8.0
    ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
    py3Dmol==0.9.1
    pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
    Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
    pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
    pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
    python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
    pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
    PyYAML==5.4.1
    pyzmq==20.0.0
    qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
    QtPy==1.9.0
    regex==2021.4.4
    requests==2.25.1
    sacremoses==0.0.45
    scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
    se3-transformer-pytorch==0.8.10
    Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
    sentry-sdk==1.1.0
    shortuuid==1.0.1
    sidechainnet==0.6.0
    six @ file:///tmp/build/80754af9/six_1623709665295/work
    smmap==4.0.0
    sniffio==1.2.0
    subprocess32==3.5.4
    terminado==0.9.4
    testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
    tokenizers==0.10.3
    toml==0.10.2
    torch==1.9.0
    tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
    tqdm==4.61.1
    traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
    transformers==4.8.0
    typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
    urllib3==1.26.5
    wandb==0.10.32
    wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
    webencodings==0.5.1
    websocket-client==1.1.0
    widgetsnbextension==3.5.1
    zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
    

    Here is a summary of my system info (lshw -short):

    H/W path    Device  Class      Description
    ==========================================
                        system     Computer
    /0                  bus        Motherboard
    /0/0                memory     59GiB System memory
    /0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
    /0/100              bridge     440FX - 82441FX PMC [Natoma]
    /0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
    /0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
    /0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
    /0/100/2            display    GD 5446
    /0/100/3            network    Elastic Network Adapter (ENA)
    /0/100/1e           display    GK210GL [Tesla K80]
    /0/100/1f           generic    Xen Platform Device
    /1          eth0    network    Ethernet interface
    
    opened by mpdprot 1
  • Whether SE3 needs pre-training

    Whether SE3 needs pre-training

    Thank you for your work. I used your reproduced SE3 as a part of my model, but the current test effect is not very good. I guess it may be because I do not have a good understanding of your model. Here are my questions:

    1. Does your model need pre-training?
    2. Can I train SE3 Transformer with the full connection layer that comes after it? Good advice is also welcome
    opened by zyk19981118 2
  • multiple molecules cases

    multiple molecules cases

    Hi,

    I use normally dataloader from PyG to handle my molecules dataset.

    Can you provide an example to make a real multistep epoch model please ?

    I run your sample code it's working but I need better understand input format as well as output format which for me would be x,y,z ?

    image

    thanks

    opened by thegodone 2
Releases(0.9.0)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
CT Based COVID 19 Diagnose by Image Processing and Deep Learning

This project proposed the deep learning and image processing method to undertake the diagnosis on 2D CT image and 3D CT volume.

1 Feb 08, 2022
Predict bus arrival time using VertexAI and Nvidia's Jetson Nano

bus_prediction predict bus arrival time using VertexAI and Nvidia's Jetson Nano imagenet the command for imagenet.py look like this python3 /path/to/i

10 Dec 22, 2022
Benchmarks for Object Detection in Aerial Images

Benchmarks for Object Detection in Aerial Images

Jian Ding 691 Dec 30, 2022
Research on Tabular Deep Learning (Python package & papers)

Research on Tabular Deep Learning For paper implementations, see the section "Papers and projects". rtdl is a PyTorch-based package providing a user-f

Yura Gorishniy 510 Dec 30, 2022
A multi-mode modulator for multi-domain few-shot classification (ICCV)

A multi-mode modulator for multi-domain few-shot classification (ICCV)

Yanbin Liu 8 Apr 28, 2022
Iowa Project - My second project done at General Assembly, focused on feature engineering and understanding Linear Regression as a concept

Project 2 - Ames Housing Data and Kaggle Challenge PROBLEM STATEMENT Inferring or Predicting? What's more valuable for a housing model? When creating

Adam Muhammad Klesc 1 Jan 03, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
Flexible Networks for Learning Physical Dynamics of Deformable Objects (2021)

Flexible Networks for Learning Physical Dynamics of Deformable Objects (2021) By Jinhyung Park, Dohae Lee, In-Kwon Lee from Yonsei University (Seoul,

Jinhyung Park 0 Jan 09, 2022
Sudoku solver - A sudoku solver with python

sudoku_solver A sudoku solver What is Sudoku? Sudoku (Japanese: 数独, romanized: s

Sikai Lu 0 May 22, 2022
Simple Tensorflow implementation of "Adaptive Convolutions for Structure-Aware Style Transfer" (CVPR 2021)

AdaConv — Simple TensorFlow Implementation [Paper] : Adaptive Convolutions for Structure-Aware Style Transfer (CVPR 2021) Note This repository does no

Junho Kim 26 Nov 18, 2022
This is a custom made virus code in python, using tkinter module.

skeleterrorBetaV0.1-Virus-code This is a custom made virus code in python, using tkinter module. This virus is not harmful to the computer, it only ma

AR 0 Nov 21, 2022
Official codebase for ICLR oral paper Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling

CLIORA This is the official codebase for ICLR oral paper: Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling. We introduce

Bo Wan 32 Dec 23, 2022
The BCNet related data and inference model.

BCNet This repository includes the some source code and related dataset of paper BCNet: Learning Body and Cloth Shape from A Single Image, ECCV 2020,

81 Dec 12, 2022
Deep learned, hardware-accelerated 3D object pose estimation

Isaac ROS Pose Estimation Overview This repository provides NVIDIA GPU-accelerated packages for 3D object pose estimation. Using a deep learned pose e

NVIDIA Isaac ROS 41 Dec 18, 2022
Allows including an action inside another action (by preprocessing the Yaml file). This is how composite actions should have worked.

actions-includes Allows including an action inside another action (by preprocessing the Yaml file). Instead of using uses or run in your action step,

Tim Ansell 70 Nov 04, 2022
OpenMMLab Model Deployment Toolset

Introduction English | 简体中文 MMDeploy is an open-source deep learning model deployment toolset. It is a part of the OpenMMLab project. Major features F

OpenMMLab 1.5k Dec 30, 2022
SemiNAS: Semi-Supervised Neural Architecture Search

SemiNAS: Semi-Supervised Neural Architecture Search This repository contains the code used for Semi-Supervised Neural Architecture Search, by Renqian

Renqian Luo 21 Aug 31, 2022
Github Traffic Insights as Prometheus metrics.

github-traffic Github Traffic collects your repository's traffic data and exposes it as Prometheus metrics. Grafana dashboard that displays the metric

Grafana Labs 34 Oct 27, 2022
IDM: An Intermediate Domain Module for Domain Adaptive Person Re-ID,

Intermediate Domain Module (IDM) This repository is the official implementation for IDM: An Intermediate Domain Module for Domain Adaptive Person Re-I

Yongxing Dai 87 Nov 22, 2022
Graph parsing approach to structured sentiment analysis.

Fine-grained Sentiment Analysis as Dependency Graph Parsing This repository contains the code and datasets described in following paper: Fine-grained

Jeremy Barnes 36 Dec 12, 2022