import torch


# create test net

test_input = torch.randn((1,3,2,2))

test_gt = torch.ones((1,1,1,1))


conv1 = torch.nn.Conv2d(3, 2, kernel_size=2)
conv2 = torch.nn.Conv2d(2,1, kernel_size=1)

a = conv1(test_input)
b = conv2(a)

loss = test_gt - b


print(conv1.weight.grad)
print(conv2.weight.grad)


output = torch.autograd.grad(loss, [conv2.weight, conv1.weight])
print(output)


print(conv1.weight.grad)
print(conv2.weight.grad)


print('after manual grad update')
if conv2.weight.grad is None:
    conv2.weight.grad = output[0]
else:
    conv2.weight.grad += output[0]

print(conv2.weight.grad)


the above is an example code of showing how to calculate gradients for a few wanted tensors. In this case, I only wanted to calculate the gradient of conv2.weight so that I can later on update only this weight with the amount calculated based on the the gradient produced by the loss function.


0 Comments

Leave a Reply

Your email address will not be published.