8
8
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
9
9
</p >
10
10
11
- This repository is the reimplementation of [ MAML] ( https://arxiv.org/abs/1703.03400 ) (Model-Agnostic
12
- Meta-Learning) algorithm. Differentiable
13
- optimizers are handled by [ Higher] ( https://github.com/facebookresearch/higher ) library and [ NN-template] ( https://github.com/lucmos/nn-template ) is used for
14
- structuring
11
+ This repository is the reimplementation
12
+ of [ MAML] ( https://arxiv.org/abs/1703.03400 ) (Model-Agnostic Meta-Learning)
13
+ algorithm. Differentiable optimizers are handled
14
+ by [ Higher] ( https://github.com/facebookresearch/higher ) library
15
+ and [ NN-template] ( https://github.com/lucmos/nn-template ) is used for structuring
15
16
the project. The default settings are used for training on Omniglot (5-way
16
17
5-shot) problem. It can be easily extended for other few-shot datasets thanks to
17
18
[ Torchmeta] ( https://github.com/tristandeleu/pytorch-meta ) library.
@@ -21,30 +22,32 @@ the project. The default settings are used for training on Omniglot (5-way
21
22
** On Local Machine**
22
23
23
24
1 . Download and install dependencies
25
+
24
26
``` bash
25
27
git clone https://github.com/rcmalli/lightning-maml.git
26
28
cd ./lightning-maml/
27
29
pip install -r requirements.txt
28
30
```
29
31
30
- 2 . Create ` .env ` file containing the info given below using your own [ Wandb.
31
- ai] ( https://wandb.ai )
32
+ 2 . Create ` .env ` file containing the info given below using your
33
+ own [ Wandb. ai] ( https://wandb.ai )
32
34
account to track experiments. You can use ` .env.template ` file.
33
35
34
36
``` bash
35
37
export DATASET_PATH=" /your/project/root/data/"
36
38
export WANDB_ENTITY=" USERNAME"
37
39
export WANDB_API_KEY=" KEY"
38
40
```
41
+
39
42
3 . Run the experiment
43
+
40
44
``` bash
41
45
python3 src/run.py train.pl_trainer.gpus=1
42
46
```
43
47
44
48
** On Google Colab**
45
-
46
- [ ![ Google Colab] ( https://colab.research.google.com/assets/colab-badge.svg )] ( https://colab.research.google.com/github/rcmalli/lightning-maml/blob/main/notebooks/lightning_maml_pub.ipynb )
47
49
50
+ [ ![ Google Colab] ( https://colab.research.google.com/assets/colab-badge.svg )] ( https://colab.research.google.com/github/rcmalli/lightning-maml/blob/main/notebooks/lightning_maml_pub.ipynb )
48
51
49
52
## Results
50
53
@@ -53,44 +56,88 @@ python3 src/run.py train.pl_trainer.gpus=1
53
56
<table class =" tg " >
54
57
<thead >
55
58
<tr >
56
- <th class="tg-0pky" colspan="2"></th>
57
- <th class="tg-7btt" colspan="2">Metatrain</th>
58
- <th class="tg-7btt" colspan="2">Metavalidation</th>
59
+ <th colspan="3"></th>
60
+ <th colspan="2">Metatrain</th>
61
+ <th colspan="2">Metavalidation</th>
62
+ </tr >
63
+ </thead >
64
+ <tbody >
65
+ <tr >
66
+ <td >Algorithm</td>
67
+ <td >Model</td>
68
+ <td >inner_steps</td>
69
+ <td >inner accuracy</td>
70
+ <td ><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
71
+ <td ><span style="font-style:normal;text-decoration:none">inner accuracy</span></td>
72
+ <td ><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
73
+ </tr >
74
+ <tr >
75
+ <td >MAML</td>
76
+ <td >OmniConv</td>
77
+ <td >1</td>
78
+ <td ></td>
79
+ <td ></td>
80
+ <td ></td>
81
+ <td ></td>
82
+ </tr >
83
+ <tr >
84
+ <td >MAML</td>
85
+ <td >OmniConv</td>
86
+ <td >5</td>
87
+ <td ></td>
88
+ <td ></td>
89
+ <td ></td>
90
+ <td ></td>
91
+ </tr >
92
+ </tbody >
93
+ </table >
94
+
95
+ ### MiniImageNet (5-way 5-shot)
96
+
97
+ <table class =" tg " >
98
+ <thead >
99
+ <tr >
100
+ <th colspan="3"></th>
101
+ <th colspan="2">Metatrain</th>
102
+ <th colspan="2">Metavalidation</th>
59
103
</tr >
60
104
</thead >
61
105
<tbody >
62
106
<tr >
63
- <td class="tg-7btt">Algorithm</td>
64
- <td class="tg-7btt">inner_steps</td>
65
- <td class="tg-6ic8">inner accuracy</td>
66
- <td class="tg-6ic8"><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
67
- <td class="tg-4erg"><span style="font-style:normal;text-decoration:none">inner accuracy</span></td>
68
- <td class="tg-fymr"><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
107
+ <td >Algorithm</td>
108
+ <td >Model</td>
109
+ <td >inner_steps</td>
110
+ <td >inner accuracy</td>
111
+ <td ><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
112
+ <td ><span style="font-style:normal;text-decoration:none">inner accuracy</span></td>
113
+ <td ><span style="font-style:normal;text-decoration:none">outer accuracy</span></td>
69
114
</tr >
70
115
<tr >
71
- <td class="tg-c3ow">MAML</td>
72
- <td class="tg-c3ow">1</td>
73
- <td class="tg-dvpl"></td>
74
- <td class="tg-dvpl"></td>
75
- <td class="tg-0pky"></td>
76
- <td class="tg-0pky"></td>
116
+ <td >MAML</td>
117
+ <td >MiniConv</td>
118
+ <td >1</td>
119
+ <td ></td>
120
+ <td ></td>
121
+ <td ></td>
122
+ <td ></td>
77
123
</tr >
78
124
<tr >
79
- <td class="tg-c3ow">MAML</td>
80
- <td class="tg-c3ow">5</td>
81
- <td class="tg-dvpl"></td>
82
- <td class="tg-dvpl"></td>
83
- <td class="tg-0pky"></td>
84
- <td class="tg-0pky"></td>
125
+ <td >MAML</td>
126
+ <td >MiniConv</td>
127
+ <td >5</td>
128
+ <td ></td>
129
+ <td ></td>
130
+ <td ></td>
131
+ <td ></td>
85
132
</tr >
86
133
</tbody >
87
134
</table >
88
135
89
136
## Customization
90
137
91
- Inside 'conf' folder, you can change all the settings depending on your
92
- problem or dataset. The default parameters are set for Omniglot dataset.
93
- Here are some examples for customization:
138
+ Inside 'conf' folder, you can change all the settings depending on your problem
139
+ or dataset. The default parameters are set for Omniglot dataset. Here are some
140
+ examples for customization:
94
141
95
142
### Debug on local machine without GPU
96
143
@@ -113,8 +160,8 @@ python3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,
113
160
114
161
### Using different dataset from Torchmeta
115
162
116
- If you want to try a different dataset (ex. MiniImageNet), you can copy
117
- default.yaml file inside ` conf/data ` to ` miniimagenet.yaml ` and edit these
163
+ If you want to try a different dataset (ex. MiniImageNet), you can copy
164
+ default.yaml file inside ` conf/data ` to ` miniimagenet.yaml ` and edit these
118
165
lines :
119
166
120
167
``` yaml
@@ -142,7 +189,9 @@ datamodule:
142
189
143
190
# you may need to update data augmentation and preprocessing steps also!!!
144
191
```
192
+
145
193
Run the experiment as follows:
194
+
146
195
``` bash
147
196
python3 src/run.py data=miniimagenet
148
197
```
@@ -176,7 +225,6 @@ pytorch-lightning as high-level library
176
225
outer_optimizer.step()
177
226
```
178
227
179
-
180
228
## References
181
229
182
230
- [ Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks] ( https://arxiv.org/abs/1703.03400 )
0 commit comments