Skip to content

Commit bf60788

Browse files
authored
add five-parameter logistic function as a model option (#813)
* add five-parameter logistic function as a model * add unit test for 5PL fit model * change signature for in to match
1 parent 34fcfeb commit bf60788

File tree

6 files changed

+118
-7
lines changed

6 files changed

+118
-7
lines changed

5PL_v1.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
model_type = "5PL"
2+
model_param_dict = {
3+
# ==== preprocess: normalize each feature ==== #
4+
# 'norm_type':'none',
5+
'norm_type': 'clip_0to1', # rescale to within [0, 1]
6+
7+
# ==== postprocess: clip final quality score ==== #
8+
'score_clip':[0.0, 100.0], # clip to within [0, 100]
9+
}

python/test/train_test_model_test.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from vmaf.config import VmafConfig
77
from vmaf.core.train_test_model import TrainTestModel, \
88
LibsvmNusvrTrainTestModel, SklearnRandomForestTrainTestModel, \
9-
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, SklearnLinearRegressionTrainTestModel
9+
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, \
10+
SklearnLinearRegressionTrainTestModel, Logistic5PLRegressionTrainTestModel
1011
from vmaf.core.noref_feature_extractor import MomentNorefFeatureExtractor
1112
from vmaf.routine import read_dataset
1213
from vmaf.tools.misc import import_python_file
@@ -309,6 +310,20 @@ def test_train_predict_extratrees(self):
309310
result = model.evaluate(xs, ys)
310311
self.assertAlmostEqual(result['RMSE'], 0.042867322777879642, places=4)
311312

313+
def test_train_logistic_fit_5PL(self):
314+
xs = Logistic5PLRegressionTrainTestModel.get_xs_from_results(self.features, [0, 1, 2, 3, 4, 5], features=['Moment_noref_feature_1st_score'])
315+
ys = Logistic5PLRegressionTrainTestModel.get_ys_from_results(self.features, [0, 1, 2, 3, 4, 5])
316+
317+
xys = {}
318+
xys.update(xs)
319+
xys.update(ys)
320+
321+
model = Logistic5PLRegressionTrainTestModel({'norm_type': 'clip_0to1'}, None)
322+
model.train(xys)
323+
result = model.evaluate(xs, ys)
324+
325+
self.assertAlmostEqual(result['RMSE'], 0.3603374311919728, places=4)
326+
312327

313328
class TrainTestModelWithDisYRawVideoExtractorTest(unittest.TestCase):
314329

python/vmaf/core/niqe_train_test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def _assert_dimension(cls, feature_names, results):
2828

2929
@classmethod
3030
@override(TrainTestModel)
31-
def get_xs_from_results(cls, results, indexs=None, aggregate=False):
31+
def get_xs_from_results(cls, results, indexs=None, aggregate=False, features=None):
3232
"""
3333
override by altering aggregate
3434
default to False
3535
"""
3636
return super(NiqeTrainTestModel, cls).get_xs_from_results(
37-
results, indexs, aggregate)
37+
results, indexs, aggregate, features)
3838

3939
@classmethod
4040
@override(TrainTestModel)

python/vmaf/core/train_test_model.py

+81-3
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def _delete(filename, **more):
740740
os.remove(filename)
741741

742742
@classmethod
743-
def get_xs_from_results(cls, results, indexs=None, aggregate=True):
743+
def get_xs_from_results(cls, results, indexs=None, aggregate=True, features=None):
744744
"""
745745
:param results: list of BasicResult, or pandas.DataFrame
746746
:param indexs: indices of results to be used
@@ -756,8 +756,11 @@ def get_xs_from_results(cls, results, indexs=None, aggregate=True):
756756
# or get_ordered_list_scores_key. Instead, just get the sorted keys
757757
feature_names = results[0].get_ordered_results()
758758

759-
feature_names = list(feature_names)
760-
cls._assert_dimension(feature_names, results)
759+
if features is not None:
760+
feature_names = [f for f in feature_names if f in features]
761+
else:
762+
feature_names = list(feature_names)
763+
cls._assert_dimension(feature_names, results)
761764

762765
# collect results into xs
763766
xs = {}
@@ -1156,6 +1159,81 @@ def _predict(cls, model, xs_2d):
11561159
ys_label_pred = model.predict(xs_2d)
11571160
return ys_label_pred
11581161

1162+
class Logistic5PLRegressionTrainTestModel(TrainTestModel, RegressorMixin):
1163+
1164+
TYPE = '5PL'
1165+
VERSION = "0.1"
1166+
1167+
@classmethod
1168+
def _train(cls, model_param, xys_2d, **kwargs):
1169+
"""
1170+
Fit the following 5PL curve using scipy.optimize.curve_fit
1171+
1172+
Q(x) = B1 + (1/2 - 1/(1 + exp(B2 * (x - B3)))) + B4 * x + B5
1173+
1174+
H. R. Sheikh, M. F. Sabir, and A. C. Bovik,
1175+
"A statistical evaluation of recent full reference image quality assessment algorithms"
1176+
IEEE Trans. Image Process., vol. 15, no. 11, pp. 3440–3451, Nov. 2006.
1177+
1178+
:param model_param:
1179+
:param xys_2d:
1180+
:return:
1181+
"""
1182+
model_param_ = model_param.copy()
1183+
1184+
# remove keys unassociated with sklearn
1185+
if 'norm_type' in model_param_:
1186+
del model_param_['norm_type']
1187+
if 'score_clip' in model_param_:
1188+
del model_param_['score_clip']
1189+
if 'custom_clip_0to1_map' in model_param_:
1190+
del model_param_['custom_clip_0to1_map']
1191+
if 'num_models' in model_param_:
1192+
del model_param_['num_models']
1193+
1194+
from scipy.optimize import curve_fit
1195+
[[b1, b2, b3, b4, b5], _] = curve_fit(
1196+
lambda x, b1, b2, b3, b4, b5: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5,
1197+
np.ravel(xys_2d[:, 1]),
1198+
np.ravel(xys_2d[:, 0]),
1199+
p0=0.5 * np.ones((5,)),
1200+
maxfev=20000
1201+
)
1202+
1203+
return dict(b1=b1, b2=b2, b3=b3, b4=b4, b5=b5)
1204+
1205+
@staticmethod
1206+
@override(TrainTestModel)
1207+
def _to_file(filename, param_dict, model_dict, **more):
1208+
format = more['format'] if 'format' in more else 'pkl'
1209+
supported_formats = ['pkl', 'json']
1210+
assert format in supported_formats, \
1211+
f'format must be in {supported_formats}, but got: {format}'
1212+
1213+
info_to_save = {'param_dict': param_dict,
1214+
'model_dict': model_dict.copy()}
1215+
1216+
if format == 'pkl':
1217+
with open(filename, 'wb') as file:
1218+
pickle.dump(info_to_save, file)
1219+
elif format == 'json':
1220+
with open(filename, 'wt') as file:
1221+
json.dump(info_to_save, file, indent=4)
1222+
else:
1223+
assert False
1224+
1225+
@classmethod
1226+
def _predict(cls, model, xs_2d):
1227+
b1 = model['b1']
1228+
b2 = model['b2']
1229+
b3 = model['b3']
1230+
b4 = model['b4']
1231+
b5 = model['b5']
1232+
1233+
curve = lambda x: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5
1234+
predicted = [curve(x) for x in np.ravel(xs_2d)]
1235+
1236+
return predicted
11591237

11601238
class RawVideoTrainTestModelMixin(object):
11611239
"""

resource/feature_param/ssim.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
feature_dict = {
2+
'SSIM_feature': ['ssim'],
3+
}

unittest

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
#!/usr/bin/env sh
22

3-
PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p '*_test.py'
3+
if [ -z "$1" ]; then
4+
pattern='*_test.py'
5+
else
6+
pattern="$1"
7+
fi
8+
9+
PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p $pattern

0 commit comments

Comments
 (0)