[1]:
# NBVAL_SKIP
import sys
sys.path.append('../')
import logging
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
logging.getLogger('graphein').setLevel(logging.INFO)
PSCDB - Baselines#
[2]:
# NBVAL_SKIP
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings("ignore")
RDKit WARNING: [23:36:27] Enabling RDKit 2019.09.3 jupyter extensions
[23:36:27] Enabling RDKit 2019.09.3 jupyter extensions
Load dataset#
[3]:
# NBVAL_SKIP
df = pd.read_csv("../datasets/pscdb/structural_rearrangement_data.csv")
pdbs = df["Free PDB"]
y = [torch.argmax(torch.Tensor(lab)).type(torch.LongTensor) for lab in LabelBinarizer().fit_transform(df.motion_type)]
Transformation from Raw Structure to ML-ready Datasets Construction with Graphein#
[4]:
# NBVAL_SKIP
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_hydrogen_bond_interactions, add_peptide_bonds, add_k_nn_edges
from graphein.protein.graphs import construct_graph
from functools import partial
# Override config with constructors
constructors = {
"edge_construction_functions": [partial(add_k_nn_edges, k=3, long_interaction_threshold=0)],
#"edge_construction_functions": [add_hydrogen_bond_interactions, add_peptide_bonds],
#"node_metadata_functions": [add_dssp_feature]
}
config = ProteinGraphConfig(**constructors)
print(config.dict())
# Make graphs
graph_list = []
y_list = []
for idx, pdb in enumerate(tqdm(pdbs)):
try:
graph_list.append(
construct_graph(pdb_code=pdb,
config=config
)
)
y_list.append(y[idx])
except:
print(str(idx) + ' processing error...')
pass
{'granularity': 'CA', 'keep_hets': False, 'insertions': False, 'pdb_dir': PosixPath('../examples/pdbs'), 'verbose': False, 'exclude_waters': True, 'deprotonate': False, 'protein_df_processing_functions': None, 'edge_construction_functions': [functools.partial(<function add_k_nn_edges at 0x7fb4b22bbf70>, k=3, long_interaction_threshold=0)], 'node_metadata_functions': [<function meiler_embedding at 0x7fb4b22c7430>], 'edge_metadata_functions': None, 'graph_metadata_functions': None, 'get_contacts_config': None, 'dssp_config': None}
URL Error [Errno 110] Connection timed out
274 processing error...
666 processing error...
677 processing error...
[44]:
# NBVAL_SKIP
pdbs[274]
#pdbs[266]
#pdbs[677]
[44]:
'3e59'
Convert Nx graphs to PyTorch Geometric#
[8]:
# NBVAL_SKIP
from graphein.ml.conversion import GraphFormatConvertor
format_convertor = GraphFormatConvertor('nx', 'pyg',
verbose = 'gnn',
columns = None)
Using backend: pytorch
[9]:
# NBVAL_SKIP
pyg_list = [format_convertor(graph) for graph in tqdm(graph_list)]
[10]:
# NBVAL_SKIP
for idx, g in enumerate(pyg_list):
g.y = y_list[idx]
g.coords = torch.FloatTensor(g.coords[0])
[11]:
# NBVAL_SKIP
for i in pyg_list:
if i.coords.shape[0] == len(i.node_id):
pass
else:
print(i)
pyg_list.remove(i)
Data(coords=[10112, 3], dist_mat=[1], edge_index=[2, 120], name=[1], node_id=[1264], y=1)
Data(coords=[820, 3], dist_mat=[1], edge_index=[2, 1431], name=[1], node_id=[808], y=2)
Data(coords=[668, 3], dist_mat=[1], edge_index=[2, 1166], name=[1], node_id=[666], y=4)
Data(coords=[2720, 3], dist_mat=[1], edge_index=[2, 3], name=[1], node_id=[340], y=5)
Model Configuration#
[32]:
# NBVAL_SKIP
config_default = dict(
n_hid = 8,
n_out = 8,
batch_size = 4,
dropout = 0.5,
lr = 0.001,
num_heads = 32,
num_att_dim = 64,
model_name = 'GCN'
)
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
config = Struct(**config_default)
global model_name
model_name = config.model_name
Construct DataLoaders#
[33]:
# NBVAL_SKIP
import numpy as np
np.random.seed(42)
idx_all = np.arange(len(pyg_list))
np.random.shuffle(idx_all)
train_idx, valid_idx, test_idx = np.split(idx_all, [int(.8*len(pyg_list)), int(.9*len(pyg_list))])
train, valid, test = [pyg_list[i] for i in train_idx], [pyg_list[i] for i in valid_idx], [pyg_list[i] for i in test_idx]
from torch_geometric.data import DataLoader
train_loader = DataLoader(train, batch_size=config.batch_size, shuffle = True, drop_last = True)
valid_loader = DataLoader(valid, batch_size=32)
test_loader = DataLoader(test, batch_size=32)
[34]:
# NBVAL_SKIP
pyg_list[0]
[34]:
Data(coords=[635, 3], dist_mat=[1], edge_index=[2, 1118], name=[1], node_id=[635], y=1)
Define Model#
[35]:
# NBVAL_SKIP
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_add_pool
from torch.nn.functional import mse_loss, nll_loss, relu, softmax, cross_entropy
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy
[36]:
# NBVAL_SKIP
class GraphNets(pl.LightningModule):
def __init__(self):
super().__init__()
if model_name == 'GCN':
self.layer1 = GCNConv(in_channels=3, out_channels=config.n_hid)
self.layer2 = GCNConv(in_channels=config.n_hid, out_channels=config.n_out)
elif model_name == 'GAT':
self.layer1 = GATConv(3, config.num_att_dim, heads=config.num_heads, dropout=config.dropout)
self.layer2 = GATConv(config.num_att_dim * config.num_heads, out_channels = config.n_out, heads=1, concat=False,
dropout=config.dropout)
elif model_name == 'GraphSAGE':
self.layer1 = SAGEConv(3, config.n_hid)
self.layer2 = SAGEConv(config.n_hid, config.n_out)
self.decoder = nn.Linear(config.n_out, 7)
def forward(self, g):
x = g.coords
x = F.dropout(x, p=config.dropout, training=self.training)
x = F.elu(self.layer1(x, g.edge_index))
x = F.dropout(x, p=config.dropout, training=self.training)
x = self.layer2(x, g.edge_index)
x = global_add_pool(x, batch=g.batch)
x = self.decoder(x)
return softmax(x)
def training_step(self, batch, batch_idx):
x = batch
y = x.y
y_hat = self(x)
loss = cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
x = batch
y = x.y
y_hat = self(x)
loss = cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
self.log("valid_loss", loss)
self.log("valid_acc", acc)
def test_step(self, batch, batch_idx):
x = batch
y = x.y
y_hat = self(x)
loss = cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
y_pred_softmax = torch.log_softmax(y_hat, dim = 1)
y_pred_tags = torch.argmax(y_pred_softmax, dim = 1)
f1 = f1_score(y.detach().cpu().numpy(), y_pred_tags.detach().cpu().numpy(), average = 'weighted')
self.log("test_loss", loss)
self.log("test_acc", acc)
self.log("test_f1", f1)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
return optimizer
[37]:
# NBVAL_SKIP
GraphNets()
[37]:
GraphNets(
(layer1): GCNConv(3, 8)
(layer2): GCNConv(8, 8)
(decoder): Linear(in_features=8, out_features=7, bias=True)
)
[40]:
# NBVAL_SKIP
from pytorch_lightning.callbacks import ModelCheckpoint
import os
file_path = './graphein_model'
if not os.path.exists(file_path):
os.mkdir(file_path)
checkpoint_callback = ModelCheckpoint(
monitor="valid_loss",
dirpath=file_path,
filename="model-{epoch:02d}-{val_loss:.2f}",
save_top_k=1,
mode="min",
)
Train!#
[41]:
# NBVAL_SKIP
# Train Model
model = GraphNets()
trainer = pl.Trainer(max_epochs=200, gpus=-1, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, valid_loader)
# evaluate on the model with the best validation set
best_model = GraphNets.load_from_checkpoint(checkpoint_callback.best_model_path)
out_best_test = trainer.test(best_model, test_loader)[0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GraphNets' on <module '__main__' (built-in)>
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/tmp/ipykernel_3426530/3685191286.py in <module>
2 model = GraphNets()
3 trainer = pl.Trainer(max_epochs=200, gpus=-1, callbacks=[checkpoint_callback])
----> 4 trainer.fit(model, train_loader, valid_loader)
5
6 # evaluate on the model with the best validation set
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
551 self.checkpoint_connector.resume_start()
552
--> 553 self._run(model)
554
555 assert self.state.stopped
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
916
917 # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 918 self._dispatch()
919
920 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
984 self.accelerator.start_predicting(self)
985 else:
--> 986 self.accelerator.start_training(self)
987
988 def run_stage(self):
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
90
91 def start_training(self, trainer: "pl.Trainer") -> None:
---> 92 self.training_type_plugin.start_training(trainer)
93
94 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py in start_training(self, trainer)
156
157 def start_training(self, trainer):
--> 158 mp.spawn(self.new_process, **self.mp_spawn_kwargs)
159 # reset optimizers, since main process is never used for training and thus does not have a valid optim state
160 trainer.optimizers = []
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
228 ' torch.multiprocessing.start_processes(...)' % start_method)
229 warnings.warn(msg)
--> 230 return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
~/anaconda3/envs/graphein-dev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
177 daemon=daemon,
178 )
--> 179 process.start()
180 error_queues.append(error_queue)
181 processes.append(process)
~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/process.py in start(self)
119 'daemonic processes are not allowed to have children'
120 _cleanup()
--> 121 self._popen = self._Popen(self)
122 self._sentinel = self._popen.sentinel
123 # Avoid a refcycle if the target function holds an indirect
~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/context.py in _Popen(process_obj)
282 def _Popen(process_obj):
283 from .popen_spawn_posix import Popen
--> 284 return Popen(process_obj)
285
286 class ForkServerProcess(process.BaseProcess):
~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_spawn_posix.py in __init__(self, process_obj)
30 def __init__(self, process_obj):
31 self._fds = []
---> 32 super().__init__(process_obj)
33
34 def duplicate_for_child(self, fd):
~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_fork.py in __init__(self, process_obj)
17 self.returncode = None
18 self.finalizer = None
---> 19 self._launch(process_obj)
20
21 def duplicate_for_child(self, fd):
~/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/popen_spawn_posix.py in _launch(self, process_obj)
60 self.sentinel = parent_r
61 with open(parent_w, 'wb', closefd=False) as f:
---> 62 f.write(fp.getbuffer())
63 finally:
64 fds_to_close = []
KeyboardInterrupt: