使用yield将任意Python递归函数改为非递归执行
递归调用在程序设计中相当常见,然而当使用Python递归调用处理较大规模的问题时,常常会遇到超出递归限制的问题。举个例子:
def recursive_add(x):
return 0 if x == 0 else x + recursive_add(x - 1)
print(recursive_add(1000))
这段代码用递归的方式实现数列的求和(仅用于举例,实际中不推荐用这种方法)。运行这段代码会触发一个异常:
Traceback (most recent call last):
File "D:\Work\Source_Codes\python\main.py", line 19, in <module>
print(recursive_add(1000))
^^^^^^^^^^^^^^^^^^^
File "D:\Work\Source_Codes\python\main.py", line 16, in recursive_add
return 0 if x == 0 else x + recursive_add(x - 1)
^^^^^^^^^^^^^^^^^^^^
File "D:\Work\Source_Codes\python\main.py", line 16, in recursive_add
return 0 if x == 0 else x + recursive_add(x - 1)
^^^^^^^^^^^^^^^^^^^^
File "D:\Work\Source_Codes\python\main.py", line 16, in recursive_add
return 0 if x == 0 else x + recursive_add(x - 1)
^^^^^^^^^^^^^^^^^^^^
[Previous line repeated 996 more times]
RecursionError: maximum recursion depth exceeded
这是因为Python限制了递归调用的层数。这个层数限制数值可以用sys.getrecursionlimit()查询,在很多平台上它的值是1000。
我们可以用sys.setrecursionlimit(limit)的方法来修改这个限制,比如sys.setrecursionlimit(10000)可以把最大递归层数改为10000,此时再运行上面的代码就能正常输出结果了。
但是修改递归调用层数限制并不总是一个可行的办法。如果你在写一个可被重用的库,在库内部代码中修改使用者的环境恐怕不见得是个好主意。
如果想在不修改系统限制的前提下正常运行多层递归的代码,通常需要把代码逻辑修改为非递归的处理方式。但是对于很多需求问题来说,递归的写法会更加简洁易懂,把递归的写法修改为非递归常常需要不小的工作量。
另一种选择是利用尾递归优化(Tail Recursion Elimination, TRE)。不幸的是,CPython解释器默认不进行尾递归优化[1];但另一方面,也有一些Python库使用一些技巧来达成了尾递归优化[2]。但无论如何,对于递归调用并不发生在return语句中的函数,尾递归优化解决不了问题。
本文介绍一种新的解决途径:只对递归代码进行极小改动(添加yield),就能把任何递归调用改为非递归执行。
yield介绍
先通过一个例子来看看yield的作用:
def func():
y = yield 1
return y
f = func()
try:
x = next(f)
print(x)
f.send(x + 1)
except StopIteration as e:
print(e.value)
当一个函数内部包含yield表达式时,这个函数的返回值自动变成一个生成器(generator)。注意f = func()这一行,这里看起来像是执行func函数中的代码,但实际上不会,它只是返回了一个包含了func函数参数及内容的生成器,但不执行func内部的任何一行代码。
这一点非常重要,我们将会利用这个特点来消除递归调用。
在之后的try语句块中执行了next(f),这个时候才会真正开始执行func中的代码。next会使func从入口开始运行,直到遇到yield表达式,此时yield后面的值会被作为next(f)的返回值赋给x,而func函数的执行暂停在了yield的位置。
f.send(x + 1)会把x + 1的值传入给func函数上次暂停的地方,并让func从上次暂停的位置继续往下执行。在func内部,x + 1作为yield表达式的值赋给了y,最后调用return y。
func中的return y并不像一般函数那样返回,而是抛出一个StopIteration异常,这个异常对象中的value属性包含了return后面的值。所以执行return之后会跳到except StopIteration as e:这里,最终打印出return的值。
对递归函数的改造
利用yield的这些功能,我们把前面的递归函数改造为如下:
def recursive_add(x):
return 0 if x == 0 else x + (yield recursive_add(x - 1))
def run_to_finish(gen):
call_stack = [gen]
return_value = None # 最近一层函数的返回值
while call_stack:
try:
if return_value is not None:
inner_call = call_stack[-1].send(return_value) # 将内层函数返回值传给外层
else:
inner_call = next(call_stack[-1]) # 函数刚开始执行或者不需要返回值
call_stack.append(inner_call)
except StopIteration as e: # 内层函数退出
del call_stack[-1]
return_value = e.value # 获取返回值
return return_value
print(run_to_finish(recursive_add(10000)))
这段代码中的recursive_add和第一次出现时唯一的区别是在递归调用的地方加上了一个yield关键字和一对括号(注意括号不能省)。这也是本文对其他任何一个递归函数的通用改造方法:把递归调用的地方改成yield,其他所有代码都不用动。
run_to_finish是用来执行改造后的递归函数的执行器,这个执行器是通用的,run_to_finish的这段代码可用于调用任何一个按本方法改造后的递归函数。你可以把run_to_finish的定义放到某个库中,要用的时候import它。
我们分析一下修改后代码的执行过程:
run_to_finish(recursive_add(10000))把一个生成器传给了run_to_finish执行器。回忆一下前面讲过当recursive_add内部有yield表达式时,调用recursive_add只会返回一个生成器,并不会执行函数内部的代码。
run_to_finish开始运行,它把这个生成器放到call_stack列表中。call_stack是记录函数“逻辑调用链”的数据结构,每层函数调用都会被记录在里面。它代替了一般递归方式中的调用栈。
接下来,只要call_stack不为空,就总是取出call_stack中最后一个生成器并让其继续执行。call_stack中的生成器是按照函数逻辑上的调用顺序存放的,最后一个生成器即为当前“最内层”的执行函数。
由于recursive_add已经被改造成了生成器,所以每次执行到yield recursive_add(x - 1)时,并不会再次进入recursive_add的函数体,而是返回一个新的生成器,这个生成器会被赋给inner_call并保存到call_stack中。
return_value记录了前一个退出的函数的“返回值”(对于生成器来说,实际上是StopIteration异常对象中包含的value),如果这个返回值存在,用send方法将其传给上一层的生成器。
如果捕获到了StopIteration异常,则说明当前最内层的函数“退出”了,把call_stack中对应的生成器删除,并记录“返回值”传给它的“调用”函数。
当recursive_add的入参为0时,recursive_add内部代码并不会执行到yield表达式,但此时recursive_add仍然返回生成器,只不过这个生成器在第一次被调用next的时候就直接抛出异常(返回0)。
当call_stack中的所有生成器都完成执行时,得到的return_value即为递归算法的最终结果。
另一个例子:DFS算法遍历二叉树
下面我们看一个更有实际意义的例子:遍历二叉树的所有节点。如果不用递归的写法,遍历树和图会需要更复杂的实现。在这个例子中,代码逻辑是递归的,但实际执行时却并不产生递归的调用。这样我们同时获得了两方面的优点:简洁、可读性好,同时也不需要大的调用栈。
class Node:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def dfs(root):
if root is None:
return
print(root.value) # 打印本节点
yield dfs(root.left) # 遍历左子树
yield dfs(root.right) # 遍历右子树
bin_tree = Node(1, Node(2, Node(3), Node(4)), Node(5))
run_to_finish(dfs(bin_tree))
尾注
[1] https://neopythonic.blogspot.com/2009/04/tail-recursion-elimination.html
[2] https://github.com/baruchel/tco
- 点赞
- 收藏
- 关注作者
评论(0)