Source code for manim_chemistry.twoD.graph_molecule

from typing import Dict, Tuple

from manim import *
import numpy as np
import networkx as nx

from ..manim_chemistry_molecule import MCMolecule
from ..molecule import AbstractMolecule


class SimpleLine(Line):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sheen_direction = self._get_unit_vector()

    def _get_unit_vector(self) -> np.array:
        vector = self.end - self.start

        return vector / np.linalg.norm(vector)


class DoubleLine(ArcBetweenPoints):
    def __init__(
        self,
        start=[-1, 0, 0],
        end=[1, 0, 0],
        angle: float = PI / 4,
        radius: float | None = None,
        **kwargs,
    ):
        self.start = start
        self.end = end
        super().__init__(start=start, end=end, angle=angle, radius=radius, **kwargs)
        self.sheen_direction = self._get_unit_vector()
        other_arc = ArcBetweenPoints(
            start=start, end=end, angle=-angle, radius=radius, **kwargs
        )
        other_arc.sheen_direction = self.sheen_direction
        self.add(other_arc)

    def _get_unit_vector(self) -> np.array:
        vector = self.end - self.start

        return vector / np.linalg.norm(vector)

    def get_vector(self):
        return self._get_unit_vector()

    def set_points_by_ends(self, start, end, *args, **kwargs) -> None:
        self.put_start_and_end_on(start=start, end=end)


class TripleLine(DoubleLine):
    def __init__(
        self, start=[-1, 0, 0], end=[1, 0, 0], angle: float = PI / 4, *args, **kwargs
    ):
        super().__init__(start=start, end=end, *args, **kwargs)
        middle_line = Line(start=start, end=end, *args, **kwargs)
        middle_line.sheen_direction = self.sheen_direction
        self.add(middle_line)


[docs] class GraphMolecule(Graph, AbstractMolecule): """Represents a molecule like a Graph from Manim. Examples --------- .. manim:: GraphMoleculeFromFile from manim_chemistry import * class GraphMoleculeFromFile(Scene): def construct(self): molecule = GraphMolecule.molecule_from_file( "../examples/molecule_files/mol_files/acetone_2d.mol" ) self.wait() self.play(Write(molecule)) self.wait() .. manim:: GraphMoleculeFromFileWithHydrogens from manim_chemistry import * class GraphMoleculeFromFileWithHydrogens(Scene): def construct(self): molecule = GraphMolecule.molecule_from_file( "../examples/molecule_files/mol_files/acetone_2d.mol", ignore_hydrogens=False ) self.wait() self.play(Write(molecule)) self.wait() .. manim:: GraphMoleculeFromPubChem from manim_chemistry import * class GraphMoleculeFromPubChem(Scene): def construct(self): molecule = GraphMolecule.molecule_from_pubchem(name="acetone") self.wait() self.play(Write(molecule)) self.wait() .. manim:: GraphMoleculeFromPubChemThreeD from manim_chemistry import * class GraphMoleculeFromPubChemThreeD(Scene): def construct(self): molecule = GraphMolecule.molecule_from_pubchem( name="acetone", three_d=True, ignore_hydrogens=False ) self.wait() self.play(Write(molecule)) self.wait() """ SUPPORTED_BOND_TYPES = { 1: SimpleLine, 2: DoubleLine, 3: TripleLine, } def __init__( self, vertices_dict: dict, edges_dict: dict, label: bool = False, numeric_label: bool = False, label_color: str = BLACK, *args, **kwargs, ): self.edges_dict = edges_dict labels = False if label or numeric_label: labels = self.make_labels(vertices_dict, numeric_label, label_color) super().__init__( vertices=vertices_dict.keys(), edges=edges_dict.keys(), vertex_config=self.make_vertex_config(vertices_dict), edge_config=self.make_edge_config(edges_dict), layout=self.make_layout(vertices_dict), labels=labels, *args, **kwargs, ) self.move_to(ORIGIN) self.atoms = self.vertices self.bonds = self.edges def _populate_edge_dict(self, edges, _): self.edges = {} for u, v in edges: bond_type = self.select_bond_from_edge((u, v)) bond = bond_type( start=self[u].get_center(), end=self[v].get_center(), z_index=-1, **self._edge_config[(u, v)], ) self.edges[(u, v)] = bond def select_bond_from_edge(self, edge): bond_type = self.edges_dict[edge].bond_type return self.select_bond_type(bond_type) def select_bond_type(self, bond_type: int): bond = self.SUPPORTED_BOND_TYPES.get(int(bond_type)) if not bond: raise Exception( f"{bond_type} is an unknown type of bond. Options are 1, 2 or 3" ) return bond def make_layout(self, vertices_dict: dict): return {index: vertex.coords for index, vertex in vertices_dict.items()} def make_vertex_config(self, vertices_dict: dict): v_dict = {} for vertex_index, mc_atom in vertices_dict.items(): v_dict[vertex_index] = { "color": mc_atom.element.color, "radius": 0.2, } return v_dict def make_edge_config(self, edges: dict): edge_config = {} for edge_key, edge in edges.items(): edge_config[edge_key] = { "stroke_color": color_gradient( [ edge.from_atom.element.color, edge.to_atom.element.color, ], length_of_output=2, ) } return edge_config def make_labels(self, vertices_dict: dict, numeric_label: bool, label_color: str): if numeric_label: return { index: Text(str(index), color=label_color).scale(0.5) for index in vertices_dict.keys() } return { index: Text(vertex.element.symbol, color=BLACK).scale(0.5) for index, vertex in vertices_dict.items() }
[docs] def get_connected_atoms(self, from_atom_index: int, to_atom_index: int) -> list: """ Given two atoms index, returns the connected atoms to the second atom after the bond between the first atom and the second one. Example: Index: 0 1 2 3 4 5 Atoms: C-O-C-N-C-H This function applied to atoms 2 (C) and 3 (N), would return [3, 4, 5], the indices of atoms N, C and H. """ if not self._graph.has_edge(from_atom_index, to_atom_index): raise Exception( f"There is no bond between atoms {from_atom_index} and {to_atom_index}" ) visited_atoms = set([from_atom_index]) connected_atoms = [to_atom_index] return self.depth_first_search( self._graph, atom=to_atom_index, visited_atoms=visited_atoms, connected_atoms=connected_atoms, )
def get_atoms_vgroup_from_index(self, atoms_indices): return VGroup(*[self.vertices[atom_index] for atom_index in atoms_indices]) def get_connected_atoms_v_group( self, from_atom_index: int, to_atom_index: int ) -> list: connected_atoms = self.get_connected_atoms( from_atom_index=from_atom_index, to_atom_index=to_atom_index ) return self.get_atoms_vgroup_from_index(connected_atoms)
[docs] def get_bonds_from_atoms_indices( self, connected_atoms: list, excluded_atom: int = 0 ) -> list: """ Given a list of atoms, returns the bonds related to those atoms. """ edges = [] for atom in connected_atoms: for bond in self.edges.keys(): if atom in bond and excluded_atom not in bond: edges.append(bond) return edges
def get_bonds_vgroup_from_index(self, connected_atoms: list, excluded_atom: int): bonds = self.get_bonds_from_atoms_indices(connected_atoms, excluded_atom) return VGroup(*[self.edges[bond_atoms] for bond_atoms in bonds]) def get_connected_atoms_and_bonds_group_from_index( self, connected_atoms: list, excluded_atom: int = 0 ): return VGroup( self.get_atoms_vgroup_from_index(atoms_indices=connected_atoms), self.get_bonds_vgroup_from_index( connected_atoms=connected_atoms, excluded_atom=excluded_atom ), ) def get_connected_atoms_and_bonds(self, from_atom: int, to_atom: int) -> VGroup: connected_atoms = self.get_connected_atoms( from_atom_index=from_atom, to_atom_index=to_atom ) return self.get_connected_atoms_and_bonds_group_from_index( connected_atoms=connected_atoms, excluded_atom=from_atom )
[docs] def find_atom_position_by_index(self, atom_index: int) -> np.array: """ Returns the position of a single atom given its index. Example: ``` molecule = GraphMolecule.molecule_from_file("examples/molecule_files/mol_files/dimethylpropane.mol") print(molecule.find_atom_position_by_index(1)) >>> array([ 0.9397, -0.7497, 0. ]) ``` Args: atom_index (int): Index of the atom inside the VDict. Returns: np.array: Array with the [x, y, z] coordinates of the atom. """ try: atom = self.atoms[atom_index] return atom.get_center() except KeyError as key_error: # TODO: Change from print to proper logging system. print(f"Atom index {atom_index} is not valid for molecule {self}") print(f"Valid indices are: {self.atoms.submob_dict.keys()}") raise key_error except Exception as exception: raise exception
[docs] def find_atoms_position_by_index(self, atoms_index_list: list) -> list: """Returns the position of multiple atoms given their indices. Example: ``` molecule = GraphMolecule.molecule_from_file("examples/molecule_files/mol_files/dimethylpropane.mol") print(molecule.find_atoms_position_by_index([1,2,3])) >>> [array([ 0.0713, -0.0263, 0. ]), array([-1.2754, 0.3464, 0. ]), array([0.9674, 1.2186, 0. ])] ``` Args: atoms_index_list (list): List of atoms indices to be gotten. Returns: list: List of the atoms positions. """ atoms_positions = [] for atom_index in atoms_index_list: atoms_positions.append( self.find_atom_position_by_index(atom_index=atom_index) ) return atoms_positions
[docs] def find_bond_center_by_index(self, bond_index: tuple) -> np.array: """Returns the [x, y, z] coordinates of a bond given a bond tuple. The bund tuple corresponds to the indices of the atoms in the bond. Example: ``` molecule = GraphMolecule.molecule_from_file("examples/molecule_files/mol_files/dimethylpropane.mol") print(molecule.find_bond_center_by_index((1, 2)) >>> array([0.51935, 0.59615, 0. ]) ``` Args: bond_index (tuple): index of the bond Returns: np.array: [x, y, z] coordinates of bond center. """ try: bond = self.bonds[bond_index] return bond.get_center() except KeyError as key_error: # TODO: Change from print to proper logging system. print(f"Bond tuple {bond_index} is not valid for molecule {self}") raise key_error except Exception as exception: raise exception
[docs] def find_position_along_bond_axis( self, bond_tuple: int, position_buff: float ) -> np.array: """Returns a position along the bond axis given a bond index and depending on a position_buff. A value of 1 will return one end of the bond. A value of -1 will return the other end. All values in between return positions at some point of the middle of the bond, being 0 the center. Values bigger or lower that 1 and -1 will return positions outside the bond. Args: bond_tuple (int): Tuple of the bond position_buff (float): Position buff Returns: np.array: [x, y, z] coordinates of the final position selected. """ try: bond = self.bonds[bond_tuple] except KeyError as key_error: # TODO: Change from print to proper logging system. print(f"Bond index {bond_tuple} is not valid for molecule {self}") print(f"Valid indices are: {self.bonds.submob_dict.keys()}") raise key_error bond_vector = bond.get_vector() bond_center = bond.get_center() return bond_center + bond_vector * position_buff * 0.5
[docs] def find_bonds_center_by_index(self, bonds_tuples_list: list) -> list: """Returns the position of multiple bonds given their indices. Example: ``` molecule = GraphMolecule.molecule_from_file("examples/molecule_files/mol_files/dimethylpropane.mol") print(molecule.find_bonds_center_by_tuple([1,2,3])) >>> [array([ 0.0713, -0.0263, 0. ]), array([-1.2754, 0.3464, 0. ]), array([0.9674, 1.2186, 0. ])] ``` Args: bondss_index_list (list): List of bonds indices to be gotten. Returns: list: List of the bonds positions. """ bonds_positions = [] for bond_index in bonds_tuples_list: bonds_positions.append( self.find_bond_center_by_index(bond_index=bond_index) ) return bonds_positions
def find_all_atoms_positions(self) -> dict: atoms_positions = {} for atom_index in self.atoms.keys(): atoms_positions[atom_index] = self.find_atom_position_by_index( atom_index=atom_index ) return atoms_positions def find_all_bonds_centers(self) -> dict: bonds_positions = {} for bond_tuple in self.bonds: bonds_positions[bond_tuple] = self.find_bond_center_by_index( bond_index=bond_tuple ) return bonds_positions
[docs] @classmethod def mc_molecule_to_atoms_and_bonds( cls, mc_molecule: MCMolecule ) -> Tuple[Dict, Dict]: """ Transforms the structure of a mc_molecule to a (vertices, edges) tuple with the following structure: - Vertices: {<atom_index>: MCAtom} - Edges: {(<from_atom_index>, <to_atom_index>): MCBond} Args: mc_molecule (MCMolecule): The origin MCMolecule Returns: Tuple[Dict, Dict]: See above. """ vertices = mc_molecule.atoms_by_index edges = { (bond.from_atom.molecule_index, bond.to_atom.molecule_index): bond for bond in mc_molecule.bonds } return vertices, edges