aho-corasick algorithm in rust

Here's an implementation of Aho-Corasick algorithm in Rust:

main.rs
use std::collections::VecDeque;
use std::collections::HashMap;

#[derive(Debug)]
struct AhoNode {
    next: HashMap<char, usize>,
    fail: Option<usize>,
    word: Option<usize>,
}

impl AhoNode {
    fn new() -> Self {
        AhoNode {
            next: HashMap::new(),
            fail: None,
            word: None,
        }
    }
}

#[derive(Debug)]
struct AhoCorasick {
    trie: Vec<AhoNode>,
}

impl AhoCorasick {
    fn new() -> Self {
        let root = AhoNode::new();
        AhoCorasick { trie: vec![root] }
    }

    fn add_word(&mut self, word: &str) -> usize {
        let mut node_idx = 0;
        for c in word.chars() {
            node_idx = if let Some(&idx) = self.trie[node_idx].next.get(&c) {
                idx
            } else {
                let new_node_idx = self.trie.len();
                self.trie.push(AhoNode::new());
                self.trie[node_idx].next.insert(c, new_node_idx);
                new_node_idx
            }
        }

        if let Some(_) = self.trie[node_idx].word {
            self.trie[node_idx].word.unwrap()
        } else {
            let word_idx = self.trie.len();
            self.trie.push(AhoNode::new());
            self.trie[node_idx].word = Some(word_idx);
            word_idx
        }
    }

    fn build_failure_links(&mut self) {
        let mut queue = VecDeque::new();
        queue.push_back(0);

        while let Some(node_idx) = queue.pop_front() {
            for (&c, &next_node_idx) in &self.trie[node_idx].next {
                queue.push_back(next_node_idx);

                let mut fail_idx = self.trie[node_idx].fail.unwrap_or(0);
                while !self.trie[fail_idx].next.contains_key(&c) && fail_idx != 0 {
                    fail_idx = self.trie[fail_idx].fail.unwrap_or(0);
                }

                let next_fail_idx = self.trie[fail_idx].next.get(&c).cloned().unwrap_or(0);
                self.trie[next_node_idx].fail = Some(next_fail_idx);

                if let Some(word_idx) = self.trie[next_fail_idx].word {
                    self.trie[next_node_idx].word = Some(word_idx);
                }
            }
        }
    }

    fn find_words(&self, text: &str) -> Vec<(usize, &str)> {
        let mut node_idx = 0;
        let mut result = Vec::new();
        for (i, c) in text.chars().enumerate() {
            if let Some(&next_node_idx) = self.trie[node_idx].next.get(&c) {
                node_idx = next_node_idx;
                if let Some(word_idx) = self.trie[node_idx].word {
                    let j = i + 1 - self.trie[word_idx].next.len();
                    let word = &text[j..=i];
                    result.push((j, word));
                }
            } else {
                let mut fail_idx = self.trie[node_idx].fail.unwrap_or(0);
                while !self.trie[fail_idx].next.contains_key(&c) && fail_idx != 0 {
                    fail_idx = self.trie[fail_idx].fail.unwrap_or(0);
                }

                node_idx = self.trie[fail_idx].next.get(&c).cloned().unwrap_or(0);

                if let Some(word_idx) = self.trie[node_idx].word {
                    let j = i + 1 - self.trie[word_idx].next.len();
                    let word = &text[j..=i];
                    result.push((j, word));
                }
            }
        }

        result
    }
}

fn main() {
    let mut ac = AhoCorasick::new();
    ac.add_word("he");
    ac.add_word("she");
    ac.add_word("his");
    ac.add_word("hers");
    ac.build_failure_links();

    let text = "ushers";
    let words = ac.find_words(text);
    println!("{:?}", words);
}
3607 chars
121 lines

The above implementation uses a simple struct to represent each node in the trie. The AhoCorasick struct contains the trie, which is a vector of these nodes.

To use the algorithm, you can create an instance of AhoCorasick, add words to it using the add_word method, and then call the build_failure_links method to build the failure links between nodes.

Once the failure links are built, you can find all occurrences of the words in a given text using the find_words method. This method returns a vector of tuples containing the starting index and the matched word.

gistlibby LogSnag