Some important Pytorch tasks - A concise summary from a vision researcher

Link to Github Repository

In [2]:
import torch.nn as nn
import torch
from torch.autograd.variable import Variable
from torchvision import datasets, models, transforms
In [20]:
model = models.resnet18(pretrained = False)

Let us first explore this model's layers and then make a decision as to which ones we want to freeze. By freeze we mean that we want the parameters of those layers to be fixed. When fine tuning a model, we are basically taking a model trained on Dataset A, and then training it on a new Dataset B. We could potentially start the training from scratch as well, but it would be like re-inventing the wheel. Let me explain why.

Suppose, I want to train a dataset to learn to differentiate between a car and a bicycle. Now, I could potentially gather images of both categories and train a network from scratch. But, given the majority of work already out there, it's easy to find a model trained to identify things like Dogs, cats, and humans. Admittedly, neither of these 3 look like cars or bicycles. However, it's still better than nothing. We could start by taking this model, and train it to learn car v/s bicycle. Gains : 1) It will be faster, 2) We need lesser images of cats and bicycles.

(If interested in knowing more, read this - http://cs231n.github.io/transfer-learning/).

Now, let's take a look at the contents of a resnet18. We use the function .children() for this purpose. This lets us look at the contents/layers of a model. Then, we use the .parameters() function to access the parameters/weights of any layer. Finally, every parameter has a property .requires_grad which defines whether a parameter is trained or frozen. By default it is True, and the network updates it in every iteration. If it is set to False, then it is not updated and is said to be "frozen".

In [21]:
child_counter = 0
for child in model.children():
    print(" child", child_counter, "is -")
    print(child)
    child_counter += 1
 child 0 is -
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
 child 1 is -
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
 child 2 is -
ReLU (inplace)
 child 3 is -
MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
 child 4 is -
Sequential (
  (0): BasicBlock (
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  )
  (1): BasicBlock (
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  )
)
 child 5 is -
Sequential (
  (0): BasicBlock (
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (downsample): Sequential (
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    )
  )
  (1): BasicBlock (
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  )
)
 child 6 is -
Sequential (
  (0): BasicBlock (
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (downsample): Sequential (
      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    )
  )
  (1): BasicBlock (
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  )
)
 child 7 is -
Sequential (
  (0): BasicBlock (
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (downsample): Sequential (
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    )
  )
  (1): BasicBlock (
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU (inplace)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  )
)
 child 8 is -
AvgPool2d (
)
 child 9 is -
Linear (512 -> 1000)

Now, you can see that some of the children are actually big chunks and have layers within them. To access one level deeper we can run .children() on a child object as well!

Let's saw we want to freeze all parameters up to first BasicBlock of Child 6. First, lets see a parameter and set it to frozen -

In [22]:
for child in model.children():
    for param in child.parameters():
        print("This is what a parameter looks like - \n",param)
        break
    break
This is what a parameter looks like - 
 Parameter containing:
(0 ,0 ,.,.) = 
  1.8160e-02  2.1680e-02  5.6358e-02  ...  -1.2987e-02 -6.1262e-02 -4.8870e-02
  2.6440e-02  1.0603e-02  1.9794e-02  ...  -4.2643e-02 -4.5565e-03 -4.8300e-02
  9.0205e-03  1.9536e-03  1.9925e-04  ...   1.1413e-02  1.1395e-02  2.8418e-03
                 ...                   ⋱                   ...                
 -2.4830e-02  8.1022e-03 -4.9934e-02  ...   2.2573e-02  1.6346e-02  3.9666e-02
 -2.3857e-02 -1.6275e-02  2.9058e-02  ...   3.0488e-02  2.0294e-02 -5.1073e-03
 -1.6848e-04  5.9266e-02 -5.8456e-03  ...   1.9757e-02 -7.8441e-02  1.3667e-02

(0 ,1 ,.,.) = 
 -1.6319e-02  3.3193e-02 -2.2146e-04  ...   1.2571e-03 -1.3313e-02 -4.7580e-02
 -4.9329e-02  3.2548e-02  5.4202e-03  ...  -4.5771e-02 -2.6863e-03 -3.6992e-03
  8.7714e-03  2.4772e-02  1.0026e-02  ...   1.6512e-02 -7.4382e-03  6.0990e-02
                 ...                   ⋱                   ...                
 -4.0751e-02  3.3605e-04 -2.1426e-02  ...   1.1318e-02 -1.5222e-04 -3.5020e-02
 -4.1432e-02 -9.1312e-03 -1.7572e-02  ...   1.6974e-03  5.9792e-03  1.2868e-02
 -4.4471e-02 -1.1013e-02  4.9902e-03  ...  -2.1241e-02  2.2371e-02 -2.1672e-02

(0 ,2 ,.,.) = 
  1.0826e-02 -4.4230e-02 -1.5594e-02  ...  -1.3197e-03  6.1211e-03 -1.6262e-02
 -1.3989e-02 -3.2357e-02  2.0250e-02  ...   7.5012e-03  2.8761e-04 -2.1318e-02
 -7.8574e-04  1.7702e-02  1.0301e-02  ...  -2.0074e-02  4.4735e-02  1.0149e-02
                 ...                   ⋱                   ...                
 -2.4707e-02  2.3952e-03  6.5615e-04  ...   4.4371e-02 -1.0678e-02  2.3425e-02
 -2.4330e-02  1.3018e-02  1.1473e-02  ...  -3.6666e-03 -2.1145e-02 -1.5511e-02
 -3.0876e-02 -1.6071e-02 -2.4506e-02  ...   2.7417e-03  6.2566e-03  1.6208e-02
     ⋮ 

(1 ,0 ,.,.) = 
 -1.0333e-02  1.5746e-02  3.0517e-02  ...  -1.0851e-02 -7.7141e-04 -4.0873e-02
 -1.6966e-02 -3.6460e-02  5.3054e-02  ...  -2.0641e-02 -1.8781e-02 -7.1048e-03
  3.9752e-02 -3.6240e-02 -4.6019e-03  ...  -2.1766e-02 -2.5955e-03 -3.4346e-02
                 ...                   ⋱                   ...                
 -1.1488e-02 -3.3896e-02  4.6620e-02  ...  -2.7367e-03  1.6170e-02 -8.7509e-03
 -3.4920e-02 -1.7164e-02 -1.2804e-02  ...  -2.6690e-02  4.9540e-02  1.9799e-02
 -1.6736e-02 -9.0173e-03 -2.0421e-02  ...  -4.0255e-03 -1.7746e-02 -4.4906e-02

(1 ,1 ,.,.) = 
  2.0440e-02  7.0067e-03 -2.8885e-03  ...  -4.8313e-02  4.4430e-02 -1.6539e-03
  1.1405e-02  1.3499e-02 -1.0181e-02  ...   3.2469e-03  2.6244e-02  3.3834e-03
  5.5702e-03  1.0040e-02 -1.1350e-02  ...   1.1416e-02  2.5718e-02 -1.1672e-02
                 ...                   ⋱                   ...                
 -2.2712e-02  3.1696e-03 -2.5725e-02  ...  -3.1355e-02  4.4028e-02 -1.7592e-02
 -1.4702e-02 -2.3544e-02 -1.7768e-02  ...  -3.6875e-02 -2.1635e-02  4.1800e-03
 -1.3653e-02 -2.0815e-02  2.5550e-02  ...   2.9072e-02  2.6506e-02 -2.1846e-02

(1 ,2 ,.,.) = 
  2.6863e-02 -1.1023e-02 -3.8302e-02  ...  -2.9343e-02 -1.3996e-02 -1.3504e-02
  2.1842e-02  1.6150e-02 -5.8077e-03  ...  -1.3306e-04  1.2624e-02  8.6928e-03
 -9.3998e-03  3.3038e-02  1.6890e-02  ...   2.2586e-02  2.9318e-02  1.3380e-02
                 ...                   ⋱                   ...                
  2.2479e-02 -1.6830e-02 -1.1064e-02  ...   7.8837e-03  5.3104e-03  5.4637e-02
 -2.3910e-02  2.9069e-02 -3.2870e-02  ...  -2.4287e-02  2.2562e-02 -1.8842e-02
 -3.3899e-02  4.5987e-02 -3.7849e-03  ...  -5.9900e-03  4.6050e-02 -2.0960e-02
     ⋮ 

(2 ,0 ,.,.) = 
 -1.3902e-02  2.4403e-02  2.3496e-02  ...  -3.6818e-03 -1.3517e-02 -3.4732e-03
 -1.3494e-02 -5.9880e-03 -1.8047e-02  ...  -2.9621e-02  2.3363e-02  4.0067e-02
 -6.7423e-02 -4.1190e-02 -1.1207e-02  ...   1.1878e-02  1.2203e-02  6.7536e-03
                 ...                   ⋱                   ...                
 -1.7779e-02 -2.1686e-02 -1.7968e-02  ...   6.1823e-04 -1.1427e-02  3.8056e-03
  4.7457e-02  5.8501e-03  1.3968e-02  ...   1.1012e-02  9.1363e-04  5.3913e-03
  1.3919e-02  3.9247e-02 -2.6585e-03  ...   4.3866e-02 -5.1949e-02  2.9817e-02

(2 ,1 ,.,.) = 
  1.4485e-02  2.5544e-02  2.7102e-02  ...   2.2926e-02  1.4463e-02  3.4932e-02
 -6.3950e-03 -7.9448e-03  2.2152e-02  ...   4.6327e-02 -2.0223e-02 -7.2063e-03
 -2.4014e-02  7.0567e-03 -4.2840e-02  ...  -1.8328e-02 -1.4452e-02  3.2739e-02
                 ...                   ⋱                   ...                
  2.6509e-02  7.7837e-03 -2.7939e-02  ...   1.0316e-02  1.5443e-02  1.5726e-02
  3.7833e-02  1.0973e-02  1.6321e-02  ...  -9.0257e-03 -3.5795e-02 -3.0684e-02
  2.1657e-03 -4.3080e-02 -2.2311e-02  ...  -3.2374e-03 -3.8052e-02 -4.9427e-02

(2 ,2 ,.,.) = 
 -1.8339e-02 -2.6231e-02  1.2887e-02  ...   1.3463e-02  2.4595e-02  9.5057e-03
 -1.4393e-02  1.5389e-02  2.7461e-03  ...  -5.0952e-02  1.2911e-02 -1.8666e-03
 -4.3669e-03 -5.8640e-03 -2.1081e-02  ...  -1.6746e-02  1.9807e-02  8.5502e-03
                 ...                   ⋱                   ...                
 -1.0186e-02 -2.3766e-02  1.6136e-02  ...  -8.0708e-02 -2.5807e-02  1.5700e-02
 -9.2855e-03 -1.9718e-02  1.9457e-02  ...  -1.3100e-02 -1.0821e-02  2.1621e-02
  8.1854e-03 -3.1841e-02  1.3033e-02  ...  -2.0506e-02  1.2037e-02  6.4032e-04
...   
     ⋮ 

(61,0 ,.,.) = 
  9.4316e-03  2.3648e-02 -8.4966e-03  ...  -2.2285e-03 -1.4238e-02  5.2163e-02
  1.1587e-03  1.2474e-02 -1.6408e-02  ...  -2.2976e-02  6.6632e-03  3.6772e-03
  3.7755e-02  8.0352e-04  8.9609e-03  ...   2.1675e-02 -3.6027e-03 -1.1842e-02
                 ...                   ⋱                   ...                
  1.2762e-02  1.9184e-02  2.7700e-02  ...   9.5043e-04 -1.7118e-03  2.9772e-02
  2.8610e-02 -1.5271e-02  5.1606e-02  ...  -3.9722e-03 -3.3161e-02 -5.1093e-02
 -2.0437e-02  1.5838e-02  2.7344e-02  ...  -2.6124e-03  3.0168e-02 -2.4499e-02

(61,1 ,.,.) = 
  1.3869e-02  2.9713e-02 -2.2218e-03  ...  -5.0385e-02 -3.8294e-02  5.0754e-02
 -2.1760e-02 -1.1468e-02  2.2944e-02  ...   1.0988e-02 -1.8024e-03  2.4294e-02
 -1.2950e-02  1.5043e-02 -1.8723e-03  ...  -2.3066e-02  1.9586e-02 -2.3099e-03
                 ...                   ⋱                   ...                
  1.7489e-02  2.0666e-02  1.1381e-02  ...   1.7181e-02 -4.0002e-02 -1.9487e-02
  4.2988e-02  2.6599e-02 -2.4061e-02  ...  -3.5973e-02 -3.5824e-03  1.1549e-02
  1.6065e-02  1.3805e-02  3.6108e-02  ...  -4.8555e-02  8.1907e-03  3.4666e-02

(61,2 ,.,.) = 
 -3.0184e-02  2.5058e-02 -2.9590e-03  ...  -4.1951e-03 -2.5637e-02  1.5420e-02
 -2.7229e-02 -4.2415e-03 -2.7928e-02  ...  -1.1144e-02  1.1510e-03 -1.2208e-02
  5.0833e-02 -1.8479e-02 -1.8046e-02  ...   2.5169e-03  2.3112e-03  8.1823e-03
                 ...                   ⋱                   ...                
  8.0303e-03  2.8856e-02  7.8058e-03  ...   2.3697e-02 -3.2406e-04  4.0989e-02
 -1.1054e-02  1.7881e-02 -7.2309e-03  ...  -2.6414e-02 -3.9901e-02  2.2379e-02
  6.5656e-03  7.0047e-03 -3.3296e-03  ...   2.8250e-03  5.2304e-03 -4.7857e-03
     ⋮ 

(62,0 ,.,.) = 
 -2.1959e-02 -1.4979e-02  1.2155e-02  ...   2.6459e-02 -2.6932e-03 -5.3835e-03
  1.3290e-02 -1.2008e-02  1.7921e-02  ...  -8.4513e-03  4.7896e-02  8.4751e-03
  8.2594e-03  4.3179e-03  9.1544e-03  ...  -7.6523e-03 -1.0549e-02 -1.5311e-02
                 ...                   ⋱                   ...                
  7.0592e-03 -5.5720e-03  5.7900e-02  ...  -7.5445e-03  1.6987e-02 -4.9320e-02
  1.2382e-03  2.9988e-02  1.5510e-02  ...   5.7371e-03 -1.9073e-02  1.1134e-02
  1.1451e-02 -2.5826e-02 -2.6174e-02  ...   2.8024e-02 -8.2831e-02  3.7890e-02

(62,1 ,.,.) = 
 -2.3812e-02  2.6700e-02  2.4878e-02  ...  -1.2590e-02  1.4942e-02  7.3503e-03
 -3.2630e-02 -2.1997e-02 -5.1692e-02  ...   1.8524e-02 -2.1054e-02  8.3692e-03
  2.4765e-02  3.4338e-02  4.4222e-02  ...  -6.9486e-03 -1.3035e-02 -1.6388e-02
                 ...                   ⋱                   ...                
  2.7492e-02  2.1982e-02 -2.1263e-02  ...  -3.3880e-02  1.2141e-02  6.8169e-03
 -2.5239e-02 -2.1256e-02 -3.8697e-03  ...  -1.5080e-02 -1.0833e-03  1.2719e-02
  3.4392e-02 -1.6532e-02 -4.6843e-04  ...   1.6460e-02  3.8641e-02 -3.2814e-02

(62,2 ,.,.) = 
 -3.2325e-02 -4.7595e-02  2.8533e-02  ...   5.0494e-02  1.8599e-02 -1.5499e-02
  2.3927e-02 -2.9398e-02 -5.7063e-02  ...  -1.2592e-03  1.5265e-02 -2.7379e-03
 -1.9947e-02 -2.8994e-02 -1.0712e-02  ...  -1.3606e-02 -5.3947e-03  1.1104e-02
                 ...                   ⋱                   ...                
 -1.5108e-02 -5.3751e-03 -6.6983e-02  ...   2.1419e-02  2.4127e-02 -1.6207e-02
  3.8778e-02  1.6684e-02  2.3376e-02  ...   1.4579e-02  2.0048e-02 -5.1052e-02
  6.9827e-04 -3.1476e-02  2.2414e-02  ...   3.5637e-02 -6.2860e-03  2.1901e-03
     ⋮ 

(63,0 ,.,.) = 
  3.6892e-02 -1.0093e-02  1.4863e-02  ...  -1.9750e-02 -3.5509e-02 -1.9200e-02
 -2.5392e-02  8.6157e-05 -2.5180e-03  ...   3.3918e-03  9.8297e-03  1.7278e-03
  1.5289e-02 -4.9295e-03  1.7144e-02  ...   1.3728e-02  2.6355e-02  1.4548e-03
                 ...                   ⋱                   ...                
 -1.0451e-02 -1.9699e-02 -3.0967e-02  ...  -8.3925e-03  8.0206e-04 -9.3016e-03
  1.5797e-02 -2.2791e-02  3.6044e-02  ...  -2.5666e-02  4.4125e-02 -8.0478e-03
 -1.3139e-02 -1.9758e-02  1.4868e-02  ...  -9.0605e-03  2.7318e-02 -1.0136e-02

(63,1 ,.,.) = 
 -4.4853e-03 -3.4300e-02 -3.2744e-02  ...  -1.2309e-02  3.7756e-02 -2.6677e-02
 -1.1187e-02 -2.2497e-03 -1.6091e-02  ...  -2.4397e-02  6.7627e-03 -1.5241e-02
 -1.4663e-02  1.9999e-02 -3.5072e-02  ...   8.0089e-03  1.9439e-02  3.5001e-02
                 ...                   ⋱                   ...                
 -4.4512e-03 -2.9858e-02  8.8768e-03  ...   2.8442e-02 -2.7011e-02  4.4332e-03
  1.0174e-02 -4.3775e-02 -2.9107e-02  ...   2.9213e-02  1.7982e-02  3.4712e-02
 -1.3463e-02 -1.4656e-02  4.7337e-03  ...   1.6846e-02 -1.6850e-02 -1.9964e-02

(63,2 ,.,.) = 
 -4.5719e-03  2.6236e-02  5.5996e-03  ...   3.7875e-03  8.6500e-03  6.2772e-03
 -2.3837e-02 -1.6006e-03 -2.1914e-02  ...  -1.3637e-02 -1.9399e-02 -1.6704e-03
 -7.3654e-03 -1.4505e-02  4.5674e-02  ...   3.2031e-02 -2.9054e-02 -1.4125e-02
                 ...                   ⋱                   ...                
  2.5126e-02  9.5580e-03 -4.0513e-03  ...  -2.6763e-02  1.8345e-02  6.2725e-04
 -5.2027e-02  1.6874e-02 -8.8866e-03  ...   7.5890e-03 -1.1678e-02  2.8387e-03
 -8.3314e-03  3.1768e-02  2.7137e-02  ...   1.4631e-02 -1.9952e-02  1.1544e-02
[torch.FloatTensor of size 64x3x7x7]

Evidently, training this will take a lot of calculations. So, by setting a bunch of these to frozen, training becomes much faster. Now, let's freeze up to first BasicBlock of Child 6

In [23]:
child_counter = 0
for child in model.children():
    if child_counter < 6:
        print("child ",child_counter," was frozen")
        for param in child.parameters():
            param.requires_grad = False
    elif child_counter == 6:
        children_of_child_counter = 0
        for children_of_child in child.children():
            if children_of_child_counter < 1:
                for param in children_of_child.parameters():
                    param.requires_grad = False
                print('child ', children_of_child_counter, 'of child',child_counter,' was frozen')
            else:
                print('child ', children_of_child_counter, 'of child',child_counter,' was not frozen')
            children_of_child_counter += 1

    else:
        print("child ",child_counter," was not frozen")
    child_counter += 1
child  0  was frozen
child  1  was frozen
child  2  was frozen
child  3  was frozen
child  4  was frozen
child  5  was frozen
child  0 of child 6  was frozen
child  1 of child 6  was not frozen
child  7  was not frozen
child  8  was not frozen
child  9  was not frozen

Important Note

Now that you have frozen this network, another thing changes to make this work. That is your optimizer. Your optimizer is the one which actually updates these values. By default, the models are written like this -

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1)

But, this will give you an error as this will try to update all the parameters of model. However, you've set a bunch of them to frozen! So, the way to pass only the ones still being updated is -

optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

SECTION 2 - Model Saving/Loading

There's 2 primary ways in which models are saved in PyTorch. The suggested one is using "state dictionaries". They're faster and requires lower space. Basically, they have no idea of the model structure, they're just the values of the parameters/weights. So, you must create your model with the required architecture and then load the values. The architecture is declared as we did it above.

In [ ]:
# Let's assume we will save/load from a path MODEL_PATH

# Saving a Model
torch.save(model.state_dict(), MODEL_PATH)

# Loading the model.

# First create a model and define it's architecture as done above in this notebook. If you want a custom architecture.
# read below it's been covered below.
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)

SECTION 3 - changing last layer, deleting last layer, adding layers

Most people who come to pytorch don't like the fact that they can't do a .pop() to remove last layer. Especially if they've used Keras. So, let's take a look at how these things can be done.

CHANGING LAST LAYER

In [25]:
# Load the model
model = models.resnet18(pretrained = False)

# Get number of parameters going in to the last layer. we need this to change the final layer. 
num_final_in = model.fc.in_features

# The final layer of the model is model.fc so we can basically just overwrite it 
#to have the output = number of classes we need. Say, 300 classes.
NUM_CLASSES = 300
model.fc = nn.Linear(num_final_in, NUM_CLASSES)

DELETING LAST LAYER (OFTEN, WHEN YOU NEED FEATURES OF A LAYER)

In [ ]:
# Load the model
model = models.resnet18(pretrained = False)

We can get the layers by using model.children() as before. Then, we can convert this into a list by using a list() command on it. Then, we can remove the last layer by indexing the list. Finally, we can use the PyTorch function nn.Sequential() to stack this modified list together into a new model. You can edit the list in any way you want. That is, you can delete the last 2 layers if you want the features of an image from the 3rd last layer!

You may even delete layers from the middle of the model. But obviously, this would lead to incorrect number of features going in to the layer after it as most layers change size of image. In this case, you can index that specific layer of the model and overwrite it just as I showed you immediately above!

In [33]:
new_model = nn.Sequential(*list(model.children())[:-1])
In [34]:
new_model_2_removed = nn.Sequential(*list(model.children())[:-2])

ADDING LAYERS

Say, you want to add a fully connected layer to the model we have right now. One obvious way would be to edit the list I discussed above and appending to it another layer. However, often times we have such a model trained and want to see if we can load that model, and add just a new layer on top of it. As mentioned above, the loaded model should have the SAME architecture as saved one, and so we can't use the list method.

We need to add layers on top. The way to do this is simple in PyTorch - We just need to create a custom model! And this brings us to our next section - creating custom models!

SECTION 4 - CUSTOM MODELS : Combining sections 1-3 and adding layers on top

Let's make a custom model. As mentioned above, we will load half of the model from a pre-trained network. This seems complicated, right? Half the model is trained, half is new. Further, we want some of it to be frozen. Some to be update-able. Really, once you've done this, you can do anything with model architectures in PyTorch.

In [ ]:
# Some imports first
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
from torch.autograd.variable import Variable
from torchvision import datasets, models, transforms

# New models are defined as classes. Then, when we want to create a model we create an object instantiating this class.
class Resnet_Added_Layers_Half_Frozen(nn.Module):
    def __init__(self,LOAD_VIS_URL=None):
        super(ResnetCombinedFull2, self).__init__()
    
         # Start with half the resnet model, swap out the final layer because that's the model we had defined above. 
        model = models.resnet18(pretrained = False)
        num_final_in = model.fc.in_features
        model.fc = nn.Linear(num_final_in, 300)
        
        # Now that the architecture is defined same as above, let's load the model we would have trained above. 
        checkpoint = torch.load(MODEL_PATH)
        model.load_state_dict(checkpoint)
        
        
        # Let's freeze the same as above. Same code as above without the print statements
        child_counter = 0
        for child in model.children():
            if child_counter < 6:
                for param in child.parameters():
                    param.requires_grad = False
            elif child_counter == 6:
                children_of_child_counter = 0
                for children_of_child in child.children():
                    if children_of_child_counter < 1:
                        for param in children_of_child.parameters():
                            param.requires_grad = False
                    else:
                    children_of_child_counter += 1

            else:
                print("child ",child_counter," was not frozen")
            child_counter += 1
        
        # Now, let's define new layers that we want to add on top. 
        # Basically, these are just objects we define here. The "adding on top" is defined by the forward()
        # function which decides the flow of the input data into the model.
        
        # NOTE - Even the above model needs to be passed to self.
        self.vismodel = nn.Sequential(*list(model.children()))
        self.projective = nn.Linear(512,400)
        self.nonlinearity = nn.ReLU(inplace=True)
        self.projective2 = nn.Linear(400,300)
        
    
    # The forward function defines the flow of the input data and thus decides which layer/chunk goes on top of what.
    def forward(self,x):
        x = self.vismodel(x)
        x = torch.squeeze(x)
        x = self.projective(x)
        x = self.nonlinearity(x)
        x = self.projective2(x)
        return x

SECTION 5 - CUSTOM LOSS FUNCTIONS

Now that we have our model all in place we can load anything and create any architecture we want. That leaves us with 2 important components in any pipeline - Loading the data, and the training part. Let's take a look at the training part. The two most important components of this step are the optimizer and the loss function. The loss function quantifies how far our existing model is from where we want to be, and the optimizer decides how to update parameters such that we can minimize the loss.

Sometimes, we need to define our own loss functions. And here are a few things to know about this -

  • custom Loss functions are defined using a custom class too. They inherit from torch.nn.Module just like the custom model.
  • Often, we need to change the dimenions of one of our inputs. This can be done using view() function.
  • If we want to add a dimension to a tensor, use the unsqueeze() function.
  • The value finally being returned by a loss function MUST BE a scalar value. Not a vector/tensor.
  • The value being returned must be a Variable. This is so that it can be used to update the parameters. The best way to do so is to just make sure that both x and y being passed in are Variables. That way any function of the two will also be a Variable.
  • A Pytorch Variable is just a Pytorch Tensor, but Pytorch is tracking the operations being done on it so that it can backpropagate to get the gradient.

Here I show a custom loss called Regress_Loss which takes as input 2 kinds of input x and y. Then it reshapes x to be similar to y and finally returns the loss by calculating L2 difference between reshaped x and y. This is a standard thing you'll run across very often in training networks.

Consider x to be shape (5,10) and y to be shape (5,5,10). So, we need to add a dimension to x, then repeat it along the added dimension to match the dimension of y. Then, (x-y) will be the shape (5,5,10). We will have to add over all three dimensions i.e. three torch.sum() to get a scalar.

In [38]:
class Regress_Loss(torch.nn.Module):
    
    def __init__(self):
        super(Regress_Loss,self).__init__()
        
    def forward(self,x,y):
        y_shape = y.size()[1]
        x_added_dim = x.unsqueeze(1)
        x_stacked_along_dimension1 = x_added_dim.repeat(1,NUM_WORDS,1)
        diff = torch.sum((y - x_stacked_along_dimension1)**2,2)
        totloss = torch.sum(torch.sum(torch.sum(diff)))
        return totloss
In [ ]: