Use of hook in pytorch

In pytorch, you can use Hook to obtain and change the value and gradient of a variable in the middle of the network, so as to analyze the network conveniently without changing the network structure.

1, hook in torch.Tensor

When using pytorch, only the gradients of leaf nodes (i.e. nodes directly specifying values, rather than nodes calculated by other variables, such as network input) will be retained, and the gradients of other intermediate nodes will be automatically released after the back propagation is completed to save video memory.

For example:

import torch

x=torch.Tensor([1,2]).requires_grad_(True)
y=torch.Tensor([3,4]).requires_grad_(True)
z=((y-x)**2).mean()
# z.retain_grad()
z.backward()

print('x.requires_grad:',x.requires_grad)
print('y.requires_grad:',y.requires_grad)
print('z.requires_grad:',z.requires_grad)

print('x.grad:',x.grad)
print('y.grad:',y.grad)
print('z.grad:',z.grad)

Output:

x.requires_grad: True
y.requires_grad: True
z.requires_grad: True
x.grad: tensor([-2., -2.])
y.grad: tensor([2., 2.])
/home/wangguoyu/test.py:14: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information.
  print('z.grad:',z.grad)
z.grad: None

Here, x and y are leaf nodes, so grad will be retained after backward, while z requires_grad is True, but because it is not a leaf node, the gradient is not retained. If we really need gradient information of non leaf nodes, we need to use retain before backward_ Grad method (i.e. remove the comments above), which can access the gradient information of z. However, use retain_grad reserved grad will occupy the video memory. If you don't want to occupy the video memory, we can use hook. For the variable a of the intermediate node, we can use a.register_hook(hook_fn) operates its gradient (you can modify or save it). Here's a hook_ FN is a self defined function, and its function declaration is:

hook_fn(grad) -> Tensor or None

Its input variable is grad of A. if Tensor is returned, the Tensor replaces the original grad of a and propagates forward; If you do not go back or return to None, the grad of a remains unchanged and continues to propagate forward.

import torch

def hook_fn(grad):
  print('here is the hook_fn')
  print(grad)
  
x=torch.Tensor([1,2]).requires_grad_(True)
y=torch.Tensor([3,4]).requires_grad_(True)
z=((y-x)**2).mean()

z.register_hook(hook_fn)

print('before backward')
z.backward()
print('after backward')

print('x.requires_grad:',x.requires_grad)
print('y.requires_grad:',y.requires_grad)
print('z.requires_grad:',z.requires_grad)

print('x.grad:',x.grad)
print('y.grad:',y.grad)
print('z.grad:',z.grad)

Output:

before backward
here is the hook_fn
tensor(1.)
after backward
x.requires_grad: True
y.requires_grad: True
z.requires_grad: True
x.grad: tensor([-2., -2.])
y.grad: tensor([2., 2.])
z.grad: None

As you can see, hook is bound in z_ After FN, when backward, the grad of z is printed, because we return to None, and the grad of z remains unchanged. Next, we change the grad of z:

import torch

def hook_fn(grad):
  grad*=2
  print('here is the hook_fn')
  print(grad)
  return grad
  
x=torch.Tensor([1,2]).requires_grad_(True)
y=torch.Tensor([3,4]).requires_grad_(True)
z=((y-x)**2).mean()

# z.register_hook(lambda x: 2*x)
z.register_hook(hook_fn)

print('before backward')
z.backward()
print('after backward')

print('x.requires_grad:',x.requires_grad)
print('y.requires_grad:',y.requires_grad)
print('z.requires_grad:',z.requires_grad)

print('x.grad:',x.grad)
print('y.grad:',y.grad)
print('z.grad:',z.grad)

Output:

before backward
here is the hook_fn
tensor(2.)
after backward
x.requires_grad: True
y.requires_grad: True
z.requires_grad: True
x.grad: tensor([-4., -4.])
y.grad: tensor([4., 4.])
z.grad: None

You can see that the grad of x and y is twice the original, because the derivative of z itself changes when the chain derivative is obtained, which shows that hook_fn changes the grad of z. In addition, hook_fn can also be a lambda expression. Remove the comments in the above code and use hook_fn has the same effect.

To be added
https://zhuanlan.zhihu.com/p/267800207

Tags: Python Pytorch Deep Learning

Posted on Tue, 12 Oct 2021 18:21:32 -0400 by eyaly