Skip to content

Commit cfd22fc

Browse files
Clean PODNet and add paper options.
1 parent aa79a1f commit cfd22fc

File tree

69 files changed

+778
-3086
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+778
-3086
lines changed

hyperfind.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def parse_args():
2626
parser.add_argument("-options", "--options", default=None, nargs="+")
2727
parser.add_argument("-threads", default=2, type=int)
2828
parser.add_argument("-resume", default=False, action="store_true")
29+
parser.add_argument("-metric", default="avg_inc_acc", choices=["avg_inc_acc", "last_acc"])
2930

3031
return parser.parse_args()
3132

@@ -45,6 +46,7 @@ def train_func(config, reporter):
4546

4647
total_avg_inc_acc = statistics.mean(all_acc)
4748
reporter(avg_inc_acc=total_avg_inc_acc)
49+
#reporter(last_acc=last_acc)
4850
return total_avg_inc_acc
4951

5052

@@ -54,24 +56,27 @@ def _get_abs_path(path):
5456
return os.path.join(os.path.dirname(os.path.realpath(__file__)), path)
5557

5658

57-
def analyse_ray_dump(ray_directory, topn):
59+
def analyse_ray_dump(ray_directory, topn, metric="avg_inc_acc"):
60+
if metric not in ("avg_inc_acc", "last_acc"):
61+
raise NotImplementedError("Unknown metric {}.".format(metric))
62+
5863
ea = Analysis(ray_directory)
5964
trials_dataframe = ea.dataframe()
60-
trials_dataframe = trials_dataframe.sort_values(by="avg_inc_acc", ascending=False)
65+
trials_dataframe = trials_dataframe.sort_values(by=metric, ascending=False)
6166

6267
mapping_col_to_index = {}
6368
result_index = -1
6469
for index, col in enumerate(trials_dataframe.columns):
6570
if col.startswith("config:"):
6671
mapping_col_to_index[col[7:]] = index
67-
elif col == "avg_inc_acc":
72+
elif col == metric:
6873
result_index = index
6974

7075
print("Ray config: {}".format(ray_directory))
7176
print("Best Config:")
7277
print(
73-
"avg_inc_acc: {} with {}.".format(
74-
trials_dataframe.iloc[0][result_index],
78+
"{}: {} with {}.".format(
79+
metric, trials_dataframe.iloc[0][result_index],
7580
_get_line_results(trials_dataframe, 0, mapping_col_to_index)
7681
)
7782
)
@@ -119,6 +124,9 @@ def get_tune_config(tune_options, options_files):
119124
with open(tune_options) as f:
120125
options = yaml.load(f, Loader=yaml.FullLoader)
121126

127+
if "epochs" in options and options["epochs"] == 1:
128+
raise ValueError("Using only 1 epoch, must be a mistake.")
129+
122130
config = {}
123131
for k, v in options.items():
124132
if not k.startswith("var:"):
@@ -141,6 +149,12 @@ def main():
141149
if args.tune is not None:
142150
config = get_tune_config(args.tune, args.options)
143151
config["threads"] = args.threads
152+
153+
try:
154+
os.system("echo '\ek{}_gridsearch\e\\'".format(args.tune))
155+
except:
156+
pass
157+
144158
ray.init()
145159
tune.run(
146160
train_func,
@@ -158,10 +172,12 @@ def main():
158172
args.ray_directory = os.path.join(args.ray_directory, args.tune.rstrip("/").split("/")[-1])
159173

160174
if args.tune is not None:
161-
print("\n\n", args.tune, "\n\n")
175+
print("\n\n", args.tune, args.options, "\n\n")
162176

163177
if args.ray_directory is not None:
164-
best_config = analyse_ray_dump(_get_abs_path(args.ray_directory), args.topn)
178+
best_config = analyse_ray_dump(
179+
_get_abs_path(args.ray_directory), args.topn, metric=args.metric
180+
)
165181

166182
if args.output_options:
167183
with open(args.output_options, "w+") as f:

inclearn/lib/factory.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from torch import optim
55

66
from inclearn import models
7-
from inclearn.convnet import (densenet, my_resnet, my_resnet2, my_resnet_brn,
8-
my_resnet_mcbn, my_resnet_mtl, resnet,
9-
resnet_mtl, ucir_resnet, vgg)
7+
from inclearn.convnet import (
8+
densenet, my_resnet, my_resnet2, my_resnet_brn, my_resnet_mcbn, my_resnet_mtl, resnet,
9+
resnet_mtl, ucir_resnet, vgg
10+
)
1011
from inclearn.lib import data, schedulers
1112

1213

@@ -59,18 +60,15 @@ def get_convnet(convnet_type, **kwargs):
5960
def get_model(args):
6061
dict_models = {
6162
"icarl": models.ICarl,
62-
#"lwf": models.LwF,
63+
"lwf": None,
6364
"e2e": models.End2End,
64-
#"medic": models.Medic,
65-
#"fixed": models.FixedRepresentation,
65+
"fixed": None,
6666
"oracle": None,
6767
"bic": models.BiC,
6868
"ucir": models.UCIR,
69-
"still": models.STILL,
69+
"podnet": models.PODNet,
7070
"lwm": models.LwM,
71-
"zil": models.ZIL,
72-
"zil2": models.ZIL2,
73-
"ull": models.ULL
71+
"zil": models.ZIL
7472
}
7573

7674
model = args["model"].lower()

0 commit comments

Comments
 (0)