From c46d142b66e0ebe4bee71a77288d58c04e10f241 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Tue, 19 Dec 2023 14:51:46 +0300 Subject: [PATCH 1/2] Fix mean pooling function --- include/text_embedder.h | 2 +- src/text_embedder.cpp | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/include/text_embedder.h b/include/text_embedder.h index 817e7cf8..1cf79fa9 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -45,7 +45,7 @@ class TextEmbedder { std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; - static std::vector mean_pooling(const std::vector>& input); + static std::vector mean_pooling(const std::vector>& input, const std::vector& attention_mask); std::string output_tensor_name; size_t num_dim; std::mutex mutex_; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b6780505..1955409d 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -91,16 +91,28 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims) } -std::vector TextEmbedder::mean_pooling(const std::vector>& inputs) { +std::vector TextEmbedder::mean_pooling(const std::vector>& inputs, const std::vector& attention_mask) { std::vector pooled_output; for (int i = 0; i < inputs[0].size(); i++) { float sum = 0; for (int j = 0; j < inputs.size(); j++) { - sum += inputs[j][i]; + sum += inputs[j][i] * attention_mask[j]; } - pooled_output.push_back(sum / inputs.size()); + pooled_output.push_back(sum); } + + // get sum of attention mask + float sum_attention_mask = 0; + for(auto& val : attention_mask) { + sum_attention_mask += val; + } + + // divide by sum of attention mask + for(auto& val : pooled_output) { + val /= sum_attention_mask; + } + return pooled_output; } @@ -171,7 +183,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote } output.push_back(temp); } - auto pooled_output = mean_pooling(output); + auto pooled_output = mean_pooling(output, encoded_input.attention_mask); return embedding_res_t(pooled_output); } } @@ -277,7 +289,7 @@ std::vector TextEmbedder::batch_embed(const std::vectorget_tokenizer_type() != TokenizerType::clip) { - outputs.push_back(embedding_res_t(mean_pooling(output))); + outputs.push_back(embedding_res_t(mean_pooling(output, encoded_inputs.attention_mask[i]))); } } } From 3972b66d5c68b76c1ebff3fe3cebf3c750edb0cd Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Tue, 19 Dec 2023 15:29:28 +0300 Subject: [PATCH 2/2] Add test --- test/collection_vector_search_test.cpp | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 280760e8..17b75aeb 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -7,6 +7,7 @@ #include #include "conversation_manager.h" #include "conversation_model_manager.h" +#include "index.h" class CollectionVectorTest : public ::testing::Test { protected: @@ -3252,4 +3253,44 @@ TEST_F(CollectionVectorTest, Test0VectorDistance) { ASSERT_EQ(results["hits"].size(), 1); ASSERT_EQ(results["hits"][0].count("vector_distance"), 1); ASSERT_EQ(results["hits"][0]["vector_distance"], 0); +} + + +TEST_F(CollectionVectorTest, TestEmbeddingValues) { + auto schema_json = + R"({ + "name": "test", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/all-MiniLM-L12-v2"}}} + ] + })"_json; + + EmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + auto add_op = coll->add(R"({ + "name": "Elskovsbarnet" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + std::vector embeddings = add_op.get()["embedding"]; + + std::vector normalized_embeddings(embeddings.size()); + + hnsw_index_t::normalize_vector(embeddings, normalized_embeddings); + + ASSERT_EQ(embeddings.size(), 384); + + std::vector actual_values{-0.07409533113241196, -0.02963513322174549, -0.018120333552360535, 0.012058400548994541, -0.07219868153333664, -0.09295058250427246, 0.018390782177448273, 0.007814675569534302, 0.026419874280691147, 0.037965331226587296, 0.020393727347254753, -0.04090584069490433, 0.03194206580519676, 0.025205004960298538, 0.02059922367334366, 0.026202859356999397, 0.009739107452332973, 0.07967381179332733, -0.006712059490382671, -0.045936256647109985, -0.0280868299305439, -0.028282660990953445, 0.00617704214528203, -0.0756121575832367, -0.009177971631288528, -0.0016412553377449512, -0.040854115039110184, -0.007597113959491253, -0.03225032240152359, -0.015282290056347847, -0.013507066294550896, -0.11270778626203537, 0.12383124977350235, 0.09607065469026566, -0.106889508664608, 0.02146402932703495, 0.061281926929950714, -0.04245373234152794, -0.05668728053569794, 0.02623145468533039, 0.016187654808163643, 0.05603780969977379, 0.0119243822991848, -0.004412775859236717, 0.040246933698654175, 0.07487507909536362, -0.05067175254225731, 0.030055716633796692, 0.014153759926557541, -0.04411328583955765, -0.010018891654908657, -0.08593358099460602, 0.037568483501672745, -0.10012772679328918, 0.029019853100180626, 0.019645709544420242, -0.0639389306306839, 0.02652929536998272, 0.015299974009394646, 0.07286490499973297, 0.029529787600040436, -0.044351380318403244, -0.041604846715927124, 0.06385225802659988, -0.007908550091087818, -0.003856210969388485, -0.03855051472783089, -0.0023078585509210825, -0.04141264036297798, -0.05051504448056221, -0.018076501786708832, -0.017384130507707596, 0.024294942617416382, 0.12094006687402725, 0.01351782027631998, 0.08950492739677429, 0.027889391407370567, -0.03165547922253609, -0.017131352797150612, -0.022714827209711075, 0.048935145139694214, -0.012115311808884144, -0.0575471930205822, -0.019780246540904045, 0.052039679139852524, 0.00199871021322906, -0.010556189343333244, -0.0176922008395195, -0.01899656467139721, -0.005256693810224533, -0.06929342448711395, -0.01906348578631878, 0.10669232159852982, -0.0058551388792693615, 0.011760520748794079, 0.0066625443287193775, 0.0019288291223347187, -0.08495593070983887, 0.03902851417660713, 0.1967391073703766, 0.007772537413984537, -0.04112537205219269, 0.08704622834920883, 0.007129311095923185, -0.07165598124265671, -0.06986088305711746, -0.028463803231716156, -0.02357759326696396, 0.015329649671912193, -0.01065903902053833, -0.09958454966545105, 0.020069725811481476, -0.04014518857002258, -0.0660862997174263, -0.055922750383615494, -0.032036129385232925, 0.01381504163146019, -0.0673903375864029, -0.025027597323060036, 0.021608922630548477, -0.0620601624250412, 0.03505481034517288, -0.054973628371953964, -0.0021920157596468925, -0.01736101694405079, -0.1220683753490448, -0.07779566198587418, 0.0008724227664060891, -0.046745795756578445, 0.06985874474048615, -0.06745105981826782, 0.052744727581739426, 0.03683020919561386, -0.03435657545924187, -0.06987597048282623, 0.00887364149093628, -0.04392600059509277, -0.03942466899752617, -0.057737983763217926, -0.00721937557682395, 0.010713488794863224, 0.03875933587551117, 0.15718387067317963, 0.008935746736824512, -0.06421459466218948, 0.02290276437997818, 0.034633539617061615, -0.06684417277574539, 0.0005746493698097765, -0.028561286628246307, 0.07741032540798187, -0.016047099605202675, 0.07573956996202469, -0.07167335599660873, -0.0015375938965007663, -0.019324950873851776, -0.033263999968767166, 0.014723926782608032, -0.0691518783569336, -0.06772343814373016, 0.0042124162428081036, 0.07307381927967072, 0.03486260399222374, 0.04603007435798645, 0.07130003720521927, -0.02456359565258026, -0.006673890631645918, -0.02338244579732418, 0.011230859905481339, 0.019877653568983078, -0.03518665209412575, 0.0206899493932724, 0.05910487845540047, 0.019732976332306862, 0.04096956551074982, 0.07400382310152054, -0.03024907223880291, -0.015541939064860344, -0.008652037009596825, 0.0935826525092125, -0.049539074301719666, -0.04189642146229744, -0.07915540784597397, 0.030161747708916664, 0.05217037349939346, 0.008498051203787327, -0.02225595712661743, 0.041023027151823044, -0.008676717057824135, 0.03920895606279373, 0.042901333421468735, -0.0509256087243557, 0.03418148308992386, 0.10294827818870544, -0.007491919212043285, -0.04547177255153656, -0.0013863483909517527, -0.016816288232803345, 0.0057535297237336636, 0.04133246839046478, -0.014831697568297386, 0.1096695065498352, -0.02640458010137081, 0.05342832952737808, -0.10505645722150803, -0.069507896900177, -0.04607844352722168, 0.030713962391018867, -0.047581497579813004, 0.07578378170728683, 0.02707124687731266, 0.05470479652285576, 0.01324087381362915, 0.005669544450938702, 0.07757364213466644, -0.027681969106197357, 0.015634633600711823, 0.011706131510436535, -0.11028207093477249, -0.03370887413620949, 0.0342826321721077, 0.052396781742572784, -0.03439828380942345, -9.332131367059089e-33, -0.003496044548228383, -0.0012644683010876179, 0.007245716638863087, 0.08308663219213486, -0.12923602759838104, 0.01113795768469572, -0.015030942857265472, 0.01813196949660778, -0.08993704617023468, 0.056248947978019714, 0.10432837903499603, 0.008380789309740067, 0.08054981380701065, -0.0016472548013553023, 0.0940462201833725, -0.002078677760437131, -0.040112320333719254, -0.022219669073820114, -0.08358576893806458, -0.022520577535033226, 0.026831910014152527, 0.020184528082609177, -0.019914891570806503, 0.11616221070289612, -0.08901996910572052, -0.016575688496232033, 0.027953164651989937, 0.07949092239141464, -0.03504502400755882, -0.04410504922270775, -0.012492713518440723, -0.06611645221710205, -0.020088162273168564, -0.019216760993003845, 0.08393155038356781, 0.11951949447393417, 0.06375068426132202, -0.061182133853435516, -0.09066124260425568, -0.046286359429359436, 0.02162717469036579, -0.02759421616792679, -0.09041713923215866, 0.008177299052476883, -0.006156154442578554, -0.0033287708647549152, -0.004311972297728062, -0.01960325799882412, -0.08414454013109207, -0.0034149065613746643, 0.015856321901082993, -0.0005123159498907626, -0.027074772864580154, 0.03869790956377983, 0.050786130130290985, -0.028933823108673096, -0.07446572184562683, 0.022279445081949234, 0.012226884253323078, -0.01748575083911419, -0.055989284068346024, -0.011646092869341373, -0.0002180236770072952, 0.10100196301937103, 0.02999500371515751, -0.021314362064003944, -0.04096762463450432, 0.05568964406847954, -0.004973178263753653, 0.013144302181899548, 0.022288570180535316, 0.09443598240613937, 0.0018029726343229413, -0.09654559940099716, -0.01457826979458332, 0.04508035257458687, 0.06526371091604233, -0.03033633343875408, 0.009471519850194454, -0.11114948242902756, -0.046912480145692825, -0.10612039268016815, 0.11780810356140137, -0.026177652180194855, 0.0320870615541935, -0.015745604410767555, 0.06458097696304321, 0.048562128096818924, -0.034073326736688614, -0.03065350651741028, 0.06918460875749588, 0.06126512959599495, 0.0058005815371870995, -0.03808598220348358, 0.03678971901535988, 4.168464892362657e-32, -0.0452132411301136, 0.051136620342731476, -0.09363184124231339, -0.032540980726480484, 0.08147275447845459, 0.03507697954773903, 0.04584404081106186, -0.00924444105476141, -0.012075415812432766, 0.0541100800037384, -0.015797585248947144, 0.05510234460234642, -0.04699498042464256, -0.018956895917654037, -0.04772498831152916, 0.05756324902176857, -0.0827300101518631, 0.004980154801160097, 0.024522915482521057, -0.019712436944246292, 0.009034484624862671, -0.012837578542530537, 0.026660654693841934, 0.06716003268957138, -0.05956435948610306, 0.0010818272130563855, -0.018492311239242554, 0.034606318920850754, 0.04679758474230766, -0.020694732666015625, 0.06055215373635292, -0.04266247898340225, 0.008420216850936413, -0.02698715589940548, -0.028203830122947693, 0.029279250651597977, -0.010966592468321323, -0.03348863869905472, -0.07982659339904785, -0.03935334458947182, -0.02174490876495838, -0.04081539437174797, 0.049022793769836426, -0.01604332961142063, -0.0032012134324759245, 0.0893029123544693, -0.0230527613312006, 0.01536057610064745, 0.027288464829325676, -0.01401998195797205, -0.057258568704128265, -0.07299835979938507, 0.032278336584568024, 0.040280167013406754, 0.060383908450603485, -0.0012196602765470743, 0.02501964196562767, -0.03808143362402916, -0.08765897154808044, 0.047424230724573135, -0.04527046158909798, -0.015525433234870434, -0.02020418457686901, -0.06228169426321983}; + + for (int i = 0; i < 384; i++) { + EXPECT_NEAR(normalized_embeddings[i], actual_values[i], 0.00001); + } } \ No newline at end of file