nostrdb: filter: add initial custom filtering logic

This adds some helpers for adding custom filtering logic
to nostr filters. These are just a callback and a closure.
There can only be one custom callback filter per filter.

Fixes: https://github.com/damus-io/nostrdb/issues/33
Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
William Casarin
2025-04-08 16:29:16 -07:00
committed by Daniel D’Aquino
parent 0b8090cb28
commit 64c16e7cc8
2 changed files with 88 additions and 1 deletions

View File

@@ -671,6 +671,12 @@ ndb_filter_elements_data(const struct ndb_filter *filter, int offset)
return data; return data;
} }
struct ndb_filter_custom *
ndb_filter_get_custom_element(const struct ndb_filter *filter, const struct ndb_filter_elements *els)
{
return (struct ndb_filter_custom *)ndb_filter_elements_data(filter, els->elements[0]);
}
unsigned char * unsigned char *
ndb_filter_get_id_element(const struct ndb_filter *filter, const struct ndb_filter_elements *els, int index) ndb_filter_get_id_element(const struct ndb_filter *filter, const struct ndb_filter_elements *els, int index)
{ {
@@ -754,6 +760,7 @@ static const char *ndb_filter_field_name(enum ndb_filter_fieldtype field)
case NDB_FILTER_LIMIT: return "limit"; case NDB_FILTER_LIMIT: return "limit";
case NDB_FILTER_SEARCH: return "search"; case NDB_FILTER_SEARCH: return "search";
case NDB_FILTER_RELAYS: return "relays"; case NDB_FILTER_RELAYS: return "relays";
case NDB_FILTER_CUSTOM: return "custom";
} }
return "unknown"; return "unknown";
@@ -821,6 +828,10 @@ static int ndb_filter_add_element(struct ndb_filter *filter, union ndb_filter_el
offset = filter->data_buf.p - filter->data_buf.start; offset = filter->data_buf.p - filter->data_buf.start;
switch (current->field.type) { switch (current->field.type) {
case NDB_FILTER_CUSTOM:
if (!cursor_push(&filter->data_buf, (unsigned char *)&el, sizeof(el)))
return 0;
break;
case NDB_FILTER_IDS: case NDB_FILTER_IDS:
case NDB_FILTER_AUTHORS: case NDB_FILTER_AUTHORS:
if (!cursor_push(&filter->data_buf, (unsigned char *)el.id, 32)) if (!cursor_push(&filter->data_buf, (unsigned char *)el.id, 32))
@@ -861,6 +872,7 @@ static int ndb_filter_add_element(struct ndb_filter *filter, union ndb_filter_el
case NDB_ELEMENT_INT: case NDB_ELEMENT_INT:
// ints are not allowed in tag filters // ints are not allowed in tag filters
case NDB_ELEMENT_UNKNOWN: case NDB_ELEMENT_UNKNOWN:
case NDB_ELEMENT_CUSTOM:
return 0; return 0;
} }
// push a pointer of the string in the databuf as an element // push a pointer of the string in the databuf as an element
@@ -925,6 +937,7 @@ int ndb_filter_add_str_element_len(struct ndb_filter *filter, const char *str, i
case NDB_FILTER_IDS: case NDB_FILTER_IDS:
case NDB_FILTER_AUTHORS: case NDB_FILTER_AUTHORS:
case NDB_FILTER_KINDS: case NDB_FILTER_KINDS:
case NDB_FILTER_CUSTOM:
return 0; return 0;
case NDB_FILTER_SEARCH: case NDB_FILTER_SEARCH:
if (current->count == 1) { if (current->count == 1) {
@@ -951,6 +964,41 @@ int ndb_filter_add_str_element(struct ndb_filter *filter, const char *str)
return ndb_filter_add_str_element_len(filter, str, strlen(str)); return ndb_filter_add_str_element_len(filter, str, strlen(str));
} }
int ndb_filter_add_custom_filter_element(struct ndb_filter *filter, ndb_filter_callback_fn *cb, void *ctx)
{
union ndb_filter_element el;
struct ndb_filter_elements *current;
struct ndb_filter_custom custom;
custom.cb = cb;
custom.ctx = ctx;
if (!(current = ndb_filter_current_element(filter)))
return 0;
switch (current->field.type) {
case NDB_FILTER_CUSTOM:
break;
case NDB_FILTER_IDS:
case NDB_FILTER_AUTHORS:
case NDB_FILTER_TAGS:
case NDB_FILTER_SEARCH:
case NDB_FILTER_RELAYS:
case NDB_FILTER_KINDS:
case NDB_FILTER_SINCE:
case NDB_FILTER_UNTIL:
case NDB_FILTER_LIMIT:
return 0;
}
if (!ndb_filter_set_elem_type(filter, NDB_ELEMENT_CUSTOM))
return 0;
el.custom_filter = custom;
return ndb_filter_add_element(filter, el);
}
int ndb_filter_add_int_element(struct ndb_filter *filter, uint64_t integer) int ndb_filter_add_int_element(struct ndb_filter *filter, uint64_t integer)
{ {
union ndb_filter_element el; union ndb_filter_element el;
@@ -964,6 +1012,7 @@ int ndb_filter_add_int_element(struct ndb_filter *filter, uint64_t integer)
case NDB_FILTER_TAGS: case NDB_FILTER_TAGS:
case NDB_FILTER_SEARCH: case NDB_FILTER_SEARCH:
case NDB_FILTER_RELAYS: case NDB_FILTER_RELAYS:
case NDB_FILTER_CUSTOM:
return 0; return 0;
case NDB_FILTER_KINDS: case NDB_FILTER_KINDS:
case NDB_FILTER_SINCE: case NDB_FILTER_SINCE:
@@ -996,6 +1045,7 @@ int ndb_filter_add_id_element(struct ndb_filter *filter, const unsigned char *id
case NDB_FILTER_KINDS: case NDB_FILTER_KINDS:
case NDB_FILTER_SEARCH: case NDB_FILTER_SEARCH:
case NDB_FILTER_RELAYS: case NDB_FILTER_RELAYS:
case NDB_FILTER_CUSTOM:
return 0; return 0;
case NDB_FILTER_IDS: case NDB_FILTER_IDS:
case NDB_FILTER_AUTHORS: case NDB_FILTER_AUTHORS:
@@ -1079,6 +1129,7 @@ static int ndb_tag_filter_matches(struct ndb_filter *filter,
case NDB_ELEMENT_INT: case NDB_ELEMENT_INT:
// int elements int tag queries are not supported // int elements int tag queries are not supported
case NDB_ELEMENT_UNKNOWN: case NDB_ELEMENT_UNKNOWN:
case NDB_ELEMENT_CUSTOM:
return 0; return 0;
} }
} }
@@ -1166,6 +1217,7 @@ static int ndb_filter_matches_with(struct ndb_filter *filter,
int i, j; int i, j;
struct ndb_filter_elements *els; struct ndb_filter_elements *els;
struct search_id_state state; struct search_id_state state;
struct ndb_filter_custom *custom;
state.filter = filter; state.filter = filter;
@@ -1244,6 +1296,12 @@ static int ndb_filter_matches_with(struct ndb_filter *filter,
// the search index will be walked for these kinds // the search index will be walked for these kinds
// of queries. // of queries.
continue; continue;
case NDB_FILTER_CUSTOM:
custom = ndb_filter_get_custom_element(filter, els);
if (custom->cb(custom->ctx, note))
continue;
break;
case NDB_FILTER_LIMIT: case NDB_FILTER_LIMIT:
cont: cont:
continue; continue;
@@ -1297,6 +1355,7 @@ static int ndb_filter_field_eq(struct ndb_filter *a_filt,
const char *a_str, *b_str; const char *a_str, *b_str;
unsigned char *a_id, *b_id; unsigned char *a_id, *b_id;
uint64_t a_int, b_int; uint64_t a_int, b_int;
struct ndb_filter_custom *a_custom, *b_custom;
if (a_field->count != b_field->count) if (a_field->count != b_field->count)
return 0; return 0;
@@ -1318,6 +1377,11 @@ static int ndb_filter_field_eq(struct ndb_filter *a_filt,
for (i = 0; i < a_field->count; i++) { for (i = 0; i < a_field->count; i++) {
switch (a_field->field.elem_type) { switch (a_field->field.elem_type) {
case NDB_ELEMENT_CUSTOM:
a_custom = ndb_filter_get_custom_element(a_filt, a_field);
b_custom = ndb_filter_get_custom_element(b_filt, b_field);
if (memcmp(a_custom, b_custom, sizeof(*a_custom)))
return 0;
case NDB_ELEMENT_UNKNOWN: case NDB_ELEMENT_UNKNOWN:
return 0; return 0;
case NDB_ELEMENT_STRING: case NDB_ELEMENT_STRING:
@@ -1373,6 +1437,7 @@ void ndb_filter_end_field(struct ndb_filter *filter)
// TODO: generic tag search sorting // TODO: generic tag search sorting
break; break;
case NDB_FILTER_SINCE: case NDB_FILTER_SINCE:
case NDB_FILTER_CUSTOM:
case NDB_FILTER_UNTIL: case NDB_FILTER_UNTIL:
case NDB_FILTER_LIMIT: case NDB_FILTER_LIMIT:
case NDB_FILTER_SEARCH: case NDB_FILTER_SEARCH:
@@ -6412,6 +6477,9 @@ static int cursor_push_json_elem_array(struct cursor *cur,
for (i = 0; i < elems->count; i++) { for (i = 0; i < elems->count; i++) {
switch (elems->field.elem_type) { switch (elems->field.elem_type) {
case NDB_ELEMENT_CUSTOM:
// can't serialize custom functions
break;
case NDB_ELEMENT_STRING: case NDB_ELEMENT_STRING:
str = ndb_filter_get_string_element(filter, elems, i); str = ndb_filter_get_string_element(filter, elems, i);
if (!cursor_push_jsonstr(cur, str)) if (!cursor_push_jsonstr(cur, str))
@@ -6464,6 +6532,9 @@ int ndb_filter_json(const struct ndb_filter *filter, char *buf, int buflen)
for (i = 0; i < filter->num_elements; i++) { for (i = 0; i < filter->num_elements; i++) {
elems = ndb_filter_get_elements(filter, i); elems = ndb_filter_get_elements(filter, i);
switch (elems->field.type) { switch (elems->field.type) {
case NDB_FILTER_CUSTOM:
// nothing to encode these as
break;
case NDB_FILTER_IDS: case NDB_FILTER_IDS:
if (!cursor_push_str(c, "\"ids\":")) if (!cursor_push_str(c, "\"ids\":"))
return 0; return 0;
@@ -7325,6 +7396,9 @@ static int ndb_filter_parse_json(struct ndb_json_parser *parser,
// we parsed a top-level field // we parsed a top-level field
switch(field) { switch(field) {
case NDB_FILTER_CUSTOM:
// can't really parse these yet
break;
case NDB_FILTER_AUTHORS: case NDB_FILTER_AUTHORS:
case NDB_FILTER_IDS: case NDB_FILTER_IDS:
if (!ndb_filter_parse_json_ids(parser, filter)) { if (!ndb_filter_parse_json_ids(parser, filter)) {

View File

@@ -2,6 +2,7 @@
#define NOSTRDB_H #define NOSTRDB_H
#include <inttypes.h> #include <inttypes.h>
#include <stdbool.h>
#include "win.h" #include "win.h"
#include "cursor.h" #include "cursor.h"
@@ -48,6 +49,7 @@ struct ndb_t {
}; };
struct ndb_str { struct ndb_str {
// NDB_PACKED_STR, NDB_PACKED_ID
unsigned char flag; unsigned char flag;
union { union {
const char *str; const char *str;
@@ -163,8 +165,9 @@ enum ndb_filter_fieldtype {
NDB_FILTER_LIMIT = 7, NDB_FILTER_LIMIT = 7,
NDB_FILTER_SEARCH = 8, NDB_FILTER_SEARCH = 8,
NDB_FILTER_RELAYS = 9, NDB_FILTER_RELAYS = 9,
NDB_FILTER_CUSTOM = 10,
}; };
#define NDB_NUM_FILTERS 7 #define NDB_NUM_FILTERS 10
// when matching generic tags, we need to know if we're dealing with // when matching generic tags, we need to know if we're dealing with
// a pointer to a 32-byte ID or a null terminated string // a pointer to a 32-byte ID or a null terminated string
@@ -173,6 +176,7 @@ enum ndb_generic_element_type {
NDB_ELEMENT_STRING = 1, NDB_ELEMENT_STRING = 1,
NDB_ELEMENT_ID = 2, NDB_ELEMENT_ID = 2,
NDB_ELEMENT_INT = 3, NDB_ELEMENT_INT = 3,
NDB_ELEMENT_CUSTOM = 4,
}; };
enum ndb_search_order { enum ndb_search_order {
@@ -250,10 +254,18 @@ struct ndb_filter_string {
int len; int len;
}; };
typedef bool ndb_filter_callback_fn(void *, struct ndb_note *);
struct ndb_filter_custom {
void *ctx;
ndb_filter_callback_fn *cb;
};
union ndb_filter_element { union ndb_filter_element {
struct ndb_filter_string string; struct ndb_filter_string string;
const unsigned char *id; const unsigned char *id;
uint64_t integer; uint64_t integer;
struct ndb_filter_custom custom_filter;
}; };
struct ndb_filter_field { struct ndb_filter_field {
@@ -551,6 +563,7 @@ int ndb_filter_init_with(struct ndb_filter *filter, int pages);
int ndb_filter_add_id_element(struct ndb_filter *, const unsigned char *id); int ndb_filter_add_id_element(struct ndb_filter *, const unsigned char *id);
int ndb_filter_add_int_element(struct ndb_filter *, uint64_t integer); int ndb_filter_add_int_element(struct ndb_filter *, uint64_t integer);
int ndb_filter_add_str_element(struct ndb_filter *, const char *str); int ndb_filter_add_str_element(struct ndb_filter *, const char *str);
int ndb_filter_add_custom_filter_element(struct ndb_filter *filter, ndb_filter_callback_fn *cb, void *ctx);
int ndb_filter_eq(const struct ndb_filter *, const struct ndb_filter *); int ndb_filter_eq(const struct ndb_filter *, const struct ndb_filter *);
/// is `a` a subset of `b` /// is `a` a subset of `b`