guest@blog.cmj.tw: ~/posts $

flake8-datetimez


主要是接手一個之前的 Python 專案,其中用了不少準備 deprecated 的 datetime 用法。但是使用的 linter 都沒有提供 auto-fix 的功能,所以我就寫了一個 linter 來幫助我們檢查與修正這些問題。

#! /usr/bin/env python
import argparse
import ast
import sys
from pathlib import Path

import libcst as cst
from loguru import logger


class DateTimeTransformer(cst.CSTTransformer):
    def __init__(self, /, *args, **kwargs):
        self._found_utcnow = False
        self._found_import = False

        super().__init__(*args, **kwargs)

    def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.Module:
        if self._found_utcnow and not self._found_import:
            logger.info("inserted missing `from datetime import UTC`")
            utc_import = cst.SimpleStatementLine(
                body=[
                    cst.ImportFrom(
                        module=cst.Name("datetime"),
                        names=[cst.ImportAlias(name=cst.Name("UTC"))],
                    ),
                    cst.ImportFrom(
                        module=cst.Name("datetime"),
                        names=[cst.ImportAlias(name=cst.Name("datetime"))],
                    ),
                ],
            )
            updated_node = updated_node.with_changes(body=[utc_import] + list(updated_node.body))

        return updated_node

    def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
        """
        convert the `datetime.utcnow()` to `datetime.now(UTC).replace(tzinfo=None)`
        """
        if not isinstance(node.func, cst.Attribute):
            return updated_node

        if node.func.attr.value == "utcnow":
            logger.info("found and replaced `datetime.utcnow()` with `datetime.now(UTC).replace(tzinfo=None)`")

            caller = node.func.value
            if isinstance(caller, cst.Name) and caller.value == "datetime":
                ...
            elif isinstance(caller, cst.Attribute) and caller.attr.value == "datetime":
                ...
            else:
                logger.warning(f"unexpected caller {caller=}")
                return updated_node

            updated_node = cst.Call(
                func=cst.Attribute(
                    value=cst.Call(
                        func=cst.Attribute(
                            caller,
                            attr=cst.Name("now"),
                        ),
                        args=[cst.Arg(cst.Name("UTC"))],
                    ),
                    attr=cst.Name("replace"),
                ),
                args=[
                    cst.Arg(
                        cst.Name("None"),
                        keyword=cst.Name("tzinfo"),
                    ),
                ],
            )
            self._found_utcnow = True

        return updated_node

    def visit_ImportFrom(self, node: cst.ImportFrom) -> cst.ImportFrom:
        if node.module and node.module.value == "datetime":
            for name in node.names:
                if name.name.value == "UTC":
                    self._found_import = True
                    break

        return node


def _process_file(file: Path) -> int:
    """
    process the file to convert the datetime.utcnow() to datetime.now(UTC)
    """
    logger.debug(f"processing the file {file=} ...")

    with open(file, encoding="utf-8") as fd:
        source = fd.read()
        tree = ast.parse(source)

    transformer = DateTimeTransformer()
    source_tree = cst.parse_module(source)
    tree = source_tree.visit(transformer)
    code = tree.code

    with open(file, "w", encoding="utf-8") as fd:
        fd.write(code)


def handler(target: str) -> int:
    """
    the entrypoint to handle the flake8-dtz003 issue that should convert the
    datetime.utcnow() to datetime.now(UTC)
    """
    path = Path(target)
    if not path.exists():
        logger.error(f"{path} does not exist")
        return 1

    if path.is_file():
        return _process_file(path)
    elif path.is_dir():
        for file in path.rglob("*.py"):
            if _process_file(file):
                return 1
    else:
        logger.error(f"{path} is not a file or a folder")
        return 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process the flake8-dtz003")
    parser.add_argument("-v", "--verbose", action="count", default=0, help="increase output verbosity")
    parser.add_argument("target", type=str, help="The filename / folder to process")

    args = parser.parse_args()

    logger.remove()
    match args.verbose:
        case 0:
            logger.add(sys.stderr, level="ERROR")
        case 1:
            logger.add(sys.stderr, level="INFO")
        case 2:
            logger.add(sys.stderr, level="DEBUG")
        case _:
            logger.add(sys.stderr, level="TRACE")

    sys.exit(handler(args.target))