This commit is contained in:
Matthijs Douze 2020-06-02 13:59:39 -07:00
parent 833d417db1
commit e8c8db85ae
2 changed files with 39 additions and 3 deletions

View File

@ -87,6 +87,7 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
int64_t ncentroids = -1;
bool use_2layer = false;
int hnsw_M = -1;
for (char *tok = strtok_r (description, " ,", &ptr);
tok;
@ -186,6 +187,8 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
} else if (hnsw_M > 0) {
index_1 = new IndexHNSWFlat (d, hnsw_M, metric);
} else {
FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup",
"dedup supported only for IVFFlat");
@ -209,6 +212,8 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
} else if (hnsw_M > 0) {
index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
} else {
index_1 = new IndexScalarQuantizer (d, qt, metric);
}
@ -248,6 +253,11 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
index_2l->q1.own_fields = true;
index_1 = index_2l;
}
} else if (hnsw_M > 0) {
IndexHNSWPQ *ipq = new IndexHNSWPQ(d, M, hnsw_M);
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
do_polysemous_training;
index_1 = ipq;
} else {
IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric);
index_pq->do_polysemous_training = do_polysemous_training;
@ -272,13 +282,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
} else if (!index &&
sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
index_1 = new IndexHNSWPQ (d, pq_m, M);
} else if (!index &&
sscanf (tok, "HNSW%d", &M) == 1) {
index_1 = new IndexHNSWFlat (d, M);
} else if (!index &&
sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
pq_m == 8) {
index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M);
} else if (!index &&
sscanf (tok, "HNSW%d", &M) == 1) {
hnsw_M = M;
// here it is unclear what we want: HNSW flat or HNSWx,Y ?
} else if (!index && (stok == "LSH" || stok == "LSHr" ||
stok == "LSHrt" || stok == "LSHt")) {
bool rotate_data = strstr(tok, "r") != nullptr;
@ -318,6 +329,11 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
}
}
if (!index && hnsw_M > 0) {
index = new IndexHNSWFlat (d, hnsw_M, metric);
del_index.set (index);
}
FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
description_in);

View File

@ -52,6 +52,26 @@ class TestFactory(unittest.TestCase):
index = faiss.index_factory(12, "IVF10,FlatDedup")
assert index.instances is not None
def test_factory_HNSW(self):
index = faiss.index_factory(12, "HNSW32")
assert index.storage.sa_code_size() == 12 * 4
index = faiss.index_factory(12, "HNSW32_SQ8")
assert index.storage.sa_code_size() == 12
index = faiss.index_factory(12, "HNSW32_PQ4")
assert index.storage.sa_code_size() == 4
def test_factory_HNSW_newstyle(self):
index = faiss.index_factory(12, "HNSW32,Flat")
assert index.storage.sa_code_size() == 12 * 4
index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
assert index.storage.sa_code_size() == 12
assert index.metric_type == faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(12, "HNSW32,PQ4")
assert index.storage.sa_code_size() == 4
index = faiss.index_factory(12, "HNSW32,PQ4np")
indexpq = faiss.downcast_index(index.storage)
assert not indexpq.do_polysemous_training
class TestCloneSize(unittest.TestCase):