From 127d0b49e8d3a6d60b7bbcaea29c4ab123d58ff9 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 29 Nov 2023 02:39:54 +0300 Subject: [PATCH 1/2] Fix handling invalid images --- src/image_embedder.cpp | 20 ++++++++++++++++++-- src/image_processor.cpp | 3 +-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/image_embedder.cpp b/src/image_embedder.cpp index a347811d..0b6b5213 100644 --- a/src/image_embedder.cpp +++ b/src/image_embedder.cpp @@ -10,7 +10,10 @@ embedding_res_t CLIPImageEmbedder::embed(const std::string& encoded_image) { auto processed_image_op = image_processor_.process_image(encoded_image); if (!processed_image_op.ok()) { - return embedding_res_t(processed_image_op.code(), processed_image_op.error()); + nlohmann::json error_json; + error_json["error"] = processed_image_op.error(); + results[i] = embedding_res_t(processed_image_op.code(), error_json); + return embedding_res_t(processed_image_op.code(), error_json); } auto processed_image = processed_image_op.get(); @@ -58,7 +61,9 @@ std::vector CLIPImageEmbedder::batch_embed(const std::vector CLIPImageEmbedder::batch_embed(const std::vector result_vector(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + result_vector[i] = results[i]; + } + + return result_vector; + } + // create input tensor std::vector input_shape = {static_cast(processed_images.size()), 3, 224, 224}; std::vector input_names = {"input_ids", "pixel_values", "attention_mask"}; diff --git a/src/image_processor.cpp b/src/image_processor.cpp index 707031d9..e9a942b3 100644 --- a/src/image_processor.cpp +++ b/src/image_processor.cpp @@ -36,8 +36,7 @@ Option CLIPImageProcessor::process_image(const std::string& i LOG(INFO) << "Running image processor"; try { output_tensors = session_->Run(Ort::RunOptions{nullptr}, input_names.data(), &input_tensor, 1, output_names.data(), output_names.size()); - } catch (const std::exception& e) { - LOG(INFO) << "Error while running image processor: " << e.what(); + } catch (...) { return Option(400, "Error while processing image"); } From 42511a05be3f0af2448af10678d60ed86c75385d Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 29 Nov 2023 02:52:08 +0300 Subject: [PATCH 2/2] Add test --- src/image_embedder.cpp | 1 - test/collection_vector_search_test.cpp | 30 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/image_embedder.cpp b/src/image_embedder.cpp index 0b6b5213..4ac5c75c 100644 --- a/src/image_embedder.cpp +++ b/src/image_embedder.cpp @@ -12,7 +12,6 @@ embedding_res_t CLIPImageEmbedder::embed(const std::string& encoded_image) { if (!processed_image_op.ok()) { nlohmann::json error_json; error_json["error"] = processed_image_op.error(); - results[i] = embedding_res_t(processed_image_op.code(), error_json); return embedding_res_t(processed_image_op.code(), error_json); } diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index fc1cf4e4..243a0594 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2987,7 +2987,6 @@ TEST_F(CollectionVectorTest, TestImageEmbedding) { auto coll = collection_create_op.get(); - LOG(INFO) << "Adding image to collection"; auto add_op = coll->add(R"({ "name": "dog", @@ -3045,4 +3044,33 @@ TEST_F(CollectionVectorTest, TryAddingMultipleImageFieldToEmbedFrom) { ASSERT_FALSE(collection_create_op.ok()); ASSERT_EQ(collection_create_op.error(), "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field."); +} + +TEST_F(CollectionVectorTest, TestInvalidImage) { + auto schema_json = + R"({ + "name": "Images", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "image", "type": "image", "store": false}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["image"], "model_config": {"model_name": "ts/clip-vit-b-p32"}}} + ] + })"_json; + + EmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + auto coll = collection_create_op.get(); + + auto add_op = coll->add(R"({ + "name": "teddy bear", + "image": "invalid" + })"_json.dump()); + + ASSERT_FALSE(add_op.ok()); + + ASSERT_EQ(add_op.error(), "Error while processing image"); + } \ No newline at end of file