diff --git a/include/conversation_model.h b/include/conversation_model.h index 61ce16ff..d98ba270 100644 --- a/include/conversation_model.h +++ b/include/conversation_model.h @@ -54,6 +54,7 @@ class CFConversationModel : public ConversationModel { static Option get_standalone_question(const nlohmann::json& conversation_history, const std::string& question, const nlohmann::json& model_config); static Option format_question(const std::string& message); static Option format_answer(const std::string& message); + static Option parse_stream_response(const std::string& response); static const inline std::string STANDALONE_QUESTION_PROMPT = R"( Rewrite the follow-up question on top of a human-assistant conversation history as a standalone question that encompasses all pertinent context.Use 1024 characters maximum. )"; diff --git a/src/conversation_model.cpp b/src/conversation_model.cpp index b482a462..659112e5 100644 --- a/src/conversation_model.cpp +++ b/src/conversation_model.cpp @@ -1,3 +1,5 @@ +#include +#include #include "conversation_model.h" #include "embedder_manager.h" #include "text_embedder_remote.h" @@ -450,31 +452,8 @@ Option CFConversationModel::get_answer(const std::string& context, json_res = json_res["errors"][0]; return Option(400, "Cloudflare API error: " + json_res["message"].get()); } - try { - auto json_res = nlohmann::json::parse(res); - std::string parsed_response = ""; - std::vector lines = json_res["response"].get>(); - for(auto& line : lines) { - while(line.find("data:") != std::string::npos) { - auto substr_line = line.substr(line.find("data:") + 6); - if(substr_line.find("[DONE]") != std::string::npos) { - break; - } - nlohmann::json json_line; - if(substr_line.find("\n") != std::string::npos) { - json_line = nlohmann::json::parse(substr_line.substr(0, substr_line.find("\n"))); - } else { - json_line = nlohmann::json::parse(substr_line); - } - parsed_response += json_line["response"]; - line = substr_line; - } - } - return Option(parsed_response); - } catch (const std::exception& e) { - LOG(ERROR) << e.what(); - return Option(400, "Got malformed response from Cloudflare API."); - } + + return parse_stream_response(res); } Option CFConversationModel::get_standalone_question(const nlohmann::json& conversation_history, @@ -572,4 +551,31 @@ Option CFConversationModel::format_answer(const std::string& mes nlohmann::json json = nlohmann::json::object(); json["assistant"] = message; return Option(json); +} + +Option CFConversationModel::parse_stream_response(const std::string& res) { + try { + auto json_res = nlohmann::json::parse(res); + std::string parsed_response = ""; + std::vector lines = json_res["response"].get>(); + std::regex data_regex("data: (.*?)\\n\\n"); + for(auto& line : lines) { + auto begin = std::sregex_iterator(line.begin(), line.end(), data_regex); + auto end = std::sregex_iterator(); + for (std::sregex_iterator i = begin; i != end; ++i) { + std::string substr_line = i->str().substr(6, i->str().size() - 8); + if(substr_line.find("[DONE]") != std::string::npos) { + break; + } + nlohmann::json json_line; + json_line = nlohmann::json::parse(substr_line); + parsed_response += json_line["response"]; + } + } + return Option(parsed_response); + } catch (const std::exception& e) { + LOG(ERROR) << e.what(); + LOG(ERROR) << "Response: " << res; + return Option(400, "Got malformed response from Cloudflare API."); + } } \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 78dc95cd..5258e61c 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -10,6 +10,7 @@ #include "index.h" #include "core_api.h" #include "vq_model_manager.h" +#include "conversation_model.h" class CollectionVectorTest : public ::testing::Test { protected: @@ -4149,4 +4150,272 @@ TEST_F(CollectionVectorTest, TestUpdatingSameDocument){ results_json = results.get(); ASSERT_EQ(results_json["found"].get(), results_json["hits"].size()); -} \ No newline at end of file +} + +TEST_F(CollectionVectorTest, TestCFModelResponseParsing) { + std::string res = R"( + { + "response": [ + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"publish\"}\n\n", + "data: {\"response\":\"Date\"}\n\n", + "data: {\"response\":\"Year\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \"}\n\n", + "data: {\"response\":\"2\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"title\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"S\"}\n\n", + "data: {\"response\":\"OP\"}\n\n", + "data: {\"response\":\"A\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"top\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" [\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Links\"}\n\n", + "data: {\"response\":\" to\"}\n\n", + "data: {\"response\":\" x\"}\n\n", + "data: {\"response\":\"k\"}\n\n", + "data: {\"response\":\"cd\"}\n\n", + "data: {\"response\":\".\"}\n\n", + "data: {\"response\":\"com\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Apr\"}\n\n", + "data: {\"response\":\"il\"}\n\n", + "data: {\"response\":\" fool\"}\n\n", + "data: {\"response\":\"s\"}\n\n", + "data: {\"response\":\"'\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Inter\"}\n\n", + "data: {\"response\":\"active\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\" with\"}\n\n", + "data: {\"response\":\" animation\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Dynamic\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\" with\"}\n\n", + "data: {\"response\":\" audio\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\" ],\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"trans\"}\n\n", + "data: {\"response\":\"cript\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"},\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"{\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"alt\"}\n\n", + "data: {\"response\":\"Title\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"I\"}\n\n", + "data: {\"response\":\"'\"}\n\n", + "data: {\"response\":\"m\"}\n\n", + "data: {\"response\":\" currently\"}\n\n", + "data: {\"response\":\" getting\"}\n\n", + "data: {\"response\":\" totally\"}\n\n", + "data: {\"response\":\" black\"}\n\n", + "data: {\"response\":\"ed\"}\n\n", + "data: {\"response\":\" out\"}\n\n", + "data: {\"response\":\".\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"id\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"6\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"image\"}\n\n", + "data: {\"response\":\"Url\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"https\"}\n\n", + "data: {\"response\":\"://\"}\n\n", + "data: {\"response\":\"im\"}\n\n", + "data: {\"response\":\"gs\"}\n\n", + "data: {\"response\":\".\"}\n\n", + "data: {\"response\":\"x\"}\n\n", + "data: {\"response\":\"k\"}\n\n", + "data: {\"response\":\"cd\"}\n\n", + "data: {\"response\":\".\"}\n\n", + "data: {\"response\":\"com\"}\n\n", + "data: {\"response\":\"/\"}\n\n", + "data: {\"response\":\"com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"/\"}\n\n", + "data: {\"response\":\"black\"}\n\n", + "data: {\"response\":\"out\"}\n\n", + "data: {\"response\":\".\"}\n\n", + "data: {\"response\":\"png\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"publish\"}\n\n", + "data: {\"response\":\"Date\"}\n\n", + "data: {\"response\":\"Day\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\"8\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"publish\"}\n\n", + "data: {\"response\":\"Date\"}\n\n", + "data: {\"response\":\"Month\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"publish\"}\n\n", + "data: {\"response\":\"Date\"}\n\n", + "data: {\"response\":\"Timestamp\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\"3\"}\n\n", + "data: {\"response\":\"2\"}\n\n", + "data: {\"response\":\"6\"}\n\n", + "data: {\"response\":\"8\"}\n\n", + "data: {\"response\":\"6\"}\n\n", + "data: {\"response\":\"6\"}\n\n", + "data: {\"response\":\"4\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"publish\"}\n\n", + "data: {\"response\":\"Date\"}\n\n", + "data: {\"response\":\"Year\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \"}\n\n", + "data: {\"response\":\"2\"}\n\n", + "data: {\"response\":\"0\"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\"1\"}\n\n", + "data: {\"response\":\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"title\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" \\\"\"}\n\n", + "data: {\"response\":\"Black\"}\n\n", + "data: {\"response\":\"out\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"top\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\":\"}\n\n", + "data: {\"response\":\" [\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Links\"}\n\n", + "data: {\"response\":\" to\"}\n\n", + "data: {\"response\":\" x\"}\n\n", + "data: {\"response\":\"k\"}\n\n", + "data: {\"response\":\"cd\"}\n\n", + "data: {\"response\":\".\"}\n\n", + "data: {\"response\":\"com\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Apr\"}\n\n", + "data: {\"response\":\"il\"}\n\n", + "data: {\"response\":\" fool\"}\n\n", + "data: {\"response\":\"s\"}\n\n", + "data: {\"response\":\"'\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Inter\"}\n\n", + "data: {\"response\":\"active\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\" with\"}\n\n", + "data: {\"response\":\" animation\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Dynamic\"}\n\n", + "data: {\"response\":\" com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\"\\\",\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"Com\"}\n\n", + "data: {\"response\":\"ics\"}\n\n", + "data: {\"response\":\" with\"}\n\n", + "data: {\"response\":\" audio\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\" ],\"}\n\n", + "data: {\"response\":\"\\n\"}\n\n", + "data: {\"response\":\"\\\"\"}\n\n", + "data: {\"response\":\"\"}\n\ndata: [DONE]\n\n" + ] + })"; + auto parsed_string = CFConversationModel::parse_stream_response(res); + ASSERT_TRUE(parsed_string.ok()); + ASSERT_EQ("00,\n\"publishDateYear\": 2011,\n\"title\": \"SOPA\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"transcript\": \" \"\n},\n{\n\"altTitle\": \"I'm currently getting totally blacked out.\",\n\"id\": \"1006\",\n\"imageUrl\": \"https://imgs.xkcd.com/comics/blackout.png\",\n\"publishDateDay\": 18,\n\"publishDateMonth\": 1,\n\"publishDateTimestamp\": 1326866400,\n\"publishDateYear\": 2011,\n\"title\": \"Blackout\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"", parsed_string.get()); +} +