pytorch 實現優化函數

import torch


def himmelblau(x):
    return (x[0]**2 + x[1]-11)**2 + (x[0] + x[1]**2 - 7)**2


# [1., 0.],[-4., 0.],[4., 0.]
x = torch.tensor([4., 0.], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-3)
for step in range(20000):

    pred = himmelblau(x)
    optimizer.zero_grad()
    pred.backward()
    optimizer.step()

    if step % 2000 == 0:
        print('step{}:x={},f(x)={}'.format(step, x.tolist(), pred.item()))




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章