rkllm.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #ifndef _RKLLM_H_
  2. #define _RKLLM_H_
  3. #ifdef __cplusplus
  4. extern "C" {
  5. #endif
  6. /**
  7. * @typedef LLMHandle
  8. * @brief A handle used to manage and interact with the large language model.
  9. */
  10. typedef void* LLMHandle;
  11. /**
  12. * @enum LLMCallState
  13. * @brief Describes the possible states of an LLM call.
  14. */
  15. typedef enum {
  16. RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */
  17. RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */
  18. RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */
  19. RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */
  20. RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4 /**< Retrieve the last hidden layer during inference. */
  21. } LLMCallState;
  22. /**
  23. * @enum RKLLMInputType
  24. * @brief Defines the types of inputs that can be fed into the LLM.
  25. */
  26. typedef enum {
  27. RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */
  28. RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */
  29. RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */
  30. RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */
  31. } RKLLMInputType;
  32. /**
  33. * @enum RKLLMInferMode
  34. * @brief Specifies the inference modes of the LLM.
  35. */
  36. typedef enum {
  37. RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */
  38. RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */
  39. } RKLLMInferMode;
  40. /**
  41. * @struct RKLLMExtendParam
  42. * @brief The extend parameters for configuring an LLM instance.
  43. */
  44. typedef struct {
  45. int32_t base_domain_id; /**< base_domain_id */
  46. uint8_t reserved[112]; /**< reserved */
  47. } RKLLMExtendParam;
  48. /**
  49. * @struct RKLLMParam
  50. * @brief Defines the parameters for configuring an LLM instance.
  51. */
  52. typedef struct {
  53. const char* model_path; /**< Path to the model file. */
  54. int32_t max_context_len; /**< Maximum number of tokens in the context window. */
  55. int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */
  56. int32_t top_k; /**< Top-K sampling parameter for token generation. */
  57. float top_p; /**< Top-P (nucleus) sampling parameter. */
  58. float temperature; /**< Sampling temperature, affecting the randomness of token selection. */
  59. float repeat_penalty; /**< Penalty for repeating tokens in generation. */
  60. float frequency_penalty; /**< Penalizes frequent tokens during generation. */
  61. float presence_penalty; /**< Penalizes tokens based on their presence in the input. */
  62. int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */
  63. float mirostat_tau; /**< Tau parameter for Mirostat sampling. */
  64. float mirostat_eta; /**< Eta parameter for Mirostat sampling. */
  65. bool skip_special_token; /**< Whether to skip special tokens during generation. */
  66. bool is_async; /**< Whether to run inference asynchronously. */
  67. const char* img_start; /**< Starting position of an image in multimodal input. */
  68. const char* img_end; /**< Ending position of an image in multimodal input. */
  69. const char* img_content; /**< Pointer to the image content. */
  70. RKLLMExtendParam extend_param; /**< Extend parameters. */
  71. } RKLLMParam;
  72. /**
  73. * @struct RKLLMLoraAdapter
  74. * @brief Defines parameters for a Lora adapter used in model fine-tuning.
  75. */
  76. typedef struct {
  77. const char* lora_adapter_path; /**< Path to the Lora adapter file. */
  78. const char* lora_adapter_name; /**< Name of the Lora adapter. */
  79. float scale; /**< Scaling factor for applying the Lora adapter. */
  80. } RKLLMLoraAdapter;
  81. /**
  82. * @struct RKLLMEmbedInput
  83. * @brief Represents an embedding input to the LLM.
  84. */
  85. typedef struct {
  86. float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */
  87. size_t n_tokens; /**< Number of tokens represented in the embedding. */
  88. } RKLLMEmbedInput;
  89. /**
  90. * @struct RKLLMTokenInput
  91. * @brief Represents token input to the LLM.
  92. */
  93. typedef struct {
  94. int32_t* input_ids; /**< Array of token IDs. */
  95. size_t n_tokens; /**< Number of tokens in the input. */
  96. } RKLLMTokenInput;
  97. /**
  98. * @struct RKLLMMultiModelInput
  99. * @brief Represents multimodal input (e.g., text and image).
  100. */
  101. typedef struct {
  102. char* prompt; /**< Text prompt input. */
  103. float* image_embed; /**< Embedding of the image (of size n_image_tokens * n_image_embed). */
  104. size_t n_image_tokens; /**< Number of image tokens. */
  105. } RKLLMMultiModelInput;
  106. /**
  107. * @struct RKLLMInput
  108. * @brief Represents different types of input to the LLM via a union.
  109. */
  110. typedef struct {
  111. RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */
  112. union {
  113. const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */
  114. RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */
  115. RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */
  116. RKLLMMultiModelInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */
  117. };
  118. } RKLLMInput;
  119. /**
  120. * @struct RKLLMLoraParam
  121. * @brief Structure defining parameters for Lora adapters.
  122. */
  123. typedef struct {
  124. const char* lora_adapter_name; /**< Name of the Lora adapter. */
  125. } RKLLMLoraParam;
  126. /**
  127. * @struct RKLLMPromptCacheParam
  128. * @brief Structure to define parameters for caching prompts.
  129. */
  130. typedef struct {
  131. int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */
  132. const char* prompt_cache_path; /**< Path to the prompt cache file. */
  133. } RKLLMPromptCacheParam;
  134. /**
  135. * @struct RKLLMInferParam
  136. * @brief Structure for defining parameters during inference.
  137. */
  138. typedef struct {
  139. RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */
  140. RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */
  141. RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */
  142. } RKLLMInferParam;
  143. /**
  144. * @struct RKLLMResultLastHiddenLayer
  145. * @brief Structure to hold the hidden states from the last layer.
  146. */
  147. typedef struct {
  148. const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */
  149. int embd_size; /**< Size of the embedding vector. */
  150. int num_tokens; /**< Number of tokens for which hidden states are stored. */
  151. } RKLLMResultLastHiddenLayer;
  152. /**
  153. * @struct RKLLMResult
  154. * @brief Structure to represent the result of LLM inference.
  155. */
  156. typedef struct {
  157. const char* text; /**< Generated text result. */
  158. int32_t token_id; /**< ID of the generated token. */
  159. RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */
  160. } RKLLMResult;
  161. /**
  162. * @typedef LLMResultCallback
  163. * @brief Callback function to handle LLM results.
  164. * @param result Pointer to the LLM result.
  165. * @param userdata Pointer to user data for the callback.
  166. * @param state State of the LLM call (e.g., finished, error).
  167. */
  168. typedef void(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state);
  169. /**
  170. * @brief Creates a default RKLLMParam structure with preset values.
  171. * @return A default RKLLMParam structure.
  172. */
  173. RKLLMParam rkllm_createDefaultParam();
  174. /**
  175. * @brief Initializes the LLM with the given parameters.
  176. * @param handle Pointer to the LLM handle.
  177. * @param param Configuration parameters for the LLM.
  178. * @param callback Callback function to handle LLM results.
  179. * @return Status code (0 for success, non-zero for failure).
  180. */
  181. int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
  182. /**
  183. * @brief Loads a Lora adapter into the LLM.
  184. * @param handle LLM handle.
  185. * @param lora_adapter Pointer to the Lora adapter structure.
  186. * @return Status code (0 for success, non-zero for failure).
  187. */
  188. int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
  189. /**
  190. * @brief Loads a prompt cache from a file.
  191. * @param handle LLM handle.
  192. * @param prompt_cache_path Path to the prompt cache file.
  193. * @return Status code (0 for success, non-zero for failure).
  194. */
  195. int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
  196. /**
  197. * @brief Releases the prompt cache from memory.
  198. * @param handle LLM handle.
  199. * @return Status code (0 for success, non-zero for failure).
  200. */
  201. int rkllm_release_prompt_cache(LLMHandle handle);
  202. /**
  203. * @brief Destroys the LLM instance and releases resources.
  204. * @param handle LLM handle.
  205. * @return Status code (0 for success, non-zero for failure).
  206. */
  207. int rkllm_destroy(LLMHandle handle);
  208. /**
  209. * @brief Runs an LLM inference task synchronously.
  210. * @param handle LLM handle.
  211. * @param rkllm_input Input data for the LLM.
  212. * @param rkllm_infer_params Parameters for the inference task.
  213. * @param userdata Pointer to user data for the callback.
  214. * @return Status code (0 for success, non-zero for failure).
  215. */
  216. int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
  217. /**
  218. * @brief Runs an LLM inference task asynchronously.
  219. * @param handle LLM handle.
  220. * @param rkllm_input Input data for the LLM.
  221. * @param rkllm_infer_params Parameters for the inference task.
  222. * @param userdata Pointer to user data for the callback.
  223. * @return Status code (0 for success, non-zero for failure).
  224. */
  225. int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
  226. /**
  227. * @brief Aborts an ongoing LLM task.
  228. * @param handle LLM handle.
  229. * @return Status code (0 for success, non-zero for failure).
  230. */
  231. int rkllm_abort(LLMHandle handle);
  232. /**
  233. * @brief Checks if an LLM task is currently running.
  234. * @param handle LLM handle.
  235. * @return Status code (0 if a task is running, non-zero for otherwise).
  236. */
  237. int rkllm_is_running(LLMHandle handle);
  238. #ifdef __cplusplus
  239. }
  240. #endif
  241. #endif