
| save | 将对象保存到磁盘文件。 |
| load | 从文件中加载用torch.save()保存的对象 |
| get_num_threads | 返回用于并行化CPU *** 作的线程数 |
| set_num_threads | 设置CPU上用于 *** 作内并行的线程数。 |
| get_num_interop_threads | 返回CPU上用于 *** 作间并行的线程数。 |
| set_num_interop_threads | 设置用于互 *** 作并行性的线程数。 |
上下文管理器有助于禁用和启用梯度计算
torch.no_grad()
torch.enable_grad()
torch.set_grad_enabled()
| no_grad | 上下文管理器禁用梯度计算. |
| enable_grad | 上下文管理器启用梯度计算. |
| set_grad_enabled | 上下文管理器将“梯度计算”设置为“开”或“关”。 |
| is_grad_enabled | 如果当前启用梯度模式,则返回True。 |
| inference_mode | 上下文管理器启用或禁用推理模式 |
| is_inference_mode_enabled | 如果当前启用了推理模式,则返回True。 |
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)