mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-05-14 18:03:02 +08:00
fix #1213
This commit is contained in:
parent
833d417db1
commit
e8c8db85ae
@ -87,6 +87,7 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
|
||||
int64_t ncentroids = -1;
|
||||
bool use_2layer = false;
|
||||
int hnsw_M = -1;
|
||||
|
||||
for (char *tok = strtok_r (description, " ,", &ptr);
|
||||
tok;
|
||||
@ -186,6 +187,8 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
del_coarse_quantizer.release ();
|
||||
index_ivf->own_fields = true;
|
||||
index_1 = index_ivf;
|
||||
} else if (hnsw_M > 0) {
|
||||
index_1 = new IndexHNSWFlat (d, hnsw_M, metric);
|
||||
} else {
|
||||
FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup",
|
||||
"dedup supported only for IVFFlat");
|
||||
@ -209,6 +212,8 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
del_coarse_quantizer.release ();
|
||||
index_ivf->own_fields = true;
|
||||
index_1 = index_ivf;
|
||||
} else if (hnsw_M > 0) {
|
||||
index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
|
||||
} else {
|
||||
index_1 = new IndexScalarQuantizer (d, qt, metric);
|
||||
}
|
||||
@ -248,6 +253,11 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
index_2l->q1.own_fields = true;
|
||||
index_1 = index_2l;
|
||||
}
|
||||
} else if (hnsw_M > 0) {
|
||||
IndexHNSWPQ *ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
||||
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
||||
do_polysemous_training;
|
||||
index_1 = ipq;
|
||||
} else {
|
||||
IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric);
|
||||
index_pq->do_polysemous_training = do_polysemous_training;
|
||||
@ -272,13 +282,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
} else if (!index &&
|
||||
sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
|
||||
index_1 = new IndexHNSWPQ (d, pq_m, M);
|
||||
} else if (!index &&
|
||||
sscanf (tok, "HNSW%d", &M) == 1) {
|
||||
index_1 = new IndexHNSWFlat (d, M);
|
||||
} else if (!index &&
|
||||
sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
|
||||
pq_m == 8) {
|
||||
index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M);
|
||||
} else if (!index &&
|
||||
sscanf (tok, "HNSW%d", &M) == 1) {
|
||||
hnsw_M = M;
|
||||
// here it is unclear what we want: HNSW flat or HNSWx,Y ?
|
||||
} else if (!index && (stok == "LSH" || stok == "LSHr" ||
|
||||
stok == "LSHrt" || stok == "LSHt")) {
|
||||
bool rotate_data = strstr(tok, "r") != nullptr;
|
||||
@ -318,6 +329,11 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
||||
}
|
||||
}
|
||||
|
||||
if (!index && hnsw_M > 0) {
|
||||
index = new IndexHNSWFlat (d, hnsw_M, metric);
|
||||
del_index.set (index);
|
||||
}
|
||||
|
||||
FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
|
||||
description_in);
|
||||
|
||||
|
@ -52,6 +52,26 @@ class TestFactory(unittest.TestCase):
|
||||
index = faiss.index_factory(12, "IVF10,FlatDedup")
|
||||
assert index.instances is not None
|
||||
|
||||
def test_factory_HNSW(self):
|
||||
index = faiss.index_factory(12, "HNSW32")
|
||||
assert index.storage.sa_code_size() == 12 * 4
|
||||
index = faiss.index_factory(12, "HNSW32_SQ8")
|
||||
assert index.storage.sa_code_size() == 12
|
||||
index = faiss.index_factory(12, "HNSW32_PQ4")
|
||||
assert index.storage.sa_code_size() == 4
|
||||
|
||||
def test_factory_HNSW_newstyle(self):
|
||||
index = faiss.index_factory(12, "HNSW32,Flat")
|
||||
assert index.storage.sa_code_size() == 12 * 4
|
||||
index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
|
||||
assert index.storage.sa_code_size() == 12
|
||||
assert index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
index = faiss.index_factory(12, "HNSW32,PQ4")
|
||||
assert index.storage.sa_code_size() == 4
|
||||
index = faiss.index_factory(12, "HNSW32,PQ4np")
|
||||
indexpq = faiss.downcast_index(index.storage)
|
||||
assert not indexpq.do_polysemous_training
|
||||
|
||||
|
||||
class TestCloneSize(unittest.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user