Source code for bluesearch.mining.relation

"""Classes and functions for relation extraction."""

# Blue Brain Search is a text mining toolbox focused on scientific use cases.
#
# Copyright (C) 2020  Blue Brain Project, EPFL.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from abc import ABC, abstractmethod
from collections import defaultdict

import pandas as pd


[docs]class REModel(ABC): """Abstract interface for relationship extraction models. Inspired by SciBERT. """ @property @abstractmethod def classes(self): """Names of supported relation classes. Returns ------- list of str Names of supported relation classes. """
[docs] @abstractmethod def predict_probs(self, annotated_sentence): """Relation probabilities between subject and object. Predict per-class probabilities for the relation between subject and object in an annotated sentence. Parameters ---------- annotated_sentence : str Sentence with exactly 2 entities being annotated accordingly. For example "<< Cytarabine >> inhibits [[ DNA polymerase ]]." Returns ------- relation_probs : pd.Series Per-class probability vector. The index contains the class names, the values are the probabilities. """
[docs] def predict(self, annotated_sentence, return_prob=False): """Predict most likely relation between subject and object. Parameters ---------- annotated_sentence : str Sentence with exactly 2 entities being annotated accordingly. For example "<< Cytarabine >> inhibits [[ DNA polymerase ]]." return_prob : bool, optional If True also returns the confidence of the predicted relation. Returns ------- relation : str Relation type. prob : float, optional Confidence of the predicted relation. """ probs = self.predict_probs(annotated_sentence) relation = probs.idxmax() if return_prob: prob = probs.max() return relation, prob else: return relation
@property @abstractmethod def symbols(self): """Generate dictionary mapping the two entity types to their annotation symbols. General structure: {'ENTITY_TYPE': ('SYMBOL_LEFT', 'SYMBOL_RIGHT')} Specific example: {'GGP': ('[[ ', ' ]]'), 'CHEBI': ('<< ', ' >>')} Make sure that left and right symbols are not identical. """
[docs]def annotate(doc, sent, ent_1, ent_2, etype_symbols): """Annotate sentence given two entities. Parameters ---------- doc : spacy.tokens.Doc The entire document (input text). Note that spacy uses it for absolute referencing. sent : spacy.tokens.Span One sentence from the `doc` where we look for relations. ent_1 : spacy.tokens.Span The first entity in the sentence. One can get its type by using the `label_` attribute. ent_2 : spacy.tokens.Span The second entity in the sentence. One can get its type by using the `label_` attribute. etype_symbols : dict or defaultdict Keys represent different entity types ("GGP", "CHEBI") and the values are tuples of size 2. Each of these tuples represents the starting and ending symbol to wrap the recognized entity with. Each ``REModel`` has the `symbols` property that encodes how its inputs should be annotated. Returns ------- result : str String representing an annotated sentence created out of the original one. Notes ----- The implementation is non-trivial because an entity can span multiple words. """ # checks if ent_1 == ent_2: raise ValueError("One needs to provide two separate entities.") if not ( sent.start <= ent_1.start <= ent_1.end <= sent.end and sent.start <= ent_2.start <= ent_2.end <= sent.end ): raise ValueError("The provided entities are outside of the given sentence.") etype_1 = ent_1.label_ etype_2 = ent_2.label_ if not isinstance(etype_symbols, defaultdict) and not ( etype_1 in etype_symbols and etype_2 in etype_symbols ): raise ValueError( "Please specify the special symbols for both of the entity types." ) tokens = [] i = sent.start while i < sent.end: new_tkn = " " # hack to keep the punctuation nice if ent_1.start == i: start, end = ent_1.start, ent_1.end new_tkn += ( etype_symbols[etype_1][0] + doc[start:end].text + etype_symbols[etype_1][1] ) elif ent_2.start == i: start, end = ent_2.start, ent_2.end new_tkn += ( etype_symbols[etype_2][0] + doc[start:end].text + etype_symbols[etype_2][1] ) else: start, end = i, i + 1 new_tkn = doc[i].text if doc[i].is_punct else new_tkn + doc[i].text tokens.append(new_tkn) i += end - start return "".join(tokens).strip()
[docs]class ChemProt(REModel): """Pretrained model extracting 13 relations between chemicals and proteins. This model supports the following entity types: - "GGP" - "CHEBI" Attributes ---------- model_ : allennlp.predictors.text_classifier.TextClassifierPredictor The actual model in the backend. Notes ----- This model depends on a package named `scibert` which is not specified in the `setup.py` since it introduces dependency conflicts. One can install it manually with the following command. .. code-block:: bash pip install git+https://github.com/allenai/scibert Note that `import scibert` has a side effect of registering the "text_classifier" model with `allennlp`. This is done via applying a decorator to a class. For more details see https://github.com/allenai/scibert/blob/06793f77d7278898159ed50da30d173cdc8fdea9/scibert/models/text_classifier.py#L14 """ def __init__(self, model_path): # Note: SciBERT is imported but unused. This is because the import has # a side-effect of registering the SciBERT model, which we use later on. import scibert # NOQA from allennlp.predictors import Predictor self.model_ = Predictor.from_path(model_path, predictor_name="text_classifier") @property def classes(self): """Names of supported relation classes.""" return [ "INHIBITOR", "SUBSTRATE", "INDIRECT-DOWNREGULATOR", "INDIRECT-UPREGULATOR", "ACTIVATOR", "ANTAGONIST", "PRODUCT-OF", "AGONIST", "DOWNREGULATOR", "UPREGULATOR", "AGONIST-ACTIVATOR", "SUBSTRATE_PRODUCT-OF", "AGONIST-INHIBITOR", ] @property def symbols(self): """Symbols for annotation.""" return {"GGP": ("[[ ", " ]]"), "CHEBI": ("<< ", " >>")}
[docs] def predict_probs(self, annotated_sentence): """Predict probabilities for the relation.""" return pd.Series( self.model_.predict(sentence=annotated_sentence)["class_probs"], index=self.classes, )
[docs]class StartWithTheSameLetter(REModel): """Check whether two entities start with the same letter (case insensitive). This relation is symmetric and works on any entity type. """ @property def classes(self): """Names of supported relation classes.""" return ["START_WITH_SAME_LETTER", "START_WITH_DIFFERENT_LETTER"]
[docs] def predict_probs(self, annotated_sentence): """Predict probabilities for the relation.""" left_symbol, _ = self.symbols["anything"] s_len = len(left_symbol) ent_1_first_letter_ix = annotated_sentence.find(left_symbol) + s_len ent_2_first_letter_ix = ( annotated_sentence.find(left_symbol, ent_1_first_letter_ix + 1) + s_len ) ent_1_first_letter = annotated_sentence[ent_1_first_letter_ix] ent_2_first_letter = annotated_sentence[ent_2_first_letter_ix] if ent_1_first_letter.lower() == ent_2_first_letter.lower(): return pd.Series([1, 0], index=self.classes) else: return pd.Series([0, 1], index=self.classes)
@property def symbols(self): """Symbols for annotation.""" return defaultdict(lambda: ("[[ ", " ]]"))