{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"14a_dropout_intro.ipynb","provenance":[{"file_id":"1A-yZtQzr_Ms-uG2sl9EChVuM1XUeu1RX","timestamp":1612871843955},{"file_id":"https://github.com/mlelarge/dataflowr/blob/master/Notebooks/04_deeper/04_dropout_intro_colab.ipynb","timestamp":1580210876000}],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"H-pGQemIx72p"},"source":["# DropOut intuition"]},{"cell_type":"code","metadata":{"id":"UbUz5Ocsx72s"},"source":["%matplotlib inline\n","import matplotlib.pyplot as plt\n","import torch"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"D8WXRqxNx72w"},"source":["# 1. Dropout from scratch"]},{"cell_type":"markdown","metadata":{"id":"n87efjJpx72x"},"source":["Let's code our own dropout function first."]},{"cell_type":"code","metadata":{"id":"b5QmfUlJx72x"},"source":["def dropout(X, drop_prob):\n"," assert 0 <= drop_prob <= 1\n"," # In this case, all elements are dropped out\n"," if drop_prob == 1:\n"," return torch.zeros_like(X)\n"," mask = torch.rand(X.shape).uniform_() > drop_prob\n"," return mask.float() * X / (1.0-drop_prob)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2YvMdJrtx72z"},"source":["X = torch.arange(16).view(2,8).float()\n","print(X)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iTu1ZNvzx722"},"source":["print(dropout(X, 0))\n","print(dropout(X, 0.5))\n","print(dropout(X, 1))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kClbySxGx724"},"source":["# 2. DropOut on a toy dataset"]},{"cell_type":"code","metadata":{"id":"p6Yxd33Xx725"},"source":["N_SAMPLES = 20\n","N_HIDDEN = 300"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BbaEOLDLx727"},"source":["Let's generate some training data"]},{"cell_type":"code","metadata":{"id":"aGacY3QHx728"},"source":["# training data\n","x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n","y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n","\n","# test data\n","test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n","test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"a6O7oAzTx72-"},"source":["# show data\n","plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.5, label='train')\n","plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.5, label='test')\n","plt.legend(loc='upper left')\n","plt.ylim((-2.5, 2.5))\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCkx0YJhx73D"},"source":["#### Exercise"]},{"cell_type":"markdown","metadata":{"id":"egV6cqNex73E"},"source":["Create a network `net_simple` as `nn.Sequential` with the following structure: `Fully Connected N_HIDDEN -> ReLU -> Fully Connected N_HIDDEN -> ReLU -> Fully Connected 1`"]},{"cell_type":"code","metadata":{"id":"NEA2ljVnx73F"},"source":["net_simple = torch.nn.Sequential(\n","# TODO\n",")\n","\n","print(net_simple)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lyDCBZWqx73I"},"source":["#### Exercise\n","\n","Now define the same architecture with `Dropout` layers in-between with $p=0.5$. Where will you place them?"]},{"cell_type":"code","metadata":{"id":"PzduP4k8x73J"},"source":["net_dropout = torch.nn.Sequential(\n","# TODO\n",")\n","\n","print(net_dropout)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"txKmgBiDx73M"},"source":["optimizer_simple = torch.optim.Adam(net_simple.parameters(), lr=0.01)\n","optimizer_dropout = torch.optim.Adam(net_dropout.parameters(), lr=0.01)\n","loss_fn = torch.nn.MSELoss()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"IttM4NmMx73P"},"source":["for epoch in range(1000):\n"," pred_simple = net_simple(x)\n"," pred_drop = net_dropout(x)\n"," loss_simple = loss_fn(pred_simple, y)\n"," loss_dropout = loss_fn(pred_drop, y)\n","\n"," optimizer_simple.zero_grad()\n"," optimizer_dropout.zero_grad()\n"," loss_simple.backward()\n"," loss_dropout.backward()\n"," optimizer_simple.step()\n"," optimizer_dropout.step()\n","\n"," if epoch % 100 == 0:\n"," # change to eval mode in order to fix drop out effect\n"," net_simple.eval()\n"," net_dropout.eval() # parameters for dropout differ from train mode\n","\n"," # plotting\n"," plt.cla()\n"," test_pred_simple = net_simple(test_x)\n"," test_pred_dropout = net_dropout(test_x)\n"," plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.3, label='train')\n"," plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.3, label='test')\n"," plt.plot(test_x.data.numpy(), test_pred_simple.data.numpy(), 'r-', lw=3, label='overfitting')\n"," plt.plot(test_x.data.numpy(), test_pred_dropout.data.numpy(), 'b--', lw=3, label='dropout(50%)')\n"," plt.text(0, -1.2, 'overfitting test loss=%.4f' % loss_fn(test_pred_simple, test_y).data.item(), fontdict={'size': 16, 'color': 'red'})\n"," plt.text(0, -1.5, 'dropout test loss=%.4f' % loss_fn(test_pred_dropout, test_y).data.item(), fontdict={'size': 16, 'color': 'blue'})\n"," plt.legend(loc='upper left'); plt.ylim((-2.5, 2.5));plt.pause(0.1)\n","\n"," # change back to train mode\n"," net_simple.train()\n"," net_dropout.train()\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Vc8m9b9ox73Q"},"source":[""],"execution_count":null,"outputs":[]}]}