Skip to content

Commit

Permalink
change unordered_set API, no hash, equal fields
Browse files Browse the repository at this point in the history
use static methods, prefixed by the type T.
This enables inlining the hot hashtable parts, and
disallows corrupting the table with changed hash or equal methods.
They really need to be declared and defined statically, just as with C++, where
we need to declare it for the template.

Fixes GH #21
  • Loading branch information
rurban committed Feb 16, 2024
1 parent df057a1 commit 293ee85
Show file tree
Hide file tree
Showing 25 changed files with 244 additions and 208 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ List of added, changed. removed features:
* algorithm: Added shuffle, iter_swap, reverse, reverse_range,
lexicographical_compare, is_sorted, is_sorted_until.
Requires now INCLUDE_ALGORITHM
* unordered_set and children: removed hash and equal init args, and fields.
They must be now declared statically beforehand as `T_hash` and `T_equal`.
* array: Added difference, intersection, symmetric_difference, assign_range.
* set: Added includes, includes_range.
* string: Added find_if, find_if_not, find_if_range, find_if_not_range, includes,
Expand Down
2 changes: 1 addition & 1 deletion api.lst
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ ctl/unordered_set.h: max_bucket_count (A *self)
ctl/unordered_set.h: load_factor (A *self)
ctl/unordered_set.h: _reserve (A *self, const size_t new_size)
ctl/unordered_set.h: reserve (A *self, size_t desired_count)
ctl/unordered_set.h: init (size_t (*_hash)(T *), int (*_equal)(T *, T *))
ctl/unordered_set.h: init (void)
ctl/unordered_set.h: init_from (A *copy)
ctl/unordered_set.h: rehash (A *self, size_t desired_count)
ctl/unordered_set.h: _rehash (A *self, size_t count)
Expand Down
97 changes: 48 additions & 49 deletions ctl/bits/integral.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* Type utilities, to apply default equal, compare, hash methods for intergral types.
/* Type utilities, to apply default equal, compare for integral types.
And hash methods.
See MIT LICENSE.
*/

Expand All @@ -14,7 +15,47 @@ _define_integral_compare(long)
#undef _define_integral_compare
*/

#include <string.h>

#ifndef CTL_HASH_DEFAULTS
#define CTL_HASH_DEFAULTS
static inline uint32_t ctl_int32_hash(uint32_t key)
{
key = ((key >> 16) ^ key) * 0x45d9f3b;
key = ((key >> 16) ^ key) * 0x45d9f3b;
key = (key >> 16) ^ key;
return key;
}
/* FNV1a. Eventually wyhash or o1hash */
static inline size_t ctl_string_hash(const char* key)
{
size_t h;
h = 2166136261u;
for (unsigned i = 0; i < strlen((char *)key); i++)
{
h ^= (unsigned char)key[i];
h *= 16777619;
}
return h;
}

#if defined(POD) && !defined(NOT_INTEGRAL)
static inline int JOIN(T, equal)(T *a, T *b)
{
return *a == *b;
}
#endif

#endif //CTL_HASH_DEFAULTS

#if defined(POD) && !defined(NOT_INTEGRAL)

#ifdef CTL_USET
static inline size_t _JOIN(A, _default_integral_hash)(T *a)
{
return ctl_int32_hash((uint32_t)*a);
}
#endif //USET

static inline int _JOIN(A, _default_integral_compare3)(T *a, T *b)
{
Expand All @@ -34,30 +75,6 @@ static inline int _JOIN(A, _default_integral_equal)(T *a, T *b)
*/
}

static inline size_t _JOIN(A, _default_integral_hash)(T *a)
{
return (size_t)*a;
}

#include <string.h>

#if defined str || defined u8string || defined charp || defined u8ident || defined ucharp

static inline size_t _JOIN(A, _default_string_hash)(T *key)
{
size_t h;
/* FNV1a, not wyhash */
h = 2166136261u;
for (unsigned i = 0; i < strlen((char *)key); i++)
{
h ^= (unsigned char)key[i];
h *= 16777619;
}
return h;
}

#endif

#define CTL_STRINGIFY_HELPER(n) #n
#define CTL_STRINGIFY(n) CTL_STRINGIFY_HELPER(n)
#define _strEQcc(s1c, s2c) !strcmp(s1c "", s2c "")
Expand All @@ -83,47 +100,29 @@ static inline bool _JOIN(A, _type_is_integral)(void)
_strEQcc(CTL_STRINGIFY(T), "llong");
}

// not C++
#ifndef __cplusplus
#define __set_str_hash(self, t) \
{ \
typeof(t) tmp = (x); \
if (__builtin_types_compatible_p(typeof(t), char *)) \
self->hash = _JOIN(A, _default_string_hash); \
else if (__builtin_types_compatible_p(typeof(t), unsigned char *)) \
self->hash = _JOIN(A, _default_string_hash); \
}
#else
#define __set_str_hash(self, t) self->hash = _JOIN(A, _default_string_hash)
#endif

static inline void _JOIN(A, _set_default_methods)(A *self)
{
#if !defined CTL_STR
#if defined str || defined u8string || defined charp || defined u8ident || defined ucharp
{
#ifdef CTL_USET
if (!self->hash)
__set_str_hash(self, T);
#else
#ifndef CTL_USET
if (!self->compare)
self->compare = str_key_compare;
#endif
if (!self->equal)
self->equal = str_equal;
#endif
}
else
#endif
#endif
#ifdef CTL_USET
if (!self->hash)
self->hash = _JOIN(A, _default_integral_hash);
#else
#ifndef CTL_USET
if (!self->compare)
self->compare = _JOIN(A, _default_integral_compare);
#endif
if (!self->equal)
self->equal = _JOIN(A, _default_integral_equal);
#else
(void)self;
#endif
}

#else
Expand Down
62 changes: 26 additions & 36 deletions ctl/unordered_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ typedef struct A
float max_load_factor;
void (*free)(T *);
T (*copy)(T *);
size_t (*hash)(T *);
int (*equal)(T *, T *);
#if CTL_USET_SECURITY_COLLCOUNTING == 4
bool is_sorted_vector;
#elif CTL_USET_SECURITY_COLLCOUNTING == 5
Expand All @@ -144,11 +142,11 @@ static inline size_t JOIN(A, bucket_count)(A *self)
static inline size_t JOIN(I, index)(A *self, T value)
{
#ifdef CTL_USET_GROWTH_POWER2
return self->hash(&value) & self->bucket_max;
return JOIN(T, hash)(&value) & self->bucket_max;
#elif __WORDSIZE == 127
return ((uint64_t) self->hash(&value) * ((uint64_t) self->bucket_max + 1)) >> 32;
return ((uint64_t) JOIN(T, hash)(&value) * ((uint64_t) self->bucket_max + 1)) >> 32;
#else
return self->hash(&value) % (self->bucket_max + 1);
return JOIN(T, hash)(&value) % (self->bucket_max + 1);
#endif
}

Expand Down Expand Up @@ -322,10 +320,12 @@ JOIN(I, range)(A* container, I* begin, I* end)
}
*/

// needed for algorithm
static inline int JOIN(A, _equal)(A *self, T *a, T *b)
{
ASSERT(self->equal || !"equal undefined");
return self->equal(a, b);
//ASSERT(JOIN(T, equal) || !"equal undefined");
(void)self;
return JOIN(T, equal)(a, b);
}

static inline A JOIN(A, init_from)(A *copy);
Expand Down Expand Up @@ -518,15 +518,15 @@ static inline B **JOIN(A, _bucket_hash)(A *self, size_t hash)
static inline B **JOIN(A, _bucket)(A *self, T value)
{
const size_t hash = JOIN(I, index)(self, value);
//LOG ("_bucket %lx %% %lu => %zu\n", self->hash(&value), self->bucket_max + 1, hash);
//LOG ("_bucket %lx %% %lu => %zu\n", JOIN(T, hash)(&value), self->bucket_max + 1, hash);
return &self->buckets[hash];
}
#endif

static inline size_t JOIN(A, bucket)(A *self, T value)
{
const size_t hash = JOIN(I, index)(self, value);
//LOG ("bucket %lx %% %lu => %zu\n", self->hash(&value), self->bucket_max + 1, hash);
//LOG ("bucket %lx %% %lu => %zu\n", JOIN(T, hash)(&value), self->bucket_max + 1, hash);
return hash;
}

Expand Down Expand Up @@ -613,12 +613,10 @@ static inline void JOIN(A, reserve)(A *self, size_t desired_count)
JOIN(A, _rehash)(self, new_size);
}

static inline A JOIN(A, init)(size_t (*_hash)(T *), int (*_equal)(T *, T *))
static inline A JOIN(A, init)(void)
{
static A zero;
A self = zero;
self.hash = _hash;
self.equal = _equal;
#ifdef POD
self.copy = JOIN(A, implicit_copy);
_JOIN(A, _set_default_methods)(&self);
Expand All @@ -633,24 +631,16 @@ static inline A JOIN(A, init)(size_t (*_hash)(T *), int (*_equal)(T *, T *))

static inline A JOIN(A, init_from)(A *copy)
{
static A zero;
A self = zero;
#ifdef POD
self.copy = JOIN(A, implicit_copy);
#else
self.free = JOIN(T, free);
self.copy = JOIN(T, copy);
#endif
self.hash = copy->hash;
self.equal = copy->equal;
A self = JOIN(A, init)();
JOIN(A, _reserve)(&self, copy->bucket_max + 1);
return self;
}

static inline void JOIN(A, rehash)(A *self, size_t desired_count)
{
if (desired_count == (self->bucket_max + 1))
return;
A rehashed = JOIN(A, init)(self->hash, self->equal);
A rehashed = JOIN(A, init)();
JOIN(A, reserve)(&rehashed, desired_count);
if (LIKELY(self->buckets && self->size)) // if desired_count 0
{
Expand Down Expand Up @@ -681,7 +671,7 @@ static inline void JOIN(A, _rehash)(A *self, size_t count)
// we do allow shrink here
if (count == self->bucket_max + 1)
return;
A rehashed = JOIN(A, init)(self->hash, self->equal);
A rehashed = JOIN(A, init)();
//LOG("_rehash %zu => %zu\n", self->size, count);
JOIN(A, _reserve)(&rehashed, count);

Expand Down Expand Up @@ -714,7 +704,7 @@ static inline B *JOIN(A, find_node)(A *self, T value)
if (self->size)
{
#ifdef CTL_USET_CACHED_HASH
size_t hash = self->hash(&value);
size_t hash = JOIN(T, hash)(&value);
B **buckets = JOIN(A, _bucket_hash)(self, hash);
#else
B **buckets = JOIN(A, _bucket)(self, value);
Expand All @@ -739,7 +729,7 @@ static inline B *JOIN(A, find_node)(A *self, T value)
if (n->cached_hash != hash)
continue;
#endif
if (self->equal(&value, &n->value))
if (JOIN(T, equal)(&value, &n->value))
{
#if 0 // not yet
// speedup subsequent read accesses?
Expand Down Expand Up @@ -802,7 +792,7 @@ static inline B **JOIN(A, push_cached)(A *self, T *value)
#endif

#ifdef CTL_USET_CACHED_HASH
size_t hash = self->hash(value);
size_t hash = JOIN(T, hash)(value);
B **buckets = JOIN(A, _bucket_hash)(self, hash);
JOIN(B, push)(buckets, JOIN(B, init_cached)(*value, hash));
#else
Expand Down Expand Up @@ -899,7 +889,7 @@ static inline I JOIN(A, emplace_hint)(I *pos, T *value)
if (!JOIN(I, done)(pos))
{
#ifdef CTL_USET_CACHED_HASH
size_t hash = self->hash(value);
size_t hash = JOIN(T, hash)(value);
B **buckets = JOIN(A, _bucket_hash)(self, hash);
#else
B **buckets = JOIN(A, _bucket)(self, *value);
Expand All @@ -924,7 +914,7 @@ static inline I JOIN(A, emplace_hint)(I *pos, T *value)
if (n->cached_hash != hash)
continue;
#endif
if (self->equal(value, &n->value))
if (JOIN(T, equal)(value, &n->value))
{
FREE_VALUE(self, *value);
return JOIN(I, iter)(self, n);
Expand Down Expand Up @@ -1060,7 +1050,7 @@ static inline void JOIN(A, _linked_erase)(A *self, B **bucket, B *n, B *prev, B
static inline void JOIN(A, erase)(A *self, T value)
{
#ifdef CTL_USET_CACHED_HASH
size_t hash = self->hash(&value);
size_t hash = JOIN(T, hash)(&value);
B **buckets = JOIN(A, _bucket_hash)(self, hash);
#else
B **buckets = JOIN(A, _bucket)(self, value);
Expand All @@ -1078,7 +1068,7 @@ static inline void JOIN(A, erase)(A *self, T value)
continue;
}
#endif
if (self->equal(&value, &n->value))
if (JOIN(T, equal)(&value, &n->value))
{
JOIN(A, _linked_erase)(self, buckets, n, prev, next);
break;
Expand Down Expand Up @@ -1115,7 +1105,7 @@ static inline size_t JOIN(A, erase_if)(A *self, int (*_match)(T *))
static inline A JOIN(A, copy)(A *self)
{
// LOG ("copy\norig size: %lu\n", self->size);
A other = JOIN(A, init)(self->hash, self->equal);
A other = JOIN(A, init)();
JOIN(A, _reserve)(&other, self->bucket_max + 1);
foreach (A, self, it)
{
Expand Down Expand Up @@ -1154,7 +1144,7 @@ static inline void JOIN(A, erase_generic)(A* self, GI *range)

static inline A JOIN(A, union)(A *a, A *b)
{
A self = JOIN(A, init)(a->hash, a->equal);
A self = JOIN(A, init)();
JOIN(A, _reserve)(&self, 1 + MAX(a->bucket_max, b->bucket_max));
foreach (A, a, it1)
JOIN(A, insert)(&self, self.copy(it1.ref));
Expand Down Expand Up @@ -1182,7 +1172,7 @@ static inline A JOIN(A, union_range)(I *r1, GI *r2)

static inline A JOIN(A, intersection)(A *a, A *b)
{
A self = JOIN(A, init)(a->hash, a->equal);
A self = JOIN(A, init)();
foreach (A, a, it)
if (JOIN(A, find_node)(b, *it.ref))
JOIN(A, insert)(&self, self.copy(it.ref));
Expand All @@ -1192,7 +1182,7 @@ static inline A JOIN(A, intersection)(A *a, A *b)
static inline A JOIN(A, intersection_range)(I *r1, GI *r2)
{
A *a = r1->container;
A self = JOIN(A, init)(a->hash, a->equal);
A self = JOIN(A, init)();
void (*next2)(struct I*) = r2->vtable.next;
T* (*ref2)(struct I*) = r2->vtable.ref;
int (*done2)(struct I*) = r2->vtable.done;
Expand All @@ -1214,7 +1204,7 @@ static inline A JOIN(A, intersection_range)(I *r1, GI *r2)

static inline A JOIN(A, difference)(A *a, A *b)
{
A self = JOIN(A, init)(a->hash, a->equal);
A self = JOIN(A, init)();
foreach (A, a, it)
if (!JOIN(A, find_node)(b, *it.ref))
JOIN(A, insert)(&self, self.copy(it.ref));
Expand Down
Loading

0 comments on commit 293ee85

Please sign in to comment.