Source code for bluesearch.entrypoint.embeddings

"""EntryPoint for the computation and saving of the embeddings."""

# Blue Brain Search is a text mining toolbox focused on scientific use cases.
#
# Copyright (C) 2020  Blue Brain Project, EPFL.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

import argparse
import logging
import pathlib
import sys

import numpy as np
import sqlalchemy

from bluesearch.entrypoint._helper import (
    CombinedHelpFormatter,
    configure_logging,
    parse_args_or_environment,
)


[docs]def run_compute_embeddings(argv=None): """Run CLI.""" # CLI setup parser = argparse.ArgumentParser( formatter_class=CombinedHelpFormatter, ) parser.add_argument( "model_name_or_class", type=str, help=""" The name or class of the model for which to compute the embeddings. Recognized model names are: 'BioBERT NLI+STS', 'SBioBERT', 'SBERT'. Recognized model classes are: 'SentTransformer', 'SklearnVectorizer'. See also 'get_embedding_model(...)'. """, ) parser.add_argument( "outfile", type=str, help="The path to where the embeddings are saved (h5 file)", ) parser.add_argument( "--batch-size-inference", default=256, type=int, help="Batch size for embeddings computation", ) parser.add_argument( "--batch-size-transfer", default=1000, type=int, help="Batch size for the concatenation of temp h5 files", ) parser.add_argument( "-c", "--checkpoint", type=str, help=""" If 'model_name_or_class' is the class, the path of the model to load. Otherwise, this argument is ignored. """, ) parser.add_argument( "--db-url", type=str, help=""" URL of the MySQL database. Generally, the scheme part of the URL should be omitted, i.e. the URL should be of the form 'my_sql_server.ch:1234/my_database'. If missing, then the environment variable DB_URL will be read. """, default=argparse.SUPPRESS, ) parser.add_argument( "--gpus", type=str, help=""" Comma separated list of GPU indices for each process. To only run on a CPU leave blank. For example '2,,3,' will use GPU 2 and 3 for the 1st and 3rd process respectively. The processes 2 and 4 will be run on a CPU. By default using CPU for all processes. """, ) parser.add_argument( "--h5-dataset-name", type=str, help=""" The name of the dataset in the H5 file. Otherwise, the value of 'model_name_or_class' is used. """, ) parser.add_argument( "--indices-path", type=str, help=""" Path to a .npy file containing sentence ids to embed. Specifically, it is a 1D numpy array of integers representing the sentence ids. If not specified we embed all sentences in the database. """, ) parser.add_argument( "--log-file", "-l", type=str, metavar="<filepath>", default=None, help="In addition to stderr, log messages to a file.", ) parser.add_argument( "--log-level", type=int, default=20, help=""" The logging level. Possible values: - 50 for CRITICAL - 40 for ERROR - 30 for WARNING - 20 for INFO - 10 for DEBUG - 0 for NOTSET """, ) parser.add_argument( "-n", "--n-processes", default=4, type=int, help="Number of processes to use", ) parser.add_argument( "-s", "--start-method", default="forkserver", choices=["fork", "forkserver", "spawn"], type=str, help=""" Multiprocessing starting method to be used. Note that using "fork" might lead to problems when doing GPU inference. """, ) parser.add_argument( "--temp-dir", type=str, help=""" The path to where temporary h5 files are saved. If not specified then identical to the folder in which the output h5 file is placed. """, ) # Parse CLI arguments env_variable_names = { "db_url": "DB_URL", } args = parse_args_or_environment(parser, env_variable_names, argv=argv) # Configure logging configure_logging(args.log_file, args.log_level) logger = logging.getLogger(__name__) logger.info(" Configuration ".center(80, "-")) for k, v in vars(args).items(): logger.info(f"{k:<32}: {v}") logger.info("-" * 80) # Imports (they are here to make --help quick) logger.info("Loading libraries") from bluesearch.embedding_models import MPEmbedder # Database related logger.info("SQL Alchemy Engine creation ....") full_url = f"mysql+mysqldb://guest:guest@{args.db_url}?charset=utf8mb4" engine = sqlalchemy.create_engine(full_url) # Path preparation and checking out_file = pathlib.Path(args.outfile) temp_dir = None if args.temp_dir is None else pathlib.Path(args.temp_dir) checkpoint_path: pathlib.Path | None = None if args.checkpoint is not None: checkpoint_path = pathlib.Path(args.checkpoint) indices_path = ( None if args.indices_path is None else pathlib.Path(args.indices_path) ) # Parse GPUs if args.gpus is None: gpus = None else: gpus = [None if x == "" else int(x) for x in args.gpus.split(",")] if indices_path is not None: if indices_path.exists(): indices = np.load(str(indices_path)) else: raise FileNotFoundError(f"Indices file {indices_path} does not exist!") else: n_sentences = list(engine.execute("SELECT COUNT(*) FROM sentences"))[0][0] indices = np.arange(1, n_sentences + 1) logger.info("Instantiating MPEmbedder") mpe = MPEmbedder( engine.url, args.model_name_or_class, indices, out_file, batch_size_inference=args.batch_size_inference, batch_size_transfer=args.batch_size_transfer, n_processes=args.n_processes, checkpoint_path=checkpoint_path, gpus=gpus, temp_folder=temp_dir, h5_dataset_name=args.h5_dataset_name, start_method=args.start_method, ) logger.info("Starting embedding") mpe.do_embedding()
if __name__ == "__main__": # pragma: no cover sys.exit(run_compute_embeddings())