Source code for rktransformers.exporters.rknn.convert

# Copyright 2025 Emmanuel Cortes. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import json
import logging
import os
import shutil

from huggingface_hub import HfApi, create_repo, snapshot_download
from optimum.exporters.onnx import main_export
from rknn.api import RKNN

from rktransformers.configuration import RKNNConfig
from rktransformers.constants import (
    ALLOW_MODEL_REPO_FILES,
    DEFAULT_MAX_SEQ_LENGTH,
    IGNORE_MODEL_REPO_FILES,
)
from rktransformers.exporters.rknn.model_card import ModelCardGenerator

from .utils import (
    clean_build_artifacts,
    download_sentence_transformer_modules_weights,
    generate_rknn_output_path,
    get_onnx_input_names,
    has_rknn_config,
    load_model_config,
    prepare_dataset_for_quantization,
    resolve_hub_repo_id,
    update_model_config_with_rknn,
)

logger = logging.getLogger(__name__)


[docs] def export_rknn(config: RKNNConfig) -> None: """ Export ONNX model or Hugging Face model to RKNN using the provided configuration. For Hugging Face models, this function: 1. Exports to ONNX using Optimum (automatically detects required inputs like token_type_ids) 2. Inspects the exported ONNX model to determine actual inputs 3. Configures RKNN toolkit with optimization and quantization settings 4. Loads and converts the ONNX model to RKNN format 5. Optionally quantizes using calibration dataset 6. Exports to RKNN format and saves complete configuration Args: config: RKNN configuration object Returns: None on success, raises RuntimeError on failure """ if not config.model_name_or_path: raise ValueError("model_name_or_path is required in configuration") base_model_id = config.model_name_or_path is_local_model = config.model_name_or_path.endswith(".onnx") or os.path.exists(config.model_name_or_path) if not is_local_model: logger.info(f"Model path '{config.model_name_or_path}' not found locally. Treating as Hugging Face Hub ID.") # Determine output directory if config.output_path: if config.output_path.endswith(".rknn"): base_output_dir = os.path.dirname(config.output_path) rknn_filename = os.path.basename(config.output_path) else: base_output_dir = config.output_path rknn_filename = "model.rknn" else: base_output_dir = os.getcwd() rknn_filename = "model.rknn" model_name = config.model_name_or_path.split("/")[-1] output_dir = os.path.join(base_output_dir, model_name) config.output_path = os.path.join(output_dir, rknn_filename) logger.info(f"Exporting model from Hub to directory: {output_dir}") try: # Preserve existing config.json with RKNN modifications before downloading/exporting config_path = os.path.join(output_dir, "config.json") preserved_config = None if has_rknn_config(output_dir): logger.info("Preserving existing config.json with RKNN modifications") with open(config_path) as f: preserved_config = json.load(f) # Download model files (excluding weights) snapshot_download( repo_id=config.model_name_or_path, local_dir=output_dir, ignore_patterns=IGNORE_MODEL_REPO_FILES, allow_patterns=ALLOW_MODEL_REPO_FILES, ) download_sentence_transformer_modules_weights( repo_id=config.model_name_or_path, local_dir=output_dir, token=config.hub_token, ) # Export to ONNX (this may overwrite config.json) export_kwargs = { "model_name_or_path": config.model_name_or_path, "output": output_dir, "task": config.task, # onnx resolves "auto" internally "opset": config.opset, "do_validation": False, "no_post_process": True, "batch_size": config.batch_size, "sequence_length": config.max_seq_length or DEFAULT_MAX_SEQ_LENGTH, } if config.task_kwargs: export_kwargs.update(config.task_kwargs) main_export(**export_kwargs) # Restore preserved config if it existed if preserved_config is not None: logger.info("Restoring preserved config.json with RKNN modifications") with open(config_path, "w") as f: json.dump(preserved_config, f, indent=2) onnx_model_path = os.path.join(output_dir, "model.onnx") except Exception as e: raise RuntimeError(f"Failed to export model from Hub: {e}") from e else: # Local ONNX model onnx_model_path = config.model_name_or_path output_dir = os.path.dirname(os.path.abspath(config.model_name_or_path)) model_config = load_model_config(output_dir) # used for auto-detection # Auto-detect max_seq_length if not provided is_user_specified_seq_len = config.max_seq_length is not None if config.max_seq_length is None: config.max_seq_length = getattr(model_config, "max_position_embeddings", DEFAULT_MAX_SEQ_LENGTH) logger.info(f"Auto-detected max_seq_length: {config.max_seq_length}") # Auto-detect type_vocab_size if not provided if config.type_vocab_size is None: config.type_vocab_size = getattr(model_config, "type_vocab_size", None) if config.type_vocab_size: logger.info(f"Auto-detected type_vocab_size: {config.type_vocab_size}") rknn = RKNN(verbose=True) logger.info(f"Configuring RKNN for {config.target_platform}") rknn.config(**config.to_dict()) # register custom operators in between rknn.config() and rknn.load_onnx() # rknn.register_custom_op() # Docs: 5.5 Custom Operators https://github.com/airockchip/rknn-toolkit2/blob/master/doc/02_Rockchip_RKNPU_User_Guide_RKNN_SDK_V2.3.2_EN.pdf logger.info(f"Loading ONNX model: {onnx_model_path}") sequence_length = config.max_seq_length batch_size = config.batch_size assert sequence_length is not None assert sequence_length >= 1, "max_seq_length must be at least 1" assert batch_size >= 1, "batch_size must be at least 1" if config.model_input_names: inputs = config.model_input_names logger.info(f"Using user-specified inputs: {inputs}") else: inputs = get_onnx_input_names(onnx_model_path) if inputs is None: # Fallback to heuristic based on type_vocab_size if ONNX inspection fails logger.warning("Failed to extract inputs from ONNX model, falling back to heuristic") if config.type_vocab_size and config.type_vocab_size > 1: inputs = ["input_ids", "attention_mask", "token_type_ids"] logger.info(f"Auto-detected inputs (with token_type_ids): {inputs}") else: inputs = ["input_ids", "attention_mask"] logger.info(f"Auto-detected inputs: {inputs}") config.model_input_names = inputs # For multiple-choice tasks, inputs are 3D: [batch_size, num_choices, sequence_length] if config.task == "multiple-choice": num_choices = config.task_kwargs.get("num_choices", 4) if config.task_kwargs else 4 input_size_list: list[list[int]] = [[batch_size, num_choices, sequence_length]] * len(inputs) else: input_size_list: list[list[int]] = [[batch_size, sequence_length]] * len(inputs) if config.dynamic_input is not None and (input_size_list not in config.dynamic_input): config.dynamic_input.append(input_size_list) logger.info(f"Loading ONNX model into RKNN with inputs: {inputs} and sizes: {input_size_list}") ret = rknn.load_onnx( model=onnx_model_path, inputs=inputs, input_size_list=input_size_list, ) if ret != 0: raise RuntimeError("Failed to load ONNX model!") dataset_file = None actual_columns = None actual_splits = None if config.quantization.do_quantization and config.quantization.dataset_name: model_dir = os.path.dirname(onnx_model_path) dataset_file, actual_columns, actual_splits = prepare_dataset_for_quantization( config.quantization.dataset_name, config.quantization.dataset_size, model_dir, inputs, config.quantization.dataset_split, config.quantization.dataset_subset, config.quantization.dataset_columns, batch_size, sequence_length, ) if actual_columns and not config.quantization.dataset_columns: config.quantization.dataset_columns = actual_columns if actual_splits and not config.quantization.dataset_split: config.quantization.dataset_split = actual_splits logger.info("Building RKNN model") ret = rknn.build(do_quantization=config.quantization.do_quantization, dataset=dataset_file) if ret != 0: clean_build_artifacts(output_dir) raise RuntimeError("Failed to build RKNN model!") # Determine output path based on configuration if config.output_path: if config.output_path.endswith(".rknn"): model_dir = os.path.dirname(config.output_path) model_name = os.path.splitext(os.path.basename(config.output_path))[0] else: model_dir = config.output_path model_name = os.path.splitext(os.path.basename(onnx_model_path.rstrip(os.sep)))[0] else: if is_local_model: model_dir = os.path.dirname(os.path.abspath(onnx_model_path)) model_name = os.path.splitext(os.path.basename(onnx_model_path))[0] else: model_dir = os.getcwd() model_name = "model" output_path, rknn_key = generate_rknn_output_path( config, model_dir, model_name, batch_size, sequence_length, is_user_specified_seq_len ) config.output_path = output_path logger.info(f"Exporting RKNN model to {config.output_path}") ret = rknn.export_rknn(config.output_path) if ret != 0: clean_build_artifacts(output_dir) raise RuntimeError("Failed to export RKNN model!") # Resolve hub_model_id early so model card uses the resolved ID if config.push_to_hub: if not config.hub_model_id: clean_build_artifacts(output_dir) raise ValueError("hub_model_id is required when push_to_hub is True") config.hub_model_id = resolve_hub_repo_id(config.hub_model_id, config.hub_token) # updated config.json is required for model card generation if config.output_path: update_model_config_with_rknn(config, model_dir, rknn_key, model_config) generator = ModelCardGenerator(pretrained_config=model_config) readme_path = generator.generate(config, model_dir, base_model_id) if readme_path: logger.info(f"Generated model card at {readme_path}") logger.info("Done!") rknn.release() if config.push_to_hub: assert config.hub_model_id is not None, "hub_model_id should be set here" logger.info(f"Pushing to Hub: {config.hub_model_id}") api = HfApi(token=config.hub_token) # Create repo if it doesn't exist try: create_repo( repo_id=config.hub_model_id, token=config.hub_token, private=config.hub_private_repo, exist_ok=True ) except Exception as e: logger.warning(f"Failed to create/check repo: {e}") # Upload files try: logger.info(f"Uploading directory {model_dir} to Hub...") # Match sentence-transformers module directories like 1_Pooling/*, 2_Dense/*, etc. upload_allow_patterns = ALLOW_MODEL_REPO_FILES.copy() upload_allow_patterns.extend( [ "[0-9]_*/**", # Match module directories and all their contents ] ) api.upload_folder( repo_id=config.hub_model_id, folder_path=model_dir, token=config.hub_token, repo_type="model", ignore_patterns=IGNORE_MODEL_REPO_FILES, allow_patterns=upload_allow_patterns, create_pr=config.hub_create_pr, ) logger.info(f"Successfully pushed to Hub: {config.hub_model_id}") except Exception as e: logger.warning(f"Error uploading to Hub: {e}") finally: # Cleanup dataset file and temp directory if dataset_file and os.path.exists(dataset_file): with contextlib.suppress(Exception): os.remove(dataset_file) temp_dir = os.path.dirname(dataset_file) if os.path.exists(temp_dir): shutil.rmtree(temp_dir) if dataset_file and os.path.exists(dataset_file): with contextlib.suppress(Exception): os.remove(dataset_file) temp_dir = os.path.dirname(dataset_file) if os.path.exists(temp_dir): shutil.rmtree(temp_dir) clean_build_artifacts(output_dir)