@@ -740,7 +740,7 @@ def _delete(filename, **more):
740
740
os .remove (filename )
741
741
742
742
@classmethod
743
- def get_xs_from_results (cls , results , indexs = None , aggregate = True ):
743
+ def get_xs_from_results (cls , results , indexs = None , aggregate = True , features = None ):
744
744
"""
745
745
:param results: list of BasicResult, or pandas.DataFrame
746
746
:param indexs: indices of results to be used
@@ -756,8 +756,11 @@ def get_xs_from_results(cls, results, indexs=None, aggregate=True):
756
756
# or get_ordered_list_scores_key. Instead, just get the sorted keys
757
757
feature_names = results [0 ].get_ordered_results ()
758
758
759
- feature_names = list (feature_names )
760
- cls ._assert_dimension (feature_names , results )
759
+ if features is not None :
760
+ feature_names = [f for f in feature_names if f in features ]
761
+ else :
762
+ feature_names = list (feature_names )
763
+ cls ._assert_dimension (feature_names , results )
761
764
762
765
# collect results into xs
763
766
xs = {}
@@ -1156,6 +1159,81 @@ def _predict(cls, model, xs_2d):
1156
1159
ys_label_pred = model .predict (xs_2d )
1157
1160
return ys_label_pred
1158
1161
1162
+ class Logistic5PLRegressionTrainTestModel (TrainTestModel , RegressorMixin ):
1163
+
1164
+ TYPE = '5PL'
1165
+ VERSION = "0.1"
1166
+
1167
+ @classmethod
1168
+ def _train (cls , model_param , xys_2d , ** kwargs ):
1169
+ """
1170
+ Fit the following 5PL curve using scipy.optimize.curve_fit
1171
+
1172
+ Q(x) = B1 + (1/2 - 1/(1 + exp(B2 * (x - B3)))) + B4 * x + B5
1173
+
1174
+ H. R. Sheikh, M. F. Sabir, and A. C. Bovik,
1175
+ "A statistical evaluation of recent full reference image quality assessment algorithms"
1176
+ IEEE Trans. Image Process., vol. 15, no. 11, pp. 3440–3451, Nov. 2006.
1177
+
1178
+ :param model_param:
1179
+ :param xys_2d:
1180
+ :return:
1181
+ """
1182
+ model_param_ = model_param .copy ()
1183
+
1184
+ # remove keys unassociated with sklearn
1185
+ if 'norm_type' in model_param_ :
1186
+ del model_param_ ['norm_type' ]
1187
+ if 'score_clip' in model_param_ :
1188
+ del model_param_ ['score_clip' ]
1189
+ if 'custom_clip_0to1_map' in model_param_ :
1190
+ del model_param_ ['custom_clip_0to1_map' ]
1191
+ if 'num_models' in model_param_ :
1192
+ del model_param_ ['num_models' ]
1193
+
1194
+ from scipy .optimize import curve_fit
1195
+ [[b1 , b2 , b3 , b4 , b5 ], _ ] = curve_fit (
1196
+ lambda x , b1 , b2 , b3 , b4 , b5 : b1 + (0.5 - 1 / (1 + np .exp (b2 * (x - b3 ))))+ b4 * x + b5 ,
1197
+ np .ravel (xys_2d [:, 1 ]),
1198
+ np .ravel (xys_2d [:, 0 ]),
1199
+ p0 = 0.5 * np .ones ((5 ,)),
1200
+ maxfev = 20000
1201
+ )
1202
+
1203
+ return dict (b1 = b1 , b2 = b2 , b3 = b3 , b4 = b4 , b5 = b5 )
1204
+
1205
+ @staticmethod
1206
+ @override (TrainTestModel )
1207
+ def _to_file (filename , param_dict , model_dict , ** more ):
1208
+ format = more ['format' ] if 'format' in more else 'pkl'
1209
+ supported_formats = ['pkl' , 'json' ]
1210
+ assert format in supported_formats , \
1211
+ f'format must be in { supported_formats } , but got: { format } '
1212
+
1213
+ info_to_save = {'param_dict' : param_dict ,
1214
+ 'model_dict' : model_dict .copy ()}
1215
+
1216
+ if format == 'pkl' :
1217
+ with open (filename , 'wb' ) as file :
1218
+ pickle .dump (info_to_save , file )
1219
+ elif format == 'json' :
1220
+ with open (filename , 'wt' ) as file :
1221
+ json .dump (info_to_save , file , indent = 4 )
1222
+ else :
1223
+ assert False
1224
+
1225
+ @classmethod
1226
+ def _predict (cls , model , xs_2d ):
1227
+ b1 = model ['b1' ]
1228
+ b2 = model ['b2' ]
1229
+ b3 = model ['b3' ]
1230
+ b4 = model ['b4' ]
1231
+ b5 = model ['b5' ]
1232
+
1233
+ curve = lambda x : b1 + (0.5 - 1 / (1 + np .exp (b2 * (x - b3 ))))+ b4 * x + b5
1234
+ predicted = [curve (x ) for x in np .ravel (xs_2d )]
1235
+
1236
+ return predicted
1159
1237
1160
1238
class RawVideoTrainTestModelMixin (object ):
1161
1239
"""
0 commit comments