Understanding Tail recursion in Scala

Tail recursion is little tricky concept in Scala and takes time to master it completely. Before we get into Tail recursion, lets try to look into recursion.
A Recursive function is the function which calls itself. If some action is repetitive, we can call the same piece of code again. Recursion could be applied to problems where you use regular loops to solve it.
Factorial program with regular loops –

[code lang=”scala”]
def factorial(n: Int): Int = {
var fact = 1
for(i <- 1 to n)
fact = fact * i;
return fact

The same can be re-written with recursion like below –

[code lang=”scala”]
def factorialWithRecursion(n: Int): Int =
if (n == 0)
return 1
return n * factorialWithRecursion(n-1)

In the recursive approach, we return expression n * factorialWithRecursion(n-1) until all elements would be calculated. This approach looks clean and concise but it has a risk when your input is large. When you made recursive call to the function one additional frame in the stack of JVM. So when your input is big, it reaches its limit of stack and results in the StackOverflow error.

In order to avoid this StackOverflow exception, the concept of Tail Recursion comes into picture.
Tail recursion is a function where the last action is invocation of a recursive function. In other words, A recursive function is said to be tail recursive if the recursive call is the last thing done by the function. There is no need to keep record of the previous state. A tail-recursive function will not build a new stack frame for each call; all calls will execute in a single frame.

In order to implement tail recursion you need two things.
1>The helper function which will be the last call in the function as per definition.
2>The @tailrec Annotation which indicates that the function is a tail recursion. It’s not mandatory, but it validates that the function is developed as a tail recursion.

Below is the function with tail recursion approach-

[code lang=”scala”]
def factorial(n: Int): Int =
// Using tail recursion
@tailrec def factorialAcc(acc: Int, n: Int): Int =
if (n <= 1)
factorialAcc(n * acc, n – 1)
factorialAcc(1, n)
} [/code]

Share This Post

An Ambivert, music lover, enthusiast, artist, designer, coder, gamer, content writer. He is Professional Software Developer with hands-on experience in Spark, Kafka, Scala, Python, Hadoop, Hive, Sqoop, Pig, php, html,css. Know more about him at

Lost Password


24 Tutorials