Roger Ngo's Website

Personal thoughts about life and tech written down.

How to Get Good at Recursion

Motivation

Just do it. Do it a lot and it will eventually come. :)

Actually, first we must understand how functions are called and how they are arranged in a stack. Then you must realize that the callee always returns back to the caller if written correctly.

Okay, so there is lots to keep in mind here, but it is brute force practice. Trace every single step and you'll see.

In this post, I will try to give you some tips on how to understand recursive functions better, and some approaches I have used in the past to make things a bit less confusing.

Warm Up

A recursive function is a function that solves a solution from the top down. What is meant by this is that to obtain the solution, we must repetitively call the function to solve the smaller (simpler) problem until a base case, or case that is known stops these calls. When the base cases are hit and the calls are stopped, the result of the simpler solutions are returned back to the caller to be used to construct the full solution.

A recurrence stops whenever a base case has been defined. A base case is in other words, some sort of condition in which we trigger the recursive function to stop and start popping results off the stack — which in turn, returns the results back to the caller (the previous recursive call). As previously stated, the base case is well known and defined and outside the recurrence.

First Example: Factorial

For example, we define n! to be (n) * (n-1) * (n-2) * ... * (1).

If we wanted to solve 3!, the typical work would look like this:

3 * 2 * 1 = 6

We can generalize this and formulate a solution as follows:

f(0) = 1, f(1) = 1
f(n) = n * f(n - 1)

The base cases of the recursion are: f(0) = 1, and f(1) = 1. So if our parameter is 0, our function would give us the result of 1, also if our parameter is 1, then our function would also give us 1.

Now, suppose we want to now solve f(2). We can then use the function along with the base case to see how get our result:

f(2) = 2 * f(2 - 1)
= 2 * f(1)
= 2 * 1
= 2

We have now proven that we can not only solve f(2), but we can use this information to now solve f(3):

f(3) = 3 * f(3 - 1)
= 3 * f(2)
= 3 * [2 * f(2 - 1)]
= 3 * [2 * f(1)]
= 3 * 2 * 1
= 6

Our factorial function, when solved iteratively (using loops) can be like this:

public int factorial(int n) {
 int result = 1;
 for(int i = n; i > 0; i--) {
  result *= n;
 }
 return result;
}

Tracing the loop for n = 3, we can see it is similar to our definition:

i = 3, result = 1
result = 1 * 3
= 3

i = 2, result = 3
result = 3 * 2
= 6

i = 1, result = 6
result = 6 * 1
= 6

i = 0, break

return: 6

Now, let's rewrite our function in a recursive manner:

public int factorial(int n) {
 if(n <= 1) {
  return 1;
 }
 else {
  return n * factorial(n - 1);
 }
}

Unsurprisingly, given how many times we have solved this, when running the function as factorial(3), we will receive the result of 6.

Here is a trace of what happens:

Initial Call (1): factorial(3)
-----> return 3 * factorial(2)
----------> Sub-Call (1.1): factorial(2)
---------------> return 2 * factorial(1)
--------------------> Sub-Call (1.2): factorial(1)
-------------------------> return 1 (Base Case)
--------------------> Return Call (1.1):
---------------> return 2 * 1
----------> Return Call (1):
-----> return 3 * 2
Return: 6

How does this happen?

Before I dive further, let's define some terms:

  • A function that calls another function is the caller.
  • A function being called is the callee.

From here forward, I will use the terms function, method and procedure interchangeably. Although it is not technically correct, it will help in terms of explaining our ideas.

Function calls use the stack to keep track of data. Under the hood, when a function invokes another, a data structure is created with the information about the function being called. This frame then contains all the local variables pertaining to that function to be operated on.

Since a stack if a LIFO (last in, first out) data structure, all elements that are pushed into the stack are first to pop out. So, the most recent call to the recursive function is aways on top of the call stack.

If we were to call factorial(3), our call stack will initially look like:

factorial(3)

If we look at the code for the recursive factorial function, we can see that another call to factorial itself is made with factorial(2). Now, we push that call onto the call stack.

factorial(2)
factorial(3)

As we keep going, we make another call to factorial with factorial(1).

factorial(1)
factorial(2)
factorial(3)

When a caller calls the callee, a frame is created for the callee and is pushed onto the stack, on top of the current caller's frame.

But as we call factorial(1), we have hit the base case and which in turn runs the line:

return 1;

We no longer call a new function, but the return statement tells us to pop the frame out of the call stack with the value specified. Therefore, the function call for factorial(1) is completed and the result is returned back to factorial(2). When factorial(2) receives the result, the call now looks like: return 2 * 1;

This gives us the result returned to be 2 for this particular call. And now, factorial(3) gets this result and its return statement is essentially:

return 3 * 2;

=> 6

Finally, the call stack is empty and we get the final result of 6.

The Base Case

As we can see here, recursion is just a function call which results in the same function call with different parameters over and over again until it hits a base case. Remember, the base case is what stops the recursive sequence and causes results to propagae back to the original caller.

This is no different than the following:

public static string SayHi() {
 return SayHello();
}

public static string SayHello() {
 return SayBye();
}

public static string SayBye() {
 return "No, don't leave just yet!";
}

public static void main(String[] args) {
 System.out.println(SayHi());
}

Can you see what happens above?

We first execute SayHi() in main, this executes SayHello() which then executes SayBye(). SayBye() then returns a string and NOT a function all -- which means SayBye()'s return statement is equivalent to a "base case".

We would then get this output:

No, don't leave just yet!

But what happens if SayBye() didn’t return a string, and instead called SayHi() again?

public static string SayHi() {
 return SayHello();
}

public static string SayHello() {
 return SayBye();
}

public static string SayBye() {
 return SayHi();
}

public static void main(String[] args) {
 System.out.println(SayHi());
}

We would then fall into an infinite loop! What happens here is that SayHi() will run SayHello() which runs SayBye() which then runs SayHi() again, and then executes SayHello() which then runs SayBye(), which then invokes SayHi() again, and then... ok, ok you get the point!

What happens here is that you will get a stack over! The stack keeps getting these function calls pushed onto itself and eventually we will run out of memory because we will be invoking these functions non-stop to no end!

This is what happens when you do not define a BASE CASE for a recursive function. A function without a base case will result in a stack overflow because there is literally no way of stopping the chain of function calls once it is invoked.

Therefore, when we design a recursive function will always need to have at minimum two things:

  1. The base case.
  2. The general (recursive) case.

Another noteworthy concept is that you do not have to call the function again to return a value off your recursive function. Take for example:

public int SumFirst5(int curr_sum, int i) {
 if(i == 6) {
  // base case, do nothing.
 }
 else {
  curr_sum += i;
  SumFirst5(curr_sum, i+1);
 }
 return curr_sum;
}

Could have just been written as:

public int SumFirst5(int curr_sum, int i) {
 if(i == 6) {
  return curr_sum;
 }
 return SumFirst5(curr_sum, i+1);
}

Either is correct, but the latter brings an optimization called tail recursion.

The next thing I would like to mention is that all you can always formulate an iterative solution from a recursive solution and vice versa. Sometimes you would want to prefer an iterative solution if you have memory/performance constraints as recursive solutions will use more memory in the stack.

On the flip side, sometimes it may be worth implementing a problem recursively as an iterative implementation can be much more challenging. There could be more room for bugs.

Know the trade-offs. Depending on your application and your environment — you will need to weigh your options on either/or.

The Second Example: Fibonacci Sequence

The fibonacci sequence is a great example of seeing how expensive some recursive computations can become.

To cut to the chase, here is the recurrence relation:

f(0) = 0,
f(n) = 1, n = 1 and n = 2
f(n) = f(n-1) + f(n-2)

With the recurrence relation above, we have defined a sequence to be 1 when n ranges from [1, 2] and the first fibonacci number when n is 0 is 0. At a quick glance we can see that this is in fact, our "base case".

An example evaluation of a fibonacci sequence where n = 10 is shown below:

1, 1, 2, 3, 5, 8, 13, 21, 34, 55

We can write a recursive function to solve this problem for all n (well, assuming our result is a 64-bit integer)

public long fibonacci(long n) {
 if(n <= 2) {
  return 1;
 }
 else {
  return fibonacci(n - 1) + fibonacci(n - 2);
 }
}

To flex our tracing muscles, let’s evaluate the this function using n = 4.

Initial Call (1): fibonacci(4)
-----> return fibonacci(3) + fibonacci(2)
----------> Call (1.1): fibonacci(3)
---------------> return fibonacci(2) + fibonacci(1)
--------------------> Call (1.1.1): fibonacci(2)
-------------------------> return 1
--------------------> Call (1.1.2): fibonacci(1)
-------------------------> return 1
---------------> Call 1.1 Return: return 1 + 1
----------> Call (1.2): fibonacci(2)
---------------> return fibonacci(1) + fibonacci(0)
--------------------> Call (1.2.1): fibonacci(1)
-------------------------> return 1
--------------------> Call (1.2.2): fibonacci(0)
-------------------------> return 0
---------------> Call 1.2 Return: return 1 + 0
---------> Call 1 Return: return 2 + 1
-----> 3

Our result returned from fibonacci(4) is 3.

I highly recommend practicing a few traces and things will begin to click.

Recursive Complexity

Things will get easier over time the more you perform these traces. There is one thing you will start to notice as you vary your input of n. You will soon notice that as n increases the time complexity of our function starts to grow rather quickly.

Without going into too much Big-O complexity analysis here, I will just point out that if n falls within the interval of [0, 2], then our operations in time are O(1). That is, they are constant.

T(N) = O(1), N = [0, 2]

However for all n > 2, we will then have:

T(N) = T(N - 1) + T(N - 1)

If we take each term and expand it, for let's say T(N - 1) we will see we make more calls:

T(N - 1) = T((N - 1) - 1) + T((N - 1) - 2)

We can deduce the same will happen for T(N-2). But going back to T(N-1)...

We will continue to expand.

T((N - 1) - 1) = T((N - 1) -1) - 1) + T((N - 1) - 1) - 2)

So each T(N - 1) call will make T((N - 1)) calls which in turn, makes T((N - 1) - 1) calls. If we just simply disregard the times we can then see:

(N-1) * ((N - 1) - 1) * (((N - 1) - 1) - 1)

Multiplying this we can see that we are calling things exponentially to the N times... This puts the recursive fibonacci call somewhere within the lines of O(2^n) in complexity.

Yikes! That will take a long time for larger values of N! If you really want to see this in demonstration and have a decently modern computer, you can compute fibonacci(48) and see how long that takes you.

Regardless, we now know that recursive solutions alone are not the silver bullet. We have to consider TIME and SPACE. Thankfully, there are techniques to get around this and create much faster recursive solutions through dynamic programming in which we will discuss another day. For now, just note that there can be a performance penalty!

Third Example: Permutation of Strings

Finding out the permutations of a given string. For example abc will yield 6 different permutations.

cab
acb
abc
bac
cba
bca

The best way to think about this is to first look for our base case. Our base case should be when our permuted string reaches the length of our "starting string". In the abc example above, our permuted string should be outputted when its length reaches 3.

Now, how do we generate these permutations? A good approach is to start with a blank string and then for each character in our starting string, create a copy and add the next letter in each position in the string.

For example if we start with a blank string, we will need to add the letter a in every position of our currently "built" string. In this case, we only have 1 position to insert into, and that is the first index.

""
"a"

Now, after inserting the character into a position in the built string, we can use this built string to recursively call our permutation function with the next character as the character to insert into. Our next character is "b", so the flow would look like:

""
Insert from 0 to 0.
"a"
Insert from 0 to 1
-> "ba"
-> "ab"

Pushing it a bit further, we now recursively call the permute function with "ba" and "ab" as built strings and the index of "c" from the starting string as the character to insert at every position:

-> "ba"
---> "cba"
-----> (return "cba", base case)
---> "bca"
-----> (return "bca", base case)
---> "bac"
-----> (return "bac", base case)
-> "ab"
---> "cab"
-----> (return "cab", base case)
---> "acb"
-----> (return "acb", base case)
---> "abc"
-----> (return "abc", base case)

As we can see here, after inserting the last character, we then recursively call again for the next character in our starting string. Since we have reached the end of the starting string where our built string is now the same length as our starting string, we reach the base case and thus output our result. The function to permute the string can be expressed as this:

public void PermuteString(string starting, string built, int currStartingIndex) {
 if(built.Length == starting.Length) {
  Console.WriteLine(built);
  return;
 }
 StringBuilder sb;
 for(int i = 0; i <= built.Length; i++) {
  sb = new StringBuilder(built);
  built.Insert(i, starting[currStartingIndex]);
  PermuteString(starting, sb.ToString(), currStartingIndex + 1);
 }
}

The important thought here is that:

  1. We are inserting a character at every position in the current string we are building.
  2. We use the BUILT string (not the string that was passed into the function) to recursively generate a new string at a deeper level.
  3. The generation will stop when we have reach the base case.

I would suggest drawing the trace of this function out in order to understand how the recursion tree happens. It can be tough to wrap around at first, but once understood, permutation type problems become easier to understand!

Closing Thoughts

For most, recursion is not intuitive at first. For myself, it is very unnatural to think recursively, but with a lot of patience, practice and time, it will become much easier to interpreted recursive functions and solutions.

Keep in mind that recursion can bring the advantage of having an elegantly written algorithm as compared to an iterative approach. However, with this elegance comes the expense of extra memory because each function call in a recursive flow will use some stack space. Over time, if one is not careful about terminating their recursive call with a "base case", the memory usage can become large enough to cause a stack overflow.

It is up to you to determine on a case by case basis on which approach will work best as a solution to your problem.