高阶函数
高阶函数(Higher-Order Function,HOF)是函数式编程中的一个重要概念。一个函数可以被称为高阶函数,表示它或者接受一个或多个函数作为参数;或者返回一个函数。如果高阶函数接收其它函数作为参数,则它的行为在运行时,可以基于传递给它的函数进行动态调整。
Python 中有一些内置的高阶函数,包括函数式编程中经常使用到的 map(), filter(), 和 reduce() 等。其中 map(), filter() 虽然在函数式编程中使用非常频繁,但是在 Python 中,列表(或字典)推导式以及生成器表达式,在功能上基本可以取代 map() 和 filter()。在 Python 程序中,Pythora 星球居民更倾向于使用列表推导式和生成器表达式。但为了保证代码的可读性,列表推导式和生成器表达式一般适用于简单逻辑,复杂功能的实现还是更适合使用这些高阶函数。
map
基本用法
在介绍生成器表达式的时候,我们使用了一个简单的示例:假设我们有一个输入迭代器,其中包含了一组数据,我们希望生成一个新的迭代器,可以产生一组新数据,这组新数据中的每个数,对应原数列中每个数的平方值。使用生成器表达式可以写成如下程序。
numbers = range(10)
squared = (x*x for x in numbers)
for num in squared:
print(num)
使用函数式编程,对于这个问题,还有另一种解决方案:使用 map() 函数。map() 函数接受一个输入函数和一个可迭代对象作为输入,然后返回一个新的迭代对象。返回的迭代对象中的每个元素,是将输入迭代对象中每个元素传递给输入函数运行后的结果。使用 map() 函数改写上面的代码如下:
numbers = range(10)
squares = map(lambda x: x*x, numbers)
for num in squared:
print(num)
在上面的示例中,传递给 map 的参数,numbers 是需要被处理的可迭代对象,匿名函数 lambda x: x*x
则表示需要对可迭代对象中的每个元素做平方。
使用多个可迭代对象
当给 map() 函数提供多个可迭代对象时,它会并行处理这些对象。这意味着它会取每个可迭代对象的第一个元素,然后应用函数;接着取每个可迭代对象的第二个元素,再次应用函数,依此类推。比如:
a = [1, 2, 3]
b = [10, 20, 30]
summed = map(lambda x, y: x + y, a, b)
print(list(summed)) # [11, 22, 33]
在上面这个示例中,lambda 函数接收两个参数,并将它们加在一起。如果传递给 map() 的可迭代对象长度不同,map() 将在最短的可迭代对象结束时停止。
使用列表推导式处理多个可迭代对象时,要借助 zip() 函数,把多个可迭代对象转换成单个可迭代对象再处理。而 map() 函数可以直接处理它们。
实现 map() 函数
我们可以进一步探究一下,我们自己如何能实现一个类似 map() 的函数。我们有如下考虑
- map() 可以接收多个可迭代对象,说明这个函数具有可变数量的参数
- map() 返回一个迭代器,我们可以使用生成器函数来实现
- map() 本身的功能还是比较简单的,把输入参数按顺序传递给输出参数就行了
实现代码如下:
# 自定义的 my_map 函数,旨在模拟内置的 map 函数的功能。
# 它接受一个函数和一个或多个可迭代对象作为参数。
def my_map(func, *iterables):
# 将所有传入的可迭代对象转换为迭代器。
# 这使得我们可以使用 next 函数从它们中提取值。
iterators = [iter(it) for it in iterables]
# 无限循环,直到其中一个迭代器耗尽为止。
while True:
# 使用一个临时列表来存储从各个迭代器中获取的元素。
result = []
# 遍历所有的迭代器。
for it in iterators:
try:
# 从当前迭代器获取下一个元素,将获取到的元素添加到结果列表中。
item = next(it)
result.append(item)
except StopIteration:
# 如果某个迭代器中没有更多的元素可供提取,则退出循环并结束生成。
return
# 使用传入的函数对从迭代器中获取的元素进行操作,
# 然后使用yield返回结果。
yield func(*tuple(result))
# 测试自定义的 my_map 函数。
lst1 = [1, 2, 3]
lst2 = [4, 5, 6]
result = my_map(lambda x, y: x + y, lst1, lst2)
print(list(result)) # 输出: [5, 7, 9]
result = my_map(lambda x: x*x, lst1)
print(list(result)) # 输出: [1, 4, 9]
上面的程序复杂在要处理可变多个可迭代对象,但如果配合使用 zip() 函数,上面的代码可以被一句简单的生成器表达式替代:
def my_map(func, *iterables):
return (func(*items) for items in zip(*iterables))
# 测试
lst1 = [1, 2, 3]
lst2 = [4, 5, 6]
result = my_map(lambda x, y: x + y, lst1, lst2)
print(list(result)) # 输出: [5, 7, 9]
filter
基本用法
filter() 用于从一个可迭代对象中过滤出满足某个条件的元素。基本用法如下:
filter(function, iterable)
它接受一个函数和一个可迭代对象。它返回一个新的迭代对象,其中只包含使输入函数返回 True 的原始元素。
我们在介绍生成器表达式时使用的另一个示例正好可以演示 filter() 函数的用法:假设我们需要从一个列表中选出长度大于 5 的单词,如果使用生成器表达式,代码如下:
result = (word for word in words if len(word) > 5)
这个示例也可以使用 filter() 函数来实现:
words = ["apple", "banana", "cherry", "date", "fig", "kiwi"]
long_words = filter(lambda x: len(x) > 5, words)
print(list(long_words)) # 输出: [banana', 'cherry']
filter() 函数只能接收一个可迭代对象,因此,它的实现比 map() 简单的多,只需要使 用上面的生成器表达式就可以了:
def my_filter(func, iterable):
return (item for item in iterable if func(item))
# 测试
lst = [1, 2, 3, 4, 5, 6, 7, 8, 9]
evens = my_filter(lambda x: x % 2 == 0, lst)
print(list(evens)) # 输出: [2, 4, 6, 8]