@@ -26,6 +26,7 @@ def parse_args():
26
26
parser .add_argument ("-options" , "--options" , default = None , nargs = "+" )
27
27
parser .add_argument ("-threads" , default = 2 , type = int )
28
28
parser .add_argument ("-resume" , default = False , action = "store_true" )
29
+ parser .add_argument ("-metric" , default = "avg_inc_acc" , choices = ["avg_inc_acc" , "last_acc" ])
29
30
30
31
return parser .parse_args ()
31
32
@@ -45,6 +46,7 @@ def train_func(config, reporter):
45
46
46
47
total_avg_inc_acc = statistics .mean (all_acc )
47
48
reporter (avg_inc_acc = total_avg_inc_acc )
49
+ #reporter(last_acc=last_acc)
48
50
return total_avg_inc_acc
49
51
50
52
@@ -54,24 +56,27 @@ def _get_abs_path(path):
54
56
return os .path .join (os .path .dirname (os .path .realpath (__file__ )), path )
55
57
56
58
57
- def analyse_ray_dump (ray_directory , topn ):
59
+ def analyse_ray_dump (ray_directory , topn , metric = "avg_inc_acc" ):
60
+ if metric not in ("avg_inc_acc" , "last_acc" ):
61
+ raise NotImplementedError ("Unknown metric {}." .format (metric ))
62
+
58
63
ea = Analysis (ray_directory )
59
64
trials_dataframe = ea .dataframe ()
60
- trials_dataframe = trials_dataframe .sort_values (by = "avg_inc_acc" , ascending = False )
65
+ trials_dataframe = trials_dataframe .sort_values (by = metric , ascending = False )
61
66
62
67
mapping_col_to_index = {}
63
68
result_index = - 1
64
69
for index , col in enumerate (trials_dataframe .columns ):
65
70
if col .startswith ("config:" ):
66
71
mapping_col_to_index [col [7 :]] = index
67
- elif col == "avg_inc_acc" :
72
+ elif col == metric :
68
73
result_index = index
69
74
70
75
print ("Ray config: {}" .format (ray_directory ))
71
76
print ("Best Config:" )
72
77
print (
73
- "avg_inc_acc : {} with {}." .format (
74
- trials_dataframe .iloc [0 ][result_index ],
78
+ "{} : {} with {}." .format (
79
+ metric , trials_dataframe .iloc [0 ][result_index ],
75
80
_get_line_results (trials_dataframe , 0 , mapping_col_to_index )
76
81
)
77
82
)
@@ -119,6 +124,9 @@ def get_tune_config(tune_options, options_files):
119
124
with open (tune_options ) as f :
120
125
options = yaml .load (f , Loader = yaml .FullLoader )
121
126
127
+ if "epochs" in options and options ["epochs" ] == 1 :
128
+ raise ValueError ("Using only 1 epoch, must be a mistake." )
129
+
122
130
config = {}
123
131
for k , v in options .items ():
124
132
if not k .startswith ("var:" ):
@@ -141,6 +149,12 @@ def main():
141
149
if args .tune is not None :
142
150
config = get_tune_config (args .tune , args .options )
143
151
config ["threads" ] = args .threads
152
+
153
+ try :
154
+ os .system ("echo '\ek{}_gridsearch\e\\ '" .format (args .tune ))
155
+ except :
156
+ pass
157
+
144
158
ray .init ()
145
159
tune .run (
146
160
train_func ,
@@ -158,10 +172,12 @@ def main():
158
172
args .ray_directory = os .path .join (args .ray_directory , args .tune .rstrip ("/" ).split ("/" )[- 1 ])
159
173
160
174
if args .tune is not None :
161
- print ("\n \n " , args .tune , "\n \n " )
175
+ print ("\n \n " , args .tune , args . options , "\n \n " )
162
176
163
177
if args .ray_directory is not None :
164
- best_config = analyse_ray_dump (_get_abs_path (args .ray_directory ), args .topn )
178
+ best_config = analyse_ray_dump (
179
+ _get_abs_path (args .ray_directory ), args .topn , metric = args .metric
180
+ )
165
181
166
182
if args .output_options :
167
183
with open (args .output_options , "w+" ) as f :
0 commit comments