import regex as re from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers, processors SPECIAL = [ "[PAD]", "[BOS]", "[EOS]", "[MASK]", "[UNK]", ] # fmt: off ELEMENTS = [ 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr' ] ELEMENTS += ["te", "si"] # These are not 'correct', but RDKit is allowing them, and they show up in PubChem. # fmt: on AROMATIC = ["b", "c", "n", "o", "p", "s", "se", "as"] ORGANIC = ["B", "C", "N", "O", "P", "S", "F", "I", "Cl", "Br"] + AROMATIC BONDS = ["-", "=", "#", "$", ":", "/", "\\", "."] CHARGE = ["+", "-"] CHIRAL = ["@", "@@", "@TH", "@AL", "@SP", "@TB", "@OH"] BRANCH = ["(", ")", "*"] RINGS = [str(i) for i in range(10)] + ["%"] TOKENS = ( SPECIAL + ELEMENTS + AROMATIC # will be deduped + ORGANIC # will be deduped + BONDS + CHARGE + CHIRAL + BRANCH + RINGS + ["[", "]"] ) # Make tokens unique while preserving order TOKENS = list(dict.fromkeys(TOKENS)) VOCAB = {tok: i for i, tok in enumerate(TOKENS)} AROMATIC_SINGLE = {"b", "c", "n", "o", "p", "s"} AROMATIC_MULTI = {"se", "as"} # two-letter aromatic tokens AROMATIC_ALL = AROMATIC_SINGLE | AROMATIC_MULTI def is_ambiguous(elem: str) -> bool: for i in range(1, len(elem)): head, tail = elem[:i], elem[i:] if head in TOKENS and tail in TOKENS: return True return False UNAMBIGUOUS_ELEMENTS = [e for e in ELEMENTS if not is_ambiguous(e)] OUTER_TOKENS = ORGANIC + BONDS + CHARGE + CHIRAL + BRANCH + RINGS + UNAMBIGUOUS_ELEMENTS OUTER_REGEX = Regex( "|".join( sorted( [re.escape(tok) for tok in OUTER_TOKENS] + [r"\[[^\[\]]+\]"], key=len, reverse=True, ) ) ) INNER_TOKENS = ELEMENTS + AROMATIC + BONDS + CHARGE + CHIRAL + RINGS + ["%"] + ["[", "]"] INNER_REGEX = Regex("|".join(sorted(map(re.escape, INNER_TOKENS), key=len, reverse=True))) tokenizer = Tokenizer(models.WordLevel(VOCAB, unk_token="[UNK]")) tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]") tokenizer.add_special_tokens(SPECIAL) tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.Split(pattern=OUTER_REGEX, behavior="isolated"), pre_tokenizers.Split(pattern=INNER_REGEX, behavior="isolated"), ] ) # type: ignore tokenizer.post_processor = processors.TemplateProcessing( single="[BOS] $A [EOS]", special_tokens=[("[BOS]", VOCAB["[BOS]"]), ("[EOS]", VOCAB["[EOS]"])] ) # type: ignore tokenizer.decoder = decoders.WordPiece(prefix="") # type: ignore