Wednesday, October 8, 2008

Deep enumeration

Over the course of my programming career I've found myself working with tree data structures quite a bit. There are many examples of APIs that use trees, like tree controls, directories of a file system, etc. It is often desirable to iterate over the contents of the nodes in the tree and do a pattern match. Unfortunately there isn't a straightforward way to do this in c# without writing some custom code. I don't like writing a custom function each time I need to do something like this, and the great thing about generics is that allows sufficient abstraction of types and functions to allow generalizing a pattern like this.

Let us define a basic abstraction of a tree element
   1: IElement
   2: {
   3:     IEnumerable<IElement> Children { get; }
   4: }

In a nutshell that's all you need. A tree node is a structure where each element has children that can be iterated over, and the hierarchy of all the elements defines the tree structure. This is of course a generalization, as we could put additional constraints such as each element being only available once in the tree, but I'll leave it up to the developer to enforce such constraints. I'm merely trying to iterate over the tree nodes.

With that in mind, we can define a function that maps

   1: IElement -> IEnumerable<IElement>
   2:  
   3: Func<IElement, IEnumerable<IElement>> childrenMappingFunction;

There are two common ways to do recursion: function recursion where a function points to itself, or stack-based recursion where we use a stack to keep a list of evaluation arguments on a stack and evaluate them as need be. We'll use stack-based iteration because it tends to be more scalable.

Without further ado let’s get to the meat of the code:

   1: public static IEnumerable<T> Enumerate<T>(
   2:     T rootItem, 
   3:     Func<T, IEnumerable<T>> childrenMappingFunction, 
   4:     bool returnRoot)
   5: {
   6:     Stack<T> workStack = new Stack<T>();
   7:     workStack.Push(rootItem);
   8:  
   9:     bool isRoot = true;
  10:  
  11:     while (workStack.Count > 0)
  12:     {
  13:         T item = workStack.Pop();
  14:  
  15:         if (isRoot)
  16:         {
  17:             if (returnRoot)
  18:             {
  19:                 yield return item;
  20:             }
  21:  
  22:             isRoot = false;
  23:         }
  24:         else
  25:         {
  26:             yield return item;
  27:         }
  28:  
  29:         IEnumerable<T> children = childrenMappingFunction(item);
  30:  
  31:         foreach (T child in children)
  32:         {
  33:             workStack.Push(child);
  34:         }
  35:     }
  36: }

Let’s look at what the code does. It first creates a stack to hold the objects that need to be evaluated. It then pushes the root item onto the stack. The loop is quite simple. While the stack has items in it, first pop an item off the stack, yield it using the c# iterator pattern, then call childrenMappingFunction to get the children (IEnumerable<T>) from T item. Then, for each child in children, add the child to the stack, so that it is evaluated on the next iteration of the while loop. That’s pretty much it. Because of the elegance of c# iterators, it’s easy to use this function to map a tree to a flatter IEnumerable<T> without ever having to concern ourselves with the details again.

I’ve added some additional overloads and parameters to allow customizing how we get the data back. The only one that might need a bit of explanation is leftToRight. What I discovered when first using this class is that because of the way items are added to the stack, the elements returned from the function are not necessarily in the same order that they were in each nodes IEnumerable. This is not always a problem, but certainly can be in certain use cases so I added a modifier, leftToRight, that ensures that the data is added to the stack in reverse order, so that when it is popped off the stack, it’s in the order that the elements were returned from IEnumerable. It’s not the most efficient way to do this and I recognize that, but I don’t use that option most of the time and in the cases where I did use it performance was not a major concern. If anyone wants to provide a cleaner implementation I welcome feedback. Here is the class in it’s entirety:

   1: public static class DeepEnumerator
   2: {
   3:     public static IEnumerable<T> Enumerate<T>(
   4:         IEnumerable<T> rootItems, 
   5:         Func<T, IEnumerable<T>> childrenMappingFunction)
   6:     {
   7:         return Enumerate(rootItems, childrenMappingFunction, true, true);
   8:     }
   9:  
  10:     public static IEnumerable<T> Enumerate<T>(
  11:         IEnumerable<T> rootItems, 
  12:         Func<T, IEnumerable<T>> childrenMappingFunction, 
  13:         bool returnRootItems, 
  14:         bool enumerateDeep)
  15:     {
  16:         foreach (T rootItem in rootItems)
  17:         {
  18:             foreach (T item in Enumerate(rootItem, childrenMappingFunction, returnRootItems, enumerateDeep, true))
  19:             {
  20:                 yield return item;
  21:             }
  22:         }
  23:     }
  24:  
  25:     public static IEnumerable<T> Enumerate<T>(
  26:         T rootItem, 
  27:         Func<T, IEnumerable<T>> childrenMappingFunction)
  28:     {
  29:         return Enumerate(rootItem, childrenMappingFunction, true, true, true);
  30:     }
  31:  
  32:     public static IEnumerable<T> Enumerate<T>(
  33:         T rootItem, 
  34:         Func<T, IEnumerable<T>> childrenMappingFunction, 
  35:         bool returnRoot, 
  36:         bool enumerateDeep, 
  37:         bool leftToRight)
  38:     {
  39:         if (enumerateDeep)
  40:         {
  41:             Stack<T> workStack = new Stack<T>();
  42:             workStack.Push(rootItem);
  43:  
  44:             bool isRoot = true;
  45:  
  46:             while (workStack.Count > 0)
  47:             {
  48:                 T item = workStack.Pop();
  49:  
  50:                 if (isRoot)
  51:                 {
  52:                     if (returnRoot)
  53:                     {
  54:                         yield return item;
  55:                     }
  56:  
  57:                     isRoot = false;
  58:                 }
  59:                 else
  60:                 {
  61:                     yield return item;
  62:                 }
  63:  
  64:                 IEnumerable<T> children = childrenMappingFunction(item);
  65:  
  66:                 if (leftToRight)
  67:                 {
  68:                     children = EnumerateListBackwards(children.ToList());
  69:                 }
  70:  
  71:                 foreach (T child in children)
  72:                 {
  73:                     workStack.Push(child);
  74:                 }
  75:             }
  76:         }
  77:         else
  78:         {
  79:             yield return rootItem;
  80:             foreach (T item in childrenMappingFunction(rootItem))
  81:             {
  82:                 yield return item;
  83:             }
  84:         }
  85:     }
  86:  
  87:     private static IEnumerable<T> EnumerateListBackwards<T>(List<T> list)
  88:     {
  89:         for (int index = list.Count - 1; index >= 0; index--)
  90:         {
  91:             yield return list[index];
  92:         }
  93:     }
  94: }

Let’s look at a use case:

Given a .Net Winforms TreeView, find all nodes in the entire tree that start with E.

Because Node.Nodes does not implement IEnumerable<Node> we need a helper function, so we’ll use that with the deep enumerator.

   1: private static Enumerable<TreeNode> IterateNodes(TreeNodeCollection treeNodes)
   2: {
   3:     foreach (TreeNode childNode in treeNodes)
   4:     {
   5:         yield return childNode;
   6:     }
   7: }
   8:  
   9: TreeView treeView = new TreeView();
  10: foreach (TreeNode treeNode in 
  11:    DeepEnumerator.Enumerate(
  12:        IterateNodes(treeView.Nodes), 
  13:        node => IterateNodes(node.Nodes)))
  14: {
  15:    if (treeNode.Text.StartsWith("E"))
  16:    {
  17:         Do Something.
  18:    }
  19: }

For a second example let us define a slightly cleaner interface:

   1: private interface ITextNode
   2: {
   3:     string Text { get; }
   4:     IEnumerable<ITextNode> Children { get; }
   5: }

What if we wanted to find all nodes in List<ITextNode> nodes where one or more of the children (root included) start with the letter E.

Using LINQ this is pretty straightforward:

   1: List<ITextNode> textNodes = new List<ITextNode>();
   2: var nodesWithOneOrMoreChildrenStartingWithLetterE =
   3:     from textNode in textNodes
   4:     let children = DeepEnumerator.Enumerate(textNode, node => node.Children)
   5:     let countWithLetterE = children.Where(node => node.Text.StartsWith("E")).Count()
   6:     where countWithLetterE > 0
   7:     select textNode;

These are just a few examples. I hope you find this class useful. Deep tree enumeration is something I find myself doing fairly often, and this class has eliminated a lot of redundant code for me.

1 comment:

Anonymous said...

Nice example of the power of generics.