Source code for mchem.forcefield.base

"""Force field definition from XML: atom types, residues, and force generators."""

from typing import List, Dict, Any
from dataclasses import dataclass
from itertools import product, combinations
import os
import xml
import xml.etree.ElementTree as ET
import warnings

from ..system import System
from ..topology import Topology
from ..template import TEMPLATES
from ..terms import TermList, Particle


[docs] def str2float(string): """Convert string to float.""" return float(string)
[docs] def str2int(string): """Convert string to int.""" return int(string)
[docs] def float2str(number): """Format number for XML (float or scientific).""" if number == 0.0 or abs(number) > 1e-5: return f"{number:.10f}" else: return f"{number:.10e}"
[docs] def str2bool(string): """Convert string to bool (True/False, case-insensitive).""" if string == "True" or string == "true": return True elif string == "False" or string == 'false': return False else: raise ValueError(f"Could not convert string to bool: {string}")
[docs] def xmlele2str(xmlele: ET.Element): """Serialize XML element to string (unicode, stripped).""" uglystr = ET.tostring(xmlele, "unicode") return uglystr.strip()
[docs] @dataclass class AtomType: """Force-field atom type with name and class.""" name: str atomClass: str
# Parsers for force XML elements Parsers = {}
[docs] class ForceField: """ Load and apply a force field from one or more XML files. Parses AtomTypes, Residues, and force elements (e.g. AmoebaBondForce); uses :data:`Parsers` to create generators that produce terms for a :class:`Topology`. """
[docs] def __init__(self, *files): """ Load force field from one or more XML files. Parameters ---------- *files : str or os.PathLike XML paths; resolved via :meth:`processFileNames` if not found. """ self.files = self.processFileNames(files) self.trees = [ET.parse(f) for f in self.files] self.atomTypes: Dict[str, AtomType] = {} self.atomClasses: Dict[str, List[AtomType]] = {} self._generators: List[Generator] = [] self._forces: List[str] = [] self.loadAtomTypes() self.loadAtomTypeDefs() self.loadForces()
[docs] def processFileNames(self, files): """Resolve file paths; search next to this module if not found. Returns list of paths.""" dirname = os.path.dirname(__file__) files = list(files) if isinstance(files, tuple) else [files] for i in range(len(files)): if not os.path.exists(files[i]): trial = os.path.join(dirname, files[i]) if not os.path.isfile(trial): raise FileNotFoundError() else: files[i] = trial return files
@property def generators(self): """List of force generators (one per force type from XML).""" return self._generators
[docs] def addGenerator(self, generator): """Append a force generator.""" self._generators.append(generator)
[docs] def getGeneratorWithClass(self, generatorClass): """Return the first generator that is an instance of `generatorClass`, or None.""" idx = None for i in range(len(self.generators)): if isinstance(self.generators[i], generatorClass): idx = i break if idx is None: return None else: return self.generators[i]
[docs] def addGeneratorWithClass(self, generatorClass): """Get or create a generator of the given class and return it.""" generator = self.getGeneratorWithClass(generatorClass) if generator is None: generator = generatorClass(self) self.addGenerator(generator) return generator
[docs] def addAtomType(self, atomTypeElement: ET.Element): """Register an atom type from an XML ``Type`` element (name, class).""" name = atomTypeElement.get("name") aclass = atomTypeElement.get("class", "") atype = AtomType(name, aclass) if name in self.atomTypes: raise Exception(f"Duplicated atom type: {name}") self.atomTypes[name] = atype cls2type = self.atomClasses.get(aclass, []) cls2type.append(atype) self.atomClasses[aclass] = cls2type
[docs] def loadAtomTypes(self): """Load all AtomTypes from parsed XML trees.""" for tree in self.trees: atomTypes = tree.getroot().find("AtomTypes") for atomTypeElement in atomTypes.findall("Type"): self.addAtomType(atomTypeElement)
[docs] def loadAtomTypeDefs(self): """Map residue atom names to atom types from Residues section (templates must exist).""" for tree in self.trees: residues = tree.getroot().find("Residues") for res in residues.findall("Residue"): ## TODO: GET ALL RESIDUES COMPLETED if res.get("name") not in TEMPLATES: continue template = TEMPLATES[res.get("name")] for atom in res.findall("Atom"): name = atom.get("name") atype = atom.get('type') template.setAtomType(name, atype)
[docs] def loadForces(self): """Parse force elements using :data:`Parsers` and register generators.""" for tree in self.trees: for child in tree.getroot(): if child.tag in Parsers: self._forces.append(child.tag) if isinstance(Parsers[child.tag], list): for parser in Parsers[child.tag]: parser.parseElement(child, self) else: Parsers[child.tag].parseElement(child, self) elif child.tag in ["Info", "Residues", "AtomTypes"]: pass else: pass
# raise ValueError(f"{child.tag} is not supported")
[docs] def assignAtomTypes(self, topology: Topology): """Set atom type and class on each atom from residue templates.""" for res in topology.residues: if res.stdName not in TEMPLATES: raise KeyError(f"ResidueTemplate {res.stdName} not defined") template = TEMPLATES[res.stdName] for atom in res.atoms: atype = template.getAtomType(atom.name) atom.setAtomType(atype) aclass = self.atomTypes[atype].atomClass atom.setAtomClass(aclass)
[docs] def findAtomTypes(self, element: ET.Element, numAtoms: int): """Resolve type/class attributes to list of atom type name tuples (supports wildcards).""" useType = any(key.startswith("type") for key in element.attrib.keys()) useClass = any(key.startswith("class") for key in element.attrib.keys()) if useType and useClass: raise ValueError(f"Specified both a type and a class for the same atom: {element.attrib}") elif (not useType) and (not useClass): raise ValueError(f"Either a type or class has to be specified for: {element.attrib}") atypes = [] for i in range(numAtoms): suffix = "" if numAtoms == 1 else str(i+1) if useType: atype = element.get(f"type{suffix}") if atype == "": # handle wild card atypes.append(["*"]) else: atypes.append([self.atomTypes[atype].name]) else: aclass = element.get(f"class{suffix}") if aclass == "": # handle wild card atypes.append(["*"]) else: atypes.append([atype.name for atype in self.atomClasses[aclass]]) atypes = list(product(*atypes)) return atypes
[docs] def createSystem(self, topology: Topology, **kwargs): """Build a :class:`System` with particles and all generator terms for the given topology.""" self.assignAtomTypes(topology) system = System() system.addMeta("name", topology.name) particles = TermList(Particle) for atom in topology.atoms(): particles.append(Particle( atom.idx, atom.name, atom.symbol, atom.mass, f"{atom.residue.number}{atom.residue.insertionCode}", atom.residue.name, atom.xx, atom.xy, atom.xz )) system.addTerms(particles) for generator in self.generators: for key, value in generator._meta.items(): system.addMeta(key, value) terms = generator.createTerms(topology, **kwargs) if isinstance(terms, tuple): for t in terms: system.addTerms(t) else: system.addTerms(terms) return system
[docs] def getParameters(self, asJaxNumpy: bool = True): params = {} for generator in self.generators: name = generator.__class__.__name__ params[name] = generator.getParameters(asJaxNumpy=asJaxNumpy) self._parameters = params return self._parameters
[docs] def updateParameters(self, param: Dict[str, Any]): """Update each generator's parameters from a dict keyed by generator class name.""" for generator in self.generators: name = generator.__class__.__name__ generator.updateParameters(param[name])
[docs] def exportAtomTypes(self): """Return XML string for AtomTypes section.""" strs = [] for tree in self.trees: atomTypes = tree.getroot().find("AtomTypes") for atomTypeElement in atomTypes.findall("Type"): strs.append(f'\t\t{xmlele2str(atomTypeElement)}') atypestr = '\t<AtomTypes>\n{}\n\t</AtomTypes>'.format('\n'.join(strs)) return atypestr
[docs] def exportAtomTypeDefs(self): """Return XML string for Residues section (atom type defs).""" strs = [] for tree in self.trees: residues = tree.getroot().find("Residues") for res in residues.findall("Residue"): restr = '\t\t<Residue name="{}">\n{}\n\t\t</Residue>'.format( res.get("name"), '\n'.join(f'\t\t\t{xmlele2str(atomEle)}' for atomEle in res.findall("Atom")) ) strs.append(restr) return '\t<Residues>\n{}\n\t</Residues>'.format('\n'.join(strs))
[docs] def save(self, path: os.PathLike): """Write force field XML to file (atom types, residue defs, all force sections).""" forcestrs = [] for force in self._forces: if isinstance(Parsers[force], list): forcestr = "\n".join([ self.getGeneratorWithClass(gencls).exportParameterToStr() for gencls in Parsers[force] ]) else: forcestr = self.getGeneratorWithClass(Parsers[force]).exportParameterToStr() forcestrs.append(f"\t<{force}>\n{forcestr}\n\t</{force}>") ffstr = "<ForceField>\n{}\n{}\n{}\n</ForceField>".format( self.exportAtomTypes(), self.exportAtomTypeDefs(), '\n'.join(forcestrs) ) with open(path, 'w') as f: f.write(ffstr)
[docs] class Generator: """ Base class for a force generator: holds parameters keyed by atom type/SMIRKS and creates terms for a topology. """
[docs] def __init__(self, ff: ForceField, paramFields: List[str] = [], raiseError: bool = True): self.ff = ff # parameters self._with_atom_types = {} self._with_smirks = {} self._parameters: Dict[str, List[Any]] = {} self._meta = {} self._raiseError = raiseError for fd in paramFields: self.addParameterField(fd)
[docs] def getMetadata(self, key: str): return self._meta[key]
[docs] def setMetadata(self, key: str, value: Any): self._meta[key] = value
@property def paramFields(self) -> List[str]: return sorted(list(self._parameters.keys())) @property def numParameters(self) -> int: key = list(self._parameters.keys())[0] return len(self._parameters[key])
[docs] def addParameterField(self, name: str): if name not in self._parameters: self._parameters[name] = []
[docs] def checkParameter(self, paramDict: Dict[str, Any]): keys = sorted(list(paramDict.keys())) assert keys == self.paramFields, f"Parameter fields does not match {keys} != {self.paramFields}"
[docs] def addParameterWithAtomTypes(self, typeOrtypes, paramDict: Dict[str, Any]): types = typeOrtypes if isinstance(typeOrtypes, list) else [typeOrtypes] for typ in types: self._with_atom_types[typ] = self.numParameters self.setParameterWithIdx(paramDict, self.numParameters)
[docs] def addParameterWithSmirks(self, smirks: str, paramDict: Dict[str, Any]): self._with_smirks[smirks] = self.numParameters self.setParameterWithIdx(paramDict, self.numParameters)
[docs] def setParameterWithIdx(self, paramDict: Dict[str, Any], paramIdx: int): self.checkParameter(paramDict) if paramIdx >= self.numParameters: for k, v in paramDict.items(): self._parameters[k].append(v) else: for k, v in paramDict.items(): self._parameters[k][paramIdx] = v
[docs] def getParameterIdxWithAtomType(self, typeQuery): paramIdx = self._with_atom_types.get(typeQuery, None) if paramIdx is None: # try wildcard match for numWildCard in range(1, 1+len(typeQuery)): for wildCardPos in combinations(range(len(typeQuery)), numWildCard): typeQueryWithWildCard = list(typeQuery) for p in wildCardPos: typeQueryWithWildCard[p] = "*" paramIdx = self._with_atom_types.get(tuple(typeQueryWithWildCard), None) if paramIdx is not None: break if paramIdx is not None: break if (paramIdx is None) and self._raiseError: raise KeyError(typeQuery) return paramIdx
[docs] def getParameterWithAtomType(self, typeQuery): paramIdx = self.getParameterIdxWithAtomType(typeQuery) if paramIdx is None: return None else: return self.getParameterWithIdx(paramIdx)
[docs] def getParameterIdxWithSmirks(self, smirksQuery: str): if self._raiseError: paramIdx = self._with_smirks[smirksQuery] else: paramIdx = self._with_smirks.get(smirksQuery, None) return paramIdx
[docs] def getParameterWithSmirks(self, smirksQuery: str): paramIdx = self.getParameterIdxWithSmirks(smirksQuery) if paramIdx is None: return None else: return self.getParameterWithIdx(paramIdx)
[docs] def getParameterWithIdx(self, idx: int): return {k: self._parameters[k][idx] for k in self._parameters.keys()}
[docs] def getParameters(self, asJaxNumpy: bool = False): if asJaxNumpy: import jax.numpy as jnp jaxParam = {} for key, value in self._parameters.items(): try: jaxParam[key] = jnp.array(value) except TypeError as e: continue return jaxParam else: return self._parameters
[docs] def updateParameters(self, param: Dict[str, Any]): self._parameters.update(param)
[docs] def createTerms(self, topology: Topology, **kwargs): """Build term list(s) for the given topology. Override in subclasses. Returns TermList or tuple of TermLists.""" raise NotImplementedError()
[docs] def raise_exception(self, msg: str, raiseError: bool = True): if raiseError: raise Exception(msg) else: warnings.warn(msg)
[docs] def exportParameterToStr(self): return ""
# raise NotImplementedError()