guest@blog.cmj.tw: ~/posts $

Python 靜態分析


AST

因為工作上的需要所以要寫一個 Python 的靜態分析程式,主要的目的有:找出特定 function 可能 raise 的 Exception。 透過同事的提醒跟 stackoverflow 上的提示,可以用 Python 內建的 ast、inspect 與 textwrap 完成:

#! /usr/bin/env python
import ast
import logging
from inspect import getsource, getclosurevars, isclass
from textwrap import dedent
from collections import ChainMap


class PyAnalysis(ast.NodeVisitor):

    def __init__(self, level=None, fmt=None):
        level = level or logging.INFO
        fmt   = logging.Formatter(fmt or r'[%(asctime)-.19s] %(pathname)s#L%(lineno)d : %(message)s')

        syslog = logging.StreamHandler()
        syslog.setFormatter(fmt)

        logger = logging.getLogger(self.__class__.__name__)
        logger.setLevel(level)
        if not logger.handlers: logger.addHandler(syslog)

        self.logger = logging.LoggerAdapter(
            logger,
            {
                'app_name': self.__class__.__name__,
            }
        )

        self.trace = {}

    def __call__(self, obj, reset=True):
        if not callable(obj):
            raise NotImplementedError(f'cannot trace `{obj}`')
        elif obj in self.trace:
            return self.trace[ obj ]

        self.logger.info(f'trace `{getattr(obj, "__name__", "")}`')
        # get the source code for the target function
        src = dedent(getsource(obj))
        # get the related variable
        var = ChainMap(*getclosurevars(obj)[:3])
        self.logger.debug(f'trace `{obj.__name__}` : {", ".join(var.keys())}')

        if reset:
            self.exceptions = []
            self.fn_called  = []

        self.analysis(src, var)
        self.trace[obj] = self.exceptions

        for fn in self.fn_called:
            try:
                new_obj = None
                for fn_attr in fn:
                    new_obj = var[ fn_attr ] if not new_obj else getattr(new_obj, fn_attr)

                if not callable(new_obj):
                    raise KeyError(new_obj)

                self(new_obj, reset=True)
            except Exception as e:
                self.logger.warning(f'`{".".join(fn)}` not found in variable')
                continue

        self.trace[obj] = self.exceptions
        return self.trace[obj]

    def visit(self, node):
        self.logger.debug(f'visit {node}')

        # catch the AST node you want
        if isinstance(node, ast.Raise):
            # get the exception raised
            n = node.exc

            while True:
                if isinstance(n, str):
                    e = str(n)
                elif isinstance(n, ast.Name):
                    e = n.id
                elif isinstance(n, ast.Attribute):
                    e = n.attr
                elif isinstance(n, ast.Call):
                    n = n.func
                    continue
                else:
                    raise NotImplementedError(f'{n} - {type(n)}')

                break
            self.exceptions += [e]
        elif isinstance(node, ast.Call):
            # function call and may cause another exception
            n, v = node.func, []

            while isinstance(n, ast.Attribute):
                v = v + [n.attr]
                n = n.value

            if isinstance(n, ast.Name):
                v = v + [n.id]
            self.fn_called += [_ for _ in reversed(v)]

        return super().visit(node)

    def analysis(self, src, var):
        self.visit(ast.parse(src))

if __name__ == '__main__':
    def Foo():
        raise KeyError

    agent = PyAnalysis()
    print(agent(Foo))