longchute

about

Multipattern Search

03 Dec 2013

This is a simple multipattern search algorithm. It isn't particularly speedy (grep -Ff words target which uses Aho-Corasick will outperform this algorithm) but it is quick to implement . It reads a list of words, sorts them into sublists grouped by length, and converts each sublist into a trie. Then, for each trie, it slides a window of that length (trie depth) over the input and walks the trie for each value in the window. Matches are stored and output at the end.

Approximately 1 MB of random sample data may be generated with openssl rand -out target -base64 $(( 2**20 * 3/4 )).

Usage: ./mps.py words target where words is the name of a file with one word per line and target is the name of the file to be searched. If debug = True, the search terms, search trees, and progress meters will be displayed. If debug = False, only matches will be printed. Matches are displayed one to a line: the match location in bytes from the beginning, a space, and the word matched.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python

from pprint import pprint as pp
import sys

debug = True

#   Input data
words = []
with open(sys.argv[1], 'r') as f:
    for a_word in f:
        a_word = a_word[:-1] if a_word[-1] == '\n' else a_word
        words.append((a_word, len(a_word)))

#   Length sort
words.sort(key=lambda w: w[1])

#   Split into lists by length
current_length  = 0
sorted_words    = []

for a_word, word_length in words:
    if current_length != word_length:
        sorted_words.append([a_word])
        current_length = word_length
    else:
        sorted_words[-1].append(a_word)

if debug:
    pp(sorted_words)

#   Trie class to store search terms
class Trie(dict):
    def add(self, data):
        if not data:
            return
        try:
            self[data[0]].add(data[1:])
        except KeyError:
            self[data[0]] = Trie()
            self[data[0]].add(data[1:])

    def find(self, data):
        if not data:
            return True
        try:
            return self[data[0]].find(data[1:])
        except KeyError:
            return False

    def depth(self, n):
        if not ((self and True) or False):
            return n

        return self.values()[0].depth(n + 1)

#   Construct tries out of pre-sorted words
tries = []

for current_sublist, a_sublist in enumerate(sorted_words):
    tries.append(Trie())

    for a_word in a_sublist:
        tries[current_sublist].add(a_word)

if debug:
    pp(tries)

#   Make a pass for each trie, sliding a window of that trie's depth
#   over the input and walking the trie for matches.
found = []

for a_trie in tries:
    current_depth   = a_trie.depth(0)
    step_back       = 1-current_depth

    if debug: print("Current depth: %d" % current_depth)

    #   seeks over the file to avoid reading it all in at once
    with open(sys.argv[2], 'rb') as f:
        f.seek(0, 2)
        endpoint = f.tell()
        if endpoint < current_depth: break
        f.seek(0)

        #   progress bar
        i       = 0
        last_i  = 5
        per     = endpoint / 10000
        per     = per if per != 0 else 1
        
        if debug:
            sys.stdout.write('0.00%')
            sys.stdout.flush()

        #   main search loop (sliding window)
        while True:
            if (endpoint - f.tell()) < current_depth: break
            current_read = f.read(current_depth)
            if not current_read: break
            if a_trie.find(current_read):
                found.append("%d %s" % (f.tell() - current_depth, current_read))
            f.seek(step_back, 1)

            #   progress bar
            if debug:
                if (i % per) == 0:
                    i_percent = "%.2f" % (i * 100.0 / endpoint)
                    sys.stdout.write("%s%s%%" % (('\b' * last_i), i_percent))
                    sys.stdout.flush()
                    last_i = len(i_percent) + 1
                i += 1
        if debug:
            sys.stdout.write("%s100.00%%\n" % ('\b' * last_i))
            sys.stdout.flush()

#   Final results
for an_item in found:
    print(an_item)