Source code for ds_utils.transformers.sentence_embedding
"""Scikit-learn compatible transformer for sentence-transformers embeddings.
Requires the optional ``nlp`` dependency group::
pip install data-science-utils[nlp]
"""
from __future__ import annotations
from typing import Any, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
ArrayLike = Union[np.ndarray, pd.Series, pd.DataFrame, Sequence[Any]]
def _check_sentence_transformers_installed() -> None:
"""Raise a helpful ``ImportError`` if ``sentence-transformers`` is not installed."""
try:
import sentence_transformers # noqa: F401
except ImportError as exc:
raise ImportError(
"sentence-transformers is required for SentenceEmbeddingTransformer. "
"Install it with: pip install data-science-utils[nlp]"
) from exc
[docs]
class SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin):
"""Wrap a ``sentence-transformers`` model for use in sklearn pipelines.
Loads a `SentenceTransformer
<https://sbert.net/docs/package_reference/sentence_transformer/model.html>`_
model lazily on first :meth:`fit` and produces a dense ``float32`` (or
quantized) embedding matrix from text inputs.
The transformer accepts strings, lists of strings, :class:`pandas.Series`,
:class:`pandas.DataFrame` (single column), and :class:`numpy.ndarray`.
``None`` and ``NaN`` values are replaced with empty strings before encoding.
.. note::
This transformer requires the optional ``nlp`` extras::
pip install data-science-utils[nlp]
:param model_name: Name or path of a ``sentence-transformers`` model
(default: ``'sentence-transformers/all-MiniLM-L6-v2'``).
:param batch_size: Batch size for encoding (default: ``32``).
:param show_progress_bar: Whether to show a progress bar during encoding
(default: ``False``).
:param normalize_embeddings: Whether to L2-normalize embeddings to unit
length (default: ``False``).
:param device: Device for computation (``'cpu'``, ``'cuda'``, etc.).
``None`` lets the library auto-detect (default: ``None``).
:param precision: Embedding precision — ``'float32'``, ``'int8'``,
``'uint8'``, ``'binary'``, or ``'ubinary'`` (default: ``'float32'``).
:param truncate_dim: Truncate embeddings to this many dimensions. Useful
for `Matryoshka <https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html>`_
models (default: ``None`` — no truncation).
:param prompt_name: Name of a prompt registered in the model's
``prompts`` dictionary (default: ``None``).
:param prompt: Raw prompt string to prepend to every input sentence
(default: ``None``).
:ivar model_: The loaded ``SentenceTransformer`` instance (set after :meth:`fit`).
:ivar embedding_dimension_: Dimensionality of the output embeddings (set after :meth:`fit`).
:ivar n_features_in_: Number of input features (always ``1``).
"""
def __init__(
self,
*,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
batch_size: int = 32,
show_progress_bar: bool = False,
normalize_embeddings: bool = False,
device: Optional[str] = None,
precision: str = "float32",
truncate_dim: Optional[int] = None,
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
) -> None:
"""See class docstring for parameter descriptions."""
self.model_name = model_name
self.batch_size = batch_size
self.show_progress_bar = show_progress_bar
self.normalize_embeddings = normalize_embeddings
self.device = device
self.precision = precision
self.truncate_dim = truncate_dim
self.prompt_name = prompt_name
self.prompt = prompt
[docs]
def set_params(self, **params: Any) -> SentenceEmbeddingTransformer:
"""Set the parameters of this estimator."""
if "precision" in params:
valid_precisions = ("float32", "int8", "uint8", "binary", "ubinary")
if params["precision"] not in valid_precisions:
raise ValueError(f"Invalid precision '{params['precision']}'. Expected one of {valid_precisions}.")
# Invalidate cached model if model-loading params change
model_reload_params = {"model_name", "device", "truncate_dim"}
if model_reload_params & params.keys() and hasattr(self, "model_"):
del self.model_
del self.embedding_dimension_
if hasattr(self, "_loaded_model_name_"):
del self._loaded_model_name_
return super().set_params(**params)
def _load_model(self) -> None:
"""Import and instantiate the ``SentenceTransformer`` model."""
_check_sentence_transformers_installed()
from sentence_transformers import SentenceTransformer
self.model_ = SentenceTransformer(self.model_name, device=self.device, truncate_dim=self.truncate_dim)
self.embedding_dimension_ = self.model_.get_sentence_embedding_dimension()
self._loaded_model_name_ = self.model_name
@staticmethod
def _is_null(val: Any) -> bool:
"""Robustly check for None, np.nan, or pd.NA without triggering array ambiguity."""
if val is None or val is pd.NA:
return True
if isinstance(val, (float, np.floating)) and np.isnan(val):
return True
return False
@staticmethod
def _validate_shape(X: ArrayLike) -> None:
"""Validate that the input shape is 1D or a single-column 2D array."""
if isinstance(X, pd.DataFrame):
if X.shape[1] != 1:
raise ValueError(
f"SentenceEmbeddingTransformer expects a single text column; got DataFrame with shape {X.shape}."
)
elif isinstance(X, np.ndarray):
if X.ndim > 1 and (X.ndim > 2 or X.shape[1] != 1):
raise ValueError(
f"SentenceEmbeddingTransformer expects a single text column; got array with shape {X.shape}."
)
@classmethod
def _prepare_texts(cls, X: ArrayLike) -> List[str]:
"""Convert input to a list of strings, replacing ``None``/``NaN`` with ``""``."""
cls._validate_shape(X)
if isinstance(X, pd.DataFrame):
raw = X.iloc[:, 0].tolist()
elif isinstance(X, pd.Series):
raw = X.tolist()
elif isinstance(X, np.ndarray):
if X.ndim > 1:
raw = X[:, 0].tolist()
else:
raw = X.tolist()
elif isinstance(X, str):
raw = [X]
else:
raw = list(X)
return ["" if cls._is_null(t) else str(t) for t in raw]
[docs]
def fit(self, X: ArrayLike, y: Any = None) -> SentenceEmbeddingTransformer:
"""Load the sentence-transformer model and record embedding metadata.
The model is loaded lazily on the first call to :meth:`fit`. Subsequent
calls reuse the cached model unless the transformer is re-created.
:param X: Text data — array-like of strings, :class:`pandas.Series`,
single-column :class:`pandas.DataFrame`, or :class:`numpy.ndarray`.
:param y: Ignored; present for sklearn API compatibility.
:return: This estimator, fitted.
"""
valid_precisions = ("float32", "int8", "uint8", "binary", "ubinary")
if self.precision not in valid_precisions:
raise ValueError(f"Invalid precision '{self.precision}'. Expected one of {valid_precisions}.")
self.n_features_in_ = 1
self._validate_shape(X)
if not hasattr(self, "model_") or getattr(self, "_loaded_model_name_", None) != self.model_name:
self._load_model()
return self
[docs]
def transform(self, X: ArrayLike) -> np.ndarray:
"""Encode text inputs into dense embedding vectors.
:param X: Same accepted forms as :meth:`fit`.
:return: Embedding matrix of shape ``(n_samples, embedding_dimension_)``.
The output dtype depends on the ``precision`` parameter (e.g., ``float32`` or ``int8``).
Note: For ``binary`` or ``ubinary`` precision, the output is a packed ``uint8`` array
where dimensions represent packed bits rather than individual embedding dims.
:raises sklearn.exceptions.NotFittedError: If :meth:`fit` has not been called.
"""
check_is_fitted(self, "model_")
texts = self._prepare_texts(X)
embeddings = self.model_.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
prompt_name=self.prompt_name,
prompt=self.prompt,
)
return np.asarray(embeddings)
[docs]
def get_feature_names_out(self, input_features: Union[None, np.ndarray, List[str]] = None) -> np.ndarray:
"""Return output feature names for this transformation.
Names follow ``dim_0``, ``dim_1``, …, ``dim_{n-1}``.
:param input_features: Names for the input column(s), or None. When provided, length must
match ``n_features_in_``.
:return: ``numpy.ndarray`` of shape ``(embedding_dimension_,)``, dtype ``object``.
"""
check_is_fitted(self, "embedding_dimension_")
if input_features is not None:
input_features = np.asarray(input_features, dtype=object)
if len(input_features) != self.n_features_in_:
raise ValueError(
f"input_features has {len(input_features)} element(s), expected {self.n_features_in_}."
)
return np.asarray([f"dim_{i}" for i in range(self.embedding_dimension_)], dtype=object)