Path: blob/master/elisp/emacs-for-python/rope-dist/rope/refactor/extract.py
1415 views
import re12from rope.base import ast, codeanalyze3from rope.base.change import ChangeSet, ChangeContents4from rope.base.exceptions import RefactoringError5from rope.refactor import (sourceutils, similarfinder,6patchedast, suites, usefunction)789# Extract refactoring has lots of special cases. I tried to split it10# to smaller parts to make it more manageable:11#12# _ExtractInfo: holds information about the refactoring; it is passed13# to the parts that need to have information about the refactoring14#15# _ExtractCollector: merely saves all of the information necessary for16# performing the refactoring.17#18# _DefinitionLocationFinder: finds where to insert the definition.19#20# _ExceptionalConditionChecker: checks for exceptional conditions in21# which the refactoring cannot be applied.22#23# _ExtractMethodParts: generates the pieces of code (like definition)24# needed for performing extract method.25#26# _ExtractVariableParts: like _ExtractMethodParts for variables.27#28# _ExtractPerformer: Uses above classes to collect refactoring29# changes.30#31# There are a few more helper functions and classes used by above32# classes.33class _ExtractRefactoring(object):3435def __init__(self, project, resource, start_offset, end_offset,36variable=False):37self.project = project38self.pycore = project.pycore39self.resource = resource40self.start_offset = self._fix_start(resource.read(), start_offset)41self.end_offset = self._fix_end(resource.read(), end_offset)4243def _fix_start(self, source, offset):44while offset < len(source) and source[offset].isspace():45offset += 146return offset4748def _fix_end(self, source, offset):49while offset > 0 and source[offset - 1].isspace():50offset -= 151return offset5253def get_changes(self, extracted_name, similar=False, global_=False):54"""Get the changes this refactoring makes5556:parameters:57- `similar`: if `True`, similar expressions/statements are also58replaced.59- `global_`: if `True`, the extracted method/variable will60be global.6162"""63info = _ExtractInfo(64self.project, self.resource, self.start_offset, self.end_offset,65extracted_name, variable=self.kind == 'variable',66similar=similar, make_global=global_)67new_contents = _ExtractPerformer(info).extract()68changes = ChangeSet('Extract %s <%s>' % (self.kind,69extracted_name))70changes.add_change(ChangeContents(self.resource, new_contents))71return changes727374class ExtractMethod(_ExtractRefactoring):7576def __init__(self, *args, **kwds):77super(ExtractMethod, self).__init__(*args, **kwds)7879kind = 'method'808182class ExtractVariable(_ExtractRefactoring):8384def __init__(self, *args, **kwds):85kwds = dict(kwds)86kwds['variable'] = True87super(ExtractVariable, self).__init__(*args, **kwds)8889kind = 'variable'909192class _ExtractInfo(object):93"""Holds information about the extract to be performed"""9495def __init__(self, project, resource, start, end, new_name,96variable, similar, make_global):97self.pycore = project.pycore98self.resource = resource99self.pymodule = self.pycore.resource_to_pyobject(resource)100self.global_scope = self.pymodule.get_scope()101self.source = self.pymodule.source_code102self.lines = self.pymodule.lines103self.new_name = new_name104self.variable = variable105self.similar = similar106self._init_parts(start, end)107self._init_scope()108self.make_global = make_global109110def _init_parts(self, start, end):111self.region = (self._choose_closest_line_end(start),112self._choose_closest_line_end(end, end=True))113114start = self.logical_lines.logical_line_in(115self.lines.get_line_number(self.region[0]))[0]116end = self.logical_lines.logical_line_in(117self.lines.get_line_number(self.region[1]))[1]118self.region_lines = (start, end)119120self.lines_region = (self.lines.get_line_start(self.region_lines[0]),121self.lines.get_line_end(self.region_lines[1]))122123@property124def logical_lines(self):125return self.pymodule.logical_lines126127def _init_scope(self):128start_line = self.region_lines[0]129scope = self.global_scope.get_inner_scope_for_line(start_line)130if scope.get_kind() != 'Module' and scope.get_start() == start_line:131scope = scope.parent132self.scope = scope133self.scope_region = self._get_scope_region(self.scope)134135def _get_scope_region(self, scope):136return (self.lines.get_line_start(scope.get_start()),137self.lines.get_line_end(scope.get_end()) + 1)138139def _choose_closest_line_end(self, offset, end=False):140lineno = self.lines.get_line_number(offset)141line_start = self.lines.get_line_start(lineno)142line_end = self.lines.get_line_end(lineno)143if self.source[line_start:offset].strip() == '':144if end:145return line_start - 1146else:147return line_start148elif self.source[offset:line_end].strip() == '':149return min(line_end, len(self.source))150return offset151152@property153def one_line(self):154return self.region != self.lines_region and \155(self.logical_lines.logical_line_in(self.region_lines[0]) ==156self.logical_lines.logical_line_in(self.region_lines[1]))157158@property159def global_(self):160return self.scope.parent is None161162@property163def method(self):164return self.scope.parent is not None and \165self.scope.parent.get_kind() == 'Class'166167@property168def indents(self):169return sourceutils.get_indents(self.pymodule.lines,170self.region_lines[0])171172@property173def scope_indents(self):174if self.global_:175return 0176return sourceutils.get_indents(self.pymodule.lines,177self.scope.get_start())178179@property180def extracted(self):181return self.source[self.region[0]:self.region[1]]182183_returned = None184@property185def returned(self):186"""Does the extracted piece contain return statement"""187if self._returned is None:188node = _parse_text(self.extracted)189self._returned = usefunction._returns_last(node)190return self._returned191192193class _ExtractCollector(object):194"""Collects information needed for performing the extract"""195196def __init__(self, info):197self.definition = None198self.body_pattern = None199self.checks = {}200self.replacement_pattern = None201self.matches = None202self.replacements = None203self.definition_location = None204205206class _ExtractPerformer(object):207208def __init__(self, info):209self.info = info210_ExceptionalConditionChecker()(self.info)211212def extract(self):213extract_info = self._collect_info()214content = codeanalyze.ChangeCollector(self.info.source)215definition = extract_info.definition216lineno, indents = extract_info.definition_location217offset = self.info.lines.get_line_start(lineno)218indented = sourceutils.fix_indentation(definition, indents)219content.add_change(offset, offset, indented)220self._replace_occurrences(content, extract_info)221return content.get_changed()222223def _replace_occurrences(self, content, extract_info):224for match in extract_info.matches:225replacement = similarfinder.CodeTemplate(226extract_info.replacement_pattern)227mapping = {}228for name in replacement.get_names():229node = match.get_ast(name)230if node:231start, end = patchedast.node_region(match.get_ast(name))232mapping[name] = self.info.source[start:end]233else:234mapping[name] = name235region = match.get_region()236content.add_change(region[0], region[1],237replacement.substitute(mapping))238239def _collect_info(self):240extract_collector = _ExtractCollector(self.info)241self._find_definition(extract_collector)242self._find_matches(extract_collector)243self._find_definition_location(extract_collector)244return extract_collector245246def _find_matches(self, collector):247regions = self._where_to_search()248finder = similarfinder.SimilarFinder(self.info.pymodule)249matches = []250for start, end in regions:251matches.extend((finder.get_matches(collector.body_pattern,252collector.checks, start, end)))253collector.matches = matches254255def _where_to_search(self):256if self.info.similar:257if self.info.make_global or self.info.global_:258return [(0, len(self.info.pymodule.source_code))]259if self.info.method and not self.info.variable:260class_scope = self.info.scope.parent261regions = []262method_kind = _get_function_kind(self.info.scope)263for scope in class_scope.get_scopes():264if method_kind == 'method' and \265_get_function_kind(scope) != 'method':266continue267start = self.info.lines.get_line_start(scope.get_start())268end = self.info.lines.get_line_end(scope.get_end())269regions.append((start, end))270return regions271else:272if self.info.variable:273return [self.info.scope_region]274else:275return [self.info._get_scope_region(self.info.scope.parent)]276else:277return [self.info.region]278279def _find_definition_location(self, collector):280matched_lines = []281for match in collector.matches:282start = self.info.lines.get_line_number(match.get_region()[0])283start_line = self.info.logical_lines.logical_line_in(start)[0]284matched_lines.append(start_line)285location_finder = _DefinitionLocationFinder(self.info, matched_lines)286collector.definition_location = (location_finder.find_lineno(),287location_finder.find_indents())288289def _find_definition(self, collector):290if self.info.variable:291parts = _ExtractVariableParts(self.info)292else:293parts = _ExtractMethodParts(self.info)294collector.definition = parts.get_definition()295collector.body_pattern = parts.get_body_pattern()296collector.replacement_pattern = parts.get_replacement_pattern()297collector.checks = parts.get_checks()298299300class _DefinitionLocationFinder(object):301302def __init__(self, info, matched_lines):303self.info = info304self.matched_lines = matched_lines305# This only happens when subexpressions cannot be matched306if not matched_lines:307self.matched_lines.append(self.info.region_lines[0])308309def find_lineno(self):310if self.info.variable and not self.info.make_global:311return self._get_before_line()312if self.info.make_global or self.info.global_:313toplevel = self._find_toplevel(self.info.scope)314ast = self.info.pymodule.get_ast()315newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])316return suites.find_visible(ast, newlines)317return self._get_after_scope()318319def _find_toplevel(self, scope):320toplevel = scope321if toplevel.parent is not None:322while toplevel.parent.parent is not None:323toplevel = toplevel.parent324return toplevel325326def find_indents(self):327if self.info.variable and not self.info.make_global:328return sourceutils.get_indents(self.info.lines,329self._get_before_line())330else:331if self.info.global_ or self.info.make_global:332return 0333return self.info.scope_indents334335def _get_before_line(self):336ast = self.info.scope.pyobject.get_ast()337return suites.find_visible(ast, self.matched_lines)338339def _get_after_scope(self):340return self.info.scope.get_end() + 1341342343class _ExceptionalConditionChecker(object):344345def __call__(self, info):346self.base_conditions(info)347if info.one_line:348self.one_line_conditions(info)349else:350self.multi_line_conditions(info)351352def base_conditions(self, info):353if info.region[1] > info.scope_region[1]:354raise RefactoringError('Bad region selected for extract method')355end_line = info.region_lines[1]356end_scope = info.global_scope.get_inner_scope_for_line(end_line)357if end_scope != info.scope and end_scope.get_end() != end_line:358raise RefactoringError('Bad region selected for extract method')359try:360extracted = info.source[info.region[0]:info.region[1]]361if info.one_line:362extracted = '(%s)' % extracted363if _UnmatchedBreakOrContinueFinder.has_errors(extracted):364raise RefactoringError('A break/continue without having a '365'matching for/while loop.')366except SyntaxError:367raise RefactoringError('Extracted piece should '368'contain complete statements.')369370def one_line_conditions(self, info):371if self._is_region_on_a_word(info):372raise RefactoringError('Should extract complete statements.')373if info.variable and not info.one_line:374raise RefactoringError('Extract variable should not '375'span multiple lines.')376377def multi_line_conditions(self, info):378node = _parse_text(info.source[info.region[0]:info.region[1]])379count = usefunction._return_count(node)380if count > 1:381raise RefactoringError('Extracted piece can have only one '382'return statement.')383if usefunction._yield_count(node):384raise RefactoringError('Extracted piece cannot '385'have yield statements.')386if count == 1 and not usefunction._returns_last(node):387raise RefactoringError('Return should be the last statement.')388if info.region != info.lines_region:389raise RefactoringError('Extracted piece should '390'contain complete statements.')391392def _is_region_on_a_word(self, info):393if info.region[0] > 0 and self._is_on_a_word(info, info.region[0] - 1) or \394self._is_on_a_word(info, info.region[1] - 1):395return True396397def _is_on_a_word(self, info, offset):398prev = info.source[offset]399if not (prev.isalnum() or prev == '_') or \400offset + 1 == len(info.source):401return False402next = info.source[offset + 1]403return next.isalnum() or next == '_'404405406class _ExtractMethodParts(object):407408def __init__(self, info):409self.info = info410self.info_collector = self._create_info_collector()411412def get_definition(self):413if self.info.global_:414return '\n%s\n' % self._get_function_definition()415else:416return '\n%s' % self._get_function_definition()417418def get_replacement_pattern(self):419variables = []420variables.extend(self._find_function_arguments())421variables.extend(self._find_function_returns())422return similarfinder.make_pattern(self._get_call(), variables)423424def get_body_pattern(self):425variables = []426variables.extend(self._find_function_arguments())427variables.extend(self._find_function_returns())428variables.extend(self._find_temps())429return similarfinder.make_pattern(self._get_body(), variables)430431def _get_body(self):432result = sourceutils.fix_indentation(self.info.extracted, 0)433if self.info.one_line:434result = '(%s)' % result435return result436437def _find_temps(self):438return usefunction.find_temps(self.info.pycore.project,439self._get_body())440441def get_checks(self):442if self.info.method and not self.info.make_global:443if _get_function_kind(self.info.scope) == 'method':444class_name = similarfinder._pydefined_to_str(445self.info.scope.parent.pyobject)446return {self._get_self_name(): 'type=' + class_name}447return {}448449def _create_info_collector(self):450zero = self.info.scope.get_start() - 1451start_line = self.info.region_lines[0] - zero452end_line = self.info.region_lines[1] - zero453info_collector = _FunctionInformationCollector(start_line, end_line,454self.info.global_)455body = self.info.source[self.info.scope_region[0]:456self.info.scope_region[1]]457node = _parse_text(body)458ast.walk(node, info_collector)459return info_collector460461def _get_function_definition(self):462args = self._find_function_arguments()463returns = self._find_function_returns()464result = []465if self.info.method and not self.info.make_global and \466_get_function_kind(self.info.scope) != 'method':467result.append('@staticmethod\n')468result.append('def %s:\n' % self._get_function_signature(args))469unindented_body = self._get_unindented_function_body(returns)470indents = sourceutils.get_indent(self.info.pycore)471function_body = sourceutils.indent_lines(unindented_body, indents)472result.append(function_body)473definition = ''.join(result)474475return definition + '\n'476477def _get_function_signature(self, args):478args = list(args)479prefix = ''480if self._extracting_method():481self_name = self._get_self_name()482if self_name is None:483raise RefactoringError('Extracting a method from a function '484'with no self argument.')485if self_name in args:486args.remove(self_name)487args.insert(0, self_name)488return prefix + self.info.new_name + \489'(%s)' % self._get_comma_form(args)490491def _extracting_method(self):492return self.info.method and not self.info.make_global and \493_get_function_kind(self.info.scope) == 'method'494495def _get_self_name(self):496param_names = self.info.scope.pyobject.get_param_names()497if param_names:498return param_names[0]499500def _get_function_call(self, args):501prefix = ''502if self.info.method and not self.info.make_global:503if _get_function_kind(self.info.scope) == 'method':504self_name = self._get_self_name()505if self_name in args:506args.remove(self_name)507prefix = self_name + '.'508else:509prefix = self.info.scope.parent.pyobject.get_name() + '.'510return prefix + '%s(%s)' % (self.info.new_name,511self._get_comma_form(args))512513def _get_comma_form(self, names):514result = ''515if names:516result += names[0]517for name in names[1:]:518result += ', ' + name519return result520521def _get_call(self):522if self.info.one_line:523args = self._find_function_arguments()524return self._get_function_call(args)525args = self._find_function_arguments()526returns = self._find_function_returns()527call_prefix = ''528if returns:529call_prefix = self._get_comma_form(returns) + ' = '530if self.info.returned:531call_prefix = 'return '532return call_prefix + self._get_function_call(args)533534def _find_function_arguments(self):535# if not make_global, do not pass any global names; they are536# all visible.537if self.info.global_ and not self.info.make_global:538return ()539if not self.info.one_line:540result = (self.info_collector.prewritten &541self.info_collector.read)542result |= (self.info_collector.prewritten &543self.info_collector.postread &544(self.info_collector.maybe_written -545self.info_collector.written))546return list(result)547start = self.info.region[0]548if start == self.info.lines_region[0]:549start = start + re.search('\S', self.info.extracted).start()550function_definition = self.info.source[start:self.info.region[1]]551read = _VariableReadsAndWritesFinder.find_reads_for_one_liners(552function_definition)553return list(self.info_collector.prewritten.intersection(read))554555def _find_function_returns(self):556if self.info.one_line or self.info.returned:557return []558written = self.info_collector.written | \559self.info_collector.maybe_written560return list(written & self.info_collector.postread)561562def _get_unindented_function_body(self, returns):563if self.info.one_line:564return 'return ' + _join_lines(self.info.extracted)565extracted_body = self.info.extracted566unindented_body = sourceutils.fix_indentation(extracted_body, 0)567if returns:568unindented_body += '\nreturn %s' % self._get_comma_form(returns)569return unindented_body570571572class _ExtractVariableParts(object):573574def __init__(self, info):575self.info = info576577def get_definition(self):578result = self.info.new_name + ' = ' + \579_join_lines(self.info.extracted) + '\n'580return result581582def get_body_pattern(self):583return '(%s)' % self.info.extracted.strip()584585def get_replacement_pattern(self):586return self.info.new_name587588def get_checks(self):589return {}590591592class _FunctionInformationCollector(object):593594def __init__(self, start, end, is_global):595self.start = start596self.end = end597self.is_global = is_global598self.prewritten = set()599self.maybe_written = set()600self.written = set()601self.read = set()602self.postread = set()603self.postwritten = set()604self.host_function = True605self.conditional = False606607def _read_variable(self, name, lineno):608if self.start <= lineno <= self.end:609if name not in self.written:610self.read.add(name)611if self.end < lineno:612if name not in self.postwritten:613self.postread.add(name)614615def _written_variable(self, name, lineno):616if self.start <= lineno <= self.end:617if self.conditional:618self.maybe_written.add(name)619else:620self.written.add(name)621if self.start > lineno:622self.prewritten.add(name)623if self.end < lineno:624self.postwritten.add(name)625626def _FunctionDef(self, node):627if not self.is_global and self.host_function:628self.host_function = False629for name in _get_argnames(node.args):630self._written_variable(name, node.lineno)631for child in node.body:632ast.walk(child, self)633else:634self._written_variable(node.name, node.lineno)635visitor = _VariableReadsAndWritesFinder()636for child in node.body:637ast.walk(child, visitor)638for name in visitor.read - visitor.written:639self._read_variable(name, node.lineno)640641def _Name(self, node):642if isinstance(node.ctx, (ast.Store, ast.AugStore)):643self._written_variable(node.id, node.lineno)644if not isinstance(node.ctx, ast.Store):645self._read_variable(node.id, node.lineno)646647def _Assign(self, node):648ast.walk(node.value, self)649for child in node.targets:650ast.walk(child, self)651652def _ClassDef(self, node):653self._written_variable(node.name, node.lineno)654655def _handle_conditional_node(self, node):656self.conditional = True657try:658for child in ast.get_child_nodes(node):659ast.walk(child, self)660finally:661self.conditional = False662663def _If(self, node):664self._handle_conditional_node(node)665666def _While(self, node):667self._handle_conditional_node(node)668669def _For(self, node):670self._handle_conditional_node(node)671672673674def _get_argnames(arguments):675result = [node.id for node in arguments.args676if isinstance(node, ast.Name)]677if arguments.vararg:678result.append(arguments.vararg)679if arguments.kwarg:680result.append(arguments.kwarg)681return result682683684class _VariableReadsAndWritesFinder(object):685686def __init__(self):687self.written = set()688self.read = set()689690def _Name(self, node):691if isinstance(node.ctx, (ast.Store, ast.AugStore)):692self.written.add(node.id)693if not isinstance(node, ast.Store):694self.read.add(node.id)695696def _FunctionDef(self, node):697self.written.add(node.name)698visitor = _VariableReadsAndWritesFinder()699for child in ast.get_child_nodes(node):700ast.walk(child, visitor)701self.read.update(visitor.read - visitor.written)702703def _Class(self, node):704self.written.add(node.name)705706@staticmethod707def find_reads_and_writes(code):708if code.strip() == '':709return set(), set()710if isinstance(code, unicode):711code = code.encode('utf-8')712node = _parse_text(code)713visitor = _VariableReadsAndWritesFinder()714ast.walk(node, visitor)715return visitor.read, visitor.written716717@staticmethod718def find_reads_for_one_liners(code):719if code.strip() == '':720return set(), set()721node = _parse_text(code)722visitor = _VariableReadsAndWritesFinder()723ast.walk(node, visitor)724return visitor.read725726727class _UnmatchedBreakOrContinueFinder(object):728729def __init__(self):730self.error = False731self.loop_count = 0732733def _For(self, node):734self.loop_encountered(node)735736def _While(self, node):737self.loop_encountered(node)738739def loop_encountered(self, node):740self.loop_count += 1741for child in node.body:742ast.walk(child, self)743self.loop_count -= 1744if node.orelse:745ast.walk(node.orelse, self)746747def _Break(self, node):748self.check_loop()749750def _Continue(self, node):751self.check_loop()752753def check_loop(self):754if self.loop_count < 1:755self.error = True756757def _FunctionDef(self, node):758pass759760def _ClassDef(self, node):761pass762763@staticmethod764def has_errors(code):765if code.strip() == '':766return False767node = _parse_text(code)768visitor = _UnmatchedBreakOrContinueFinder()769ast.walk(node, visitor)770return visitor.error771772def _get_function_kind(scope):773return scope.pyobject.get_kind()774775776def _parse_text(body):777body = sourceutils.fix_indentation(body, 0)778node = ast.parse(body)779return node780781def _join_lines(code):782lines = []783for line in code.splitlines():784if line.endswith('\\'):785lines.append(line[:-1].strip())786else:787lines.append(line.strip())788return ' '.join(lines)789790791