Integrating a Custom Model
Implementing the model wrapper¶
To integrate a custom model we need to wrap it into the shared model interface.
from syntheseus import BackwardReactionModel
As a demonstration we'll integrate a dummy model which only accepts molecules that are chains of carbon atoms CC...C
and predicts "reactions" that split that chain into two parts. The only method we need to implement is _get_reactions
; we split it into a few helper methods below for readability.
from typing import Sequence
from syntheseus import Bag, Molecule, SingleProductReaction
class ToyModel(BackwardReactionModel):
def _get_reactions(
self, inputs: list[Molecule], num_results: int
) -> list[Sequence[SingleProductReaction]]:
return [
self._get_reactions_single(mol)[:num_results]
for mol in inputs
]
def _get_reaction_score(self, i: int, n_atoms: int) -> float:
# Give higher score to reactions which break the input into
# equal-sized pieces.
return float(min(i, n_atoms - i))
def _get_reactions_single(
self, mol: Molecule
) -> Sequence[SingleProductReaction]:
n = len(mol.smiles)
if mol.smiles != n * "C":
return []
scores = [self._get_reaction_score(i, n) for i in range(1, n)]
score_total = sum(scores)
probs = [score / score_total for score in scores]
reactions = []
for i, prob in zip(range(1, n), probs):
reactant_1 = Molecule(i * "C")
reactant_2 = Molecule((n - i) * "C")
reactions.append(
SingleProductReaction(
reactants=Bag([reactant_1, reactant_2]),
product=mol,
metadata={"probability": prob},
)
)
return sorted(
reactions,
key=lambda r: r.metadata["probability"],
reverse=True,
)
Let's make sure this works. Note that we implement _get_reactions
but call the models using __call__
; this allows syntheseus
to inject extra processing such as deduplication or caching.
model = ToyModel()
def print_predictions(model, smiles: str):
[reactions] = model([Molecule(smiles)])
for reaction in reactions:
probability = reaction.metadata["probability"]
print(f"{reaction} (probability: {probability:.3f})")
print_predictions(model, "CCCC")
CC.CC>>CCCC (probability: 0.500) C.CCC>>CCCC (probability: 0.250)
The model is working as expected. CCC.C>>CCCC
is not returned, as order of reactants in a Bag
doesn't matter, and thus it's the same as C.CCC>>CCCC
. However, note that currently syntheseus
only removes duplicated reactions but does not add the probabilities of all duplicates together (and search algorithms generally do not depend on all the probabilities summing up to 1). If you prefer to instead sum the probabilities of duplicate reactions, you can implement this behaviour yourself by overriding filter_reactions
(or even in _get_reactions
directly).
Running search¶
As in the "Quick Start" tutorial, we will now proceed to running multi-step search using our newly integrated model. This time we will use a proper search algorithm (Retro*) instead of BFS, so that it takes into account the single-step probabilities.
from syntheseus.search.analysis.route_extraction import (
iter_routes_time_order,
)
from syntheseus.search.mol_inventory import SmilesListInventory
from syntheseus.search.algorithms.best_first.retro_star import (
RetroStarSearch
)
from syntheseus.search.node_evaluation.common import (
ConstantNodeEvaluator,
ReactionModelLogProbCost,
)
def get_routes(model):
search_algorithm = RetroStarSearch(
reaction_model=model,
mol_inventory=SmilesListInventory(smiles_list=["C"]),
limit_iterations=100, # max number of algorithm iterations
limit_reaction_model_calls=100, # max number of model calls
time_limit_s=60.0, # max runtime in seconds
value_function=ConstantNodeEvaluator(0.0),
and_node_cost_fn=ReactionModelLogProbCost(),
)
output_graph, _ = search_algorithm.run_from_mol(
Molecule("CCCCCCCC")
)
routes = list(
iter_routes_time_order(output_graph, max_routes=100)
)
print(f"Found {len(routes)} routes")
return output_graph, routes
model = ToyModel(use_cache=True)
output_graph, routes = get_routes(model)
Found 22 routes
Let's see how many times the reaction model was actually called.
model.num_calls()
7
This makes sense: even though there are many more nodes in the search graph, the search only enountered 7 unique non-purchasable products (chains with lengths between 2 and 8); as we set use_cache=True
the model was called on each of these products exactly once. We can pass count_cache=True
to get the number of calls including those for which the answer was already cached.
model.num_calls(count_cache=True)
64
Let's take a look at the routes that were found. To make sure they were explored starting with higher probability steps, we plot the first and last route found.
from syntheseus.search.visualization import visualize_andor
for name, idx in [("first", 0), ("last", -1)]:
visualize_andor(
output_graph, filename=f"route_{name}.pdf", nodes=routes[idx]
)
The contents of the files route_{first, last}.pdf
should look like the below. Search only considers unique reactants for a given reaction step; even though our model always returns two reactants, if these are the same then search will create a reaction with only a single child node. Given that our probabilities were set up to prefer splitting the input into equal-sized chunks, the first route found halves the input SMILES in each reaction step, while the last route always splits out a single atom.
In the case above search had an easy job finding the best route as the higher probability steps also led to reaching building block molecules sooner. In general, algorithms will be implicitly biased towards not only higher probability steps but also taking less steps overall. However, we can modify our toy model to strongly prefer unbalanced splits, and verify that then the order of routes is roughly reversed.
class ToyModelUnbalanced(ToyModel):
def _get_reaction_score(self, i: int, n_atoms: int) -> float:
score = super()._get_reaction_score(i, n_atoms)
return (1.0 / score) ** 4.0
print_predictions(ToyModelUnbalanced(), "CCCCCCCC")
C.CCCCCCC>>CCCCCCCC (probability: 0.464) CC.CCCCCC>>CCCCCCCC (probability: 0.029) CCC.CCCCC>>CCCCCCCC (probability: 0.006) CCCC.CCCC>>CCCCCCCC (probability: 0.002)
output_graph, routes = get_routes(
ToyModelUnbalanced(use_cache=True)
)
visualize_andor(
output_graph,
filename=f"route_first_unbalanced.pdf",
nodes=routes[0],
)
Found 22 routes
Indeed, the first route found during this search is the "maximally unbalanced" one, which was the last route found previously.