| | """Utilities for converting Graphein Networks to Geometric Deep Learning formats. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | from __future__ import annotations |
| |
|
| | from typing import List, Optional |
| |
|
| | import networkx as nx |
| | import numpy as np |
| | import torch |
| |
|
| | try: |
| | from graphein.utils.dependencies import import_message |
| | except ImportError: |
| | raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html') |
| |
|
| | try: |
| | import torch_geometric |
| | from torch_geometric.data import Data |
| | except ImportError: |
| | import_message( |
| | submodule="graphein.ml.conversion", |
| | package="torch_geometric", |
| | pip_install=True, |
| | conda_channel="rusty1s", |
| | ) |
| |
|
| | try: |
| | import dgl |
| | except ImportError: |
| | import_message( |
| | submodule="graphein.ml.conversion", |
| | package="dgl", |
| | pip_install=True, |
| | conda_channel="dglteam", |
| | ) |
| |
|
| | try: |
| | import jax.numpy as jnp |
| | except ImportError: |
| | import_message( |
| | submodule="graphein.ml.conversion", |
| | package="jax", |
| | pip_install=True, |
| | conda_channel="conda-forge", |
| | ) |
| | try: |
| | import jraph |
| | except ImportError: |
| | import_message( |
| | submodule="graphein.ml.conversion", |
| | package="jraph", |
| | pip_install=True, |
| | conda_channel="conda-forge", |
| | ) |
| |
|
| |
|
| | SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"] |
| | """Supported conversion formats. |
| | |
| | ``"nx"``: NetworkX graph |
| | |
| | ``"pyg"``: PyTorch Geometric Data object |
| | |
| | ``"dgl"``: DGL graph |
| | |
| | ``"Jraph"``: Jraph GraphsTuple |
| | """ |
| |
|
| | SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"] |
| | """Supported verbosity levels for preserving graph features in conversion.""" |
| |
|
| |
|
| | class GraphFormatConvertor: |
| | """ |
| | Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats. |
| | Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion |
| | formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`. |
| | |
| | :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS` |
| | :type src_format: Literal["nx", "pyg", "dgl", "jraph"] |
| | :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in: |
| | ``graphein.ml.conversion.SUPPORTED_FORMATS`` |
| | :type dst_format: Literal["nx", "pyg", "dgl", "jraph"] |
| | :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features) |
| | as some are unsupported by various downstream frameworks |
| | :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY |
| | :param columns: List of columns in the node features to retain |
| | :type columns: List[str], optional |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | src_format: str, |
| | dst_format: str, |
| | verbose: SUPPORTED_VERBOSITY = "gnn", |
| | columns: Optional[List[str]] = None, |
| | ): |
| | if (src_format not in SUPPORTED_FORMATS) or ( |
| | dst_format not in SUPPORTED_FORMATS |
| | ): |
| | raise ValueError( |
| | "Please specify from supported format, " |
| | + "/".join(SUPPORTED_FORMATS) |
| | ) |
| | self.src_format = src_format |
| | self.dst_format = dst_format |
| |
|
| | |
| | if (columns is None) and (verbose not in SUPPORTED_VERBOSITY): |
| | raise ValueError( |
| | "Please specify the supported verbose mode (" |
| | + "/".join(SUPPORTED_VERBOSITY) |
| | + ") or specify column names!" |
| | ) |
| |
|
| | if columns is None: |
| | if verbose == "gnn": |
| | columns = [ |
| | "edge_index", |
| | "coords", |
| | "dist_mat", |
| | "name", |
| | "node_id", |
| | ] |
| | elif verbose == "default": |
| | columns = [ |
| | "b_factor", |
| | "chain_id", |
| | "coords", |
| | "dist_mat", |
| | "edge_index", |
| | "kind", |
| | "name", |
| | "node_id", |
| | "residue_name", |
| | ] |
| | elif verbose == "all_info": |
| | columns = [ |
| | "atom_type", |
| | "b_factor", |
| | "chain_id", |
| | "chain_ids", |
| | "config", |
| | "coords", |
| | "dist_mat", |
| | "edge_index", |
| | "element_symbol", |
| | "kind", |
| | "name", |
| | "node_id", |
| | "node_type", |
| | "pdb_df", |
| | "raw_pdb_df", |
| | "residue_name", |
| | "residue_number", |
| | "rgroup_df", |
| | "sequence_A", |
| | "sequence_B", |
| | ] |
| | self.columns = columns |
| |
|
| | self.type2form = { |
| | "atom_type": "str", |
| | "b_factor": "float", |
| | "chain_id": "str", |
| | "coords": "np.array", |
| | "dist_mat": "np.array", |
| | "element_symbol": "str", |
| | "node_id": "str", |
| | "residue_name": "str", |
| | "residue_number": "int", |
| | "edge_index": "torch.tensor", |
| | "kind": "str", |
| | } |
| |
|
| | def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph: |
| | """ |
| | Converts ``NetworkX`` graph to ``DGL`` |
| | |
| | :param G: ``nx.Graph`` to convert to ``DGLGraph`` |
| | :type G: nx.Graph |
| | :return: ``DGLGraph`` object version of input ``NetworkX`` graph |
| | :rtype: dgl.DGLGraph |
| | """ |
| | g = dgl.DGLGraph() |
| | node_id = list(G.nodes()) |
| | G = nx.convert_node_labels_to_integers(G) |
| |
|
| | |
| |
|
| | node_dict = {} |
| | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | node_dict[str(key)] = ( |
| | [value] if i == 0 else node_dict[str(key)] + [value] |
| | ) |
| |
|
| | string_dict = {} |
| | node_dict_transformed = {} |
| | for i, j in node_dict.items(): |
| | if i == "coords": |
| | node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type( |
| | "torch.FloatTensor" |
| | ) |
| | elif i == "dist_mat": |
| | node_dict_transformed[i] = torch.Tensor( |
| | np.asarray(j[0].values) |
| | ).type("torch.FloatTensor") |
| | elif self.type2form[i] == "str": |
| | string_dict[i] = j |
| | elif self.type2form[i] in ["float", "int"]: |
| | node_dict_transformed[i] = torch.Tensor(np.array(j)) |
| | g.add_nodes( |
| | len(node_id), |
| | node_dict_transformed, |
| | ) |
| |
|
| | edge_dict = {} |
| | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
| |
|
| | |
| | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | edge_dict[str(key)] = ( |
| | list(value) |
| | if i == 0 |
| | else edge_dict[str(key)] + list(value) |
| | ) |
| |
|
| | edge_transform_dict = {} |
| | for i, j in node_dict.items(): |
| | if self.type2form[i] == "str": |
| | string_dict[i] = j |
| | elif self.type2form[i] in ["float", "int"]: |
| | edge_transform_dict[i] = torch.Tensor(np.array(j)) |
| | g.add_edges(edge_index[0], edge_index[1], edge_transform_dict) |
| |
|
| | |
| | graph_dict = { |
| | str(feat_name): [G.graph[feat_name]] |
| | for feat_name in G.graph |
| | if str(feat_name) in self.columns |
| | } |
| |
|
| | return g |
| |
|
| | def convert_nx_to_pyg(self, G: nx.Graph) -> Data: |
| | """ |
| | Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed. |
| | |
| | :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object |
| | :type G: nx.Graph |
| | :return: ``Data`` object containing networkx graph data |
| | :rtype: pytorch_geometric.data.Data |
| | """ |
| |
|
| | |
| | data = {"node_id": list(G.nodes())} |
| | G = nx.convert_node_labels_to_integers(G) |
| |
|
| | |
| | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
| |
|
| | |
| | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | data[str(key)] = ( |
| | [value] if i == 0 else data[str(key)] + [value] |
| | ) |
| |
|
| | |
| | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | data[str(key)] = ( |
| | list(value) if i == 0 else data[str(key)] + list(value) |
| | ) |
| |
|
| | |
| | for feat_name in G.graph: |
| | if str(feat_name) in self.columns: |
| | data[str(feat_name)] = [G.graph[feat_name]] |
| |
|
| | if "edge_index" in self.columns: |
| | data["edge_index"] = edge_index.view(2, -1) |
| |
|
| | data = Data.from_dict(data) |
| | data.num_nodes = G.number_of_nodes() |
| | return data |
| |
|
| | @staticmethod |
| | def convert_nx_to_nx(G: nx.Graph) -> nx.Graph: |
| | """ |
| | Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself. |
| | |
| | :param G: NetworkX Graph |
| | :type G: nx.Graph |
| | :return: NetworkX Graph |
| | :rtype: nx.Graph |
| | """ |
| | return G |
| |
|
| | @staticmethod |
| | def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph: |
| | """ |
| | Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes. |
| | |
| | :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph. |
| | :type G: dgl.DGLGraph |
| | :return: NetworkX graph object. |
| | :rtype: nx.Graph |
| | """ |
| | node_attrs = G.node_attr_schemes().keys() |
| | edge_attrs = G.edge_attr_schemes().keys() |
| | return dgl.to_networkx(G, node_attrs, edge_attrs) |
| |
|
| | @staticmethod |
| | def convert_pyg_to_nx(G: Data) -> nx.Graph: |
| | """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``). |
| | |
| | :param G: Pytorch Geometric Data. |
| | :type G: torch_geometric.data.Data |
| | :returns: NetworkX graph. |
| | :rtype: nx.Graph |
| | """ |
| | return torch_geometric.utils.to_networkx(G) |
| |
|
| | def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple: |
| | """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``. |
| | |
| | :param G: Networkx graph to convert. |
| | :type G: nx.Graph |
| | :return: Jraph GraphsTuple graph. |
| | :rtype: jraph.GraphsTuple |
| | """ |
| | G = nx.convert_node_labels_to_integers(G) |
| |
|
| | n_node = len(G) |
| | n_edge = G.number_of_edges() |
| | edge_list = list(G.edges()) |
| | senders, receivers = zip(*edge_list) |
| | senders, receivers = jnp.array(senders), jnp.array(receivers) |
| |
|
| | |
| | node_features = {} |
| | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | |
| | |
| | |
| | |
| | |
| | feat = ( |
| | [value] |
| | if i == 0 |
| | else node_features[str(key)] + [value] |
| | ) |
| | try: |
| | feat = torch.tensor(feat) |
| | node_features[str(key)] = feat |
| | except TypeError: |
| | node_features[str(key)] = feat |
| |
|
| | |
| | edge_features = {} |
| | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| | for key, value in feat_dict.items(): |
| | if str(key) in self.columns: |
| | edge_features[str(key)] = ( |
| | list(value) |
| | if i == 0 |
| | else edge_features[str(key)] + list(value) |
| | ) |
| |
|
| | |
| | global_context = { |
| | str(feat_name): [G.graph[feat_name]] |
| | for feat_name in G.graph |
| | if str(feat_name) in self.columns |
| | } |
| |
|
| | return jraph.GraphsTuple( |
| | nodes=node_features, |
| | senders=senders, |
| | receivers=receivers, |
| | edges=edge_features, |
| | n_node=n_node, |
| | n_edge=n_edge, |
| | globals=global_context, |
| | ) |
| |
|
| | def __call__(self, G: nx.Graph): |
| | nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)") |
| | dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)") |
| | return dst_g |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | def convert_nx_to_pyg_data(G: nx.Graph) -> Data: |
| | |
| | data = {"node_id": list(G.nodes())} |
| |
|
| | G = nx.convert_node_labels_to_integers(G) |
| |
|
| | |
| | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
| |
|
| | |
| | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| | for key, value in feat_dict.items(): |
| | data[str(key)] = [value] if i == 0 else data[str(key)] + [value] |
| |
|
| |
|
| | |
| | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| | for key, value in feat_dict.items(): |
| | if key == 'distance': |
| | data[str(key)] = ( |
| | [value] if i == 0 else data[str(key)] + [value] |
| | ) |
| | else: |
| | data[str(key)] = ( |
| | [list(value)] if i == 0 else data[str(key)] + [list(value)] |
| | ) |
| |
|
| | |
| | for feat_name in G.graph: |
| | data[str(feat_name)] = [G.graph[feat_name]] |
| |
|
| | data["edge_index"] = edge_index.view(2, -1) |
| | data = Data.from_dict(data) |
| | data.num_nodes = G.number_of_nodes() |
| |
|
| | return data |
| |
|