AST
因為工作上的需要所以要寫一個 Python 的靜態分析程式,主要的目的有:找出特定 function 可能 raise 的 Exception。 透過同事的提醒跟 stackoverflow 上的提示,可以用 Python 內建的 ast、inspect 與 textwrap 完成:
#! /usr/bin/env python
import ast
import logging
from inspect import getsource, getclosurevars, isclass
from textwrap import dedent
from collections import ChainMap
class PyAnalysis(ast.NodeVisitor):
def __init__(self, level=None, fmt=None):
level = level or logging.INFO
fmt = logging.Formatter(fmt or r'[%(asctime)-.19s] %(pathname)s#L%(lineno)d : %(message)s')
syslog = logging.StreamHandler()
syslog.setFormatter(fmt)
logger = logging.getLogger(self.__class__.__name__)
logger.setLevel(level)
if not logger.handlers: logger.addHandler(syslog)
self.logger = logging.LoggerAdapter(
logger,
{
'app_name': self.__class__.__name__,
}
)
self.trace = {}
def __call__(self, obj, reset=True):
if not callable(obj):
raise NotImplementedError(f'cannot trace `{obj}`')
elif obj in self.trace:
return self.trace[ obj ]
self.logger.info(f'trace `{getattr(obj, "__name__", "")}`')
# get the source code for the target function
src = dedent(getsource(obj))
# get the related variable
var = ChainMap(*getclosurevars(obj)[:3])
self.logger.debug(f'trace `{obj.__name__}` : {", ".join(var.keys())}')
if reset:
self.exceptions = []
self.fn_called = []
self.analysis(src, var)
self.trace[obj] = self.exceptions
for fn in self.fn_called:
try:
new_obj = None
for fn_attr in fn:
new_obj = var[ fn_attr ] if not new_obj else getattr(new_obj, fn_attr)
if not callable(new_obj):
raise KeyError(new_obj)
self(new_obj, reset=True)
except Exception as e:
self.logger.warning(f'`{".".join(fn)}` not found in variable')
continue
self.trace[obj] = self.exceptions
return self.trace[obj]
def visit(self, node):
self.logger.debug(f'visit {node}')
# catch the AST node you want
if isinstance(node, ast.Raise):
# get the exception raised
n = node.exc
while True:
if isinstance(n, str):
e = str(n)
elif isinstance(n, ast.Name):
e = n.id
elif isinstance(n, ast.Attribute):
e = n.attr
elif isinstance(n, ast.Call):
n = n.func
continue
else:
raise NotImplementedError(f'{n} - {type(n)}')
break
self.exceptions += [e]
elif isinstance(node, ast.Call):
# function call and may cause another exception
n, v = node.func, []
while isinstance(n, ast.Attribute):
v = v + [n.attr]
n = n.value
if isinstance(n, ast.Name):
v = v + [n.id]
self.fn_called += [_ for _ in reversed(v)]
return super().visit(node)
def analysis(self, src, var):
self.visit(ast.parse(src))
if __name__ == '__main__':
def Foo():
raise KeyError
agent = PyAnalysis()
print(agent(Foo))