Advanced Flow Control
Video
Recursion
One very common programming technique, and one you will see in the wilds, is recursion. This is a technique for calling a function from itself. Perhaps the most common example is the Fibonacci Numbers:
def fibonacci(n):
if n==0:
return 0
if n==1:
return 1
return fibonacci(n-1)+fibonacci(n-2)
Already this should look familiar to you - recursion, philosophically, is just programmer-speak for induction. The main concerns are all there:
- A Base Case - all recursive functions should eventually terminate in some non-recursive state.
- An Induction Step - where we call the function on some data that is in some sense "closer" to the base case.
Recursion can provide some excellent benefits, which should be easy to imagine for a mathematician: it allows you to do write just one step in a longer stream, and can allow for you to store the values of central computations to which less common values converge.
However, there are some problems when translating this technique to the real world.
Repeated Computations
Note how our fibonacci program executes fibonacci
twice at each $n>1$. This means, naiively, it will run each step triangularly many times: for the 5th fibonacci number, the call counts are:
n | calls |
---|---|
5 | 1 |
4 | 1 |
3 | 2 |
2 | 3 |
1 | 5 |
0 | 3 |
This is manageable, but what about the 11th fibonacci number? At that point we are calling fibonacci(1)
$89$ times.
In the programming world, we have to worry about our efficiency. And for recursively-defined functions, that often means a technique called memoization
Memoization For Recursion
The provided example - that of the Fibonacci numbers - is remarkably inefficient; we can dramatically increase its efficiency by caching results.
This kind of caching is called memoization, and there are countless ways to do it.
Fundamentally, the idea is:
# An external, persistent cache.
fibonacci_cache = {}
def fibonacci(n):
if n <= 1: return n
# Checking the cache for stored values.
if n in fibonacci_cache:
return fibonacci_cache[n]
result = fibonacci(n-1)+fibonacci(n-2)
fibonacci_cache[n] = result
return result
This is a dramatically faster implementation, trading some storage space and lookup time in exchange for an implementation which calls each function only twice - as expected.
With some rational thought (thats what we're trained for!), we can realize that we don't actually need to store all the fibonacci numbers - just the last two.
In Python, there is a standard library tool for this: the decorator lru_cache
.
Decorators are functions on the space of functions (everything is an object!) which change their behavior.
In this case, lru_cache
stores the last n distinct inputs (a configurable parameter between 0 - the degenerate case - and infinity, represented by a limit of None
). This is a tradeoff between disc space and speed; often, however, you will have a "main sequence" which is most important, and a reasonably sized lru_cache
will keep most of what is needed without keeping much of the chaff.
There are other caches, and you can read them on your own, but lru_cache
is the only one in core Python as of this writing.
This can lead to a very easy modification:
from functools import lru_cache
@lru_cache(maxsize=2)
def fibonacci(n):
if n <= 1: return n
return fibonacci(n-2)+fibonacci(n-1)
And now we have a very efficient recursive fibonacci implementation.
Unfortunately, for high intensity computing, this is still wrong.
Stack Limits
In a lower level language like C
, recursive implementations can be very efficient even for large numbers. This is often due to compile-time shenanigans converting them to loops.
Python functions are objects, however, and thus contain a lot of overhead. Each time we call a function it adds another layer to the Call Stack. To prevent Stack Overflow, the Call Stack is limited by a Maximum Recursion Depth.
Like a lot of system information, it can by accessed through the sys
module:
import sys
sys.getrecursionlimit()
If your function is well within that bound, and efficiently designed, then recursion is a fine answer. However, many mathematicians use recursively defined objects with orders of magnitude more steps.
We can see that our relatively efficient fibonacci algorithm can handle values approaching that limit with ease, but it isn't as fast as it could be and has a seemingly arbitrary cutoff.
This is where we need to turn our thinking around. Instead of constructing fibonacci
from the top down, we can build it from the ground up using recursion:
def fibonacci(n):
# A quick check for base cases is often nice.
if n<=1: return n
# Python allows array assignment of variables.
current,previous = 1,0
for _ in range(n-1): # _ is the convention for a unused iterator.
# Array assignment allows you to use values to redefine
# *themselves* consistently.
current,previous = current+previous, current
return current
Now we can construct Fibonacci numbers into the hundreds of thousands with ease, and even further.
For efficient computation of higher fibonacci numbers - getting in to 10 digits and beyond - you should consider going back to the theory. There are often surprising and sneaky ways of getting the results of recursion without going through all the intermediate junk.
Creating generators with
When constructing a sequence, you often want to use each value in turn - while always having the next value available to you.yield
This is one area where Python shines, but it can be a bit daunting. Python functions can return multiple answers, one at a time, with yield
. The function keeps track of its own internal state, and continues execution when asked for a new function, until it hits another yield
or runs out of things to do.
Play around with the following generator code to get a feel for how generators work:
def fibonacci_generator():
yield 0
current,previous = 1,0
while True:
yield current
current,previous = current+previous, current
fibgen = fibonacci_generator
next(fibgen)
next(fibgen)
[i for _,i in zip(range(5),fibgen)]
Errors - What To Do When Things Go Wrong
We have already talked about how you can throw errors to warn people of bad inputs, but errors are far more versatile. In your mathematics, you may encounter degenerate cases.
For example, consider this integer discrete logarithm implementation:
def discrete_logarithm(base,value,modulus):
""" Computes a Discrete Logarithm $\ell$ of a given value $v$ with respect to a given base $b$ and modulus $m$ - such that $b^\ell \mod m = v$.
Parameters
-------
base: int
value: int
modulus: int
Examples
-------
>>> discrete_logarithm(2,4,11)
2
>>> discrete_logarithm(2,5,11)
5
"""
base = base%modulus
value = value%modulus
logarithm = 0
candidate = 1
while True:
if candiate == value:
return logarithm
logarithm += 1
candidate *= base
candidate %= modulus
This works well for situations where it takes a value - but runs forever in other cases. That is bad.
However, it is non-polynomial to decide whether a number is a power of $b$ for general $m$ before doing this computation. We need some way to detect when things have gone wrong.
A good idea is to never iterate forever. Come up with a reasonable bound for the number of iterations - using your mathematical brain - and stop if you exceed it. For example, in this case, the Fermat-Euler theorem states that we will cycle in no more than $\phi(n) \leq n-1$ values. Therefore we can modify the function:
def discrete_logarithm(base,value,modulus):
base = base%modulus
value = value%modulus
logarithm = 0
candidate = 1
for _ in range(modulus-1):
if candiate == value:
return logarithm
logarithm += 1
candidate *= base
candidate %= modulus
Now it doesn't run forever, but it also doesn't return a value or warn us if things go wrong. Let's add some information:
def discrete_logarithm(base,value,modulus):
base = base%modulus
value = value%modulus
candidate = 1
for logarithm in range(modulus-1):
if candidate == value:
return logarithm
candidate *= base
candidate %= modulus
raise ValueError("{value} is not a power of {base} mod {modulus}".format(
base=base,value=value, modulus=modulus))
Now we can add in a test case for this:
"""
>>> discrete_logarithm(2,15,21)
Traceback (most recent call last):
...
ValueError: 15 is not a power of 2 mod 21
"""
For most composite $m$, we are going to want to stop far before $m-1$ iterations, so we can use a list to check if we have looped:
def discrete_logarithm(base,value,modulus):
base = base%modulus
value = value%modulus
candidate = 1
previous_values = []
for logarithm in range(modulus-1):
if candidate == value:
return logarithm
if candidate in previous_values:
break
previous_values.append(candidate)
candidate *= base
candidate %= modulus
raise ValueError("{value} is not a power of {base} mod {modulus}".format(
base=base,value=value, modulus=modulus))
This is a slight tradeoff - and managing those tradeoffs is an informed skill in computing - but can lead to much faster rejection of logarithms that do not work.
Modern Discrete Logarithms
For the general discrete logarithm problem, without factoring (or even, in strange cases, without knowing) the modulus, the current standard is Baby Step Giant Step. Read up on this - perhaps try to implement it - to see how it embraces the concept of the "space-time tradeoff".
With factoring the modulus, there are a wide variety of spectacular algorithms. See chapter 5 of Crandall, R., & Pomerance, C. B. (2006). Prime numbers: a computational perspective (Vol. 182). Springer Science & Business Media.
Catching Errors
Say for example one step of an algorithm is to find the first base $b$, in a list, to which a given number is a power mod $m$.
We can do that with try/except
blocks:
candidate_bases = [2,3,5]
target = 15
modulus = 21
for base in candidate_bases:
try:
dlog = discrete_logarithm(base,target,modulus)
except:
continue
This will ignore all errors - we can catch just the ValueError we threw:
candidate_bases = [2,3,5]
target = 15
modulus = 21
for base in candidate_bases:
try:
dlog = discrete_logarithm(base,target,modulus)
except ValueError:
continue
Custom Errors
But there are lots of sources of value errors. We can be more specific with a custom error.
Errors are classes - so we can subclass ValueError
to get a new error:
class LogarithmBasisError(ValueError):
pass
and now we can throw - and catch - our new, more specific error.
Worksheet
Move on to today's worksheet for more practice with recursion, and with not using it.