diff --git a/examples/train-gnn-dihedrals/.gitignore b/examples/train-gnn-dihedrals/.gitignore new file mode 100644 index 00000000..b8f2783d --- /dev/null +++ b/examples/train-gnn-dihedrals/.gitignore @@ -0,0 +1,3 @@ +lightning_logs +tmp.pkl +api.qcarchive.molssi.org_443 \ No newline at end of file diff --git a/examples/train-gnn-dihedrals/train-gnn-central-bond.ipynb b/examples/train-gnn-dihedrals/train-gnn-central-bond.ipynb new file mode 100644 index 00000000..d18e72e3 --- /dev/null +++ b/examples/train-gnn-dihedrals/train-gnn-central-bond.ipynb @@ -0,0 +1,7194 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "94ff1817-f641-4a11-ae4a-91c7a6c73a7e", + "metadata": {}, + "source": [ + "# Train a GNN directly to predict torsion energies\n", + "\n", + "**Note: this is an *experimental* notebook only to demonstrate a proof-of-concept.\n", + "While some parts of this notebook may eventually be fully supported by OpenFF-NAGL, the general conclusion arrived at is that it is likely easiest to work at this solution outside of the NAGL framework.**\n", + "\n", + "To execute this example fully, the following packages are required.\n", + "\n", + "* openff-nagl\n", + "* openff-recharge\n", + "* openff-qcsubmit\n", + "* psi4\n", + "\n", + "However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just `openff-nagl` installed and simply load the training/validation data from the provided `.parquet` files. The commands are provided at the end of the \"Generate and format training data\" section, but commented out." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "522bc606-c1aa-4cc7-89da-87440c61f8ba", + "metadata": {}, + "outputs": [], + "source": [ + "import collections\n", + "import tqdm\n", + "\n", + "from qcportal import PortalClient\n", + "from openff.units import unit\n", + "\n", + "from openff.toolkit import Molecule, ForceField\n", + "from openff.qcsubmit.results import BasicResultCollection\n", + "from openff.recharge.esp.storage import MoleculeESPRecord\n", + "from openff.recharge.esp.qcresults import from_qcportal_results\n", + "from openff.recharge.grids import MSKGridSettings\n", + "from openff.recharge.utilities.geometry import compute_vector_field\n", + "\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import MDAnalysis as mda" + ] + }, + { + "cell_type": "markdown", + "id": "097895ad-685b-4825-8c22-1a0f79467e9d", + "metadata": {}, + "source": [ + "## Generate and format training data\n", + "\n", + "\n", + "### Downloading from QCArchive\n", + "First, we will create training data. We'll download a smaller training set for the purposes of this example." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "964318ce-9e08-4551-9af4-dc49ea89d027", + "metadata": {}, + "outputs": [], + "source": [ + "# qc_client = PortalClient(\"https://api.qcarchive.molssi.org:443\", cache_dir=\".\")\n", + "\n", + "# # download dataset from QCPortal\n", + "# br_esps_collection = BasicResultCollection.from_server(\n", + "# client=qc_client,\n", + "# datasets=\"OpenFF multi-Br ESP Fragment Conformers v1.1\",\n", + "# spec_name=\"HF/6-31G*\",\n", + "# )\n", + "\n", + "# records_and_molecules = br_esps_collection.to_records()" + ] + }, + { + "cell_type": "markdown", + "id": "9dc86d0a-c59e-4f8f-93c7-3b74adb6636e", + "metadata": {}, + "source": [ + "### Convert to PyArrow dataset\n", + "\n", + "NAGL reads in and trains to data from [PyArrow tables](https://arrow.apache.org/docs/python/getstarted.html#creating-arrays-and-tables). Below we create some easy data by using an existing force field to assign torsion energies to each individual torsion, which then gets summed over the central bond.\n", + "\n", + "*Note: the maths could probably use double-checking.*" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a0690063-59fa-4863-9ce5-85a7f511d20d", + "metadata": {}, + "outputs": [], + "source": [ + "def calc_torsion_energy(angle, parameter):\n", + " angle = (angle * unit.degrees).m_as(unit.radians)\n", + " total = 0 * unit.kilojoules_per_mole\n", + " for k, phase, periodicity in zip(parameter.k, parameter.phase, parameter.periodicity):\n", + " phase = phase.m_as(unit.radians)\n", + " subtotal = k * (1 + np.cos(periodicity * angle - phase))\n", + " total += subtotal\n", + " return total\n", + "\n", + "\n", + "def get_central_bond_torsions(mol, forcefield):\n", + " labels = forcefield.label_molecules(mol.to_topology())[0][\"ProperTorsions\"]\n", + " u = mda.Universe(mol.to_rdkit())\n", + "\n", + " bonds_to_dihedrals = collections.defaultdict(list)\n", + " for key in labels:\n", + " center = tuple(sorted(key[1:3]))\n", + " bonds_to_dihedrals[center].append(key)\n", + "\n", + " # sort bonds...\n", + " energies = []\n", + " keys = sorted([\n", + " tuple(sorted([bond.atom1_index, bond.atom2_index]))\n", + " for bond in mol.bonds\n", + " ])\n", + " for key in keys:\n", + " dihedrals = bonds_to_dihedrals[key]\n", + " \n", + " energy = 0 * unit.kilojoules_per_mole\n", + " for dihedral_indices in dihedrals:\n", + " dihedral_parameter = labels[dihedral_indices]\n", + " angle = u.atoms[list(dihedral_indices)].dihedral.value()\n", + " energy += calc_torsion_energy(angle, dihedral_parameter)\n", + "\n", + " energies.append(energy.m_as(unit.kilojoules_per_mole))\n", + " return energies" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "70634680-8249-49f6-ad24-bc5178851360", + "metadata": {}, + "outputs": [], + "source": [ + "pyarrow_entries = []\n", + "forcefield = ForceField(\"openff-2.1.0.offxml\")\n", + "# for record, molecule in tqdm.tqdm(records_and_molecules):\n", + "# central_torsion_energies = get_central_bond_torsions(molecule, forcefield)\n", + "# entry = {\n", + "# # hopefully this preserves bond order\n", + "# \"mapped_smiles\": molecule.to_smiles(mapped=True),\n", + "# \"torsion_energies\": central_torsion_energies,\n", + "# \"conformer\": molecule.conformers[0].m_as(unit.angstrom).flatten().tolist(),\n", + "# }\n", + "# pyarrow_entries.append(entry)\n", + "\n", + "\n", + "# # arbitrarily split into training and validation datasets\n", + "# training_pyarrow_entries = pyarrow_entries[:-10]\n", + "# validation_pyarrow_entries = pyarrow_entries[-10:]\n", + "\n", + "# training_table = pa.Table.from_pylist(training_pyarrow_entries)\n", + "# validation_table = pa.Table.from_pylist(validation_pyarrow_entries)\n", + "# training_table" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "72a9cdce-5411-4fe3-b2ea-46fb63e09d2c", + "metadata": {}, + "outputs": [], + "source": [ + "# pq.write_table(training_table, \"training_dataset_table.parquet\")\n", + "# pq.write_table(validation_table, \"validation_dataset_table.parquet\")\n", + "\n", + "# to read back in -- note, the files saved here give the full dataset, not the 50 record subset\n", + "training_table = pq.read_table(\"training_dataset_table.parquet\")\n", + "validation_table = pq.read_table(\"validation_dataset_table.parquet\")" + ] + }, + { + "cell_type": "markdown", + "id": "4682e8cb-fba8-4cec-97b0-4d996e24ecf9", + "metadata": {}, + "source": [ + "## Set up for training a GNN" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3d3c909f-c1b9-44e0-86e0-d4ccf6ceb996", + "metadata": {}, + "outputs": [], + "source": [ + "from openff.nagl.config import (\n", + " TrainingConfig,\n", + " OptimizerConfig,\n", + " ModelConfig,\n", + " DataConfig\n", + ")\n", + "from openff.nagl.config.model import (\n", + " ConvolutionModule, ReadoutModule,\n", + " ConvolutionLayer, ForwardLayer,\n", + ")\n", + "from openff.nagl.config.data import DatasetConfig\n", + "from openff.nagl.training.training import TrainingGNNModel\n", + "from openff.nagl.features.atoms import (\n", + " AtomicElement,\n", + " AtomConnectivity,\n", + " AtomInRingOfSize,\n", + " AtomAverageFormalCharge,\n", + ")\n", + "\n", + "from openff.nagl.training.loss import GeneralLinearFitTarget" + ] + }, + { + "cell_type": "markdown", + "id": "88be032d-0101-479e-a468-2c5db34600fc", + "metadata": {}, + "source": [ + "### Defining the training config\n", + "\n", + "#### Defining a ModelConfig\n", + "\n", + "First we define a ModelConfig. This is done in Python so we can define custom PostprocessLayers to compute c_ij coefficients, and a custom bond feature pooling layer that takes alpha as an input.\n", + "\n", + "Caveats:\n", + "- both these custom layers use *new features* implemented in this branch which are currently unsupported by OpenFF NAGL proper.\n", + "- everything assumes that bonds/angles/torsions etc are properly sorted (which is how it's implemented in the branch). Anything else would require more accounting\n", + "- The current implementation in NAGL doesn't allow for multiple molecules at the moment.\n", + "\n", + "*Also note, again the maths below could probably use double-checking.*" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ee293637-0653-4f17-9a23-784bda10132c", + "metadata": {}, + "outputs": [], + "source": [ + "from openff.nagl.nn.postprocess import PostprocessLayer\n", + "from openff.nagl.nn._pooling import PoolProperTorsionFeatures, PoolBondFeatures\n", + "from collections import defaultdict\n", + "\n", + "class ComputeSCoefficients(PostprocessLayer):\n", + " \"\"\"Computes c_ij\"\"\"\n", + "\n", + " name: str = \"compute_s_coefficients\"\n", + " n_features: int = 1\n", + "\n", + " def __init__(self, pooling_layer = None):\n", + " super().__init__()\n", + " self._pooling_layer = pooling_layer\n", + "\n", + " def forward(\n", + " self,\n", + " molecule,\n", + " inputs: torch.Tensor,\n", + " **kwargs\n", + " ):\n", + " c_ij = inputs[:, 0] # (n_torsions, 1)\n", + " c_ij = torch.flatten(c_ij) # (n_torsions,)\n", + "\n", + " d_ij = self._pooling_layer._calculate_internal_coordinates(molecule)\n", + " s_ij = torch.empty((2, *d_ij.shape), dtype=d_ij.dtype)\n", + " s_ij[0, :] = torch.cos(d_ij)\n", + " s_ij[1, :] = torch.sin(d_ij)\n", + " s_ij = s_ij.T # (n_torsions, 2)\n", + "\n", + " # filter by central bond\n", + " proper_torsion_indices_T = molecule._pooling_representations[\"proper_torsion\"]\n", + "\n", + " bond_indices = defaultdict(list)\n", + " for i, atom_2 in enumerate(\n", + " proper_torsion_indices_T[1]\n", + " ):\n", + " atom_3 = proper_torsion_indices_T[2][i]\n", + " bond = tuple(sorted([atom_2.item(), atom_3.item()]))\n", + " bond_indices[bond].append(i)\n", + "\n", + " a_dict = {}\n", + " \n", + " for bond, indices in bond_indices.items():\n", + " s = torch.sum(\n", + " c_ij[indices].reshape((-1, 1)) * s_ij[indices],\n", + " dim=0\n", + " )\n", + " s_norm = torch.norm(s)\n", + " a = torch.arctan2(*(s / s_norm))\n", + " a_dict[bond] = a\n", + "\n", + " # ... set bonds that aren't central bonds in torsions to 0?\n", + " all_bond_indices = molecule._get_bonds()\n", + " for key in all_bond_indices:\n", + " if key not in a_dict:\n", + " a_dict[key] = torch.tensor(0)\n", + "\n", + " # sort\n", + " sorted_keys = sorted(a_dict)\n", + " alphas = torch.empty((len(a_dict),)).flatten()\n", + " for i, key in enumerate(sorted_keys):\n", + " alphas[i] = a_dict[key].item()\n", + "\n", + " # note if molecules are batched, bond indices will be cumulative\n", + " return alphas\n", + "\n", + "\n", + "class InjectablePoolBondFeatures(PoolBondFeatures):\n", + " name = \"injectable_bond\"\n", + " def _get_final_representations(self, molecule, readouts=None, **kwargs):\n", + " representations = self._get_pooled_representations(molecule)\n", + "\n", + " # assume bonds are properly sorted\n", + " alphas = readouts[\"alpha\"].reshape((-1, 1))\n", + "\n", + " representations = [\n", + " torch.cat([representation, alphas], dim=1)\n", + " for representation in representations\n", + " ]\n", + " return representations" + ] + }, + { + "cell_type": "markdown", + "id": "21d605e5-77fa-41cb-8e18-12523b7f5b61", + "metadata": {}, + "source": [ + "Now the normal definition of a model. Note this uses the same features and general architecture as the NAGL model used for AM1-BCC partial charges.\n", + "\n", + "The readout modules calculates *two* properties: 1) alpha and b) energies.\n", + "\n", + "The GNNModel version is '0.2' to be incompatible with what is currently supported in OpenFF NAGL." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f7ddfbde-1349-42ec-a2aa-0f9ead097e99", + "metadata": {}, + "outputs": [], + "source": [ + "atom_features = [\n", + " AtomicElement(categories=[\"H\", \"C\", \"N\", \"O\", \"F\", \"Br\", \"S\", \"P\", \"I\"]),\n", + " AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),\n", + " AtomInRingOfSize(ring_size=3),\n", + " AtomInRingOfSize(ring_size=4),\n", + " AtomInRingOfSize(ring_size=5),\n", + " AtomInRingOfSize(ring_size=6),\n", + " AtomAverageFormalCharge(),\n", + "]\n", + "\n", + "# define our convolution module\n", + "convolution_module = ConvolutionModule(\n", + " architecture=\"SAGEConv\",\n", + " # construct 6 layers with dropout 0 (default),\n", + " # hidden feature size 512, and ReLU activation function\n", + " # these layers can also be individually specified,\n", + " # but we just duplicate the layer 6 times for identical layers\n", + " layers=[\n", + " ConvolutionLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " aggregator_type=\"mean\"\n", + " )\n", + " ] * 6,\n", + ")\n", + "\n", + "# define our readout module/s\n", + "# multiple are allowed but let's focus on charges\n", + "readout_modules = {\n", + " # key is the name of output property, any naming is allowed\n", + " \"alpha\": ReadoutModule(\n", + " pooling=\"proper_torsion\",\n", + " postprocess=ComputeSCoefficients(),\n", + " # 2 layers\n", + " layers=[\n", + " ForwardLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " )\n", + " ] * 2,\n", + " ),\n", + " \"energies\": ReadoutModule(\n", + " pooling=InjectablePoolBondFeatures,\n", + " layers=[\n", + " ForwardLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " ),\n", + " ForwardLayer(\n", + " hidden_feature_size=512,\n", + " activation_function=\"ReLU\",\n", + " ),\n", + " ForwardLayer(\n", + " hidden_feature_size=1,\n", + " activation_function=\"Identity\",\n", + " )\n", + " ]\n", + " )\n", + "}\n", + "\n", + "# bring it all together\n", + "model_config = ModelConfig(\n", + " version=\"0.2\",\n", + " atom_features=atom_features,\n", + " convolution=convolution_module,\n", + " readouts=readout_modules,\n", + " include_xyz=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a68f7b57-a1f7-43f4-ba2d-9579dd9562db", + "metadata": {}, + "source": [ + "#### Defining a DataConfig\n", + "\n", + "We can then define our dataset configs. Here we also have to specify our training targets." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bf524fad-42e5-4020-a9d9-90a56343b7a1", + "metadata": {}, + "outputs": [], + "source": [ + "from openff.nagl.training.loss import ReadoutTarget\n", + "\n", + "target = ReadoutTarget(\n", + " # what we're using to evaluate loss\n", + " target_label=\"torsion_energies\",\n", + " # the output of the GNN we use to evaluate loss\n", + " prediction_label=\"energies\",\n", + " # how we want to evaluate loss, e.g. RMSE, MSE, ...\n", + " metric=\"rmse\",\n", + " # how much to weight this target\n", + " # helps with scaling in multi-target optimizations\n", + " weight=1,\n", + " denominator=1,\n", + ")\n", + "\n", + "training_to_torsions = DatasetConfig(\n", + " sources=[\"training_dataset_table.parquet\"],\n", + " targets=[target],\n", + " batch_size=100,\n", + ")\n", + "validating_to_torsions = DatasetConfig(\n", + " sources=[\"validation_dataset_table.parquet\"],\n", + " targets=[target],\n", + " batch_size=100,\n", + ")\n", + "\n", + "# bringing it together\n", + "data_config = DataConfig(\n", + " training=training_to_torsions,\n", + " validation=validating_to_torsions\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cfd8bba9-241c-4c67-8f88-a89ae608ec89", + "metadata": {}, + "source": [ + "#### Defining an OptimizerConfig\n", + "\n", + "The optimizer config is relatively simple; the only moving part here currently is the learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1759ed0f-2e89-4d22-80ae-2c0ba04f37a0", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer_config = OptimizerConfig(optimizer=\"Adam\", learning_rate=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "7c784871-e834-470e-a346-bde893326fee", + "metadata": {}, + "source": [ + "#### Creating a TrainingConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f5958818-03ef-4580-9cbf-c588cbb3141f", + "metadata": {}, + "outputs": [], + "source": [ + "training_config = TrainingConfig(\n", + " model=model_config,\n", + " data=data_config,\n", + " optimizer=optimizer_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "17d61258-fe16-4993-9c19-57118c9d0a1a", + "metadata": {}, + "source": [ + "### Creating a TrainingGNNModel\n", + "\n", + "Now we can create a `TrainingGNNModel`, which allows easy training of a `GNNModel`. The `GNNModel` can be accessed through `TrainingGNNModel.model`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3eefcea4-2e9a-4e4a-b3fc-939f3f733fb8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrainingGNNModel(\n", + " (model): GNNModel(\n", + " (convolution_module): ConvolutionModule(\n", + " (gcn_layers): SAGEConvStack(\n", + " (0): SAGEConv(\n", + " (feat_drop): Dropout(p=0.0, inplace=False)\n", + " (activation): ReLU()\n", + " (fc_neigh): Linear(in_features=20, out_features=512, bias=False)\n", + " (fc_self): Linear(in_features=20, out_features=512, bias=True)\n", + " )\n", + " (1-5): 5 x SAGEConv(\n", + " (feat_drop): Dropout(p=0.0, inplace=False)\n", + " (activation): ReLU()\n", + " (fc_neigh): Linear(in_features=512, out_features=512, bias=False)\n", + " (fc_self): Linear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (readout_modules): ModuleDict(\n", + " (alpha): ReadoutModule(\n", + " (pooling_layer): PoolProperTorsionFeatures(\n", + " (layers): SequentialLayers(\n", + " (0): Linear(in_features=2049, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): ReLU()\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): Linear(in_features=512, out_features=1, bias=True)\n", + " (7): Identity()\n", + " (8): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (postprocess_layer): ComputeSCoefficients()\n", + " )\n", + " (energies): ReadoutModule(\n", + " (pooling_layer): InjectablePoolBondFeatures(\n", + " (layers): SequentialLayers(\n", + " (0): Linear(in_features=1025, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): ReLU()\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): Linear(in_features=512, out_features=1, bias=True)\n", + " (7): Identity()\n", + " (8): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_model = TrainingGNNModel(training_config)\n", + "training_model" + ] + }, + { + "cell_type": "markdown", + "id": "d696dae6-93ce-4f0a-ab0f-ebb7e72fa176", + "metadata": {}, + "source": [ + "We can look at the initial capabilities of the model by comparing its energies to reference data." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "89a473ae-cc08-4fac-a355-45d9c5def612", + "metadata": {}, + "outputs": [], + "source": [ + "torsion_layer = training_model.model.readout_modules[\"alpha\"].pooling_layer\n", + "# very hacky assignment to make current model work\n", + "# NAGL doesn't elegantly allow for passing this in during model creation\n", + "training_model.model.readout_modules[\"alpha\"].postprocess_layer._pooling_layer = torsion_layer" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "adb99158-fa97-412d-972b-89f6f563b84a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/lily/pydev/openff-nagl/openff/nagl/utils/_tensors.py:214: UserWarning: Using torch.cross without specifying the dim arg is deprecated.\n", + "Please either pass the dim explicitly or simply use torch.linalg.cross.\n", + "The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /Users/runner/miniforge3/conda-bld/libtorch_1719361031659/work/aten/src/ATen/native/Cross.cpp:66.)\n", + " normal1 = torch.cross(ba, bc)\n", + "/Users/lily/pydev/openff-nagl/openff/nagl/nn/_models.py:240: UserWarning: TODO: currently non-atom-wise properties are not properly handled!!! We just assume they are **strictly sequential**. In general we don't recommend using multiple molecules!!\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "array([-0.19341571, -0.22279456, -0.22279456, -0.22279456, 7.53086858,\n", + " -0.20101197, -0.20101197, -0.20901833, -0.21364443, -0.21364446])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_molecule = Molecule.from_smiles(\"CCCBr\")\n", + "test_molecule.generate_conformers(n_conformers=1)\n", + "reference_energies = get_central_bond_torsions(test_molecule, forcefield)\n", + "\n", + "# switch to eval mode\n", + "training_model.model.eval()\n", + "\n", + "with torch.no_grad():\n", + " energies_1 = training_model.model.compute_properties(\n", + " test_molecule,\n", + " as_numpy=True\n", + " )[\"energies\"]\n", + "\n", + "# switch back to training mode\n", + "training_model.model.train()\n", + "\n", + "# compare charges\n", + "differences = reference_energies - energies_1\n", + "differences" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e6dae76b-d938-4124-a5f3-0a6448d1559e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.030007618485268825, 0, 0, 0, 7.7472088522614175, 0, 0, 0, 0, 0]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reference_energies" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d79fb091-8190-4eda-8b12-5d0c4a2bb0a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.22342333, 0.22279456, 0.22279456, 0.22279456, 0.21634027,\n", + " 0.20101197, 0.20101197, 0.20901833, 0.21364443, 0.21364446])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "energies_1" + ] + }, + { + "cell_type": "markdown", + "id": "79cb1d56-3d70-4307-93f0-0d5956fdb32e", + "metadata": {}, + "source": [ + "### Training the model\n", + "\n", + "We use Pytorch Lightning to train." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5bfcbce1-8323-4403-a607-89292681375b", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import TQDMProgressBar" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "c8cb0a08-eea6-41e4-9fd6-158204e26e79", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "trainer = pl.Trainer(\n", + " max_epochs=100,\n", + " callbacks=[TQDMProgressBar()], # add progress bar\n", + " accelerator=\"cpu\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "f8846142-77f2-4d60-a74e-c3bc3c3ed3ee", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = training_model.create_data_module(verbose=False)" + ] + }, + { + "cell_type": "markdown", + "id": "cda70524-d4a6-4c0e-a0b2-393c72d61941", + "metadata": {}, + "source": [ + "Currently there are an abundance of warnings about conformer geometries being in angstrom." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "4e0174d7-a57b-49d9-89a3-5c756adc0d8c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Featurizing dataset: 0it [00:00, ?it/s]\n", + "Featurizing batch: 0%| | 0/640 [00:00 List[DGLMolecule]: diff --git a/openff/nagl/molecule/_dgl/molecule.py b/openff/nagl/molecule/_dgl/molecule.py index c22d0db0..96432a84 100644 --- a/openff/nagl/molecule/_dgl/molecule.py +++ b/openff/nagl/molecule/_dgl/molecule.py @@ -7,7 +7,7 @@ from openff.nagl.features.bonds import BondFeature from openff.nagl.molecule._base import NAGLMoleculeBase, MoleculeMixin from .utils import ( - FORWARD, + FORWARD, REVERSE, dgl_heterograph_to_homograph, openff_molecule_to_dgl_graph, ) @@ -23,6 +23,59 @@ def to_homogenous(self): def to(self, device: str): graph = self.graph.to(device) return type(self)(graph) + + def _get_bonds(self) -> list[tuple[int, int]]: + a, b = self.graph.edges(etype=FORWARD) + bonds = list(zip(a.tolist(), b.tolist())) + # sort + bonds = [tuple(sorted(bond)) for bond in bonds] + return sorted(bonds) + + def _get_angles(self) -> list[tuple[int, int, int]]: + angles = set() + bonds = self._get_bonds() + bonds += [(b, a) for a, b in bonds] + for atom1, atom2 in bonds: + atom3s = self.graph.successors(atom2, etype=FORWARD).tolist() + # TODO: necessary? + atom3s += self.graph.predecessors(atom2, etype=FORWARD).tolist() + for atom3 in atom3s: + if atom1 == atom3: + continue + if atom1 < atom3: + angles.add((atom1, atom2, atom3)) + else: + angles.add((atom3, atom2, atom1)) + return sorted(list(angles)) + + def _get_dihedrals(self) -> list[tuple[int, int, int, int]]: + dihedrals = set() + angles = self._get_angles() + angles += [(c, b, a) for a, b, c in angles] + for atom1, atom2, atom3 in angles: + atom4s = self.graph.successors(atom3, etype=FORWARD).tolist() + # TODO: necessary? + atom4s += self.graph.predecessors(atom3, etype=FORWARD).tolist() + for atom4 in atom4s: + if atom2 == atom4: + continue + if atom1 < atom4: + dihedrals.add((atom1, atom2, atom3, atom4)) + else: + dihedrals.add((atom4, atom3, atom2, atom1)) + + atom0s = self.graph.successors(atom1, etype=FORWARD).tolist() + atom0s += self.graph.predecessors(atom1, etype=FORWARD).tolist() + for atom0 in atom0s: + if atom2 == atom0: + continue + if atom3 < atom0: + dihedrals.add((atom3, atom2, atom1, atom0)) + else: + dihedrals.add((atom0, atom1, atom2, atom3)) + return sorted(list(dihedrals)) + + class DGLMolecule(MoleculeMixin, DGLBase): @@ -39,6 +92,29 @@ def n_graph_nodes(self): @property def n_graph_edges(self): return int(self.graph.number_of_edges(FORWARD)) + + @classmethod + def from_openff_config( + cls, + molecule, + model_config, + atom_feature_tensor: Optional[torch.Tensor] = None, + bond_feature_tensor: Optional[torch.Tensor] = None, + model=None, + ): + return cls.from_openff( + molecule, + atom_features=model_config.atom_features, + bond_features=model_config.bond_features, + atom_feature_tensor=atom_feature_tensor, + bond_feature_tensor=bond_feature_tensor, + enumerate_resonance_forms=model_config.enumerate_resonance_forms, + lowest_energy_only=True, + max_path_length=None, + include_all_transfer_pathways=False, + include_xyz=model_config.include_xyz, + model=model, + ) @classmethod @requires_package("dgl") @@ -53,6 +129,8 @@ def from_openff( lowest_energy_only: bool = True, max_path_length: Optional[int] = None, include_all_transfer_pathways: bool = False, + include_xyz: bool = False, + model=None, ): import dgl from openff.nagl.utils.resonance import ResonanceEnumerator @@ -93,6 +171,7 @@ def from_openff( bond_feature_tensor=bond_feature_tensor, forward=cls._graph_forward_edge_type, reverse=cls._graph_backward_edge_type, + include_xyz=include_xyz ) for offmol in offmols ] @@ -107,8 +186,26 @@ def from_openff( mapped_smiles = offmols[0].to_smiles(mapped=True) - return cls( + obj = cls( graph=graph, n_representations=len(offmols), mapped_smiles=mapped_smiles ) + # n_atoms = len(offmols[0].atoms) + # if model is not None: + # all_pooling_layers = [ + # readout.pooling_layer + # for readout in model.readout_modules.values() + # ] + # for pooling_layer in all_pooling_layers: + # if pooling_layer.name == "atom": + # continue + # indices = pooling_layer._generate_transposed_pooling_representation( + # molecule + # ) + # all_indices = [] + # for i in range(len(offmols)): + # all_indices.append(indices + (i * n_atoms)) + # indices = torch.cat(all_indices, dim=1) + # obj._pooling_representations[pooling_layer.name] = indices + return obj diff --git a/openff/nagl/molecule/_dgl/utils.py b/openff/nagl/molecule/_dgl/utils.py index fdfb76a6..6fc2b8d2 100644 --- a/openff/nagl/molecule/_dgl/utils.py +++ b/openff/nagl/molecule/_dgl/utils.py @@ -1,7 +1,9 @@ +import warnings from typing import Dict, List, TYPE_CHECKING, Optional import torch import numpy as np +from openff.units import unit from openff.utilities import requires_package from openff.nagl.features.atoms import AtomFeature @@ -14,8 +16,6 @@ import dgl from openff.toolkit.topology.molecule import Molecule - - @requires_package("dgl") def openff_molecule_to_base_dgl_graph( molecule: "Molecule", @@ -54,8 +54,12 @@ def openff_molecule_to_dgl_graph( bond_feature_tensor: Optional[torch.Tensor] = None, forward: str = FORWARD, reverse: str = REVERSE, + include_xyz: bool = False, ) -> "dgl.DGLHeteroGraph": - from openff.nagl.molecule._utils import _get_openff_molecule_information + from openff.nagl.molecule._utils import ( + _get_openff_molecule_information, + _add_xyz_information + ) if len(atom_features) and atom_feature_tensor is not None: raise ValueError( @@ -76,6 +80,11 @@ def openff_molecule_to_dgl_graph( reverse=reverse, ) + # add coordinates + if include_xyz: + _add_xyz_information(molecule, molecule_graph) + + # add atom features if len(atom_features): atom_featurizer = AtomFeaturizer(atom_features) diff --git a/openff/nagl/molecule/_graph/_graph.py b/openff/nagl/molecule/_graph/_graph.py index 351b04c7..db1887b2 100644 --- a/openff/nagl/molecule/_graph/_graph.py +++ b/openff/nagl/molecule/_graph/_graph.py @@ -2,6 +2,7 @@ import copy from collections import defaultdict from typing import List, Tuple, TYPE_CHECKING +import warnings from openff.nagl.features.atoms import AtomFeature from openff.nagl.features.bonds import BondFeature @@ -313,13 +314,21 @@ def from_openff( molecule: "Molecule", atom_features: Tuple[AtomFeature, ...] = tuple(), bond_features: Tuple[BondFeature, ...] = tuple(), + include_xyz: bool = False, ): - from openff.nagl.molecule._utils import _get_openff_molecule_information + from openff.nagl.molecule._utils import ( + _get_openff_molecule_information, + _add_xyz_information + ) nx_graph = openff_molecule_to_base_nx_graph(molecule) molecule_graph = cls(nx_graph) + # add coordinates + if include_xyz: + _add_xyz_information(molecule, molecule_graph) + if len(atom_features): atom_featurizer = AtomFeaturizer(atom_features) atom_features = atom_featurizer.featurize(molecule) diff --git a/openff/nagl/molecule/_graph/molecule.py b/openff/nagl/molecule/_graph/molecule.py index fd71d5b1..7e8da157 100644 --- a/openff/nagl/molecule/_graph/molecule.py +++ b/openff/nagl/molecule/_graph/molecule.py @@ -23,6 +23,22 @@ def n_graph_nodes(self): @property def n_graph_edges(self): return int(self.graph.graph.number_of_edges()) + + @classmethod + def from_openff_config( + cls, + molecule, + model_config, + ): + return cls.from_openff( + molecule, + atom_features=model_config.atom_features, + bond_features=model_config.bond_features, + enumerate_resonance_forms=model_config.enumerate_resonance_forms, + lowest_energy_only=True, + max_path_length=None, + include_all_transfer_pathways=False, + ) @classmethod def from_openff( @@ -34,6 +50,7 @@ def from_openff( lowest_energy_only: bool = True, max_path_length: Optional[int] = None, include_all_transfer_pathways: bool = False, + include_xyz: bool = False, ): from openff.nagl.utils.resonance import ResonanceEnumerator @@ -51,6 +68,7 @@ def from_openff( offmol, atom_features=atom_features, bond_features=bond_features, + include_xyz=include_xyz ) for offmol in offmols ] diff --git a/openff/nagl/molecule/_utils.py b/openff/nagl/molecule/_utils.py index 0e8ad907..6f5112e2 100644 --- a/openff/nagl/molecule/_utils.py +++ b/openff/nagl/molecule/_utils.py @@ -1,4 +1,6 @@ from typing import Dict, TYPE_CHECKING +import warnings + if TYPE_CHECKING: from openff.toolkit.topology.molecule import Molecule @@ -24,3 +26,20 @@ def _get_openff_molecule_information( "formal_charge": torch.tensor(charges, dtype=torch.int8), "atomic_number": torch.tensor(atomic_numbers, dtype=torch.int8), } + +def _add_xyz_information( + molecule, molecule_graph +): + from openff.units import unit + import torch + + if not molecule.conformers: + raise ValueError("Molecule does not have coordinates.") + if len(molecule.conformers) > 1: + warnings.warn( + "Molecule has multiple conformers. Using the first one." + ) + molecule_graph.ndata["xyz"] = torch.tensor( + molecule.conformers[0].m_as(unit.angstrom), + dtype=torch.float32, + ) diff --git a/openff/nagl/nn/_containers.py b/openff/nagl/nn/_containers.py index 04f1e0da..3ff21489 100644 --- a/openff/nagl/nn/_containers.py +++ b/openff/nagl/nn/_containers.py @@ -1,4 +1,5 @@ import copy +import typing from typing import List, Optional, Union, Tuple, Callable import torch @@ -8,7 +9,7 @@ from openff.nagl.nn.activation import ActivationFunction from openff.nagl.nn.gcn._base import _GCNStackMeta, BaseConvModule from openff.nagl.nn._sequential import SequentialLayers -from openff.nagl.nn._pooling import PoolingLayer, get_pooling_layer +from openff.nagl.nn._pooling import PoolingLayer, get_pooling_layer, get_pooling_layer_type from openff.nagl.nn.postprocess import PostprocessLayer, _PostprocessLayerMeta @@ -111,6 +112,7 @@ def __init__( pooling_layer: PoolingLayer, readout_layers: SequentialLayers, postprocess_layer: Optional[PostprocessLayer] = None, + pooling_kwargs: dict[str, typing.Any] = None, ): """ @@ -127,31 +129,45 @@ def __init__( super().__init__() - self.pooling_layer = get_pooling_layer(pooling_layer) - self.readout_layers = readout_layers + if pooling_kwargs is None: + pooling_kwargs = {} + + self.pooling_layer = get_pooling_layer( + pooling_layer, + layers=readout_layers, + **pooling_kwargs + ) if postprocess_layer is not None: if not isinstance(postprocess_layer, PostprocessLayer): postprocess_layer = _PostprocessLayerMeta._get_object(postprocess_layer) self.postprocess_layer = postprocess_layer - def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]) -> torch.Tensor: - x = self._forward_unpostprocessed(molecule) + @property + def readout_layers(self): + return self.pooling_layer.layers + + def forward( + self, + molecule: Union[DGLMolecule, DGLMoleculeBatch], + **kwargs + ) -> torch.Tensor: + x = self._forward_unpostprocessed(molecule, **kwargs) if self.postprocess_layer is not None: - x = self.postprocess_layer.forward(molecule, x) + x = self.postprocess_layer.forward(molecule, x, **kwargs) return x def _forward_unpostprocessed( - self, molecule: Union[DGLMolecule, DGLMoleculeBatch] + self, molecule: Union[DGLMolecule, DGLMoleculeBatch], + **kwargs ) -> torch.Tensor: """ Forward pass without postprocessing the readout modules. This is quality-of-life method for debugging and testing. It is *not* intended for public use. """ - x = self.pooling_layer.forward(molecule) - x = self.readout_layers.forward(x) + x = self.pooling_layer.forward(molecule, **kwargs) return x def copy(self, copy_weights: bool = False): @@ -167,9 +183,14 @@ def copy(self, copy_weights: bool = False): def from_config( cls, readout_config, - n_input_features: int + n_input_features: int = None ): pooling_layer = readout_config.pooling + pooling_layer_type = get_pooling_layer_type(pooling_layer) + pooling_kwargs = {} + for k in ["include_internal_coordinates"]: + pooling_kwargs[k] = getattr(readout_config, k) + hidden_feature_sizes = [ layer.hidden_feature_size for layer in readout_config.layers ] @@ -179,20 +200,24 @@ def from_config( layer_dropout = [ layer.dropout for layer in readout_config.layers ] + postprocess_layer = None if readout_config.postprocess is not None: postprocess_layer = _PostprocessLayerMeta._get_object(readout_config.postprocess) hidden_feature_sizes.append(postprocess_layer.n_features) layer_activation_functions.append(ActivationFunction.Identity) layer_dropout.append(0.0) + n_dense_input_features = pooling_layer_type.get_n_feature_columns(n_input_features) readout_layers = SequentialLayers.with_layers( - n_input_features, + n_dense_input_features, hidden_feature_sizes, layer_activation_functions, layer_dropout, ) + return cls( pooling_layer, readout_layers, - postprocess_layer + postprocess_layer, + pooling_kwargs=pooling_kwargs ) \ No newline at end of file diff --git a/openff/nagl/nn/_dataset.py b/openff/nagl/nn/_dataset.py index 2fe03d3e..3419bbeb 100644 --- a/openff/nagl/nn/_dataset.py +++ b/openff/nagl/nn/_dataset.py @@ -9,12 +9,14 @@ import pickle import tempfile import typing +import warnings import tqdm import torch from openff.utilities import requires_package from torch.utils.data import Dataset, DataLoader, ConcatDataset +from openff.units import unit from openff.nagl._base.base import ImmutableModel from openff.nagl.config.training import TrainingConfig from openff.nagl.features.atoms import AtomFeature @@ -127,7 +129,9 @@ def from_openff( bond_features: typing.List[BondFeature], atom_feature_tensor: typing.Optional[torch.Tensor] = None, bond_feature_tensor: typing.Optional[torch.Tensor] = None, - + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, + model=None, ): dglmol = DGLMolecule.from_openff( openff_molecule, @@ -135,6 +139,9 @@ def from_openff( bond_features=bond_features, atom_feature_tensor=atom_feature_tensor, bond_feature_tensor=bond_feature_tensor, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, + model=model, ) labels_ = {} @@ -160,8 +167,11 @@ def from_mapped_smiles( labels: typing.Dict[str, typing.Any], atom_features: typing.List[AtomFeature], bond_features: typing.List[BondFeature], + conformer: typing.Optional[np.ndarray] = None, atom_feature_tensor: typing.Optional[torch.Tensor] = None, bond_feature_tensor: typing.Optional[torch.Tensor] = None, + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): """ Create a dataset entry from a mapped SMILES string. @@ -198,6 +208,13 @@ def from_mapped_smiles( mapped_smiles, allow_undefined_stereo=True, ) + if conformer is not None: + warnings.warn( + "Conformer geometry is expected to be in angstrom" + ) + molecule._conformers = [ + np.array(conformer).reshape((-1, 3)) * unit.angstrom + ] return cls.from_openff( molecule, labels, @@ -205,6 +222,8 @@ def from_mapped_smiles( bond_features, atom_feature_tensor, bond_feature_tensor, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) @classmethod @@ -214,14 +233,24 @@ def _from_unfeaturized_pyarrow_row( atom_features: typing.List[AtomFeature], bond_features: typing.List[BondFeature], smiles_column: str = "mapped_smiles", + conformer_column: typing.Optional[str] = "conformer", + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): labels = dict(row) mapped_smiles = labels.pop(smiles_column) + if conformer_column is not None: + conformer = labels.pop(conformer_column, None) + else: + conformer = None return cls.from_mapped_smiles( mapped_smiles, labels, atom_features, bond_features, + conformer=conformer, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) @classmethod @@ -231,6 +260,8 @@ def _from_featurized_pyarrow_row( atom_feature_column: str, bond_feature_column: str, smiles_column: str = "mapped_smiles", + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): from openff.toolkit import Molecule @@ -260,6 +291,8 @@ def _from_featurized_pyarrow_row( bond_features=[], atom_feature_tensor=atom_features, bond_feature_tensor=bond_features, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms ) @@ -320,6 +353,8 @@ def from_arrow_dataset( cache_directory: typing.Optional[pathlib.Path] = None, use_cached_data: bool = True, n_processes: int = 0, + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): import pyarrow as pa import pyarrow.dataset as ds @@ -360,6 +395,8 @@ def from_arrow_dataset( atom_features=atom_features, bond_features=bond_features, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms ) else: converter = functools.partial( @@ -367,6 +404,8 @@ def from_arrow_dataset( atom_feature_column=atom_feature_column, bond_feature_column=bond_feature_column, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms ) if columns is not None and atom_feature_column not in columns: columns.append(atom_feature_column) @@ -396,12 +435,16 @@ def _pickle_entry_from_unfeaturized_row( atom_features=None, bond_features=None, smiles_column="mapped_smiles", + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): entry = DGLMoleculeDatasetEntry._from_unfeaturized_pyarrow_row( row, atom_features=atom_features, bond_features=bond_features, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) f = io.BytesIO() pickle.dump(entry, f) @@ -413,12 +456,16 @@ def _pickle_entry_from_featurized_row( atom_feature_column: str = "atom_features", bond_feature_column: str = "bond_features", smiles_column: str = "mapped_smiles", + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): entry = DGLMoleculeDatasetEntry._from_featurized_pyarrow_row( row, atom_feature_column=atom_feature_column, bond_feature_column=bond_feature_column, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) f = io.BytesIO() pickle.dump(entry, f) @@ -454,8 +501,11 @@ def from_arrow_dataset( atom_feature_column: typing.Optional[str] = None, bond_feature_column: typing.Optional[str] = None, smiles_column: str = "mapped_smiles", + conformer_column: str = "conformer", columns: typing.Optional[typing.List[str]] = None, n_processes: int = 0, + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, ): import pyarrow.dataset as ds @@ -463,6 +513,8 @@ def from_arrow_dataset( columns = list(columns) if smiles_column not in columns: columns.append(smiles_column) + if conformer_column not in columns: + columns.append(conformer_column) if atom_feature_column is None and bond_feature_column is None: converter = functools.partial( @@ -470,6 +522,8 @@ def from_arrow_dataset( atom_features=atom_features, bond_features=bond_features, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) else: converter = functools.partial( @@ -477,6 +531,8 @@ def from_arrow_dataset( atom_feature_column=atom_feature_column, bond_feature_column=bond_feature_column, smiles_column=smiles_column, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, ) if columns is not None and atom_feature_column not in columns: columns.append(atom_feature_column) @@ -512,6 +568,9 @@ def from_openff( label_function: typing.Optional[ typing.Callable[["Molecule"], typing.Dict[str, typing.Any]] ] = None, + include_xyz: bool = False, + enumerate_resonance_forms: bool = False, + model=None ): if labels is None: labels = [{} for _ in molecules] @@ -552,6 +611,9 @@ def from_openff( bond_features=bond_features, atom_feature_tensor=atom_tensor, bond_feature_tensor=bond_tensor, + include_xyz=include_xyz, + enumerate_resonance_forms=enumerate_resonance_forms, + model=model ) for molecule, atom_tensor, bond_tensor, label in zip( molecules, atom_feature_tensors, bond_feature_tensors, labels diff --git a/openff/nagl/nn/_models.py b/openff/nagl/nn/_models.py index 5a05de2f..d276cd93 100644 --- a/openff/nagl/nn/_models.py +++ b/openff/nagl/nn/_models.py @@ -36,10 +36,13 @@ def forward( ) -> Dict[str, torch.Tensor]: self.convolution_module(molecule) - readouts: Dict[str, torch.Tensor] = { - readout_type: readout_module(molecule) - for readout_type, readout_module in self.readout_modules.items() - } + readouts: Dict[str, torch.Tensor] = {} + for readout_type, readout_module in self.readout_modules.items(): + readouts[readout_type] = readout_module( + molecule, + model=self, + readouts=readouts, + ) return readouts def _forward_unpostprocessed(self, molecule: "DGLMoleculeOrBatch"): @@ -216,29 +219,42 @@ def compute_properties( tensor = np.empty else: tensor = torch.empty - for property_name, value in results[0].items(): - combined_results[property_name] = tensor( - molecule.n_atoms, - dtype=value.dtype - ) + for property_name in results[0].keys(): + n_values = sum(len(result[property_name]) for result in results) + combined_results[property_name] = tensor(n_values) + seen_indices = collections.defaultdict(set) - for result, indices in zip(results, all_indices): + for i, (result, indices) in enumerate(zip(results, all_indices)): for property_name, value in result.items(): - combined_results[property_name][indices] = value - if seen_indices[property_name] & set(indices): - raise ValueError( - "Overlapping indices in the fragments" + j = 0 + if self.readout_modules[property_name].pooling_layer.name == "atom": + combined_results[property_name][indices] = value + if seen_indices[property_name] & set(indices): + raise ValueError( + "Overlapping indices in the fragments" + ) + seen_indices[property_name].update(indices) + else: + warnings.warn( + "TODO: currently non-atom-wise properties " + "are not properly handled!!! " + "We just assume they are **strictly sequential**. " + "In general we don't recommend using multiple molecules!!" ) - seen_indices[property_name].update(indices) + combined_results[property_name][j : j+ len(value)] = value + j += len(value) + + expected_indices = list(range(molecule.n_atoms)) for property_name, seen_indices in seen_indices.items(): - assert sorted(seen_indices) == expected_indices, ( - f"Missing indices for property {property_name}: " - f"{set(expected_indices) - seen_indices}" - ) + if self.readout_modules[property_name].pooling_layer.name == "atom": + assert sorted(seen_indices) == expected_indices, ( + f"Missing indices for property {property_name}: " + f"{set(expected_indices) - seen_indices}" + ) return combined_results @@ -320,7 +336,8 @@ def _compute_properties( try: values = self._compute_properties_dgl(molecule) - except (MissingOptionalDependencyError, TypeError): + except (MissingOptionalDependencyError, TypeError) as e: + raise e values = self._compute_properties_nagl(molecule) @@ -417,10 +434,9 @@ def compute_property( def _compute_properties_nagl(self, molecule: "Molecule") -> "torch.Tensor": from openff.nagl.molecule._graph.molecule import GraphMolecule - nxmol = GraphMolecule.from_openff( + nxmol = GraphMolecule.from_openff_config( molecule, - atom_features=self.config.atom_features, - bond_features=self.config.bond_features, + self.config, ) model = self if self._is_dgl: @@ -436,10 +452,10 @@ def _compute_properties_dgl(self, molecule: "Molecule") -> "torch.Tensor": "and cannot be used to compute properties with the DGL backend" ) - dglmol = DGLMolecule.from_openff( + dglmol = DGLMolecule.from_openff_config( molecule, - atom_features=self.config.atom_features, - bond_features=self.config.bond_features, + self.config, + model=self, ) return self.forward(dglmol) @@ -448,16 +464,15 @@ def _convert_to_nagl_molecule(self, molecule: "Molecule"): if self._is_dgl: from openff.nagl.molecule._dgl.molecule import DGLMolecule - return DGLMolecule.from_openff( + return DGLMolecule.from_openff_config( molecule, - atom_features=self.config.atom_features, - bond_features=self.config.bond_features, + self.config, + model=self ) - return GraphMolecule.from_openff( + return GraphMolecule.from_openff_config( molecule, - atom_features=self.config.atom_features, - bond_features=self.config.bond_features, + self.config, ) @classmethod diff --git a/openff/nagl/nn/_pooling.py b/openff/nagl/nn/_pooling.py index 1be0b51d..279bbe1e 100644 --- a/openff/nagl/nn/_pooling.py +++ b/openff/nagl/nn/_pooling.py @@ -1,32 +1,88 @@ +""" +Pooling layers +============== + +A pooling layer is a layer that takes the output of a graph convolutional layer and +produces a single feature vector for each molecule. This is typically done by +aggregating the node features produced by the graph convolutional layer. + +In NAGL, pooling layers are implemented as subclasses of `PoolingLayer`. +They are invoked at various stages of the model. + + +""" + import abc import functools +import logging from typing import ClassVar, Dict, Union, TYPE_CHECKING, Iterable +import torch import torch.nn from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch, DGLMoleculeOrBatch from openff.nagl.nn._sequential import SequentialLayers +# TODO: make toolkit-agnostic +from rdkit.Chem import rdMolTransforms + if TYPE_CHECKING: import dgl + from openff.toolkit import Molecule + + +logger = logging.getLogger(__name__) + + +def _append_internal_coordinate(pooling_layer): + """A decorator to append internal coordinates to the pooling layer.""" + + def wrapper(representations, molecule): + if pooling_layer._include_internal_coordinates: + internal_coordinates = pooling_layer._calculate_internal_coordinates(molecule) + internal_coordinates = internal_coordinates.reshape((-1, 1)) + representations = [ + torch.cat([representation, internal_coordinates], dim=1) + for representation in representations + ] + return representations + + return wrapper + + + class PoolingLayer(torch.nn.Module, abc.ABC): """A convenience class for pooling together node feature vectors produced by a graph convolutional layer. """ - n_feature_columns: ClassVar[int] = 0 + def __init__( + self, + layers: SequentialLayers = None, + pooling_function: callable = torch.add, + ): + super().__init__() + self.layers = layers + self._pooling_function = pooling_function - @abc.abstractmethod - def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: + def forward(self, molecule: DGLMoleculeOrBatch, **kwargs) -> torch.Tensor: """Returns the pooled feature vector.""" - + representations = self._get_final_pooled_representations(molecule, **kwargs) + # apply layers + forwarded = [self.layers(h) for h in representations] + return self._pooling_function(*forwarded) @abc.abstractmethod def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]: """Returns the number of values per molecule.""" + @classmethod + @abc.abstractmethod + def get_n_feature_columns(cls, n_input_features: int) -> int: + raise NotImplementedError + class PoolAtomFeatures(PoolingLayer): """A convenience class for pooling the node feature vectors produced by a graph convolutional layer. @@ -34,86 +90,241 @@ class PoolAtomFeatures(PoolingLayer): This class simply returns the features "h" from the graphs node data. """ - n_feature_columns: ClassVar[int] = 1 + name: ClassVar[str] = "atom" - def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: + def forward(self, molecule: DGLMoleculeOrBatch, **kwargs) -> torch.Tensor: return molecule.graph.ndata[molecule._graph_feature_name] def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]: return molecule.n_atoms_per_molecule + + @classmethod + def get_n_feature_columns(cls, n_input_features: int) -> int: + return n_input_features + + +class _SymmetricPoolingLayer(PoolingLayer): + name: ClassVar[str] = "" + n_atoms: ClassVar[int] = 0 + + def __init__( + self, + layers: SequentialLayers = None, + include_internal_coordinates: bool = False, + ): + super().__init__(layers) + self._include_internal_coordinates = include_internal_coordinates + + def _get_final_pooled_representations(self, molecule: DGLMoleculeOrBatch, **kwargs) -> torch.Tensor: + representations = self._get_pooled_representations(molecule) + if self._include_internal_coordinates: + internal_coordinates = self._calculate_internal_coordinates(molecule) + internal_coordinates = internal_coordinates.reshape((-1, 1)) + else: + internal_coordinates = torch.zeros( + (representations[0].shape[0], 1), dtype=torch.float32 + ) + representations = [ + torch.cat([representation, internal_coordinates], dim=1) + for representation in representations + ] + + return representations + + + def _generate_transposed_pooling_representation(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: + indices = self._generate_single_pooling_representation(molecule) + transposed = [] + if indices: + n_params = len(indices[0]) + for i in range(n_params): + transposed.append([index[i] for index in indices]) + t = torch.tensor(transposed, dtype=torch.long) + molecule._pooling_representations[self.name] = t + + def _generate_single_pooling_representation(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: + raise NotImplementedError + + + def _get_pooled_representations(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: + h_data = molecule.graph.ndata[molecule._graph_feature_name] + + representations = [] + if not self.name in molecule._pooling_representations: + self._generate_transposed_pooling_representation(molecule) + forward_indices = molecule._pooling_representations[self.name] + if forward_indices.shape[1] > 0: + for row in forward_indices: + representations.append(h_data[row]) + + h_forward = torch.cat(representations, dim=1) + h_reverse = torch.cat(representations[::-1], dim=1) + return [h_forward, h_reverse] + + + def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]: + return molecule._n_pooling_representations_per_molecule[self.name] + + def _calculate_internal_coordinates_general(self, molecule: DGLMoleculeOrBatch, calculate_function): + forward_indices = molecule._pooling_representations[self.name] + arrays = [] + xyz_data = molecule.graph.ndata["xyz"] + if forward_indices.shape[1] > 0: + for row in forward_indices: + arrays.append(xyz_data[row]) + return calculate_function(*arrays) + @classmethod + def get_n_feature_columns(cls, n_input_features: int) -> int: + return (n_input_features * cls.n_atoms) + 1 -class PoolBondFeatures(PoolingLayer): + + +class PoolBondFeatures(_SymmetricPoolingLayer): """A convenience class for pooling the node feature vectors produced by a graph convolutional layer into a set of symmetric bond (edge) features. """ - n_feature_columns: ClassVar[int] = 2 + name: ClassVar[str] = "bond" + n_atoms: ClassVar[int] = 2 - def __init__(self, layers: SequentialLayers): - super().__init__() - self.layers = layers + # def _generate_single_pooling_representation(self, molecule: "Molecule"): + # bond_indices = sorted([ + # tuple(sorted([bond.atom1_index, bond.atom2_index])) + # for bond in molecule.bonds + # ]) + # return bond_indices + + def _generate_single_pooling_representation(self, molecule: DGLMoleculeOrBatch): + return molecule._get_bonds() + + + + def _calculate_internal_coordinates(self, molecule: DGLMoleculeOrBatch): + from openff.nagl.utils._tensors import calculate_distances + return self._calculate_internal_coordinates_general( + molecule, calculate_distances + ) - @staticmethod - def _apply_edges( - edges: "dgl.udf.EdgeBatch", feature_name: str = "h" - ) -> Dict[str, torch.Tensor]: - h_u = edges.src[feature_name] - h_v = edges.dst[feature_name] - return {feature_name: torch.cat([h_u, h_v], 1)} - - # def _directionwise_forward( - # self, - # molecule: DGLMoleculeOrBatch, - # edge_type: str = "forward", - # ): - # graph = molecule.graph - # apply_edges = functools.partial( - # self._apply_edges, - # feature_name=molecule._graph_feature_name, - # ) - # with graph.local_scope(): - # graph.apply_edges(apply_edges, etype=edge_type) - # edges = graph.edges[edge_type].data[molecule._graph_feature_name] - # return self.layers(edges) - - def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor: - graph = molecule.graph - node = molecule._graph_feature_name - apply_edges = functools.partial( - self._apply_edges, - feature_name=node, +class PoolAngleFeatures(_SymmetricPoolingLayer): + """A convenience class for pooling the node feature vectors produced by + a graph convolutional layer into a set of symmetric angle features. + """ + + name: ClassVar[str] = "angle" + n_atoms: ClassVar[int] = 3 + + # def _generate_single_pooling_representation(self, molecule: "Molecule"): + # # molecule.angles is just a set of tuples of atoms + # angle_indices = [] + # for angle in molecule.angles: + # indices = ( + # angle[0].molecule_atom_index, + # angle[1].molecule_atom_index, + # angle[2].molecule_atom_index, + # ) + # if indices[-1] < indices[0]: + # indices = indices[::-1] + # angle_indices.append(indices) + # angle_indices = sorted(angle_indices) + # return angle_indices + + def _generate_single_pooling_representation(self, molecule: DGLMoleculeOrBatch): + return molecule._get_angles() + + def _calculate_internal_coordinates(self, molecule: DGLMoleculeOrBatch): + from openff.nagl.utils._tensors import calculate_angles + return self._calculate_internal_coordinates_general( + molecule, calculate_angles ) + - with graph.local_scope(): - graph.apply_edges(apply_edges, etype=molecule._graph_forward_edge_type) - h_forward = graph.edges[molecule._graph_forward_edge_type].data[node] - with graph.local_scope(): - graph.apply_edges(apply_edges, etype=molecule._graph_backward_edge_type) - h_reverse = graph.edges[molecule._graph_backward_edge_type].data[node] +class PoolProperTorsionFeatures(_SymmetricPoolingLayer): + """A convenience class for pooling the node feature vectors produced by + a graph convolutional layer into a set of symmetric proper torsion features. + """ - # h_forward = self._directionwise_forward( - # molecule, - # molecule._graph_forward_edge_type, - # ) - # h_reverse = self._directionwise_forward( - # molecule, - # molecule._graph_backward_edge_type, - # ) - return self.layers(h_forward) + self.layers(h_reverse) + name: ClassVar[str] = "proper_torsion" + n_atoms: ClassVar[int] = 4 - def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]: - return molecule.n_bonds_per_molecule + # def _generate_single_pooling_representation(self, molecule: "Molecule"): + # proper_torsion_indices = [] + # for torsion in molecule.propers: + # indices = ( + # torsion[0].molecule_atom_index, + # torsion[1].molecule_atom_index, + # torsion[2].molecule_atom_index, + # torsion[3].molecule_atom_index, + # ) + # if indices[-1] < indices[0]: + # indices = indices[::-1] + # proper_torsion_indices.append(indices) + # proper_torsion_indices = sorted(proper_torsion_indices) + # return proper_torsion_indices + + def _generate_single_pooling_representation(self, molecule: DGLMoleculeOrBatch): + return molecule._get_dihedrals() + def _calculate_internal_coordinates(self, molecule: DGLMoleculeOrBatch): + from openff.nagl.utils._tensors import calculate_dihedrals + return self._calculate_internal_coordinates_general( + molecule, calculate_dihedrals + ) -def get_pooling_layer(layer: Union[str, PoolingLayer]) -> PoolingLayer: - if isinstance(layer, PoolingLayer): + +class PoolOneFourFeatures(_SymmetricPoolingLayer): + name: ClassVar[str] = "one_four" + + # def _generate_single_pooling_representation(self, molecule: "Molecule"): + # one_four_indices = [] + # for torsion in molecule.propers: + # indices = ( + # torsion[0].molecule_atom_index, + # torsion[3].molecule_atom_index, + # ) + # if indices[-1] < indices[0]: + # indices = indices[::-1] + # one_four_indices.append(indices) + # one_four_indices = sorted(set(one_four_indices)) + # return one_four_indices + + def _calculate_internal_coordinates(self, molecule: DGLMoleculeOrBatch): + from openff.nagl.utils._tensors import calculate_distances + return self._calculate_internal_coordinates_general( + molecule, calculate_distances + ) + + +def get_pooling_layer_type(layer: Union[str, type]) -> type: + if isinstance(layer, type) and issubclass(layer, PoolingLayer): return layer + + LAYER_TYPES = { + "atom": PoolAtomFeatures, + "bond": PoolBondFeatures, + "angle": PoolAngleFeatures, + "proper_torsion": PoolProperTorsionFeatures, + "one_four": PoolOneFourFeatures, + } + if isinstance(layer, str): - if layer.lower() in {"atom", "atoms"}: - return PoolAtomFeatures() - if layer.lower() in {"bond", "bonds"}: - return PoolBondFeatures() + if layer.endswith("s"): # remove plural + layer = layer[:-1] + if layer in LAYER_TYPES: + return LAYER_TYPES[layer] + + raise NotImplementedError(f"Unsupported pooling layer '{layer}'.") + + +def get_pooling_layer(layer: Union[str, PoolingLayer], **kwargs) -> PoolingLayer: + if isinstance(layer, PoolingLayer): + return layer + elif isinstance(layer, str): + layer = get_pooling_layer_type(layer) + + if isinstance(layer, type) and issubclass(layer, PoolingLayer): + return layer(**kwargs) + raise NotImplementedError(f"Unsupported pooling layer '{layer}'.") \ No newline at end of file diff --git a/openff/nagl/nn/postprocess.py b/openff/nagl/nn/postprocess.py index f125b132..57898610 100644 --- a/openff/nagl/nn/postprocess.py +++ b/openff/nagl/nn/postprocess.py @@ -78,6 +78,7 @@ def forward( self, molecule: Union[DGLMolecule, DGLMoleculeBatch], inputs: torch.Tensor, + **kwargs ) -> torch.Tensor: electronegativity = inputs[:, 0] hardness = inputs[:, 1] @@ -156,6 +157,7 @@ def forward( self, molecule: Union[DGLMolecule, DGLMoleculeBatch], inputs: torch.Tensor, + **kwargs, ) -> torch.Tensor: charge_priors = inputs[:, 0] electronegativity = inputs[:, 1] diff --git a/openff/nagl/tests/training/test_training.py b/openff/nagl/tests/training/test_training.py index 603561ee..44dd493e 100644 --- a/openff/nagl/tests/training/test_training.py +++ b/openff/nagl/tests/training/test_training.py @@ -7,6 +7,19 @@ import numpy as np import pytorch_lightning as pl +from openff.nagl.config.data import DatasetConfig, DataConfig +from openff.nagl.config.optimizer import OptimizerConfig +from openff.nagl.config.training import TrainingConfig +from openff.nagl.config.model import ( + ForwardLayer, + ReadoutModule, + ModelConfig, + ConvolutionLayer, + ConvolutionModule +) +from openff.nagl.features import atoms +from openff.nagl.training.metrics import RMSEMetric +from openff.nagl.training.loss import ReadoutTarget from openff.nagl.training.training import DGLMoleculeDataModule, DataHash, TrainingGNNModel from openff.nagl.nn._models import GNNModel from openff.nagl.nn._dataset import ( @@ -377,3 +390,104 @@ def test_train_model_no_error(example_training_config, tmpdir): accelerator="gpu", devices=1, max_epochs=2, ) trainer.fit(model, datamodule=data_module) + + + +@pytest.fixture() +def forward_layer(): + single_readout_layer = ForwardLayer( + hidden_feature_size=128, # 128 features per hidden convolution layer + activation_function="ReLU", # max(0, x) activation function for layer + dropout=0.0, # no dropout + ) + return single_readout_layer + +@pytest.fixture() +def convolution_layer(): + single_convolution_layer = ConvolutionLayer( + hidden_feature_size=128, # 128 features per hidden convolution layer + aggregator_type="mean", # aggregate atom representations with mean + activation_function="ReLU", # max(0, x) activation function for layer + dropout=0.0, # no dropout + ) + return single_convolution_layer + +@pytest.fixture() +def convolution_module(convolution_layer): + convolution_module = ConvolutionModule( + architecture="SAGEConv", # GraphSAGE GCN + layers=[convolution_layer] * 3, # 3 hidden convolution layers + ) + return convolution_module + + +def test_no_postprocess_layer( + convolution_module, + forward_layer, + tmpdir +): + + atom_features = [atoms.AtomicElement(categories=["C", "H"])] + + readout_module = ReadoutModule( + pooling="atoms", + layers=[forward_layer] * 4, # 4 internal readout layers + postprocess=None + ) + + model_config = ModelConfig( + version="0.1", + atom_features=atom_features, + bond_features=[], + convolution=convolution_module, + readouts={ + "predicted-am1bcc-charges": readout_module + } + ) + + with tmpdir.as_cwd(): + # copy over the data + shutil.copytree( + EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.resolve(), + EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.stem + ) + + dataset_name = "example-data-labelled-unfeaturized-short" + + charge_rmse_target = ReadoutTarget( + metric=RMSEMetric(), # use RMSE to calculate loss + target_label="am1bcc_charges", # column to use from data as reference target + prediction_label="predicted-am1bcc-charges", # readout value to compare to target + denominator=1.0, # denominator to normalise loss -- important for multi-target objectives + weight=1.0, # how much to weight the loss -- important for multi-target objectives + ) + + training_dataset_config = DatasetConfig( + sources=[dataset_name], + targets=[charge_rmse_target], + batch_size=1000, + ) + + test_dataset_config = validation_dataset_config = DatasetConfig( + sources=[dataset_name], + targets=[charge_rmse_target], + batch_size=1000, + ) + + data_config = DataConfig( + training=training_dataset_config, + validation=validation_dataset_config, + test=test_dataset_config + ) + + optimizer_config = OptimizerConfig( + optimizer="Adam", + learning_rate=0.001, + ) + + training_config = TrainingConfig( + model=model_config, + data=data_config, + optimizer=optimizer_config + ) + training_model = TrainingGNNModel(training_config) diff --git a/openff/nagl/toolkits/openff.py b/openff/nagl/toolkits/openff.py index cf47283c..0eaa5779 100644 --- a/openff/nagl/toolkits/openff.py +++ b/openff/nagl/toolkits/openff.py @@ -355,6 +355,11 @@ def split_up_molecule( for ix in indices: subgraph = nx.convert_node_labels_to_integers(graph.subgraph(ix)) fragment = molecule_from_networkx(subgraph) + if molecule.conformers: + new_conformers = [] + for conformer in molecule.conformers: + new_conformers.append(conformer[ix]) + fragment._conformers = new_conformers fragments.append(fragment) if return_indices: diff --git a/openff/nagl/training/training.py b/openff/nagl/training/training.py index 001dcea0..4b0b3715 100644 --- a/openff/nagl/training/training.py +++ b/openff/nagl/training/training.py @@ -115,6 +115,8 @@ def _torch_optimizer(self): def create_data_module(self, n_processes: int = 0, verbose: bool = True): return DGLMoleculeDataModule(self.config, n_processes=n_processes, verbose=verbose) + + class DGLMoleculeDataModule(pl.LightningDataModule): @@ -176,6 +178,8 @@ def _get_dgl_molecule_dataset( format="parquet", atom_features=self.config.model.atom_features, bond_features=self.config.model.bond_features, + include_xyz=self.config.model.include_xyz, + enumerate_resonance_forms=self.config.model.enumerate_resonance_forms, columns=columns, cache_directory=cache_dir, use_cached_data=config.use_cached_data, @@ -187,6 +191,8 @@ def _get_dgl_molecule_dataset( format="parquet", atom_features=self.config.model.atom_features, bond_features=self.config.model.bond_features, + include_xyz=self.config.model.include_xyz, + enumerate_resonance_forms=self.config.model.enumerate_resonance_forms, columns=columns, n_processes=self.n_processes, ) diff --git a/openff/nagl/utils/_tensors.py b/openff/nagl/utils/_tensors.py new file mode 100644 index 00000000..570714f5 --- /dev/null +++ b/openff/nagl/utils/_tensors.py @@ -0,0 +1,270 @@ +""" +Module of package-agnostic maths utilities. + +This module contains utility functions for working +with PyTorch tensors and numpy arrays. +""" +import functools +import numpy as np +import torch + +TensorType = np.ndarray | torch.Tensor + +__all__ = [ + "calculate_distances", + "calculate_angles", + "calculate_dihedrals", +] + + +def _switch_backend_function_wrapper( + torch_function: callable, + numpy_function: callable, +): + """Wrap a function to switch between PyTorch and numpy backends. + + Parameters + ---------- + torch_function : callable + The function to call if the input is a PyTorch tensor. + numpy_function : callable + The function to call if the input is a numpy array. + + Returns + ------- + callable + The wrapped function. + """ + + @functools.wraps(torch_function) + def wrapped_function(*args, **kwargs): + if isinstance(args[0], np.ndarray): + return numpy_function(*args, **kwargs) + elif args[0].__module__.startswith("torch"): + return torch_function(*args, **kwargs) + else: + raise NotImplementedError( + f"Function not implemented for type {type(args[0])}" + ) + return wrapped_function + +def _calculate_distances_torch( + source: torch.Tensor, + destination: torch.Tensor, +): + """ + Calculate the Euclidean distance between two sets of points. + + Parameters + ---------- + source : torch.Tensor + The source points. + destination : torch.Tensor + The destination points. + + Returns + ------- + torch.Tensor + The Euclidean distances between the source and destination points. + """ + return torch.norm(source - destination, dim=1) + +def _calculate_distances_numpy( + source: np.ndarray, + destination: np.ndarray, +): + """ + Calculate the Euclidean distance between two sets of points. + + Parameters + ---------- + source : np.ndarray + The source points. + destination : np.ndarray + The destination points. + + Returns + ------- + np.ndarray + The Euclidean distances between the source and destination points. + """ + return np.linalg.norm(source - destination, axis=1) + + +calculate_distances = _switch_backend_function_wrapper( + _calculate_distances_torch, + _calculate_distances_numpy, +) + +def calculate_distances( + source: TensorType, + destination: TensorType, +): + """Calculate the Euclidean distance between two sets of points. + + Parameters + ---------- + source : torch.Tensor + The source points. + destination : torch.Tensor + The destination points. + + Returns + ------- + torch.Tensor + The Euclidean distances between the source and destination points. + """ + if isinstance(source, np.ndarray): + return np.linalg.norm(source - destination, axis=1) + return torch.norm(source - destination, dim=1) + + +def _calculate_angles_torch( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, +): + """ + Calculate the angle between three sets of points. + + Parameters + ---------- + a : torch.Tensor + The first set of points. + b : torch.Tensor + The second set of points. + c : torch.Tensor + The third set of points. + + Returns + ------- + torch.Tensor + The angles between the three sets of points. + """ + ba = a - b + bc = c - b + cosine_angle = torch.sum(ba * bc, dim=1) / ( + torch.norm(ba, dim=1) * torch.norm(bc, dim=1) + ) + return torch.acos(cosine_angle) + +def _calculate_angles_numpy( + a: np.ndarray, + b: np.ndarray, + c: np.ndarray, +): + """ + Calculate the angle between three sets of points. + + Parameters + ---------- + a : np.ndarray + The first set of points. + b : np.ndarray + The second set of points. + c : np.ndarray + The third set of points. + + Returns + ------- + np.ndarray + The angles between the three sets of points. + """ + ba = a - b + bc = c - b + cosine_angle = np.sum(ba * bc, axis=1) / ( + np.linalg.norm(ba, axis=1) * np.linalg.norm(bc, axis=1) + ) + return np.arccos(cosine_angle) + +calculate_angles = _switch_backend_function_wrapper( + _calculate_angles_torch, + _calculate_angles_numpy, +) + +def _calculate_dihedrals_torch( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + d: torch.Tensor, +): + """ + Calculate the dihedral angle between four sets of points. + + Parameters + ---------- + a : torch.Tensor + The first set of points. + b : torch.Tensor + The second set of points. + c : torch.Tensor + The third set of points. + d : torch.Tensor + The fourth set of points. + + Returns + ------- + torch.Tensor + The dihedral angles between the four sets of points. + """ + ba = a - b + bc = c - b + cd = d - c + + normal1 = torch.cross(ba, bc) + normal2 = torch.cross(bc, cd) + + m1 = torch.cross(normal1, bc) + m2 = torch.cross(normal2, bc) + + x = torch.sum(m1 * m2, dim=1) + y = torch.sum(normal1 * normal2, dim=1) + + return torch.atan2(y, x) + + +def _calculate_dihedrals_numpy( + a: np.ndarray, + b: np.ndarray, + c: np.ndarray, + d: np.ndarray, +): + """ + Calculate the dihedral angle between four sets of points. + + Parameters + ---------- + a : np.ndarray + The first set of points. + b : np.ndarray + The second set of points. + c : np.ndarray + The third set of points. + d : np.ndarray + The fourth set of points. + + Returns + ------- + np.ndarray + The dihedral angles between the four sets of points. + """ + ba = a - b + bc = c - b + cd = d - c + + normal1 = np.cross(ba, bc) + normal2 = np.cross(bc, cd) + + m1 = np.cross(normal1, bc) + m2 = np.cross(normal2, bc) + + x = np.sum(m1 * m2, axis=1) + y = np.sum(normal1 * normal2, axis=1) + + return np.arctan2(y, x) + +calculate_dihedrals = _switch_backend_function_wrapper( + _calculate_dihedrals_torch, + _calculate_dihedrals_numpy, +) + diff --git a/openff/nagl/utils/resonance.py b/openff/nagl/utils/resonance.py index d396dff1..ca2e4884 100644 --- a/openff/nagl/utils/resonance.py +++ b/openff/nagl/utils/resonance.py @@ -218,6 +218,8 @@ def enumerate_resonance_forms( _molecule_from_dict(resonance_form) for resonance_form in resonance_forms ] + for mol in molecules: + mol._conformers = copy.deepcopy(self.molecule.conformers) else: if not as_dicts: