# Copyright 2025 Emmanuel Cortes. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import shutil
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Generic, Literal, overload
import numpy as np
import torch
from huggingface_hub import HfApi, ModelHubMixin
from huggingface_hub.constants import HF_HUB_CACHE
from optimum.utils.file_utils import find_files_matching_pattern
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
PretrainedConfig,
)
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.utils import logging
from transformers.utils.doc import add_end_docstrings, add_start_docstrings
from transformers.utils.hub import cached_file, is_offline_mode
from typing_extensions import Unpack
from .configuration import RKNNConfig
from .constants import (
RKNN_FILE_PATTERN,
RKNN_WEIGHTS_NAME,
CoreMaskType,
PlatformType,
)
from .modeling_utils import MODEL_OUTPUT_T, PreTrainedModel, RKNNRuntime, TENSOR_Ts
from .utils.docs import (
FROM_PRETRAINED_START_DOCSTRING,
RKNN_MODEL_END_DOCSTRING,
TEXT_INPUTS_DOCSTRING,
TOKENIZER_FOR_DOC,
add_start_docstrings_to_model_forward,
)
from .utils.import_utils import (
is_rknn_toolkit_lite_available,
)
from .utils.logging_utils import suppress_output
logger = logging.get_logger(__name__)
[docs]
class RKModel(
RKNNRuntime,
PreTrainedModel,
ModelHubMixin,
Generic[MODEL_OUTPUT_T, Unpack[TENSOR_Ts]],
library_name="rk-transformers",
tags=["rknn", "rockchip", "npu"],
):
"""Base class for RKNN-backed text models integrated with the Hugging Face Hub."""
model_type: str = "rknn_model"
auto_model_class = AutoModel
[docs]
def __init__(
self,
*,
model_id: str | None = None,
config: PretrainedConfig | None = None,
model_path: str | Path,
platform: PlatformType | None = None,
core_mask: CoreMaskType = "auto",
rknn_config: RKNNConfig | None = None,
max_seq_length: int = 512,
batch_size: int = 1,
) -> None:
if config is None:
raise ValueError("A Hugging Face config is required to build an RKModel.")
super().__init__(model_path=model_path, platform=platform, core_mask=core_mask, rknn_config=rknn_config)
self.model_id = model_id
self.config = config
# Set defaults for input_names, batch_size, and max_seq_length
self.input_names = ["input_ids", "attention_mask"]
if getattr(config, "type_vocab_size", 1) > 1:
self.input_names.append("token_type_ids")
self.batch_size = batch_size
self.max_seq_length = max_seq_length
if self.rknn_config:
if hasattr(self.rknn_config, "model_input_names") and self.rknn_config.model_input_names:
self.input_names = self.rknn_config.model_input_names
if hasattr(self.rknn_config, "max_seq_length") and self.rknn_config.max_seq_length is not None:
self.max_seq_length = self.rknn_config.max_seq_length
if hasattr(self.rknn_config, "batch_size"):
self.batch_size = self.rknn_config.batch_size
self.pad_token_id = 0
self.pad_token_type_id = 0
self.pad_attention_mask = 0 # Huggingface transformers uses 0 for padding attention mask
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
if hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id is not None:
self.pad_token_id = self.tokenizer.pad_token_id
if hasattr(self.tokenizer, "pad_token_type_id") and self.tokenizer.pad_token_type_id is not None:
self.pad_token_type_id = self.tokenizer.pad_token_type_id
except Exception:
logger.warning("Failed to load tokenizer. Using default padding IDs (0).")
# From optimum.onnxruntime.modeling.ORTModel
AutoConfig.register(self.model_type, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
@overload
def __call__(self, *args: Any, return_dict: Literal[False], **kwargs: Any) -> tuple[Unpack[TENSOR_Ts]]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def __call__(self, *args: Any, return_dict: Literal[True], **kwargs: Any) -> MODEL_OUTPUT_T: ...
@overload
def __call__(self, *args: Any, **kwargs: Any) -> MODEL_OUTPUT_T: ...
# return_dict omitted -> MODEL_OUTPUT_T
# return_dict=True -> MODEL_OUTPUT_T
# return_dict=False -> tuple[Unpack[TENSOR_Ts]]
[docs]
def __call__(
self,
*args: Any,
return_dict: bool = True,
**kwargs: Any,
) -> MODEL_OUTPUT_T | tuple[Unpack[TENSOR_Ts]]:
return self.forward(*args, return_dict=return_dict, **kwargs)
[docs]
def forward(self, *args: Any, **kwargs: Any) -> MODEL_OUTPUT_T | tuple[Unpack[TENSOR_Ts]]:
"""Define the computation performed at every call.
Should be overridden by all subclasses.
"""
raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "forward" function')
@property
def device(self) -> torch.device:
"""Return the device on which the model is stored."""
return torch.device("cpu")
[docs]
def to(self, device: torch.device | str) -> "RKModel":
"""No-op for RKModel. For compatibility with Hugging Face Transformers Pipelines."""
return self
def _tensor_to_numpy(self, tensor: torch.Tensor | np.ndarray, dtype: np.dtype[Any]) -> np.ndarray:
if tensor is None:
raise ValueError("Input tensor is required for RKNN inference.")
if isinstance(tensor, torch.Tensor):
array = tensor.detach().cpu().numpy()
elif isinstance(tensor, np.ndarray):
array = tensor
else:
array = np.asarray(tensor)
if array.dtype != dtype:
array = array.astype(dtype, copy=False)
return array
def _torch_if_needed(self, use_torch: bool, array: np.ndarray) -> torch.Tensor | np.ndarray:
if use_torch:
contiguous = np.ascontiguousarray(array)
return torch.from_numpy(contiguous)
return array
def _ones_like(self, reference: torch.Tensor | np.ndarray, use_torch: bool) -> torch.Tensor | np.ndarray:
if use_torch:
if not isinstance(reference, torch.Tensor): # pragma: no cover - defensive
reference = torch.from_numpy(np.asarray(reference))
return torch.ones_like(reference)
return np.ones_like(np.asarray(reference))
def _zeros_like(self, reference: torch.Tensor | np.ndarray, use_torch: bool) -> torch.Tensor | np.ndarray:
if use_torch:
if not isinstance(reference, torch.Tensor): # pragma: no cover - defensive
reference = torch.from_numpy(np.asarray(reference))
return torch.zeros_like(reference)
return np.zeros_like(np.asarray(reference))
def _pad_to_model_input_dimensions(
self,
tensor: torch.Tensor | np.ndarray,
padding_id: int,
use_torch: bool,
target_shape: tuple[int, ...] | None = None,
) -> torch.Tensor | np.ndarray:
"""Pad tensor to match model's expected input dimensions.
Handles arbitrary tensor ranks (2D for standard tasks, 3D for multiple-choice, etc.)
by padding each dimension independently to match the target shape.
Args:
tensor: Input tensor to pad (e.g., shape: [batch, seq_len] or [batch, num_choices, seq_len])
padding_id: Value to use for padding
use_torch: Whether to use PyTorch or NumPy for padding
target_shape: Target shape for the tensor. If None, defaults to 2D padding behavior
using self.batch_size and self.max_seq_length.
Returns:
Padded tensor with shape matching target_shape
"""
# Default to 2D padding for backward compatibility
if target_shape is None:
target_shape = (getattr(self, "batch_size", tensor.shape[0]), self.max_seq_length)
if len(target_shape) != len(tensor.shape):
raise ValueError(
f"Target shape rank ({len(target_shape)}) must match tensor rank ({len(tensor.shape)}). "
f"Got target_shape={target_shape}, tensor.shape={tensor.shape}"
)
needs_padding = any(current < target for current, target in zip(tensor.shape, target_shape, strict=True))
if not needs_padding:
return tensor
# Calculate padding for each dimension
# Padding goes at the "end" of each dimension (right/bottom)
if use_torch:
# PyTorch pad format: (dim_n_before, dim_n_after, ..., dim_0_before, dim_0_after)
# We only pad at the end, so all "before" values are 0
pad_values: list[int] = []
for current_dim, target_dim in reversed(list(zip(tensor.shape, target_shape, strict=True))):
pad_values.extend([0, max(0, target_dim - current_dim)]) # (before, after)
tensor = torch.nn.functional.pad(tensor, tuple(pad_values), value=padding_id) # type: ignore
else:
# NumPy pad format: ((dim_0_before, dim_0_after), (dim_1_before, dim_1_after), ...)
pad_width = [
(0, max(0, target - current)) for current, target in zip(tensor.shape, target_shape, strict=True)
]
tensor = np.pad(tensor, pad_width, constant_values=padding_id)
return tensor
def _prepare_text_inputs(
self,
input_ids: torch.Tensor | np.ndarray,
attention_mask: torch.Tensor | np.ndarray | None,
token_type_ids: torch.Tensor | np.ndarray | None,
input_shape: tuple[int, ...] | None = None,
) -> tuple[bool, dict[str, torch.Tensor | np.ndarray | None], tuple[int, ...]]:
"""Prepare text inputs for RKNN inference with padding.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
token_type_ids: Token type IDs (optional)
input_shape: Expected input shape (e.g., [batch_size, seq_len] for 2D,
[batch_size, num_choices, seq_len] for 3D). If None, defaults
to 2D shape using self.batch_size and self.max_seq_length.
Returns:
Tuple of (use_torch, model_inputs, original_shape)
"""
if input_ids is None:
raise ValueError("`input_ids` is required for RKModel text inference.")
use_torch = isinstance(input_ids, torch.Tensor)
original_shape = tuple(input_ids.shape)
# Calculate target shape
if input_shape is None:
# Default 2D behavior: [batch_size, seq_len]
target_shape = (getattr(self, "batch_size", original_shape[0]), self.max_seq_length)
else:
# Use provided input_shape, filling in dimensions as needed
target_shape = input_shape
# Pad inputs to target shape
input_ids = self._pad_to_model_input_dimensions(
input_ids, padding_id=self.pad_token_id, use_torch=use_torch, target_shape=target_shape
)
if attention_mask is None:
attention_mask = self._ones_like(input_ids, use_torch)
attention_mask = self._pad_to_model_input_dimensions(
attention_mask, padding_id=self.pad_attention_mask, use_torch=use_torch, target_shape=target_shape
)
if "token_type_ids" in self.input_names:
if token_type_ids is None:
token_type_ids = self._zeros_like(input_ids, use_torch) # Use padded input_ids as reference
else:
token_type_ids = self._pad_to_model_input_dimensions(
token_type_ids, padding_id=self.pad_token_type_id, use_torch=use_torch, target_shape=target_shape
)
return (
use_torch,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
},
original_shape,
)
def _run_text_model(
self,
use_torch: bool,
model_inputs: dict[str, torch.Tensor | np.ndarray | None],
expected_outputs: Sequence[str],
) -> dict[str, torch.Tensor | np.ndarray]:
ordered_inputs: list[np.ndarray] = []
for name in self.input_names:
tensor = model_inputs.get(name)
if tensor is None:
continue
ordered_inputs.append(self._tensor_to_numpy(tensor, np.dtype(np.int16)))
if self.rknn is None:
raise RuntimeError("RKNN runtime has been released and can no longer run inference.")
# Suppress RKNN inference logs
with suppress_output():
if is_rknn_toolkit_lite_available():
# data_type: int8 | uint8 | int16 | float16 | float32 - limitation with rknn MM API/Hardware
# This an issue for models with embeddings since they require int64 inputs.
outputs = self.rknn.inference(inputs=ordered_inputs, data_type="int16") # type: ignore
else:
outputs = self.rknn.inference(inputs=ordered_inputs)
if outputs is None:
input_summaries = [f"shape={arr.shape}, dtype={arr.dtype}" for arr in ordered_inputs]
raise RuntimeError(
"RKNN inference returned None. "
"This is likely due to a mismatch between model input shapes and the given inputs. "
f"Input summary: {input_summaries}"
)
if len(outputs) < len(expected_outputs):
logger.error(
"RKNN inference output mismatch: expected %d outputs (%s), got %d outputs",
len(expected_outputs),
expected_outputs,
len(outputs),
)
raise RuntimeError("RKNN inference did not return the expected number of outputs.")
prepared: dict[str, torch.Tensor | np.ndarray] = {}
for idx, name in enumerate(expected_outputs):
prepared[name] = self._torch_if_needed(use_torch, np.asarray(outputs[idx]))
return prepared
def _warn_on_unhandled_inputs(self, kwargs: dict[str, Any]) -> None:
if kwargs:
logger.warning_once( # type: ignore - transformers logger util
"%s received unsupported arguments: %s",
self.__class__.__name__,
", ".join(kwargs.keys()),
)
def _save_pretrained(self, save_directory: Path) -> None:
target = save_directory / RKNN_WEIGHTS_NAME
shutil.copyfile(self.model_path, target)
@staticmethod
def _cached_file(
path_or_repo_id: str | Path,
filename: str,
subfolder: str = "",
revision: str | None = "main",
force_download: bool = False,
local_files_only: bool = False,
token: bool | str | None = None,
cache_dir: str | Path = HF_HUB_CACHE,
proxies: dict | None = None,
) -> Path:
cached_path = cached_file(
path_or_repo_id,
filename=filename,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
proxies=proxies,
)
if cached_path is None:
raise FileNotFoundError(f"Unable to cache RKNN artifact `{filename}` from {path_or_repo_id}.")
return Path(cached_path)
@staticmethod
def _infer_file_path(
pattern: str,
candidate_files: list[Path],
standard_file_name: str,
target_file_name: str | None = None,
) -> Path:
if target_file_name is not None:
specific = [file for file in candidate_files if file.name == target_file_name]
if not specific:
raise FileNotFoundError(
f"Could not find any RKNN files with target file name {target_file_name}. "
f"Candidates: {candidate_files}."
)
if len(specific) > 1:
logger.warning(
"Found multiple RKNN files named %s, using %s.",
target_file_name,
specific[0].name,
)
return specific[0]
standard = [file for file in candidate_files if file.name == standard_file_name]
if len(standard) == 1:
return standard[0]
if len(standard) > 1:
logger.warning(
"Found multiple RKNN files named %s, using %s.",
standard_file_name,
standard[0].name,
)
return standard[0]
pattern_files = [path for path in candidate_files if re.search(pattern, str(path))]
if not pattern_files:
raise FileNotFoundError(
f"Could not find an RKNN artifact matching pattern {pattern}. Candidates: {candidate_files}."
)
if len(pattern_files) > 1:
logger.warning(
"Found multiple RKNN files matching pattern %s, using %s.",
pattern,
pattern_files[0].name,
)
return pattern_files[0]
@staticmethod
def _list_repo_rknn_files(
model_id: str,
*,
revision: str | None,
token: str | bool | None,
subfolder: str,
) -> list[Path]:
"""Enumerate RKNN files from a remote Hugging Face repository."""
api = HfApi(token=token if isinstance(token, str) else None)
try:
repo_files = api.list_repo_files(model_id, revision=revision, repo_type="model")
except Exception as exc: # pragma: no cover - network errors
logger.debug("Failed to list repo files for %s: %s", model_id, exc)
return []
filtered: list[Path] = []
for file_path in repo_files:
if not re.search(RKNN_FILE_PATTERN, file_path):
continue
path_obj = Path(file_path)
if subfolder:
try:
path_obj.relative_to(subfolder)
except ValueError:
continue
filtered.append(path_obj)
return filtered
@classmethod
def _resolve_config(
cls,
model_id: str,
config: Any | None,
*,
revision: str | None,
cache_dir: str | Path | None,
force_download: bool,
local_files_only: bool,
token: str | bool | None,
trust_remote_code: bool,
proxies: dict | None = None,
) -> PretrainedConfig:
if isinstance(config, PretrainedConfig):
return config
if isinstance(config, dict):
return PretrainedConfig.from_dict(config)
try:
return AutoConfig.from_pretrained(
model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
trust_remote_code=trust_remote_code,
proxies=proxies,
)
except Exception as exc:
logger.warning(
"Falling back to a generic config for %s because AutoConfig loading failed: %s",
model_id,
exc,
)
fallback_model_type = getattr(getattr(cls.auto_model_class, "config_class", None), "model_type", None)
return PretrainedConfig(
model_type=fallback_model_type or cls.model_type,
name_or_path=model_id,
)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
config: PretrainedConfig | None,
# rknn options
platform: PlatformType | None = None,
core_mask: CoreMaskType = "auto",
# hub options
subfolder: str = "",
revision: str | None = None,
force_download: bool = False,
resume_download: bool | None = False,
proxies: dict | None = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str | Path | None,
token: str | bool | None,
# file options
file_name: str | None = None,
**model_kwargs: Any,
):
cache_dir = cache_dir or HF_HUB_CACHE
if is_offline_mode() and not local_files_only:
local_files_only = True
if os.path.isfile(model_id):
model_path = Path(model_id)
elif file_name is not None:
model_path = cls._cached_file(
model_id,
filename=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
force_download=force_download,
cache_dir=cache_dir,
revision=revision,
token=token,
proxies=proxies,
)
else:
candidate_files = find_files_matching_pattern(
model_id,
pattern=RKNN_FILE_PATTERN,
glob_pattern="**/*.rknn",
subfolder=subfolder,
token=token,
revision=revision,
)
if not candidate_files:
candidate_files = cls._list_repo_rknn_files(
model_id,
revision=revision,
token=token,
subfolder=subfolder,
)
if not candidate_files:
raise FileNotFoundError(f"Could not find any RKNN model file in {model_id}.")
if Path(model_id).is_dir():
candidate_files = [path.relative_to(model_id) for path in candidate_files]
resolved_file = cls._infer_file_path(
RKNN_FILE_PATTERN,
candidate_files,
standard_file_name=RKNN_WEIGHTS_NAME,
target_file_name=file_name,
)
subfolder_to_use = resolved_file.parent.as_posix()
if subfolder_to_use == ".":
subfolder_to_use = ""
model_path = cls._cached_file(
model_id,
filename=resolved_file.name,
subfolder=subfolder_to_use,
local_files_only=local_files_only,
force_download=force_download,
cache_dir=cache_dir,
revision=revision,
token=token,
proxies=proxies,
)
resolved_config = cls._resolve_config(
model_id,
config,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
trust_remote_code=trust_remote_code,
proxies=proxies,
)
# Try to get rknn config from the resolved config object
root_rknn_config = {}
if hasattr(resolved_config, "rknn"):
root_rknn_config = resolved_config.rknn
elif isinstance(resolved_config, dict) and "rknn" in resolved_config:
root_rknn_config = resolved_config["rknn"]
model_rknn_config = None
if root_rknn_config:
# Match model filename to keys in rknn config (e.g. "rknn/model.rknn")
filename = model_path.name
for key, conf in root_rknn_config.items():
if key.endswith(filename):
try:
model_rknn_config = RKNNConfig.from_dict(conf)
logger.info(f"Loaded RKNN config for {filename}")
break
except Exception as e:
logger.warning(f"Failed to parse RKNN config for {key}: {e}")
if not model_rknn_config:
logger.warning("RKNN config not found in config.json. Use default batch_size=1 and max_seq_length=512.")
return cls(
model_id=model_id,
config=resolved_config,
model_path=model_path,
platform=platform,
core_mask=core_mask,
rknn_config=model_rknn_config,
**model_kwargs,
)
[docs]
@classmethod
@add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING)
def from_pretrained(
cls,
pretrained_model_name_or_path: str | Path,
*,
config: PretrainedConfig | None = None,
# rknn options
platform: PlatformType | None = None,
core_mask: CoreMaskType = "auto",
# hub options
subfolder: str = "",
revision: str | None = None,
force_download: bool = False,
resume_download: bool | None = False,
proxies: dict | None = None,
token: str | bool | None = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str | Path | None = None,
# file options
file_name: str | None = None,
**model_kwargs: Any,
):
return super().from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
config=config,
platform=platform,
core_mask=core_mask,
subfolder=subfolder,
revision=revision,
force_download=force_download,
proxies=proxies,
token=token,
local_files_only=local_files_only,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
file_name=file_name,
**model_kwargs,
)
FEATURE_EXTRACTION_EXAMPLE = r"""
Example of feature extraction:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
inputs = tokenizer("My name is Philipp and I live in Germany.", return_tensors="np")
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
list(last_hidden_state.shape)
# [1, 12, 384]
"""
MASKED_LM_EXAMPLE = r"""
Example of masked language modeling:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
inputs = tokenizer("The capital of France is [MASK].", return_tensors="np")
outputs = model(**inputs)
logits = outputs.logits
list(logits.shape)
# [1, 512, 30522]
"""
SEQUENCE_CLASSIFICATION_EXAMPLE = r"""
Example of single-label classification:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
outputs = model(**inputs)
logits = outputs.logits
list(logits.shape)
# [1, 2]
"""
[docs]
@add_end_docstrings(RKNN_MODEL_END_DOCSTRING)
class RKModelForSequenceClassification(RKModel[SequenceClassifierOutput, torch.Tensor | np.ndarray]):
"""RKNN model for sequence classification/regression tasks."""
auto_model_class = AutoModelForSequenceClassification
[docs]
@add_start_docstrings_to_model_forward(
TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ SEQUENCE_CLASSIFICATION_EXAMPLE.format(
processor_class=TOKENIZER_FOR_DOC,
model_class="RKModelForSequenceClassification",
checkpoint="rk-transformers/distilbert-base-uncased-finetuned-sst-2-english",
)
)
def forward(
self,
input_ids: torch.Tensor | np.ndarray,
attention_mask: torch.Tensor | np.ndarray | None = None,
token_type_ids: torch.Tensor | np.ndarray | None = None,
*,
return_dict: bool = True,
**kwargs: Any,
):
self._warn_on_unhandled_inputs(kwargs)
use_torch, model_inputs, original_shape = self._prepare_text_inputs(input_ids, attention_mask, token_type_ids)
outputs = self._run_text_model(use_torch, model_inputs, ["logits"])
logits = outputs["logits"][: original_shape[0]]
if not return_dict:
return (logits,)
return SequenceClassifierOutput(logits=logits) # type: ignore[arg-type]
QUESTION_ANSWERING_EXAMPLE = r"""
Example of question answering:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="np")
outputs = model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
list(start_logits.shape)
# [1, 512]
list(end_logits.shape)
# [1, 512]
"""
[docs]
@add_end_docstrings(RKNN_MODEL_END_DOCSTRING)
class RKModelForQuestionAnswering(
RKModel[QuestionAnsweringModelOutput, torch.Tensor | np.ndarray, torch.Tensor | np.ndarray]
):
"""RKNN Model with a QuestionAnsweringModelOutput for extractive question-answering tasks like SQuAD."""
auto_model_class = AutoModelForQuestionAnswering
[docs]
@add_start_docstrings_to_model_forward(
TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ QUESTION_ANSWERING_EXAMPLE.format(
processor_class=TOKENIZER_FOR_DOC,
model_class="RKModelForQuestionAnswering",
checkpoint="rk-transformers/distilbert-base-cased-distilled-squad",
)
)
def forward(
self,
input_ids: torch.Tensor | np.ndarray,
attention_mask: torch.Tensor | np.ndarray | None = None,
token_type_ids: torch.Tensor | np.ndarray | None = None,
*,
return_dict: bool = True,
**kwargs: Any,
):
self._warn_on_unhandled_inputs(kwargs)
use_torch, model_inputs, original_shape = self._prepare_text_inputs(input_ids, attention_mask, token_type_ids)
outputs = self._run_text_model(use_torch, model_inputs, ["start_logits", "end_logits"])
start_logits = outputs["start_logits"][: original_shape[0]]
end_logits = outputs["end_logits"][: original_shape[0]]
if not return_dict:
return (start_logits, end_logits)
return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) # type: ignore[arg-type]
TOKEN_CLASSIFICATION_EXAMPLE = r"""
Example of token classification:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
inputs = tokenizer("My name is Philipp and I live in Germany.", return_tensors="np")
outputs = model(**inputs)
logits = outputs.logits
list(logits.shape)
# [1, 512, 9]
"""
[docs]
@add_end_docstrings(RKNN_MODEL_END_DOCSTRING)
class RKModelForTokenClassification(RKModel[TokenClassifierOutput, torch.Tensor | np.ndarray]):
"""RKNN Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""" # noqa: E501
auto_model_class = AutoModelForTokenClassification
[docs]
@add_start_docstrings_to_model_forward(
TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ TOKEN_CLASSIFICATION_EXAMPLE.format(
processor_class=TOKENIZER_FOR_DOC,
model_class="RKModelForTokenClassification",
checkpoint="rk-transformers/bert-base-NER",
)
)
def forward(
self,
input_ids: torch.Tensor | np.ndarray,
attention_mask: torch.Tensor | np.ndarray | None = None,
token_type_ids: torch.Tensor | np.ndarray | None = None,
*,
return_dict: bool = True,
**kwargs: Any,
):
self._warn_on_unhandled_inputs(kwargs)
use_torch, model_inputs, original_shape = self._prepare_text_inputs(input_ids, attention_mask, token_type_ids)
outputs = self._run_text_model(use_torch, model_inputs, ["logits"])
logits = outputs["logits"][: original_shape[0]]
if not return_dict:
return (logits,)
return TokenClassifierOutput(logits=logits) # type: ignore[arg-type]
MULTIPLE_CHOICE_EXAMPLE = r"""
Example of multiple choice:
.. code-block:: python
from transformers import {processor_class}
from rktransformers.modeling import {model_class}
import torch
tokenizer = {processor_class}.from_pretrained("{checkpoint}")
model = {model_class}.from_pretrained("{checkpoint}")
prompt = "In Italy, pizza is served in slices."
choice0 = "It is eaten with a fork and knife."
choice1 = "It is eaten while held in the hand."
choice2 = "It is blended into a smoothie."
choice3 = "It is folded into a taco."
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;))
encoding = tokenizer([prompt, prompt, prompt, prompt], [choice0, choice1, choice2, choice3], return_tensors="np", padding=True)
inputs = {{k: np.expand_dims(v, 0) for k, v in encoding.items()}}
outputs = model(**inputs)
logits = outputs.logits
list(logits.shape)
# [1, 4]
""" # noqa: E501