反模式:不必要的调用 ray.get 会影响性能
Contents
反模式:不必要的调用 ray.get 会影响性能#
TLDR: 避免在中间步骤中不必要地调用 ray.get()。直接使用对象引用进行操作,只在最后一步调用 ray.get() 获取最终结果。
当调用 ray.get() 时,对象必须传输到调用 ray.get() 的 worker/node。如果你不需要操作对象,你可能不需要调用 ray.get()!
通常,最好在调用 ray.get() 之前等待尽可能长的时间,甚至设计程序以避免完全调用 ray.get()。
代码示例#
反模式:
import ray
import numpy as np
ray.init()
@ray.remote
def generate_rollout():
return np.ones((10000, 10000))
@ray.remote
def reduce(rollout):
return np.sum(rollout)
# `ray.get()` downloads the result here.
rollout = ray.get(generate_rollout.remote())
# Now we have to reupload `rollout`
reduced = ray.get(reduce.remote(rollout))
更好的方法:
# Don't need ray.get here.
rollout_obj_ref = generate_rollout.remote()
# Rollout object is passed by reference.
reduced = ray.get(reduce.remote(rollout_obj_ref))
请注意,在反模式示例中,我们调用 ray.get(),这迫使我们将大型 rollout 传输到 driver,然后再传输到 reduce worker。
在修复版本中,我们只传递对象的引用给 reduce 任务。
reduce worker 会隐式调用 ray.get() 从 generate_rollout worker 直接获取实际的 rollout 数据,避免了额外的拷贝到 driver。
其他与 ray.get() 相关的反模式包括: