pybind11嘗試編寫

include "chat.h"

include <pybind11/pybind11.h>

include <pybind11/stl.h>

include "models.cpp"

namespace chatllm {

namespace py = pybind11;
using namespace pybind11::literals;

// class PyBaseTokenizer : public BaseTokenizer {
// public:
// using BaseTokenizer::BaseTokenizer;

// std::vector<int> encode(const std::string &text, int max_length) const override {
// PYBIND11_OVERRIDE_PURE(std::vector<int>, BaseTokenizer, encode, text, max_length);
// }
// std::string decode(const std::vector<int> &ids) const override {
// PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids);
// }
// std::vector<int> encode_messages(const std::vector<ChatMessage> &history, int max_length) const override {
// PYBIND11_OVERLOAD_PURE(std::vector<int>, BaseTokenizer, encode_messages, history, max_length);
// }
// };

// class PyBaseModelForCausalLM : public BaseModelForCausalLM {
// public:
// using BaseModelForCausalLM::BaseModelForCausalLM;

// void load(ModelLoader &loader) override { PYBIND11_OVERLOAD_PURE(void, PyBaseModelForCausalLM, load, loader); }

// ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx,
// bool is_decoding) const override {
// PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx,
// is_decoding)
// }
// };

template <typename T>
static inline std::string to_string(const T &obj) {
std::ostringstream oss;
oss << obj;
return oss.str();
}

PYBIND11_MODULE(_C, m) {
m.doc() = "ChatLLM.cpp python binding";

py::enum_<ModelType>(m, "ModelType")
    .value("MINICPM", ModelType::MODEL_TYPE_MINICPM);

py::class_<minicpm::Config>(m, "MiniCPMConfig")
    // .def_readonly("dtype", &BaseConfig::dtype)
    .def_readonly("vocab_size", &minicpm::Config::vocab_size)
    .def_readonly("hidden_size", &minicpm::Config::hidden_size)
    .def_readonly("num_attention_heads", &minicpm::Config::num_attention_heads)
    .def_readonly("num_hidden_layers", &minicpm::Config::num_hidden_layers)
    .def_readonly("intermediate_size", &minicpm::Config::intermediate_size)
    .def_readonly("max_length", &minicpm::Config::max_length)
    .def_readonly("bos_token_id", &minicpm::Config::bos_token_id)
    .def_readonly("eos_token_id", &minicpm::Config::eos_token_id)
    .def_readonly("pad_token_id", &minicpm::Config::pad_token_id)
    .def_readonly("sep_token_id", &minicpm::Config::sep_token_id)
    .def_readonly("num_key_value_heads", &minicpm::Config::num_key_value_heads)
    .def_readonly("rope_scaling", &minicpm::Config::rope_scaling)
    .def_readonly("rope_theta", &minicpm::Config::rope_theta)
    .def_readonly("scale_depth", &minicpm::Config::scale_depth);

py::class_<GenerationConfig>(m, "GenerationConfig")
    .def(py::init<int, int, bool, int, float, float, int>(), "max_length"_a = 2048,
        "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0,
        "top_p"_a = 0.7, "temperature"_a = 0.95, "num_threads"_a = 0)
    .def_readwrite("max_length", &GenerationConfig::max_length)
    .def_readwrite("max_context_length", &GenerationConfig::max_context_length)
    .def_readwrite("do_sample", &GenerationConfig::do_sample)
    .def_readwrite("top_k", &GenerationConfig::top_k)
    .def_readwrite("top_p", &GenerationConfig::top_p)
    .def_readwrite("temperature", &GenerationConfig::temperature)
    .def_readwrite("num_threads", &GenerationConfig::num_threads);

// py::class_<ChatMessage>(m, "ChatMessage")
//     .def(py::init<std::string, std::string, std::vector<ToolCallMessage>>(), "role"_a, "content"_a,
//          "tool_calls"_a = std::vector<ToolCallMessage>{})
//     .def("__repr__", &to_string<ChatMessage>)
//     .def("__str__", &to_string<ChatMessage>)
//     .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
//     .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
//     .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
//     .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION)
//     .def_readwrite("role", &ChatMessage::role)
//     .def_readwrite("content", &ChatMessage::content)
//     .def_readwrite("tool_calls", &ChatMessage::tool_calls);

// py::class_<minicpm::Tokenizer>(m, "Tokenizer")
//     .def("encode", &minicpm::Tokenizer::encode, py::arg("text"))
//     .def("decode", &minicpm::Tokenizer::decode, "ids"_a);

// py::class_<chatllm::BaseHistoryEncoder>(m, "BaseHistoryEncoder");
// py::class_<chatllm::BaseTokenizer>(m, "BaseTokenizer")
//     .def("load", [](chatllm::BaseTokenizer& tokenizer, const char *buffer, int n_vocab){

//     });
// py::class_<chatllm::BaseStreamer>(m, "BaseStreamer");
// py::class_<chatllm::TextStreamer>(m, "TextStreamer");
    // .def(py::init<chatllm::BaseTokenizer>(), "tokenizer"_a); // 有bug

py::class_<chatllm::BaseTokenizer, minicpm::Tokenizer>(m, "MiniCPMTokenizer")
    .def("encode", [](minicpm::Tokenizer& tokenizer, const std::string& text){
        return tokenizer.encode(text);
    })
    .def("decode", [](minicpm::Tokenizer& tokenizer, const std::vector<int> &ids){
        return tokenizer.decode(ids);
    });
    // .def("load", [](minicpm::Tokenizer& tokenizer, const char *buffer, int n_vocab){
    //     return tokenizer.load(buffer, n_vocab);
    // });

// py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
//     .def("generate_next_token", &minicpm::ConditionalGeneration::generate_next_token, 
//     "input_ids"_a, "gen_config"_a);

py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
    .def("generate_next_token", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config) {
        int gen_token = -1;
        if (generation.get_n_past() == 0) {
            gen_token = generation.generate_next_token(input_ids, gen_config);
            generation.set_n_past(generation.get_n_past() + input_ids.size());
        } else {
            int lastElement = input_ids.back();
            const std::vector<int> &lastElementVec = {lastElement};
            gen_token = generation.generate_next_token(lastElementVec, gen_config);
            generation.set_n_past(generation.get_n_past() + 1);
        }
        return gen_token;
    })
    .def("reset_n_past", [](minicpm::ConditionalGeneration& generation){
        generation.set_n_past(0);
    })
    .def_readonly("config", &minicpm::ConditionalGeneration::config);
    // .def("generate", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config,
    //                           const bool continuous,
    //                           bool &completed){
        
    // });

// ===== ChatGLM3 =====

// py::class_<ChatGLM3Tokenizer, BaseTokenizer>(m, "ChatGLM3Tokenizer");

// ===== Pipeline ====

py::class_<Pipeline>(m, "Pipeline")
    .def(py::init<const std::string &>(), "path"_a)
    .def_property_readonly("model", [](const Pipeline &self) { return self.model; })
    .def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer; })
    .def("chat", [](Pipeline& pipeline, std::vector<std::string> &history, const GenerationConfig &gen_config){
        return pipeline.chat(history, gen_config);
    });

}

} // namespace chatglm

from pathlib import Path
import chatllm_cpp._C as _C

class Pipeline(_C.Pipeline):
def init(self, model_path: str) -> None:
if Path(model_path).is_file():
# load ggml model
super().init(str(model_path))
else:
raise RuntimeError("參數錯誤")

def chat(
    self,
    message: str,
    *,
    max_length: int = 2048,
    max_context_length: int = 512,
    do_sample: bool = True,
    top_k: int = 0,
    top_p: float = 0.7,
    temperature: float = 0.95,
    num_threads: int = 0,
    # stream: bool = False,
):
    input_ids = self.tokenizer.encode(message)
    
    gen_config = _C.GenerationConfig(
        max_length=max_length,
        max_new_tokens=max_new_tokens,
        max_context_length=max_context_length,
        do_sample=do_sample,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        num_threads=num_threads,
    )
    _C.
    if stream:
        return self._stream_chat(input_ids=input_ids, gen_config=gen_config)
    return self._sync_chat(input_ids=input_ids, gen_config=gen_config)

import _C
pipeline = _C.Pipeline(r"C:\Users\KyoDa\Downloads\chatllm.cpp\quantized_16.bin")
question = "Hello."
ids = pipeline.tokenizer.encode(f" <用戶>{question}<AI>")
config = _C.GenerationConfig()
new_token = 0
pipeline.model.reset_n_past()
print(pipeline.model.config.eos_token_id, "<< id")
while new_token != pipeline.model.config.eos_token_id:
new_token = pipeline.model.generate_next_token(ids, config)
ids.append(new_token);
print(new_token, end=',', flush=True)

print(pipeline.tokenizer.decode(ids))

pipeline.chat(["Hello."], config)

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,967評論 6 531
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,273評論 3 415
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 175,870評論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,742評論 1 309
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,527評論 6 407
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,010評論 1 322
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,108評論 3 440
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,250評論 0 288
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 48,769評論 1 333
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,656評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,853評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,371評論 5 358
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,103評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,472評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,717評論 1 281
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,487評論 3 390
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,815評論 2 372

推薦閱讀更多精彩內容

  • 記錄源碼編譯Tensorflow的曲折彎路 前言 通過tensorflow訓練深度學習神經網絡模型一般是pytho...
    V_愛一世春秋閱讀 3,625評論 1 0
  • 前言 在之前的pybind11系列實踐中,開發流程大致是這樣的: 第一步: 首先在C/C++ IDE中編寫C/C+...
    俠之大者_7d3f閱讀 14,276評論 4 4
  • CPP 1、在main執行之前和之后執行的代碼可能是什么? main函數執行之前,主要就是初始化系統相關資源: 設...
    voidFan閱讀 1,709評論 1 6
  • 廢話不多說,自己進入今天的主題 1、面向對象的特征有哪些方面? 答:面向對象的特征主要有以下幾個方面: - 抽象:...
    傳奇內服號閱讀 2,369評論 1 31
  • 程序員面試寶典 一、C++ 基礎 1. 位運算 返回x二進制數中的1的個數? 返回x,y的平均值? 返回絕對值?...
    小任同學an閱讀 1,197評論 0 0