Path: blob/main/utils/overwrite_expected_slice.py
1440 views
# coding=utf-81# Copyright 2023 The HuggingFace Inc. team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14import argparse15from collections import defaultdict161718def overwrite_file(file, class_name, test_name, correct_line, done_test):19_id = f"{file}_{class_name}_{test_name}"20done_test[_id] += 12122with open(file, "r") as f:23lines = f.readlines()2425class_regex = f"class {class_name}("26test_regex = f"{4 * ' '}def {test_name}("27line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}"28another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}"29in_class = False30in_func = False31in_line = False32insert_line = False33count = 034spaces = 03536new_lines = []37for line in lines:38if line.startswith(class_regex):39in_class = True40elif in_class and line.startswith(test_regex):41in_func = True42elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)):43spaces = len(line.split(correct_line.split()[0])[0])44count += 14546if count == done_test[_id]:47in_line = True4849if in_class and in_func and in_line:50if ")" not in line:51continue52else:53insert_line = True5455if in_class and in_func and in_line and insert_line:56new_lines.append(f"{spaces * ' '}{correct_line}")57in_class = in_func = in_line = insert_line = False58else:59new_lines.append(line)6061with open(file, "w") as f:62for line in new_lines:63f.write(line)646566def main(correct, fail=None):67if fail is not None:68with open(fail, "r") as f:69test_failures = set([l.strip() for l in f.readlines()])70else:71test_failures = None7273with open(correct, "r") as f:74correct_lines = f.readlines()7576done_tests = defaultdict(int)77for line in correct_lines:78file, class_name, test_name, correct_line = line.split(";")79if test_failures is None or "::".join([file, class_name, test_name]) in test_failures:80overwrite_file(file, class_name, test_name, correct_line, done_tests)818283if __name__ == "__main__":84parser = argparse.ArgumentParser()85parser.add_argument("--correct_filename", help="filename of tests with expected result")86parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None)87args = parser.parse_args()8889main(args.correct_filename, args.fail_filename)909192