diff --git a/util/test/rdtest/util.py b/util/test/rdtest/util.py index 368e54177..94aaeef4d 100644 --- a/util/test/rdtest/util.py +++ b/util/test/rdtest/util.py @@ -131,8 +131,16 @@ def png_compare(test_img: str, ref_img: str, tolerance: int = 2): test_reader = png.Reader(filename=test_img) ref_reader = png.Reader(filename=ref_img) - test_w, test_h, test_data, _ = test_reader.read() - ref_w, ref_h, ref_data, _ = ref_reader.read() + test_w, test_h, test_data, test_info = test_reader.read() + ref_w, ref_h, ref_data, ref_info = ref_reader.read() + + # lookup rgba data straight + rgba_get = lambda data, x: data[x] + # lookup rgb data and return 255 for alpha + rgb_get = lambda data, x: data[ (x >> 2)*3 + (x % 4) ] if (x % 4) < 3 else 255 + + test_get = (rgba_get if test_info['alpha'] else rgb_get) + ref_get = (rgba_get if ref_info['alpha'] else rgb_get) if test_w != ref_w or test_h != test_h: return False @@ -141,7 +149,8 @@ def png_compare(test_img: str, ref_img: str, tolerance: int = 2): diff_data = [] for test_row, ref_row in zip(test_data, ref_data): - diff = [min(255, abs(test_row[i] - ref_row[i])*4) for i in range(0, test_w*4)] + + diff = [min(255, abs(test_get(test_row, i) - ref_get(ref_row, i))*4) for i in range(0, test_w*4)] is_same = is_same and not any([d > tolerance*4 for d in diff])