tomlev

C# Trie (improved)

Aug 2nd, 2011
142
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5.  
  6. namespace MyCollections
  7. {
  8.     public class Trie<TValue> : IDictionary<string, TValue>
  9.     {
  10.         private int _count;
  11.         private readonly TrieNode _rootNode;
  12.  
  13.         public Trie()
  14.         {
  15.             _rootNode = new TrieNode('\0', null);
  16.         }
  17.  
  18.         #region Core implementation
  19.  
  20.         private IEnumerable<KeyValuePair<string, TValue>> Enumerate(TrieNode root, string prefix)
  21.         {
  22.             string key;
  23.             if (root == _rootNode)
  24.                 key = prefix + string.Empty;
  25.             else
  26.                 key = prefix + root.PartialKey;
  27.  
  28.             if (root.HasValue)
  29.                 yield return new KeyValuePair<string, TValue>(key, root.Value);
  30.  
  31.             foreach (var kvp in root.Children)
  32.             {
  33.                 foreach (var entry in Enumerate(kvp.Value, key))
  34.                 {
  35.                     yield return entry;
  36.                 }
  37.             }
  38.         }
  39.  
  40.         private static TrieNode FindNode(TrieNode root, char[] key, int depth, bool create, out bool created)
  41.         {
  42.             created = false;
  43.  
  44.             if (depth == key.Length)
  45.                 return root;
  46.  
  47.             TrieNode node;
  48.             if (root.Children.TryGetValue(key[depth], out node))
  49.                 return FindNode(node, key, depth + 1, create, out created);
  50.  
  51.             if (create)
  52.             {
  53.                 var current = root;
  54.                 node = null;
  55.                 for (int i = depth; i < key.Length; i++)
  56.                 {
  57.                     node = new TrieNode(key[i], current);
  58.                     current.Children.Add(key[i], node);
  59.                     current = node;
  60.                 }
  61.                 created = true;
  62.                 return node;
  63.             }
  64.  
  65.             return null;
  66.         }
  67.  
  68.         private TrieNode FindNode(string key, bool create, out bool created)
  69.         {
  70.             return FindNode(_rootNode, key.ToCharArray(), 0, create, out created);
  71.         }
  72.  
  73.         private TrieNode FindNode(string key, bool create)
  74.         {
  75.             bool created;
  76.             return FindNode(_rootNode, key.ToCharArray(), 0, create, out created);
  77.         }
  78.  
  79.         private static void Prune(TrieNode removedNode)
  80.         {
  81.             var current = removedNode;
  82.             while (current != null)
  83.             {
  84.                 if (current.Children.Any() || current.HasValue)
  85.                     break;
  86.                 if (current.Parent != null)
  87.                 {
  88.                     current.Parent.Children.Remove(current.PartialKey);
  89.                 }
  90.                 current = current.Parent;
  91.             }
  92.         }
  93.  
  94.         class TrieNode
  95.         {
  96.             private readonly char _partialKey;
  97.             private TValue _value;
  98.             private bool _hasValue;
  99.             private readonly TrieNode _parent;
  100.             private readonly SortedDictionary<char, TrieNode> _children;
  101.  
  102.             public TrieNode(char partialKey, TrieNode parent)
  103.             {
  104.                 _partialKey = partialKey;
  105.                 _parent = parent;
  106.                 _children = new SortedDictionary<char, TrieNode>();
  107.             }
  108.  
  109.             public TrieNode Parent
  110.             {
  111.                 get { return _parent; }
  112.             }
  113.  
  114.             public char PartialKey
  115.             {
  116.                 get { return _partialKey; }
  117.             }
  118.  
  119.             public bool HasValue
  120.             {
  121.                 get { return _hasValue; }
  122.             }
  123.  
  124.             public TValue Value
  125.             {
  126.                 get { return _value; }
  127.                 set
  128.                 {
  129.                     _value = value;
  130.                     _hasValue = true;
  131.                 }
  132.             }
  133.  
  134.             public void ClearValue()
  135.             {
  136.                 _hasValue = false;
  137.             }
  138.  
  139.             public IDictionary<char, TrieNode> Children
  140.             {
  141.                 get { return _children; }
  142.             }
  143.         }
  144.  
  145.         #endregion
  146.  
  147.         #region IDictionary<string, TValue> implementation
  148.  
  149.         public IEnumerator<KeyValuePair<string, TValue>> GetEnumerator()
  150.         {
  151.             return Enumerate(_rootNode, null).GetEnumerator();
  152.         }
  153.  
  154.         IEnumerator IEnumerable.GetEnumerator()
  155.         {
  156.             return GetEnumerator();
  157.         }
  158.  
  159.         void ICollection<KeyValuePair<string, TValue>>.Add(KeyValuePair<string, TValue> item)
  160.         {
  161.             Add(item.Key, item.Value);
  162.         }
  163.  
  164.         public void Clear()
  165.         {
  166.             _rootNode.ClearValue();
  167.             _rootNode.Children.Clear();
  168.         }
  169.  
  170.         bool ICollection<KeyValuePair<string, TValue>>.Contains(KeyValuePair<string, TValue> item)
  171.         {
  172.             TValue value;
  173.             if (TryGetValue(item.Key, out value))
  174.                 return Equals(value, item.Value);
  175.             return false;
  176.         }
  177.  
  178.         void ICollection<KeyValuePair<string, TValue>>.CopyTo(KeyValuePair<string, TValue>[] array, int arrayIndex)
  179.         {
  180.             if (_count + arrayIndex > array.Length)
  181.                 throw new ArgumentException("The destination array is not large enough");
  182.             int index = arrayIndex;
  183.             foreach (var item in Enumerate(_rootNode, null))
  184.             {
  185.                 array[index] = item;
  186.                 index++;
  187.             }
  188.         }
  189.  
  190.         bool ICollection<KeyValuePair<string, TValue>>.Remove(KeyValuePair<string, TValue> item)
  191.         {
  192.             var node = FindNode(item.Key, false);
  193.             if (node == null)
  194.                 return false;
  195.             if (node.HasValue && Equals(node.Value, item.Value))
  196.             {
  197.                 node.ClearValue();
  198.                 _count--;
  199.                 return true;
  200.             }
  201.             return false;
  202.         }
  203.  
  204.         public int Count
  205.         {
  206.             get { return _count; }
  207.         }
  208.  
  209.         public bool IsReadOnly
  210.         {
  211.             get { return false; }
  212.         }
  213.  
  214.         public bool ContainsKey(string key)
  215.         {
  216.             var node = FindNode(key, false);
  217.             if (node != null)
  218.                 return node.HasValue;
  219.             return false;
  220.         }
  221.  
  222.         public void Add(string key, TValue value)
  223.         {
  224.             var node = FindNode(key, true);
  225.             if (node.HasValue)
  226.                 throw new ArgumentException("An element with the same key already exists in the trie.");
  227.             node.Value = value;
  228.             _count++;
  229.         }
  230.  
  231.         public bool Remove(string key)
  232.         {
  233.             var node = FindNode(key, false);
  234.             if (node == null)
  235.                 return false;
  236.  
  237.             if (node.HasValue)
  238.             {
  239.                 node.ClearValue();
  240.                 Prune(node);
  241.                 _count--;
  242.             }
  243.             return false;
  244.         }
  245.  
  246.         public bool TryGetValue(string key, out TValue value)
  247.         {
  248.             value = default(TValue);
  249.             var node = FindNode(key, false);
  250.             if (node == null)
  251.                 return false;
  252.             if (node.HasValue)
  253.             {
  254.                 value = node.Value;
  255.                 return true;
  256.             }
  257.             return false;
  258.         }
  259.  
  260.         public TValue this[string key]
  261.         {
  262.             get
  263.             {
  264.                 TValue value;
  265.                 if (TryGetValue(key, out value))
  266.                     return value;
  267.                 throw new KeyNotFoundException();
  268.             }
  269.             set
  270.             {
  271.                 bool created;
  272.                 var node = FindNode(key, true, out created);
  273.                 node.Value = value;
  274.                 if (created)
  275.                     _count++;
  276.             }
  277.         }
  278.  
  279.         public ICollection<string> Keys
  280.         {
  281.             get { return Enumerate(_rootNode, null).Select(kvp => kvp.Key).ToList(); }
  282.         }
  283.  
  284.         public ICollection<TValue> Values
  285.         {
  286.             get { return Enumerate(_rootNode, null).Select(kvp => kvp.Value).ToList(); }
  287.         }
  288.  
  289.         #endregion
  290.  
  291.         #region Public trie-specific methods
  292.  
  293.         public bool ContainsPrefix(string prefix)
  294.         {
  295.             var node = FindNode(prefix, false);
  296.             return (node.HasValue || node.Children.Any());
  297.         }
  298.  
  299.         public bool RemovePrefix(string prefix)
  300.         {
  301.             var node = FindNode(prefix, false);
  302.             if (node == null)
  303.                 return false;
  304.  
  305.             int count = Enumerate(node, null).Count();
  306.             if (count == 0)
  307.                 return false;
  308.  
  309.             node.Children.Clear();
  310.             node.ClearValue();
  311.             Prune(node);
  312.             _count -= count;
  313.  
  314.             return true;
  315.         }
  316.  
  317.         public IEnumerable<KeyValuePair<string, TValue>> FindPrefix(string prefix)
  318.         {
  319.             var node = FindNode(prefix, false);
  320.             if (node == null)
  321.                 yield break;
  322.  
  323.             string prefix2 = null;
  324.             if (prefix.Length > 0)
  325.                 prefix2 = prefix.Substring(0, prefix.Length - 1);
  326.             foreach (var kvp in Enumerate(node, prefix2))
  327.             {
  328.                 yield return kvp;
  329.             }
  330.         }
  331.  
  332.         #endregion
  333.     }
  334. }
RAW Paste Data