Skip to content

Commit e698b4d

Browse files
committed
Add pypsnr fex and subclasses; add tests.
1 parent 30e4cb7 commit e698b4d

File tree

2 files changed

+254
-1
lines changed

2 files changed

+254
-1
lines changed

python/test/feature_extractor_test.py

+143-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
MomentFeatureExtractor, \
1010
PsnrFeatureExtractor, SsimFeatureExtractor, MsSsimFeatureExtractor, \
1111
VifFrameDifferenceFeatureExtractor, \
12-
AnsnrFeatureExtractor, VmafIntegerFeatureExtractor
12+
AnsnrFeatureExtractor, PypsnrFeatureExtractor, VmafIntegerFeatureExtractor, \
13+
PypsnrMaxdb100FeatureExtractor
1314
from vmaf.core.asset import Asset
1415
from vmaf.core.result_store import FileSystemResultStore
1516

@@ -626,6 +627,147 @@ def test_run_psnr_fextractor_proc(self):
626627
self.assertAlmostEqual(results[0]['PSNR_feature_psnr_score'], 27.645446604166665, places=8)
627628
self.assertAlmostEqual(results[1]['PSNR_feature_psnr_score'], 31.87683660416667, places=8)
628629

630+
def test_run_pypsnr_fextractor(self):
631+
632+
ref_path, dis_path, asset, asset_original = set_default_576_324_videos_for_testing()
633+
634+
self.fextractor = PypsnrFeatureExtractor(
635+
[asset, asset_original],
636+
None, fifo_mode=True,
637+
result_store=None
638+
)
639+
self.fextractor.run(parallelize=True)
640+
641+
results = self.fextractor.results
642+
643+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 30.755063979166664, places=4)
644+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 38.449441057158786, places=4)
645+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 40.9919102486235, places=4)
646+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 60.0, places=4)
647+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 60.0, places=4)
648+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 60.0, places=4)
649+
650+
def test_run_pypsnr_fextractor_10bit(self):
651+
652+
ref_path, dis_path, asset, asset_original = set_default_576_324_10bit_videos_for_testing()
653+
654+
self.fextractor = PypsnrFeatureExtractor(
655+
[asset, asset_original],
656+
None, fifo_mode=True,
657+
result_store=None
658+
)
659+
self.fextractor.run(parallelize=True)
660+
661+
results = self.fextractor.results
662+
663+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 30.780573260053277, places=4)
664+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 38.769832063651364, places=4)
665+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28418847734209, places=4)
666+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 72.0, places=4)
667+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 72.0, places=4)
668+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 72.0, places=4)
669+
670+
def test_run_pypsnr_fextractor_10bit_b(self):
671+
672+
ref_path, dis_path, asset, asset_original = set_default_576_324_10bit_videos_for_testing_b()
673+
674+
self.fextractor = PypsnrFeatureExtractor(
675+
[asset, asset_original],
676+
None, fifo_mode=True,
677+
result_store=None
678+
)
679+
self.fextractor.run(parallelize=True)
680+
681+
results = self.fextractor.results
682+
683+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.57145231892744, places=4)
684+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.03859552689696, places=4)
685+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28060001337217, places=4)
686+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 72.0, places=4)
687+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 72.0, places=4)
688+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 72.0, places=4)
689+
690+
def test_run_pypsnr_fextractor_12bit(self):
691+
692+
ref_path, dis_path, asset, asset_original = set_default_576_324_12bit_videos_for_testing()
693+
694+
self.fextractor = PypsnrFeatureExtractor(
695+
[asset, asset_original],
696+
None, fifo_mode=True,
697+
result_store=None
698+
)
699+
self.fextractor.run(parallelize=True)
700+
701+
results = self.fextractor.results
702+
703+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.577817940053734, places=4)
704+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.044961148023255, places=4)
705+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28696563449846, places=4)
706+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 84.0, places=4)
707+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 84.0, places=4)
708+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 84.0, places=4)
709+
710+
def test_run_pypsnr_fextractor_16bit(self):
711+
712+
ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()
713+
714+
self.fextractor = PypsnrFeatureExtractor(
715+
[asset, asset_original],
716+
None, fifo_mode=True,
717+
result_store=None
718+
)
719+
self.fextractor.run(parallelize=True)
720+
721+
results = self.fextractor.results
722+
723+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.579806240311484, places=4)
724+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.046949448281005, places=4)
725+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.288953934756215, places=4)
726+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 108.0, places=4)
727+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 108.0, places=4)
728+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 108.0, places=4)
729+
730+
def test_run_pypsnr_fextractor_16bit_custom_max_db(self):
731+
732+
ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()
733+
734+
self.fextractor = PypsnrFeatureExtractor(
735+
[asset, asset_original],
736+
None, fifo_mode=True,
737+
result_store=None,
738+
optional_dict={'max_db': 100.0}
739+
)
740+
self.fextractor.run(parallelize=True)
741+
742+
results = self.fextractor.results
743+
744+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.579806240311484, places=4)
745+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.046949448281005, places=4)
746+
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.288953934756215, places=4)
747+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 100.0, places=4)
748+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 100.0, places=4)
749+
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 100.0, places=4)
750+
751+
def test_run_pypsnr_fextractor_maxdb100_16bit(self):
752+
753+
ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()
754+
755+
self.fextractor = PypsnrMaxdb100FeatureExtractor(
756+
[asset, asset_original],
757+
None, fifo_mode=True,
758+
result_store=None,
759+
)
760+
self.fextractor.run(parallelize=True)
761+
762+
results = self.fextractor.results
763+
764+
self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnry_score'], 32.579806240311484, places=4)
765+
self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnru_score'], 39.046949448281005, places=4)
766+
self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnrv_score'], 41.288953934756215, places=4)
767+
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnry_score'], 100.0, places=4)
768+
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnru_score'], 100.0, places=4)
769+
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnrv_score'], 100.0, places=4)
770+
629771

630772
if __name__ == '__main__':
631773
unittest.main(verbosity=2)

python/vmaf/core/feature_extractor.py

+111
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,117 @@ def _post_process_result(cls, result):
436436
return result
437437

438438

439+
class PypsnrFeatureExtractor(FeatureExtractor):
440+
441+
TYPE = "Pypsnr_feature"
442+
VERSION = "1.0"
443+
444+
ATOM_FEATURES = ['psnry', 'psnru', 'psnrv']
445+
446+
@staticmethod
447+
def _assert_bit_depth(ref_yuv_reader, dis_yuv_reader):
448+
if ref_yuv_reader._is_8bit():
449+
assert dis_yuv_reader._is_8bit()
450+
elif ref_yuv_reader._is_10bitle():
451+
assert dis_yuv_reader._is_10bitle()
452+
elif ref_yuv_reader._is_12bitle():
453+
assert dis_yuv_reader._is_12bitle()
454+
elif ref_yuv_reader._is_16bitle():
455+
assert dis_yuv_reader._is_16bitle()
456+
else:
457+
assert False, 'unknown bit depth and type'
458+
459+
def _get_max_db(self, ref_yuv_reader):
460+
if self.optional_dict is not None and 'max_db' in self.optional_dict:
461+
assert type(self.optional_dict['max_db']) == int or float
462+
return self.optional_dict['max_db']
463+
elif ref_yuv_reader._is_8bit():
464+
return 60.0
465+
elif ref_yuv_reader._is_10bitle():
466+
return 72.0
467+
elif ref_yuv_reader._is_12bitle():
468+
return 84.0
469+
elif ref_yuv_reader._is_16bitle():
470+
return 108.0
471+
else:
472+
assert False, 'unknown bit depth and type'
473+
474+
def _generate_result(self, asset):
475+
quality_w, quality_h = asset.quality_width_height
476+
yuv_type = self._get_workfile_yuv_type(asset)
477+
log_dicts = list()
478+
with YuvReader(filepath=asset.ref_procfile_path, width=quality_w, height=quality_h,
479+
yuv_type=yuv_type) as ref_yuv_reader:
480+
with YuvReader(filepath=asset.dis_procfile_path, width=quality_w, height=quality_h,
481+
yuv_type=yuv_type) as dis_yuv_reader:
482+
483+
self._assert_bit_depth(ref_yuv_reader, dis_yuv_reader)
484+
max_db = self._get_max_db(ref_yuv_reader)
485+
486+
frm = 0
487+
while True:
488+
try:
489+
ref_yuv = ref_yuv_reader.next(format='float')
490+
dis_yuv = dis_yuv_reader.next(format='float')
491+
except StopIteration:
492+
break
493+
494+
ref_y, ref_u, ref_v = ref_yuv
495+
dis_y, dis_u, dis_v = dis_yuv
496+
mse_y, mse_u, mse_v = np.mean((ref_y - dis_y)**2) + 1e-16, \
497+
np.mean((ref_u - dis_u)**2) + 1e-16, \
498+
np.mean((ref_v - dis_v)**2) + 1e-16
499+
psnr_y, psnr_u, psnr_v = min(10 * np.log10(1.0 / mse_y), max_db), \
500+
min(10 * np.log10(1.0 / mse_u), max_db), \
501+
min(10 * np.log10(1.0 / mse_v), max_db)
502+
503+
log_dicts.append({
504+
'frame': frm,
505+
'psnry': psnr_y,
506+
'psnru': psnr_u,
507+
'psnrv': psnr_v,
508+
})
509+
510+
frm += 1
511+
512+
log_file_path = self._get_log_file_path(asset)
513+
with open(log_file_path, 'wt') as log_file:
514+
log_file.write(str(log_dicts))
515+
516+
@override(FeatureExtractor)
517+
def _get_feature_scores(self, asset):
518+
519+
log_file_path = self._get_log_file_path(asset)
520+
521+
with open(log_file_path, 'rt') as log_file:
522+
log_str = log_file.read()
523+
log_dicts = ast.literal_eval(log_str)
524+
525+
feature_result = dict()
526+
frm = 0
527+
for log_dict in log_dicts:
528+
assert frm == log_dict['frame']
529+
for ft in self.ATOM_FEATURES:
530+
feature_result.setdefault(self.get_scores_key(ft), []).append(log_dict[ft])
531+
frm += 1
532+
533+
return feature_result
534+
535+
536+
class PypsnrMaxdb100FeatureExtractor(PypsnrFeatureExtractor):
537+
538+
TYPE = "Pypsnr_maxdb100_feature"
539+
540+
@override(Executor)
541+
def _custom_init(self):
542+
super()._custom_init()
543+
if self.optional_dict is not None:
544+
assert 'max_db' not in self.optional_dict
545+
if self.optional_dict is None:
546+
self.optional_dict = dict()
547+
self.optional_dict['max_db'] = 100.0
548+
549+
439550
class PsnrFeatureExtractor(VmafexecFeatureExtractorMixin, FeatureExtractor):
440551

441552
TYPE = "PSNR_feature"

0 commit comments

Comments
 (0)