← 返回主页

第11课: 装饰器与生成器

装饰器 (Decorator)

装饰器是一种设计模式,用于在不修改原函数代码的情况下增加额外功能。

函数装饰器基础

def my_decorator(func):
    def wrapper():
        print("函数执行前")
        func()
        print("函数执行后")
    return wrapper

@my_decorator
def say_hello():
    print("Hello!")

say_hello()
# 输出:
# 函数执行前
# Hello!
# 函数执行后

带参数的装饰器

def my_decorator(func):
    def wrapper(*args, **kwargs):
        print(f"调用函数: {func.__name__}")
        result = func(*args, **kwargs)
        print(f"函数返回: {result}")
        return result
    return wrapper

@my_decorator
def add(a, b):
    return a + b

result = add(3, 5)  # 8

保留函数元信息

from functools import wraps

def my_decorator(func):
    @wraps(func)  # 保留原函数的名称和文档字符串
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@my_decorator
def greet(name):
    """问候函数"""
    return f"Hello, {name}"

print(greet.__name__)  # greet
print(greet.__doc__)   # 问候函数

常用装饰器示例

计时装饰器

import time
from functools import wraps

def timer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__}执行时间: {end - start:.4f}秒")
        return result
    return wrapper

@timer
def slow_function():
    time.sleep(1)
    return "完成"

slow_function()

日志装饰器

def log(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        print(f"调用 {func.__name__},参数: {args}, {kwargs}")
        result = func(*args, **kwargs)
        print(f"{func.__name__} 返回: {result}")
        return result
    return wrapper

@log
def multiply(a, b):
    return a * b

multiply(3, 4)

缓存装饰器

from functools import lru_cache

@lru_cache(maxsize=128)
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

print(fibonacci(100))  # 快速计算
print(fibonacci.cache_info())  # 查看缓存信息

带参数的装饰器

def repeat(times):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for _ in range(times):
                result = func(*args, **kwargs)
            return result
        return wrapper
    return decorator

@repeat(3)
def greet(name):
    print(f"Hello, {name}!")

greet("张三")
# 输出3次: Hello, 张三!

类装饰器

class Counter:
    def __init__(self, func):
        self.func = func
        self.count = 0

    def __call__(self, *args, **kwargs):
        self.count += 1
        print(f"调用次数: {self.count}")
        return self.func(*args, **kwargs)

@Counter
def say_hello():
    print("Hello!")

say_hello()  # 调用次数: 1
say_hello()  # 调用次数: 2

生成器 (Generator)

生成器是一种特殊的迭代器,使用yield关键字返回值,可以暂停和恢复执行。

基本生成器

def count_up_to(n):
    count = 1
    while count <= n:
        yield count
        count += 1

# 使用生成器
for num in count_up_to(5):
    print(num)  # 1, 2, 3, 4, 5

生成器表达式

# 列表推导式(立即创建列表)
squares_list = [x**2 for x in range(10)]

# 生成器表达式(惰性求值)
squares_gen = (x**2 for x in range(10))

# 遍历生成器
for square in squares_gen:
    print(square)

生成器的优势

# 内存效率:处理大数据
def read_large_file(file_path):
    with open(file_path, 'r') as file:
        for line in file:
            yield line.strip()

# 只在需要时读取一行,不会一次性加载整个文件
for line in read_large_file('large_file.txt'):
    process(line)

无限生成器

def infinite_sequence():
    num = 0
    while True:
        yield num
        num += 1

# 使用
gen = infinite_sequence()
print(next(gen))  # 0
print(next(gen))  # 1
print(next(gen))  # 2

生成器方法

def echo():
    while True:
        value = yield
        if value is not None:
            print(f"收到: {value}")

gen = echo()
next(gen)  # 启动生成器
gen.send("Hello")  # 发送值
gen.send("World")
gen.close()  # 关闭生成器

yield from

def generator1():
    yield 1
    yield 2

def generator2():
    yield 3
    yield 4

def combined():
    yield from generator1()
    yield from generator2()

for num in combined():
    print(num)  # 1, 2, 3, 4

实用示例

斐波那契数列生成器

def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

# 获取前10个斐波那契数
fib = fibonacci()
for _ in range(10):
    print(next(fib))

数据管道

def read_data():
    for i in range(10):
        yield i

def filter_even(numbers):
    for num in numbers:
        if num % 2 == 0:
            yield num

def square(numbers):
    for num in numbers:
        yield num ** 2

# 组合生成器
pipeline = square(filter_even(read_data()))
print(list(pipeline))  # [0, 4, 16, 36, 64]

批处理生成器

def batch(iterable, size):
    batch_list = []
    for item in iterable:
        batch_list.append(item)
        if len(batch_list) == size:
            yield batch_list
            batch_list = []
    if batch_list:
        yield batch_list

# 使用
data = range(10)
for batch_data in batch(data, 3):
    print(batch_data)
# [0, 1, 2]
# [3, 4, 5]
# [6, 7, 8]
# [9]

装饰器与生成器结合

def generator_timer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        gen = func(*args, **kwargs)
        for item in gen:
            yield item
        end = time.time()
        print(f"生成器执行时间: {end - start:.4f}秒")
    return wrapper

@generator_timer
def count_to(n):
    for i in range(n):
        yield i

list(count_to(1000000))

练习

  1. 创建一个装饰器,统计函数被调用的次数
  2. 编写一个生成器,生成指定范围内的质数
  3. 创建一个装饰器,在函数执行前后打印日志
  4. 编写一个生成器,逐行读取文件并过滤空行
练习答案:
# 练习1:调用计数装饰器
def call_counter(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        wrapper.count += 1
        print(f"调用次数: {wrapper.count}")
        return func(*args, **kwargs)
    wrapper.count = 0
    return wrapper

@call_counter
def test():
    print("测试函数")

test()
test()

# 练习2:质数生成器
def prime_generator(start, end):
    for num in range(start, end + 1):
        if num < 2:
            continue
        is_prime = True
        for i in range(2, int(num ** 0.5) + 1):
            if num % i == 0:
                is_prime = False
                break
        if is_prime:
            yield num

for prime in prime_generator(1, 20):
    print(prime)

# 练习3:日志装饰器
def logger(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        print(f"[开始] 调用 {func.__name__}")
        result = func(*args, **kwargs)
        print(f"[结束] {func.__name__} 返回 {result}")
        return result
    return wrapper

@logger
def add(a, b):
    return a + b

add(3, 5)

# 练习4:文件行生成器
def read_non_empty_lines(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                yield line

for line in read_non_empty_lines('data.txt'):
    print(line)