[1]:
# NBVAL_SKIP
import sys
sys.path.append('../')

import logging
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
logging.getLogger('graphein').setLevel(logging.INFO)

PSCDB - Baselines#

[2]:
# NBVAL_SKIP
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings("ignore")
RDKit WARNING: [23:36:27] Enabling RDKit 2019.09.3 jupyter extensions
[23:36:27] Enabling RDKit 2019.09.3 jupyter extensions

Load dataset#

[3]:
# NBVAL_SKIP
df = pd.read_csv("../datasets/pscdb/structural_rearrangement_data.csv")
pdbs = df["Free PDB"]
y = [torch.argmax(torch.Tensor(lab)).type(torch.LongTensor) for lab in LabelBinarizer().fit_transform(df.motion_type)]

Transformation from Raw Structure to ML-ready Datasets Construction with Graphein#

[4]:
# NBVAL_SKIP
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_hydrogen_bond_interactions, add_peptide_bonds, add_k_nn_edges
from graphein.protein.graphs import construct_graph

from functools import partial

# Override config with constructors
constructors = {
    "edge_construction_functions": [partial(add_k_nn_edges, k=3, long_interaction_threshold=0)],
    #"edge_construction_functions": [add_hydrogen_bond_interactions, add_peptide_bonds],
    #"node_metadata_functions": [add_dssp_feature]
}

config = ProteinGraphConfig(**constructors)
print(config.dict())

# Make graphs
graph_list = []
y_list = []
for idx, pdb in enumerate(tqdm(pdbs)):
    try:
        graph_list.append(
            construct_graph(pdb_code=pdb,
                        config=config
                       )
            )
        y_list.append(y[idx])
    except:
        print(str(idx) + ' processing error...')
        pass
{'granularity': 'CA', 'keep_hets': False, 'insertions': False, 'pdb_dir': PosixPath('../examples/pdbs'), 'verbose': False, 'exclude_waters': True, 'deprotonate': False, 'protein_df_processing_functions': None, 'edge_construction_functions': [functools.partial(<function add_k_nn_edges at 0x7fb4b22bbf70>, k=3, long_interaction_threshold=0)], 'node_metadata_functions': [<function meiler_embedding at 0x7fb4b22c7430>], 'edge_metadata_functions': None, 'graph_metadata_functions': None, 'get_contacts_config': None, 'dssp_config': None}
URL Error [Errno 110] Connection timed out
274 processing error...
666 processing error...
677 processing error...
[44]:
# NBVAL_SKIP
pdbs[274]
#pdbs[266]
#pdbs[677]
[44]:
'3e59'

Convert Nx graphs to PyTorch Geometric#

[8]:
# NBVAL_SKIP
from graphein.ml.conversion import GraphFormatConvertor

format_convertor = GraphFormatConvertor('nx', 'pyg',
                                        verbose = 'gnn',
                                        columns = None)
Using backend: pytorch
[9]:
# NBVAL_SKIP
pyg_list = [format_convertor(graph) for graph in tqdm(graph_list)]
[10]:
# NBVAL_SKIP
for idx, g in enumerate(pyg_list):
    g.y = y_list[idx]
    g.coords = torch.FloatTensor(g.coords[0])
[11]:
# NBVAL_SKIP
for i in pyg_list:
    if i.coords.shape[0] == len(i.node_id):
        pass
    else:
        print(i)
        pyg_list.remove(i)
Data(coords=[10112, 3], dist_mat=[1], edge_index=[2, 120], name=[1], node_id=[1264], y=1)
Data(coords=[820, 3], dist_mat=[1], edge_index=[2, 1431], name=[1], node_id=[808], y=2)
Data(coords=[668, 3], dist_mat=[1], edge_index=[2, 1166], name=[1], node_id=[666], y=4)
Data(coords=[2720, 3], dist_mat=[1], edge_index=[2, 3], name=[1], node_id=[340], y=5)

Model Configuration#

[32]:
# NBVAL_SKIP
config_default = dict(
    n_hid = 8,
    n_out = 8,
    batch_size = 4,
    dropout = 0.5,
    lr = 0.001,
    num_heads = 32,
    num_att_dim = 64,
    model_name = 'GCN'
)

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

config = Struct(**config_default)

global model_name
model_name = config.model_name

Construct DataLoaders#

[33]:
# NBVAL_SKIP
import numpy as np
np.random.seed(42)
idx_all = np.arange(len(pyg_list))
np.random.shuffle(idx_all)

train_idx, valid_idx, test_idx = np.split(idx_all, [int(.8*len(pyg_list)), int(.9*len(pyg_list))])
train, valid, test = [pyg_list[i] for i in train_idx], [pyg_list[i] for i in valid_idx], [pyg_list[i] for i in test_idx]

from torch_geometric.data import DataLoader
train_loader = DataLoader(train, batch_size=config.batch_size, shuffle = True, drop_last = True)
valid_loader = DataLoader(valid, batch_size=32)
test_loader = DataLoader(test, batch_size=32)
[34]:
# NBVAL_SKIP
pyg_list[0]
[34]:
Data(coords=[635, 3], dist_mat=[1], edge_index=[2, 1118], name=[1], node_id=[635], y=1)

Define Model#

[35]:
# NBVAL_SKIP
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_add_pool
from torch.nn.functional import mse_loss, nll_loss, relu, softmax, cross_entropy
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy
[36]:
# NBVAL_SKIP
class GraphNets(pl.LightningModule):
    def __init__(self):
        super().__init__()

        if model_name == 'GCN':
            self.layer1 = GCNConv(in_channels=3, out_channels=config.n_hid)
            self.layer2 = GCNConv(in_channels=config.n_hid, out_channels=config.n_out)

        elif model_name == 'GAT':
            self.layer1 = GATConv(3, config.num_att_dim, heads=config.num_heads, dropout=config.dropout)
            self.layer2 = GATConv(config.num_att_dim * config.num_heads, out_channels = config.n_out, heads=1, concat=False,
                                 dropout=config.dropout)

        elif model_name == 'GraphSAGE':
            self.layer1 = SAGEConv(3, config.n_hid)
            self.layer2 = SAGEConv(config.n_hid, config.n_out)

        self.decoder = nn.Linear(config.n_out, 7)

    def forward(self, g):
        x = g.coords
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = F.elu(self.layer1(x, g.edge_index))
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = self.layer2(x, g.edge_index)
        x = global_add_pool(x, batch=g.batch)
        x = self.decoder(x)
        return softmax(x)

    def training_step(self, batch, batch_idx):
        x = batch
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        self.log("valid_loss", loss)
        self.log("valid_acc", acc)

    def test_step(self, batch, batch_idx):
        x = batch
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        y_pred_softmax = torch.log_softmax(y_hat, dim = 1)
        y_pred_tags = torch.argmax(y_pred_softmax, dim = 1)
        f1 = f1_score(y.detach().cpu().numpy(), y_pred_tags.detach().cpu().numpy(), average = 'weighted')

        self.log("test_loss", loss)
        self.log("test_acc", acc)
        self.log("test_f1", f1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        return optimizer
[37]:
# NBVAL_SKIP
GraphNets()
[37]:
GraphNets(
  (layer1): GCNConv(3, 8)
  (layer2): GCNConv(8, 8)
  (decoder): Linear(in_features=8, out_features=7, bias=True)
)
[40]:
# NBVAL_SKIP
from pytorch_lightning.callbacks import ModelCheckpoint
import os

file_path = './graphein_model'
if not os.path.exists(file_path):
    os.mkdir(file_path)

checkpoint_callback = ModelCheckpoint(
    monitor="valid_loss",
    dirpath=file_path,
    filename="model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min",
)

Train!#

[41]:
# NBVAL_SKIP
# Train Model
model = GraphNets()
trainer = pl.Trainer(max_epochs=200, gpus=-1, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, valid_loader)

# evaluate on the model with the best validation set
best_model = GraphNets.load_from_checkpoint(checkpoint_callback.best_model_path)
out_best_test = trainer.test(best_model, test_loader)[0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GraphNets' on <module '__main__' (built-in)>
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_3426530/3685191286.py in <module>
      2 model = GraphNets()
      3 trainer = pl.Trainer(max_epochs=200, gpus=-1, callbacks=[checkpoint_callback])
----> 4 trainer.fit(model, train_loader, valid_loader)
      5
      6 # evaluate on the model with the best validation set

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
    551         self.checkpoint_connector.resume_start()
    552
--> 553         self._run(model)
    554
    555         assert self.state.stopped

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    916
    917         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 918         self._dispatch()
    919
    920         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
    984             self.accelerator.start_predicting(self)
    985         else:
--> 986             self.accelerator.start_training(self)
    987
    988     def run_stage(self):

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     90
     91     def start_training(self, trainer: "pl.Trainer") -> None:
---> 92         self.training_type_plugin.start_training(trainer)
     93
     94     def start_evaluating(self, trainer: "pl.Trainer") -> None:

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py in start_training(self, trainer)
    156
    157     def start_training(self, trainer):
--> 158         mp.spawn(self.new_process, **self.mp_spawn_kwargs)
    159         # reset optimizers, since main process is never used for training and thus does not have a valid optim state
    160         trainer.optimizers = []

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
    228                ' torch.multiprocessing.start_processes(...)' % start_method)
    229         warnings.warn(msg)
--> 230     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')

~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    177             daemon=daemon,
    178         )
--> 179         process.start()
    180         error_queues.append(error_queue)
    181         processes.append(process)

~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/process.py in start(self)
    119                'daemonic processes are not allowed to have children'
    120         _cleanup()
--> 121         self._popen = self._Popen(self)
    122         self._sentinel = self._popen.sentinel
    123         # Avoid a refcycle if the target function holds an indirect

~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/context.py in _Popen(process_obj)
    282         def _Popen(process_obj):
    283             from .popen_spawn_posix import Popen
--> 284             return Popen(process_obj)
    285
    286     class ForkServerProcess(process.BaseProcess):

~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_spawn_posix.py in __init__(self, process_obj)
     30     def __init__(self, process_obj):
     31         self._fds = []
---> 32         super().__init__(process_obj)
     33
     34     def duplicate_for_child(self, fd):

~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_fork.py in __init__(self, process_obj)
     17         self.returncode = None
     18         self.finalizer = None
---> 19         self._launch(process_obj)
     20
     21     def duplicate_for_child(self, fd):

~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_spawn_posix.py in _launch(self, process_obj)
     60             self.sentinel = parent_r
     61             with open(parent_w, 'wb', closefd=False) as f:
---> 62                 f.write(fp.getbuffer())
     63         finally:
     64             fds_to_close = []

KeyboardInterrupt: