"""Base Config object for use with Molecule Graph Construction."""
# Graphein
# Author: Arian Jamasb <arian@jamasb.io>
# License: MIT
# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein
from __future__ import annotations
from functools import partial
from pathlib import Path
from typing import Any, Callable, List, Optional, Union
from deepdiff import DeepDiff
from pydantic import BaseModel
from typing_extensions import Literal
from graphein.molecule.edges.atomic import add_atom_bonds
from graphein.molecule.edges.distance import (
add_distance_threshold,
add_fully_connected_edges,
add_k_nn_edges,
)
from graphein.molecule.features.nodes.atom_type import atom_type_one_hot
from graphein.utils.config import PartialMatchOperator, PathMatchOperator
GraphAtoms = Literal[
"C",
"H",
"O",
"N",
"F",
"P",
"S",
"Cl",
"Br",
"I",
"B",
]
"""Allowable atom types for nodes in the graph."""
[docs]class MoleculeGraphConfig(BaseModel):
"""
Config Object for Molecule Structure Graph Construction.
:param verbose: Specifies verbosity of graph creation process.
:type verbose: bool
:param add_hs: Specifies whether hydrogens should be added to the graph.
:type add_hs: bool
:param edge_construction_functions: List of functions that take an ``nx.Graph`` and return an ``nx.Graph`` with desired
edges added. Prepared edge constructions can be found in :ref:`graphein.protein.edges`
:type edge_construction_functions: List[Callable]
:param node_metadata_functions: List of functions that take an ``nx.Graph``
:type node_metadata_functions: List[Callable], optional
:param edge_metadata_functions: List of functions that take an
:type edge_metadata_functions: List[Callable], optional
:param graph_metadata_functions: List of functions that take an ``nx.Graph`` and return an ``nx.Graph`` with added
graph-level features and metadata.
:type graph_metadata_functions: List[Callable], optional
"""
verbose: bool = False
add_hs: bool = False
# Graph construction functions
edge_construction_functions: List[Union[Callable, str]] = [
add_fully_connected_edges,
add_k_nn_edges,
add_distance_threshold,
add_atom_bonds,
]
node_metadata_functions: Optional[List[Union[Callable, str]]] = [
atom_type_one_hot
]
edge_metadata_functions: Optional[List[Union[Callable, str]]] = None
graph_metadata_functions: Optional[List[Callable]] = None
def __eq__(self, other: Any) -> bool:
"""Overwrites the BaseModel __eq__ function in order to check more specific cases (like partial functions)."""
if isinstance(other, MoleculeGraphConfig):
return (
DeepDiff(
self,
other,
custom_operators=[
PartialMatchOperator(types=[partial]),
PathMatchOperator(types=[Path]),
],
)
== {}
)
return self.dict() == other