btllib
counting_bloom_filter.hpp
1#ifndef BTLLIB_COUNTING_BLOOM_FILTER_HPP
2#define BTLLIB_COUNTING_BLOOM_FILTER_HPP
3
4#include "btllib/bloom_filter.hpp"
5#include "btllib/nthash.hpp"
6#include "btllib/status.hpp"
7
8#include "cpptoml.h"
9
10#include <atomic>
11#include <cmath>
12#include <cstdint>
13#include <fstream>
14#include <limits>
15#include <memory>
16#include <string>
17#include <vector>
18
19namespace btllib {
20
21static const char* const COUNTING_BLOOM_FILTER_SIGNATURE =
22 "[BTLCountingBloomFilter_v5]";
23static const char* const KMER_COUNTING_BLOOM_FILTER_SIGNATURE =
24 "[BTLKmerCountingBloomFilter_v5]";
25
26template<typename T>
27class KmerCountingBloomFilter;
28
34template<typename T>
36{
37
38public:
41
49 CountingBloomFilter(size_t bytes,
50 unsigned hash_num,
51 std::string hash_fn = "");
52
58 explicit CountingBloomFilter(const std::string& path);
59
62
63 CountingBloomFilter& operator=(const CountingBloomFilter&) = delete;
64 CountingBloomFilter& operator=(CountingBloomFilter&&) = delete;
65
72 void insert(const uint64_t* hashes);
73
79 void insert(const std::vector<uint64_t>& hashes) { insert(hashes.data()); }
80
89 T contains(const uint64_t* hashes) const;
90
98 T contains(const std::vector<uint64_t>& hashes) const
99 {
100 return contains(hashes.data());
101 }
102
111 T contains_insert(const uint64_t* hashes);
112
120 T contains_insert(const std::vector<uint64_t>& hashes)
121 {
122 return contains_insert(hashes.data());
123 }
124
134 T insert_contains(const uint64_t* hashes);
135
143 T insert_contains(const std::vector<uint64_t>& hashes)
144 {
145 return insert_contains(hashes.data());
146 }
147
159 T insert_thresh_contains(const uint64_t* hashes, T threshold);
160
172 T insert_thresh_contains(const std::vector<uint64_t>& hashes,
173 const T threshold)
174 {
175 return insert_thresh_contains(hashes.data(), threshold);
176 }
177
189 T contains_insert_thresh(const uint64_t* hashes, T threshold);
190
200 T contains_insert_thresh(const std::vector<uint64_t>& hashes,
201 const T threshold)
202 {
203 return contains_insert_thresh(hashes.data(), threshold);
204 }
205
207 size_t get_bytes() const { return bytes; }
209 uint64_t get_pop_cnt() const;
211 double get_occupancy() const;
213 unsigned get_hash_num() const { return hash_num; }
215 double get_fpr() const;
217 const std::string& get_hash_fn() const { return hash_fn; }
218
224 void save(const std::string& path);
225
231 static bool is_bloom_file(const std::string& path)
232 {
233 return btllib::BloomFilter::check_file_signature(
234 path, COUNTING_BLOOM_FILTER_SIGNATURE);
235 }
236
237private:
238 CountingBloomFilter(const std::shared_ptr<BloomFilterInitializer>& bfi);
239
240 void insert(const uint64_t* hashes, T min_val);
241
242 friend class KmerCountingBloomFilter<T>;
243
244 size_t bytes = 0;
245 size_t array_size = 0;
246 unsigned hash_num = 0;
247 std::string hash_fn;
248 std::unique_ptr<std::atomic<T>[]> array;
249};
250
256template<typename T>
258{
259
260public:
263
271 KmerCountingBloomFilter(size_t bytes, unsigned hash_num, unsigned k);
272
278 explicit KmerCountingBloomFilter(const std::string& path);
279
282
283 KmerCountingBloomFilter& operator=(const KmerCountingBloomFilter&) = delete;
285
292 void insert(const char* seq, size_t seq_len);
293
299 void insert(const std::string& seq) { insert(seq.c_str(), seq.size()); }
300
307 void insert(const uint64_t* hashes) { counting_bloom_filter.insert(hashes); }
308
314 void insert(const std::vector<uint64_t>& hashes)
315 {
316 counting_bloom_filter.insert(hashes);
317 }
318
327 uint64_t contains(const char* seq, size_t seq_len) const;
328
336 uint64_t contains(const std::string& seq) const
337 {
338 return contains(seq.c_str(), seq.size());
339 }
340
349 T contains(const uint64_t* hashes) const
350 {
351 return counting_bloom_filter.contains(hashes);
352 }
353
361 T contains(const std::vector<uint64_t>& hashes) const
362 {
363 return counting_bloom_filter.contains(hashes);
364 }
365
374 T contains_insert(const char* seq, size_t seq_len);
375
383 T contains_insert(const std::string& seq)
384 {
385 return contains_insert(seq.c_str(), seq.size());
386 }
387
396 T contains_insert(const uint64_t* hashes)
397 {
398 return counting_bloom_filter.contains_insert(hashes);
399 }
400
408 T contains_insert(const std::vector<uint64_t>& hashes)
409 {
410 return counting_bloom_filter.contains_insert(hashes);
411 }
412
421 T insert_contains(const char* seq, size_t seq_len);
422
430 T insert_contains(const std::string& seq)
431 {
432 return insert_contains(seq.c_str(), seq.size());
433 }
434
444 T insert_contains(const uint64_t* hashes)
445 {
446 return counting_bloom_filter.insert_contains(hashes);
447 }
448
456 T insert_contains(const std::vector<uint64_t>& hashes)
457 {
458 return counting_bloom_filter.insert_contains(hashes);
459 }
460
471 T insert_thresh_contains(const char* seq, size_t seq_len, T threshold);
472
482 T insert_thresh_contains(const std::string& seq, const T threshold)
483 {
484 return insert_thresh_contains(seq.c_str(), seq.size(), threshold);
485 }
486
498 T insert_thresh_contains(const uint64_t* hashes, const T threshold)
499 {
500 return counting_bloom_filter.insert_thresh_contains(hashes, threshold);
501 }
502
514 T insert_thresh_contains(const std::vector<uint64_t>& hashes,
515 const T threshold)
516 {
517 return counting_bloom_filter.insert_thresh_contains(hashes, threshold);
518 }
519
530 T contains_insert_thresh(const char* seq, size_t seq_len, T threshold);
531
541 T contains_insert_thresh(const std::string& seq, const T threshold)
542 {
543 return contains_insert_thresh(seq.c_str(), seq.size(), threshold);
544 }
545
557 T contains_insert_thresh(const uint64_t* hashes, const T threshold)
558 {
559 return counting_bloom_filter.contains_insert_thresh(hashes, threshold);
560 }
561
571 T contains_insert_thresh(const std::vector<uint64_t>& hashes,
572 const T threshold)
573 {
574 return counting_bloom_filter.contains_insert_thresh(hashes, threshold);
575 }
576
578 size_t get_bytes() const { return counting_bloom_filter.get_bytes(); }
580 uint64_t get_pop_cnt() const { return counting_bloom_filter.get_pop_cnt(); }
582 double get_occupancy() const { return counting_bloom_filter.get_occupancy(); }
584 unsigned get_hash_num() const { return counting_bloom_filter.get_hash_num(); }
586 double get_fpr() const { return counting_bloom_filter.get_fpr(); }
588 unsigned get_k() const { return k; }
590 const std::string& get_hash_fn() const
591 {
592 return counting_bloom_filter.get_hash_fn();
593 }
596 {
597 return counting_bloom_filter;
598 }
599
605 void save(const std::string& path);
606
613 static bool is_bloom_file(const std::string& path)
614 {
615 return btllib::BloomFilter::check_file_signature(
616 path, KMER_COUNTING_BLOOM_FILTER_SIGNATURE);
617 }
618
619private:
620 KmerCountingBloomFilter(const std::shared_ptr<BloomFilterInitializer>& bfi);
621
622 unsigned k = 0;
623 CountingBloomFilter<T> counting_bloom_filter;
624};
625
626using CountingBloomFilter8 = CountingBloomFilter<uint8_t>;
627using CountingBloomFilter16 = CountingBloomFilter<uint16_t>;
628using CountingBloomFilter32 = CountingBloomFilter<uint32_t>;
629
630using KmerCountingBloomFilter8 = KmerCountingBloomFilter<uint8_t>;
631using KmerCountingBloomFilter16 = KmerCountingBloomFilter<uint16_t>;
632using KmerCountingBloomFilter32 = KmerCountingBloomFilter<uint32_t>;
633
634template<typename T>
636 unsigned hash_num,
637 std::string hash_fn)
638 : bytes(
639 size_t(std::ceil(double(bytes) / sizeof(uint64_t)) * sizeof(uint64_t)))
640 , array_size(get_bytes() / sizeof(array[0]))
641 , hash_num(hash_num)
642 , hash_fn(std::move(hash_fn))
643 , array(new std::atomic<T>[array_size])
644{
645 check_error(bytes == 0, "CountingBloomFilter: memory budget must be >0!");
646 check_error(hash_num == 0,
647 "CountingBloomFilter: number of hash values must be >0!");
649 hash_num > MAX_HASH_VALUES,
650 "CountingBloomFilter: number of hash values cannot be over 1024!");
651 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
652 "Atomic primitives take extra memory. CountingBloomFilter will "
653 "have less than " +
654 std::to_string(bytes) + " for bit array.");
655 std::memset((void*)array.get(), 0, array_size * sizeof(array[0]));
656}
657
658/*
659 * Assumes min_count is not std::numeric_limits<T>::max()
660 */
661template<typename T>
662inline void
663CountingBloomFilter<T>::insert(const uint64_t* hashes, T min_val)
664{
665 // Update flag to track if increment is done on at least one counter
666 bool update_done = false;
667 T new_val, tmp_min_val;
668 while (true) {
669 new_val = min_val + 1;
670 for (size_t i = 0; i < hash_num; ++i) {
671 tmp_min_val = min_val;
672 update_done = array[hashes[i] % array_size].compare_exchange_strong(
673 tmp_min_val, new_val);
674 }
675 if (update_done ||
676 (min_val = contains(hashes)) == std::numeric_limits<T>::max()) {
677 break;
678 }
679 }
680}
681
682template<typename T>
683inline void
684CountingBloomFilter<T>::insert(const uint64_t* hashes)
685{
686 contains_insert(hashes);
687}
688
689template<typename T>
690inline T
691CountingBloomFilter<T>::contains(const uint64_t* hashes) const
692{
693 T min = array[hashes[0] % array_size];
694 for (size_t i = 1; i < hash_num; ++i) {
695 const size_t idx = hashes[i] % array_size;
696 if (array[idx] < min) {
697 min = array[idx];
698 }
699 }
700 return min;
701}
702
703template<typename T>
704inline T
706{
707 const auto count = contains(hashes);
708 if (count < std::numeric_limits<T>::max()) {
709 insert(hashes, count);
710 }
711 return count;
712}
713
714template<typename T>
715inline T
717{
718 const auto count = contains(hashes);
719 if (count < std::numeric_limits<T>::max()) {
720 insert(hashes, count);
721 return count + 1;
722 }
723 return std::numeric_limits<T>::max();
724}
725
726template<typename T>
727inline T
729 const T threshold)
730{
731 const auto count = contains(hashes);
732 if (count < threshold) {
733 insert(hashes, count);
734 return count + 1;
735 }
736 return count;
737}
738
739template<typename T>
740inline T
742 const T threshold)
743{
744 const auto count = contains(hashes);
745 if (count < threshold) {
746 insert(hashes, count);
747 }
748 return count;
749}
750
751template<typename T>
752inline uint64_t
754{
755 uint64_t pop_cnt = 0;
756#pragma omp parallel for default(none) reduction(+ : pop_cnt)
757 for (size_t i = 0; i < array_size; ++i) {
758 if (array[i] > 0) {
759 ++pop_cnt;
760 }
761 }
762 return pop_cnt;
763}
764
765template<typename T>
766inline double
768{
769 return double(get_pop_cnt()) / double(array_size);
770}
771
772template<typename T>
773inline double
775{
776 return std::pow(get_occupancy(), double(hash_num));
777}
778
779template<typename T>
780inline CountingBloomFilter<T>::CountingBloomFilter(const std::string& path)
782 std::make_shared<BloomFilterInitializer>(path,
783 COUNTING_BLOOM_FILTER_SIGNATURE))
784{}
785
786template<typename T>
788 const std::shared_ptr<BloomFilterInitializer>& bfi)
789 : bytes(*bfi->table->get_as<decltype(bytes)>("bytes"))
790 , array_size(bytes / sizeof(array[0]))
791 , hash_num(*(bfi->table->get_as<decltype(hash_num)>("hash_num")))
792 , hash_fn(bfi->table->contains("hash_fn")
793 ? *(bfi->table->get_as<decltype(hash_fn)>("hash_fn"))
794 : "")
795 , array(new std::atomic<T>[array_size])
796{
797 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
798 "Atomic primitives take extra memory. CountingBloomFilter will "
799 "have less than " +
800 std::to_string(bytes) + " for bit array.");
801 const auto loaded_counter_bits =
802 *(bfi->table->get_as<size_t>("counter_bits"));
803 check_error(sizeof(array[0]) * CHAR_BIT != loaded_counter_bits,
804 "CountingBloomFilter" +
805 std::to_string(sizeof(array[0]) * CHAR_BIT) +
806 " tried to load a file of CountingBloomFilter" +
807 std::to_string(loaded_counter_bits));
808 bfi->ifs.read((char*)array.get(),
809 std::streamsize(array_size * sizeof(array[0])));
810}
811
812template<typename T>
813inline void
814CountingBloomFilter<T>::save(const std::string& path)
815{
816 /* Initialize cpptoml root table
817 Note: Tables and fields are unordered
818 Ordering of table is maintained by directing the table
819 to the output stream immediately after completion */
820 auto root = cpptoml::make_table();
821
822 /* Initialize bloom filter section and insert fields
823 and output to ostream */
824 auto header = cpptoml::make_table();
825 header->insert("bytes", get_bytes());
826 header->insert("hash_num", get_hash_num());
827 if (!hash_fn.empty()) {
828 header->insert("hash_fn", hash_fn);
829 }
830 header->insert("counter_bits", size_t(sizeof(array[0]) * CHAR_BIT));
831 std::string header_string = COUNTING_BLOOM_FILTER_SIGNATURE;
832 header_string =
833 header_string.substr(1, header_string.size() - 2); // Remove [ ]
834 root->insert(header_string, header);
835
837 path, *root, (char*)array.get(), array_size * sizeof(array[0]));
838}
839
840template<typename T>
842 unsigned hash_num,
843 unsigned k)
844 : k(k)
845 , counting_bloom_filter(bytes, hash_num, HASH_FN)
846{}
847
848template<typename T>
849inline void
850KmerCountingBloomFilter<T>::insert(const char* seq, size_t seq_len)
851{
852 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
853 while (nthash.roll()) {
854 counting_bloom_filter.insert(nthash.hashes());
855 }
856}
857
858template<typename T>
859inline uint64_t
860KmerCountingBloomFilter<T>::contains(const char* seq, size_t seq_len) const
861{
862 uint64_t sum = 0;
863 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
864 while (nthash.roll()) {
865 sum += counting_bloom_filter.contains(nthash.hashes());
866 }
867 return sum;
868}
869
870template<typename T>
871inline T
872KmerCountingBloomFilter<T>::contains_insert(const char* seq, size_t seq_len)
873{
874 uint64_t sum = 0;
875 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
876 while (nthash.roll()) {
877 sum += counting_bloom_filter.contains_insert(nthash.hashes());
878 }
879 return sum;
880}
881
882template<typename T>
883inline T
884KmerCountingBloomFilter<T>::insert_contains(const char* seq, size_t seq_len)
885{
886 uint64_t sum = 0;
887 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
888 while (nthash.roll()) {
889 sum += counting_bloom_filter.insert_contains(nthash.hashes());
890 }
891 return sum;
892}
893
894template<typename T>
895inline T
897 size_t seq_len,
898 const T threshold)
899{
900 uint64_t sum = 0;
901 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
902 while (nthash.roll()) {
903 sum +=
904 counting_bloom_filter.insert_thresh_contains(nthash.hashes(), threshold);
905 }
906 return sum;
907}
908
909template<typename T>
910inline T
912 size_t seq_len,
913 const T threshold)
914{
915 uint64_t sum = 0;
916 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
917 while (nthash.roll()) {
918 sum +=
919 counting_bloom_filter.contains_insert_thresh(nthash.hashes(), threshold);
920 }
921 return sum;
922}
923
924template<typename T>
926 const std::string& path)
928 std::make_shared<BloomFilterInitializer>(
929 path,
930 KMER_COUNTING_BLOOM_FILTER_SIGNATURE))
931{}
932
933template<typename T>
935 const std::shared_ptr<BloomFilterInitializer>& bfi)
936 : k(*(bfi->table->get_as<decltype(k)>("k")))
937 , counting_bloom_filter(bfi)
938{
939 check_error(counting_bloom_filter.hash_fn != HASH_FN,
940 "KmerCountingBloomFilter: loaded hash function (" +
941 counting_bloom_filter.hash_fn +
942 ") is different from the one used by default (" + HASH_FN +
943 ").");
944}
945
946template<typename T>
947inline void
948KmerCountingBloomFilter<T>::save(const std::string& path)
949{
950 /* Initialize cpptoml root table
951 Note: Tables and fields are unordered
952 Ordering of table is maintained by directing the table
953 to the output stream immediately after completion */
954 auto root = cpptoml::make_table();
955
956 /* Initialize bloom filter section and insert fields
957 and output to ostream */
958 auto header = cpptoml::make_table();
959 header->insert("bytes", get_bytes());
960 header->insert("hash_num", get_hash_num());
961 header->insert("hash_fn", get_hash_fn());
962 header->insert("counter_bits",
963 size_t(sizeof(counting_bloom_filter.array[0]) * CHAR_BIT));
964 header->insert("k", k);
965 std::string header_string = KMER_COUNTING_BLOOM_FILTER_SIGNATURE;
966 header_string =
967 header_string.substr(1, header_string.size() - 2); // Remove [ ]
968 root->insert(header_string, header);
969
971 *root,
972 (char*)counting_bloom_filter.array.get(),
973 counting_bloom_filter.array_size *
974 sizeof(counting_bloom_filter.array[0]));
975}
976
977} // namespace btllib
978
979#endif
void save(const std::string &path)
Definition: counting_bloom_filter.hpp:36
T insert_thresh_contains(const std::vector< uint64_t > &hashes, const T threshold)
Definition: counting_bloom_filter.hpp:172
void insert(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:79
double get_fpr() const
Definition: counting_bloom_filter.hpp:774
const std::string & get_hash_fn() const
Definition: counting_bloom_filter.hpp:217
T contains_insert(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:705
static bool is_bloom_file(const std::string &path)
Definition: counting_bloom_filter.hpp:231
T insert_contains(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:143
void save(const std::string &path)
Definition: counting_bloom_filter.hpp:814
uint64_t get_pop_cnt() const
Definition: counting_bloom_filter.hpp:753
T contains_insert_thresh(const uint64_t *hashes, T threshold)
Definition: counting_bloom_filter.hpp:741
T contains(const std::vector< uint64_t > &hashes) const
Definition: counting_bloom_filter.hpp:98
size_t get_bytes() const
Definition: counting_bloom_filter.hpp:207
unsigned get_hash_num() const
Definition: counting_bloom_filter.hpp:213
T contains(const uint64_t *hashes) const
Definition: counting_bloom_filter.hpp:691
CountingBloomFilter()
Definition: counting_bloom_filter.hpp:40
T insert_contains(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:716
void insert(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:684
T contains_insert_thresh(const std::vector< uint64_t > &hashes, const T threshold)
Definition: counting_bloom_filter.hpp:200
T insert_thresh_contains(const uint64_t *hashes, T threshold)
Definition: counting_bloom_filter.hpp:728
T contains_insert(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:120
double get_occupancy() const
Definition: counting_bloom_filter.hpp:767
Definition: counting_bloom_filter.hpp:258
uint64_t contains(const char *seq, size_t seq_len) const
Definition: counting_bloom_filter.hpp:860
T insert_contains(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:456
T insert_thresh_contains(const std::vector< uint64_t > &hashes, const T threshold)
Definition: counting_bloom_filter.hpp:514
void save(const std::string &path)
Definition: counting_bloom_filter.hpp:948
CountingBloomFilter< T > & get_counting_bloom_filter()
Definition: counting_bloom_filter.hpp:595
T insert_contains(const char *seq, size_t seq_len)
Definition: counting_bloom_filter.hpp:884
T insert_thresh_contains(const std::string &seq, const T threshold)
Definition: counting_bloom_filter.hpp:482
T contains_insert_thresh(const uint64_t *hashes, const T threshold)
Definition: counting_bloom_filter.hpp:557
T insert_thresh_contains(const char *seq, size_t seq_len, T threshold)
Definition: counting_bloom_filter.hpp:896
T contains(const uint64_t *hashes) const
Definition: counting_bloom_filter.hpp:349
T contains_insert(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:396
T contains_insert(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:408
T contains_insert(const std::string &seq)
Definition: counting_bloom_filter.hpp:383
size_t get_bytes() const
Definition: counting_bloom_filter.hpp:578
T contains_insert_thresh(const char *seq, size_t seq_len, T threshold)
Definition: counting_bloom_filter.hpp:911
T insert_thresh_contains(const uint64_t *hashes, const T threshold)
Definition: counting_bloom_filter.hpp:498
T contains_insert_thresh(const std::vector< uint64_t > &hashes, const T threshold)
Definition: counting_bloom_filter.hpp:571
T insert_contains(const std::string &seq)
Definition: counting_bloom_filter.hpp:430
const std::string & get_hash_fn() const
Definition: counting_bloom_filter.hpp:590
double get_occupancy() const
Definition: counting_bloom_filter.hpp:582
void insert(const std::vector< uint64_t > &hashes)
Definition: counting_bloom_filter.hpp:314
void insert(const char *seq, size_t seq_len)
Definition: counting_bloom_filter.hpp:850
T contains_insert(const char *seq, size_t seq_len)
Definition: counting_bloom_filter.hpp:872
unsigned get_hash_num() const
Definition: counting_bloom_filter.hpp:584
T insert_contains(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:444
uint64_t get_pop_cnt() const
Definition: counting_bloom_filter.hpp:580
unsigned get_k() const
Definition: counting_bloom_filter.hpp:588
void insert(const std::string &seq)
Definition: counting_bloom_filter.hpp:299
uint64_t contains(const std::string &seq) const
Definition: counting_bloom_filter.hpp:336
void insert(const uint64_t *hashes)
Definition: counting_bloom_filter.hpp:307
T contains_insert_thresh(const std::string &seq, const T threshold)
Definition: counting_bloom_filter.hpp:541
static bool is_bloom_file(const std::string &path)
Definition: counting_bloom_filter.hpp:613
KmerCountingBloomFilter()
Definition: counting_bloom_filter.hpp:262
double get_fpr() const
Definition: counting_bloom_filter.hpp:586
T contains(const std::vector< uint64_t > &hashes) const
Definition: counting_bloom_filter.hpp:361
Definition: nthash.hpp:54
Definition: bloom_filter.hpp:16
void check_error(bool condition, const std::string &msg)
void check_warning(bool condition, const std::string &msg)