反模式:不必要的调用 ray.get 会影响性能#

TLDR: 避免在中间步骤中不必要地调用 ray.get()。直接使用对象引用进行操作,只在最后一步调用 ray.get() 获取最终结果。

当调用 ray.get() 时,对象必须传输到调用 ray.get() 的 worker/node。如果你不需要操作对象,你可能不需要调用 ray.get()

通常,最好在调用 ray.get() 之前等待尽可能长的时间,甚至设计程序以避免完全调用 ray.get()

代码示例#

反模式:

import ray
import numpy as np

ray.init()


@ray.remote
def generate_rollout():
    return np.ones((10000, 10000))


@ray.remote
def reduce(rollout):
    return np.sum(rollout)


# `ray.get()` downloads the result here.
rollout = ray.get(generate_rollout.remote())
# Now we have to reupload `rollout`
reduced = ray.get(reduce.remote(rollout))
../../_images/unnecessary-ray-get-anti.svg

更好的方法:

# Don't need ray.get here.
rollout_obj_ref = generate_rollout.remote()
# Rollout object is passed by reference.
reduced = ray.get(reduce.remote(rollout_obj_ref))
../../_images/unnecessary-ray-get-better.svg

请注意,在反模式示例中,我们调用 ray.get(),这迫使我们将大型 rollout 传输到 driver,然后再传输到 reduce worker。

在修复版本中,我们只传递对象的引用给 reduce 任务。 reduce worker 会隐式调用 ray.get()generate_rollout worker 直接获取实际的 rollout 数据,避免了额外的拷贝到 driver。

其他与 ray.get() 相关的反模式包括: