foundationdb/flow/IndexedSet.h
Alex Miller 0ac868ad5d "Simplify" IndexedSet's insert and addMetric API.
The existing code tried to work around the complexities of optionally using
rvalue references' move capabilities if they exist.  As seen in the previous
MapPair, there's a combinatorial explosion of prototypes to declare as the
parameter length increases.  Because of this, addMetric ended up with a strange
API, and there was a wrapper to make a copy for insert.

Instead, we can apply the idiom of using universal/forwarding references and
std::forward to allow the compiler to instantiate the combinations that are
needed.  There's a TagData struct with no copy constructor that validates that
move constructors can be properly called still.

I measured a 12-byte difference between before and after this change, so no
template bloat was introduced.
2017-10-03 20:15:12 -07:00

1115 lines
38 KiB
C++

/*
* IndexedSet.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2018 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLOW_INDEXEDSET_H
#define FLOW_INDEXEDSET_H
#pragma once
#include "Platform.h"
#include "FastAlloc.h"
#include "Trace.h"
#include "Error.h"
#include <deque>
#include <vector>
// IndexedSet<T, Metric> is similar to a std::set<T>, with the following additional features:
// - Each element in the set is associated with a value of type Metric
// - sumTo() and sumRange() can report the sum of the metric values associated with a
// contiguous range of elements in O(lg N) time
// - index() can be used to find an element having a given sumTo() in O(lg N) time
// - Search functions (find(), lower_bound(), etc) can accept a type comparable to T instead of T
// (e.g. StringRef when T is std::string or Standalone<StringRef>). This can save a lot of needless
// copying at query time for read-mostly sets with string keys.
// - iterators are not const; the responsibility of not changing the order lies with the caller
// - the size() function is missing; if the metric being used is a count sumTo(end()) will do instead
// A number of STL compatibility features are missing and should be added as needed.
// T must define operator <, which must define a total order. Unlike std::set,
// a user-defined predicate is not currently supported as a template parameter.
// Metric is required to have operators + and - and <, and behavior is undefined if
// the sum of metrics for all elements of a set overflows the Metric type.
// Map<Key,Value> is similar to a std::map<Key,Value>, except that it inherits the search key type
// flexibility of IndexedSet<>, uses MapPair<Key,Value> by default instead of pair<Key,Value>
// (use iterator->key instead of iterator->first), and uses FastAllocator for nodes.
template <class T>
class Future;
class Void;
template <class T, class Metric>
struct IndexedSet{
typedef T value_type;
typedef T key_type;
private: // Forward-declare IndexedSet::Node because Clang is much stricter about this ordering.
struct Node : FastAllocated<Node> {
// Here, and throughout all code that indirectly instantiates a Node, we rely on forwarding
// references so that we don't need to maintain the set of 2^arity lvalue and rvalue reference
// combinations, but still take advantage of move constructors when available (or required).
template <class T_, class Metric_>
Node(T_&& data, Metric_&& m, Node* parent=0) : data(std::forward<T_>(data)), total(std::forward<Metric_>(m)), parent(parent), balance(0) {
child[0] = child[1] = NULL;
}
~Node(){
delete child[0];
delete child[1];
}
T data;
signed char balance; // right height - left height
Metric total; // this + child[0] + child[1]
Node *child[2]; // left, right
Node *parent;
};
public:
struct iterator{
typename IndexedSet::Node *i;
iterator() : i(0) {};
iterator(typename IndexedSet::Node *n) : i(n) {};
T& operator*() { return i->data; };
T* operator->() { return &i->data; }
void operator++();
void decrementNonEnd();
bool operator == ( const iterator& r ) const { return i == r.i; }
bool operator != ( const iterator& r ) const { return i != r.i; }
};
IndexedSet() : root(NULL) {};
~IndexedSet() { delete root; }
IndexedSet(IndexedSet&& r) noexcept(true) : root(r.root) { r.root = NULL; }
IndexedSet& operator=(IndexedSet&& r) noexcept(true) { delete root; root = r.root; r.root = 0; return *this; }
iterator begin() const;
iterator end() const { return iterator(); }
iterator previous(iterator i) const;
iterator lastItem() const;
bool empty() const { return !root; }
void clear() { delete root; root = NULL; }
void swap( IndexedSet& r ) { std::swap( root, r.root ); }
// Place data in the set with the given metric. If an item equal to data is already in the set and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced)
template <class T_, class Metric_>
iterator insert(T_ &&data, Metric_ &&metric, bool replaceExisting = true);
// Insert all items from data into set. All items will use metric. If an item equal to data is already in the set and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced). returns the number of items inserted.
int insert(const std::vector<std::pair<T,Metric>>& data, bool replaceExisting = true);
// Increase the metric for the given item by the given amount. Inserts data into the set if it
// doesn't exist. Returns the new sum.
template <class T_, class Metric_>
Metric addMetric( T_ && data, Metric_ && metric );
// Remove the data item, if any, which is equal to key
template <class Key>
void erase(const Key &key) { erase( find(key) ); }
// Erase the indicated item. No effect if item == end().
// SOMEDAY: Return ++item
void erase(iterator item);
// Erase all data items x for which begin<=x<end
template <class Key>
void erase(const Key& begin, const Key& end) { erase( lower_bound(begin), lower_bound(end) ); }
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
template <class Key>
Future<Void> eraseAsync(const Key& begin, const Key& end);
// Erase the items in the indicated range.
void erase(iterator begin, iterator end);
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
Future<Void> eraseAsync(iterator begin, iterator end);
// Returns the number of items equal to key (either 0 or 1)
template <class Key>
int count(const Key &key) const { return find(key) != end(); }
// Returns x such that key==*x, or end()
template <class Key>
iterator find(const Key &key) const;
// Returns the smallest x such that *x>=key, or end()
template <class Key>
iterator lower_bound(const Key &key) const;
// Returns the smallest x such that *x>key, or end()
template <class Key>
iterator upper_bound(const Key &key) const;
// Returns the largest x such that *x<=key, or end()
template <class Key>
iterator lastLessOrEqual( const Key &key ) const;
// Returns smallest x such that sumTo(x+1) > metric, or end()
template <class M>
iterator index( M const& metric ) const;
// Return the metric inserted with item x
Metric getMetric(iterator x) const;
// Return the sum of getMetric(x) for begin()<=x<to
Metric sumTo(iterator to) const;
// Return the sum of getMetric(x) for begin<=x<end
Metric sumRange(iterator begin, iterator end) const { return sumTo(end) - sumTo(begin); }
// Return the sum of getMetric(x) for all x s.t. begin <= *x && *x < end
template <class Key>
Metric sumRange(const Key& begin, const Key& end) const { return sumRange(lower_bound(begin), lower_bound(end)); }
// Return the amount of memory used by an entry in the IndexedSet
static int getElementBytes() { return sizeof(Node); }
private:
// Copy operations unimplemented. SOMEDAY: Implement and make public.
IndexedSet( const IndexedSet& );
IndexedSet& operator=( const IndexedSet& );
Node *root;
Metric eraseHalf( Node* start, Node* end, int eraseDir, int& heightDelta, std::vector<Node*>& toFree );
void erase( iterator begin, iterator end, std::vector<Node*>& toFree );
void replacePointer( Node* oldNode, Node* newNode ) {
if (oldNode->parent)
oldNode->parent->child[ oldNode->parent->child[1] == oldNode ] = newNode;
else
root = newNode;
if (newNode)
newNode->parent = oldNode->parent;
}
// direction 0 = left, 1 = right
template <int direction>
static void moveIterator(Node* &i){
if (i->child[0^direction]) {
i = i->child[0^direction];
while (i->child[1^direction])
i = i->child[1^direction];
} else {
while (i->parent && i->parent->child[0^direction] == i)
i = i->parent;
i = i->parent;
}
}
public: // but testonly
std::pair<int, int> testonly_assertBalanced(Node*n=0, int d=0, bool a=true);
};
class NoMetric {
public:
NoMetric() {}
NoMetric(int) {} // NoMetric(1)
NoMetric operator+(NoMetric const&) const { return NoMetric(); }
NoMetric operator-(NoMetric const&) const { return NoMetric(); }
bool operator<(NoMetric const&) const { return false; }
};
template <class Key, class Value>
class MapPair {
public:
Key key;
Value value;
template <class Key_, class Value_>
MapPair( Key_&& key, Value_&& value ) : key(std::forward<Key_>(key)), value(std::forward<Value_>(value)) {}
void operator= ( MapPair const& rhs ) { key = rhs.key; value = rhs.value; }
MapPair( MapPair const& rhs ) : key(rhs.key), value(rhs.value) {}
MapPair(MapPair&& r) noexcept(true) : key(std::move(r.key)), value(std::move(r.value)) {}
void operator=(MapPair&& r) noexcept(true) { key = std::move(r.key); value = std::move(r.value); }
bool operator<(MapPair<Key,Value> const& r) const { return key < r.key; }
bool operator==(MapPair<Key,Value> const& r) const { return key == r.key; }
bool operator!=(MapPair<Key,Value> const& r) const { return key != r.key; }
//private: MapPair( const MapPair& );
};
template <class Key, class Value>
inline MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type> mapPair(Key&& key, Value&& value) { return MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type>(std::forward<Key>(key), std::forward<Value>(value)); }
template <class Key, class Value, class CompatibleWithKey>
bool operator<(MapPair<Key, Value> const& l, CompatibleWithKey const& r) { return l.key < r; }
template <class Key, class Value, class CompatibleWithKey>
bool operator<(CompatibleWithKey const& l, MapPair<Key, Value> const& r) { return l < r.key; }
template <class Key, class Value, class Pair = MapPair<Key,Value>, class Metric=NoMetric >
class Map {
public:
typedef typename IndexedSet<Pair,Metric>::iterator iterator;
Map() {}
iterator begin() const { return set.begin(); }
iterator end() const { return set.end(); }
iterator lastItem() const { return set.lastItem(); }
iterator previous(iterator i) const { return set.previous(i); }
bool empty() const { return set.empty(); }
Value& operator[]( const Key& key ) {
iterator i = set.insert( Pair(key, Value()), Metric(1), false );
return i->value;
}
Value& get( const Key& key, Metric m = Metric(1) ) {
iterator i = set.insert( Pair(key, Value()), m, false );
return i->value;
}
iterator insert( const Pair& p, bool replaceExisting = true, Metric m = Metric(1) ) { return set.insert(p, m, replaceExisting); }
iterator insert( Pair && p, bool replaceExisting = true, Metric m = Metric(1) ) { return set.insert(std::move(p), m, replaceExisting); }
int insert( const std::vector<std::pair<MapPair<Key,Value>, Metric>>& pairs, bool replaceExisting = true) { return set.insert(pairs, replaceExisting); }
template <class KeyCompatible>
void erase( KeyCompatible const& k ) { set.erase(k); }
void erase( iterator b, iterator e ) { set.erase(b,e); }
void erase( iterator x ) { set.erase(x); }
void clear() { set.clear(); }
Metric size() const {
static_assert(!std::is_same<Metric, NoMetric>::value, "size() on Map with NoMetric is not valid!");
return sumTo(end());
}
template <class KeyCompatible>
iterator find( KeyCompatible const& k ) const { return set.find(k); }
template <class KeyCompatible>
iterator lower_bound( KeyCompatible const& k ) const { return set.lower_bound(k); }
template <class KeyCompatible>
iterator upper_bound( KeyCompatible const& k ) const { return set.upper_bound(k); }
template <class KeyCompatible>
iterator lastLessOrEqual( KeyCompatible const& k ) const { return set.lastLessOrEqual(k); }
template <class M>
iterator index( M const& metric ) const { return set.index(metric); }
Metric getMetric(iterator x) const { return set.getMetric(x); }
Metric sumTo(iterator to) const { return set.sumTo(to); }
Metric sumRange(iterator begin, iterator end) const { return set.sumRange(begin,end); }
template <class KeyCompatible>
Metric sumRange(const KeyCompatible& begin, const KeyCompatible& end) const { return set.sumRange(begin,end); }
static int getElementBytes() { return IndexedSet< Pair, Metric >::getElementBytes(); }
Map(Map&& r) noexcept(true) : set(std::move(r.set)) {}
void operator=(Map&& r) noexcept(true) { set = std::move(r.set); }
private:
Map( Map<Key,Value,Pair> const& ); // unimplemented
void operator=( Map<Key,Value,Pair> const& ); // unimplemented
IndexedSet< Pair, Metric > set;
};
/////////////////////// implementation //////////////////////////
template <class T, class Metric>
void IndexedSet<T,Metric>::iterator::operator++(){
moveIterator<1>(i);
}
template <class T, class Metric>
void IndexedSet<T,Metric>::iterator::decrementNonEnd(){
moveIterator<0>(i);
}
template <class Node>
void ISRotate(Node*& oldRootRef, int d) {
Node *oldRoot = oldRootRef;
Node *newRoot = oldRoot->child[1-d];
// metrics
auto orTotal = oldRoot->total - newRoot->total;
if (newRoot->child[d])
orTotal = orTotal + newRoot->child[d]->total;
newRoot->total = oldRoot->total;
oldRoot->total = orTotal;
//pointers
oldRoot->child[1-d] = newRoot->child[d];
if (oldRoot->child[1-d]) oldRoot->child[1-d]->parent = oldRoot;
newRoot->child[d] = oldRoot;
newRoot->parent = oldRoot->parent;
oldRoot->parent = newRoot;
oldRootRef = newRoot;
}
template <class Node>
void ISAdjustBalance(Node* root, int d, int bal) {
Node *n = root->child[d];
Node *nn = n->child[1-d];
if ( !nn->balance )
root->balance = n->balance = 0;
else if ( nn->balance == bal ) {
root->balance = -bal;
n->balance = 0;
} else {
root->balance = 0;
n->balance = bal;
}
nn->balance = 0;
}
template <class Node>
int ISRebalance( Node*& root ) {
// Pre: root is a tree having the BST, metric, and balance invariants but not (necessarily) the AVL invariant. root->child[0] and root->child[1] are AVL.
// Post: root is an AVL tree with the same nodes
// Returns: the change in height of root
// rebalance is O(1) if abs(root->balance)<=2, and probably O(log N) otherwise. (The rare "still unbalanced" recursion is hard to analyze)
//
// The documentation of this function will be referencing the following tree (where
// nodes A, C, E, and G represent subtrees of unspecified height). Thus for each node X,
// we know the value of balance(X), but not height(X).
//
// We will assume that balance(F) < 0 (so we will be rotating right).
// Trees that rotate to the left will perform analagous operations.
//
// F
// / \
// B G
// / \
// A D
// / \
// C E
if (!root || (root->balance >= -1 && root->balance <= +1))
return 0;
int rebalanceDir = root->balance<0; // 1 if rotating right, 0 if rotating left
auto* n = root->child[ 1-rebalanceDir ]; // Node B
int bal = rebalanceDir ? +1 : -1; // 1 if rotating right, -1 if rotating left
int rootBal = root->balance;
// Depending on the balance at B, we will be required to do one or two rotations.
// If balance(B) <= 0, then we do only one rotation (the second of the two).
//
// In a tree where balance(B) == +1, we are required to do both rotations.
// The result of the first rotation will be:
//
// F
// / \
// D G
// / \
// B E
// / \
// A C
//
bool doubleRotation = n->balance == bal;
if (doubleRotation) {
int x = n->child[rebalanceDir]->balance; // balance of Node D
ISRotate( root->child[1-rebalanceDir], 1-rebalanceDir); // Rotate at Node B
// Change node pointed to by 'n' to prepare for the second rotation
// After this first rotation, Node D will be the left child of the root
n = root->child[1-rebalanceDir];
// Compute the balance at the new root node D' of our rotation
// We know that height(A) == max(height(C), height(E)) because B had balance of +1
// If height(E) >= height(C), then height(E) == height(A) and balance(D') = -1
// Otherwise height(C) == height(E) + 1, and therefore balance(D') = -2
n->balance = ((x==-bal) ? -2 : -1)*bal;
// Compute the balance at the old root node B' of our rotation
// As stated above, height(A) == max(height(C), height(E))
// If height(C) >= height(E), then height(A) == height(C) and balance(B') = 0
// Otherwise height(A) == height(E) == height(C) + 1, and therefore balance(B') = -1
n->child[1-rebalanceDir]->balance = ((x==bal) ? -1 : 0)*bal;
}
// At this point, we perform the "second" rotation (which may actually be the first
// if the "first" rotation was not performed). The rotation that is performed is the
// same for both trees, but the result will be different depending on which tree we
// started with:
//
// If unrotated: If once rotated:
//
// B D
// / \ / \
// A F B F
// / \ / \ / \
// D G A C E G
// / \
// C E
//
// The documentation for this second rotation will be based on the unrotated original tree.
// Compute the balance at the new root node B'.
// balance(B') = 1 + max(height(D), height(G)) - height(A) = 1 + max(height(D) - height(A), height(G) - height(A))
// balance(B') = 1 + max(balance(B), height(G) - height(A))
//
// Now, we must find height(G) - height(A):
// If height(A) >= height(D) (i.e. balance(B) <= 0), then
// height(G) - height(A) = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(A) = height(D) - balance(B) = height(B) - 1 - balance(B), so
// height(G) - height(A) = height(G) - height(B) + 1 + balance(B) = balance(F) + 1 + balance(B)
//
// balance(B') = 1 + max(balance(B), balance(F) + 1 + max(balance(B), 0))
//
int nBal = n->balance * bal; // Direction corrected balance at Node B
int newRootBalance = bal * (1 + std::max(nBal, bal * root->balance + 1 + std::max(nBal, 0)));
// Compute the balance at the old root node F' (which becomes a child of the new root).
// balance(F') = height(G) - height(D)
//
// If height(D) >= height(A) (i.e. balance(B) >= 0), then height(D) = height(B) - 1, so
// balance(F') = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(D) = height(A) + balance(B) = height(B) - 1 + balance(B), so
// balance(F') = height(G) - height(B) + 1 - balance(B) = balance(F) + 1 - balance(B)
//
// balance(F') = balance(F) + 1 - min(balance(B), 0)
//
int newChildBalance = root->balance + bal * (1 - std::min(nBal, 0));
ISRotate( root, rebalanceDir );
root->balance = newRootBalance;
root->child[rebalanceDir]->balance = newChildBalance;
// If the original tree is very unbalanced, the unbalance may have been "pushed" down into this subtree, so recursively rebalance that if necessary.
int childHeightChange = ISRebalance(root->child[rebalanceDir]);
root->balance += childHeightChange * bal;
newRootBalance *= bal;
// Compute the change in height at the root
// We will look at the single and double rotation cases separately
//
// If we did a single rotation, then height(A) >= height(D).
// As a result, height(A) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(A) + 2,
// and the new height is max(height(D) + 2 + childHeightChange, height(A) + 1), so
//
// heightChange_single = max(height(D) + 2 + childHeightChange, height(A) + 1) - (height(A) + 2)
// heightChange_single = max(height(D) - height(A) + childHeightChange, -1)
// heightChange_single = max(balance(B) + childHeightChange, -1)
//
// If we did a double rotation, then height(D) = height(A) + 1 in the original tree.
// As a result, height(D) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(D) + 2,
// and the new height is max(height(A), height(C), height(E), height(G)) + 2
//
// balance(B) == 1, so height(A) == max(height(C), height(E)).
// Also, height(A) = height(D) - 1 >= height(G)
// Therefore the new height is height(A) + 2
//
// heightChange_double = height(A) + 2 - (height(D) + 2)
// heightChange_double = height(A) - height(D)
// heightChange_double = -1
//
int heightChange = doubleRotation ? -1 : std::max(nBal + childHeightChange, -1);
// If the root is still unbalanced, then it should at least be more balanced than before. Recursively rebalance the root until we get a balanced tree.
if (root->balance <-1 || root->balance > +1) {
ASSERT(abs(root->balance) < abs(rootBal));
heightChange += ISRebalance(root);
}
return heightChange;
}
template <class Node>
Node* ISCommonSubtreeRoot(Node* first, Node* last) {
// Finds the smallest common subtree of first and last and returns its root node
//Find the depth of first and last
int firstDepth=0, lastDepth=0;
for(auto f = first; f; f=f->parent) firstDepth++;
for(auto f = last; f; f=f->parent) lastDepth++;
//Traverse up the tree from the deeper of first and last until f and l are at the same depth
auto f = first, l = last;
for(int i=firstDepth; i>lastDepth; i--) f = f->parent;
for(int i=lastDepth; i>firstDepth; i--) l = l->parent;
//Traverse up from f and l simultaneously until we reach a common node
while (f != l) {
f = f->parent;
l = l->parent;
}
return f;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::begin() const {
Node *x = root;
while (x && x->child[0])
x = x->child[0];
return x;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::previous(typename IndexedSet<T,Metric>::iterator i) const {
if (i==end())
return lastItem();
moveIterator<0>(i.i);
return i;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lastItem() const {
Node *x = root;
while (x && x->child[1])
x = x->child[1];
return x;
}
template <class T, class Metric> template<class T_, class Metric_>
Metric IndexedSet<T,Metric>::addMetric(T_&& data, Metric_&& metric){
auto i = find( data );
if (i == end()) {
insert( std::forward<T_>(data), std::forward<Metric_>(metric) );
return metric;
} else {
Metric m = metric + getMetric(i);
insert( std::forward<T_>(data), m );
return m;
}
}
template <class T, class Metric> template<class T_, class Metric_>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::insert(T_&& data, Metric_&& metric, bool replaceExisting){
if (root == NULL){
root = new Node(std::forward<T_>(data), std::forward<Metric_>(metric));
return root;
}
Node *t = root;
int d; // direction
// traverse to find insert point
while (true){
d = t->data < data;
if (!d && !(data < t->data)) { // t->data == data
Node *returnNode = t;
if(replaceExisting) {
t->data = std::move(data);
Metric delta = t->total;
t->total = metric;
if (t->child[0]) t->total = t->total + t->child[0]->total;
if (t->child[1]) t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + delta;
}
}
return returnNode;
}
Node *nextT = t->child[d];
if (!nextT) break;
t = nextT;
}
Node *newNode = new Node(std::forward<T_>(data), std::forward<Metric_>(metric), t);
t->child[d] = newNode;
while (true){
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1){
Node** parent = t->parent ? &t->parent->child[t->parent->child[1]==t] : &root;
//assert( *parent == t );
Node *n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal){
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1-d);
t = *parent;
break;
}
if (!t->parent) break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + metric;
}
return newNode;
}
template <class T, class Metric>
int IndexedSet<T,Metric>::insert(const std::vector<std::pair<T,Metric>>& dataVector, bool replaceExisting) {
int num_inserted = 0;
Node *blockStart = NULL;
Node *blockEnd = NULL;
for(int i = 0; i < dataVector.size(); ++i) {
Metric metric = dataVector[i].second;
T data = std::move(dataVector[i].first);
int d = 1; // direction
if(blockStart == NULL || (blockEnd != NULL && data >= blockEnd->data)) {
blockEnd = NULL;
if (root == NULL){
root = new Node(std::move(data), metric);
num_inserted++;
blockStart = root;
continue;
}
Node *t = root;
// traverse to find insert point
bool foundNode = false;
while (true){
d = t->data < data;
if (!d)
blockEnd = t;
if (!d && !(data < t->data)) { // t->data == data
Node *returnNode = t;
if(replaceExisting) {
num_inserted++;
t->data = std::move(data);
Metric delta = t->total;
t->total = metric;
if (t->child[0]) t->total = t->total + t->child[0]->total;
if (t->child[1]) t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + delta;
}
}
blockStart = returnNode;
foundNode = true;
break;
}
Node *nextT = t->child[d];
if (!nextT) {
blockStart = t;
break;
}
t = nextT;
}
if(foundNode)
continue;
}
Node *t = blockStart;
while(t->child[d]) {
t = t->child[d];
d = 0;
}
Node *newNode = new Node(std::move(data), metric, t);
num_inserted++;
t->child[d] = newNode;
blockStart = newNode;
while (true){
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1){
Node** parent = t->parent ? &t->parent->child[t->parent->child[1]==t] : &root;
//assert( *parent == t );
Node *n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal){
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1-d);
t = *parent;
break;
}
if (!t->parent) break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + metric;
}
}
return num_inserted;
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::eraseHalf( Node* start, Node* end, int eraseDir, int& heightDelta, std::vector<Node*>& toFree ) {
// Removes all nodes between start (inclusive) and end (exclusive) from the set, where start is equal to end or one of its descendants
// eraseDir 1 means erase the right half (nodes > at) of the left subtree of end. eraseDir 0 means the left half of the right subtree
// toFree is extended with the roots of completely removed subtrees
// heightDelta will be set to the change in height of the end node
// Returns the amount that should be subtracted from end node's metric value (and, by extension, the metric values of all ancestors of the end node).
//
// The end node may be left unbalanced (AVL invariant broken)
// The end node may be left with the incorrect metric total (the correct value is end->total = end->total + metricDelta)
// scare quotes in comments mean the values when eraseDir==1 (when eraseDir==0, "left" means right etc)
// metricDelta measures how much should be subtracted from the current node's metrics
Metric metricDelta = 0;
heightDelta = 0;
int fromDir = 1 - eraseDir;
// Begin removing nodes at start continuing up until we get to end
while(start != end) {
start->total = start->total - metricDelta;
IndexedSet<T,Metric>::Node *parent = start->parent;
// Obtain the child pointer to start, which rebalance will update with the new root of the subtree currently rooted at start
IndexedSet<T,Metric>::Node *& node = parent->child[ parent->child[1] == start ];
int nextDir = parent->child[1] == start;
if (fromDir==eraseDir) {
// The "right" subtree has been half-erased, and the "left" subtree doesn't need to be (nor does node).
// But this node might be unbalanced by the shrinking "right" subtree. Rebalance and continue up.
heightDelta += ISRebalance( node );
} else {
// The "left" subtree has been half-erased. `start' and its "right" subtree will be completely erased,
// leaving only the "left" subtree in its place (which is already AVL balanced).
heightDelta += -1 - std::max<int>(0, node->balance * (eraseDir ? +1 : -1));
metricDelta = metricDelta + start->total;
// If there is a surviving subtree of start, then connect it to start->parent
IndexedSet<T,Metric>::Node *n = node->child[fromDir];
node = n; // This updates the appropriate child pointer of start->parent
if (n) {
metricDelta = metricDelta - n->total;
n->parent = start->parent;
}
start->child[fromDir] = NULL;
toFree.push_back( start );
}
int dir = (nextDir ? +1 : -1);
int oldBalance = parent->balance;
// The change in height from removing nodes should never increase our height
ASSERT(heightDelta <= 0);
parent->balance += heightDelta * dir;
// Compute the change in height of start's parent based on its change in balance.
// Because we can only be (possibly) shrinking one subtree of parent:
// If we were originally heavier on the shrunken size (oldBalance * dir > 0), then the change in height is at most abs(oldBalance) == oldBalance * dir.
// If we were lighter on the shrunken side, then height cannot change.
int maxHeightChange = std::max(oldBalance * dir, 0);
int balanceChange = (oldBalance - parent->balance) * dir;
heightDelta = -std::min(maxHeightChange, balanceChange);
start = parent;
fromDir = nextDir;
}
return metricDelta;
}
template <class T, class Metric>
void IndexedSet<T,Metric>::erase( typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end, std::vector<Node*>& toFree ) {
// Removes all nodes in the set between first and last, inclusive.
// toFree is extended with the roots of completely removed subtrees.
ASSERT(!end.i || (begin.i && *begin <= *end));
if(begin == end)
return;
IndexedSet<T,Metric>::Node* first = begin.i;
IndexedSet<T,Metric>::Node* last = previous(end).i;
IndexedSet<T,Metric>::Node* subRoot = ISCommonSubtreeRoot(first, last);
Metric metricDelta = 0;
int leftHeightDelta = 0;
int rightHeightDelta = 0;
// Erase all matching nodes that descend from subRoot, by first erasing descendants of subRoot->child[0] and then erasing the descendants of subRoot->child[1]
// subRoot is not removed from the tree at this time
metricDelta = metricDelta + eraseHalf( first, subRoot, 1, leftHeightDelta, toFree );
metricDelta = metricDelta + eraseHalf( last, subRoot, 0, rightHeightDelta, toFree );
// Change in the height of subRoot due to past activity, before subRoot is rebalanced. subRoot->balance already reflects changes in height to its children.
int heightDelta = leftHeightDelta + rightHeightDelta;
// Rebalance and update metrics for all nodes from subRoot up to the root
for(auto p = subRoot; p != NULL; p = p->parent) {
p->total = p->total - metricDelta;
auto& pc = p->parent ? p->parent->child[p->parent->child[1]==p] : root;
heightDelta += ISRebalance(pc);
p = pc;
// Update the balance and compute heightDelta for p->parent
if (p->parent) {
int oldb = p->parent->balance;
int dir = (p->parent->child[1]==p ? +1 : -1);
p->parent->balance += heightDelta * dir;
heightDelta = (std::max(p->parent->balance*dir, 0) - std::max(oldb*dir, 0));
}
}
// Erase the subRoot using the single node erase implementation
erase( IndexedSet<T,Metric>::iterator(subRoot) );
}
template <class T, class Metric>
void IndexedSet<T,Metric>::erase(iterator toErase) {
Node* rebalanceNode;
int rebalanceDir;
{
// Find the node to erase
Node* t = toErase.i;
if (!t) return;
if (!t->child[0] || !t->child[1]) {
Metric tMetric = t->total;
if (t->child[0]) tMetric = tMetric - t->child[0]->total;
if (t->child[1]) tMetric = tMetric - t->child[1]->total;
for( Node* p = t->parent; p; p = p->parent )
p->total = p->total - tMetric;
rebalanceNode = t->parent;
if (rebalanceNode) rebalanceDir = rebalanceNode->child[1] == t;
int d = !t->child[0]; // Only one child, on this side (or no children!)
replacePointer(t, t->child[d]);
t->child[d] = 0;
delete t;
} else { // Remove node with two children
Node* predecessor = t->child[0];
while ( predecessor->child[1] )
predecessor = predecessor->child[1];
rebalanceNode = predecessor->parent;
if (rebalanceNode == t) rebalanceNode = predecessor;
if (rebalanceNode) rebalanceDir = rebalanceNode->child[1] == predecessor;
Metric tMetric = t->total - t->child[0]->total - t->child[1]->total;
if (predecessor->child[0]) predecessor->total = predecessor->total - predecessor->child[0]->total;
for( Node* p = predecessor->parent; p != t; p = p->parent )
p->total = p->total - predecessor->total;
for( Node* p = t->parent; p; p = p->parent )
p->total = p->total - tMetric;
// Replace t with predecessor
replacePointer( predecessor, predecessor->child[0] );
replacePointer( t, predecessor );
predecessor->balance = t->balance;
for(int i=0; i<2; i++) {
Node* c = predecessor->child[i] = t->child[i];
if (c) {
c->parent = predecessor;
predecessor->total = predecessor->total + c->total;
t->child[i] = 0;
}
}
delete t;
}
}
if (!rebalanceNode) return;
while (true) {
rebalanceNode->balance += rebalanceDir ? -1 : +1;
if ( rebalanceNode->balance < -1 || rebalanceNode->balance > +1 ) {
Node** parent = rebalanceNode->parent ? &rebalanceNode->parent->child[rebalanceNode->parent->child[1]==rebalanceNode] : &root;
Node* n = rebalanceNode->child[ 1-rebalanceDir ];
int bal = rebalanceDir ? +1 : -1;
if (n->balance == -bal) {
rebalanceNode->balance = n->balance = 0;
ISRotate( *parent, rebalanceDir );
} else if (n->balance == bal) {
ISAdjustBalance( rebalanceNode, 1-rebalanceDir, -bal );
ISRotate( rebalanceNode->child[1-rebalanceDir], 1-rebalanceDir);
ISRotate( *parent, rebalanceDir );
} else { // n->balance == 0
rebalanceNode->balance = -bal;
n->balance = bal;
ISRotate( *parent, rebalanceDir );
break;
}
rebalanceNode = *parent;
} else if ( rebalanceNode->balance ) // +/- 1, we are done
break;
if (!rebalanceNode->parent) break;
rebalanceDir = rebalanceNode->parent->child[1] == rebalanceNode;
rebalanceNode = rebalanceNode->parent;
}
}
// Returns x such that key==*x, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::find(const Key &key) const {
Node* t = root;
while (t){
int d = t->data < key;
if (!d && !(key < t->data)) // t->data == key
return iterator(t);
t = t->child[d];
}
return end();
}
// Returns the smallest x such that *x>=key, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lower_bound(const Key &key) const {
Node* t = root;
if (!t) return iterator();
while (true) {
Node *n = t->child[ t->data < key ];
if (!n) break;
t = n;
}
if (t->data < key)
moveIterator<1>(t);
return iterator(t);
}
// Returns the smallest x such that *x>key, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::upper_bound(const Key &key) const {
Node* t = root;
if (!t) return iterator();
while (true) {
Node *n = t->child[ !(key < t->data) ];
if (!n) break;
t = n;
}
if (!(key < t->data))
moveIterator<1>(t);
return iterator(t);
}
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lastLessOrEqual(const Key &key) const {
iterator i = upper_bound(key);
if (i == begin()) return end();
return previous(i);
}
// Returns first x such that metric < sum(begin(), x+1), or end()
template <class T, class Metric>
template <class M>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::index( M const& metric ) const
{
M m = metric;
Node* t = root;
while (t) {
if (t->child[0] && m < t->child[0]->total)
t = t->child[0];
else {
m = m - t->total;
if (t->child[1])
m = m + t->child[1]->total;
if (m < M())
return iterator(t);
t = t->child[1];
}
}
return end();
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::getMetric(typename IndexedSet<T,Metric>::iterator x) const {
Metric m = x.i->total;
for(int i=0; i<2; i++)
if (x.i->child[i])
m = m - x.i->child[i]->total;
return m;
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::sumTo(typename IndexedSet<T,Metric>::iterator end) const {
if (!end.i)
return root ? root->total : Metric();
Metric m = end.i->child[0] ? end.i->child[0]->total : Metric();
for(Node* p = end.i; p->parent; p=p->parent) {
if (p->parent->child[1] == p) {
m = m - p->total;
m = m + p->parent->total;
}
}
return m;
}
#include "flow.h"
#include "IndexedSet.actor.h"
template <class T, class Metric>
void IndexedSet<T,Metric>::erase(typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end) {
std::vector<IndexedSet<T,Metric>::Node*> toFree;
erase(begin, end, toFree);
ISFreeNodes(toFree, true);
}
template <class T, class Metric>
template <class Key>
Future<Void> IndexedSet<T, Metric>::eraseAsync(const Key &begin, const Key &end) {
return eraseAsync(lower_bound(begin), lower_bound(end) );
}
template <class T, class Metric>
Future<Void> IndexedSet<T, Metric>::eraseAsync(typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end) {
std::vector<IndexedSet<T,Metric>::Node*> toFree;
erase(begin, end, toFree);
return uncancellable(ISFreeNodes(toFree, false));
}
#endif