feat: 脚本兼容Python 3.9, 抽取更多公共代码; 更新使用文档

This commit is contained in:
dhb52 2024-05-02 23:38:56 +08:00
parent d98419b03f
commit 6d19690bea
2 changed files with 159 additions and 153 deletions

View File

@ -50,8 +50,16 @@ TODO 暂未支持
使用方式如下: 使用方式如下:
```Bash 安装依赖库
python3 convertor.py
```bash
pip install simple-ddl-parser
``` ```
然后TODO 执行如下命令打印生成 postgresql 的脚本内容其他可选参数有oracle, sqlserver
```Bash
python3 convertor.py postgres
```
程序将sql脚本打印到终端可以重定向到临时文件tmp.sql, 确认无误后可以利用IDEA专业版进行格式化。

View File

@ -6,11 +6,12 @@ Author: dhb52 (https://gitee.com/dhb52)
pip install simple-ddl-parser pip install simple-ddl-parser
""" """
import argparse
import pathlib import pathlib
import re import re
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple from typing import Dict, Generator, Optional, Tuple, Union
from simple_ddl_parser import DDLParser from simple_ddl_parser import DDLParser
@ -60,12 +61,12 @@ class Convertor(ABC):
self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content) self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
@abstractmethod @abstractmethod
def translate_type(self, type: str, size: None | int | Tuple[int]) -> str: def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str:
"""字段类型转换 """字段类型转换
Args: Args:
type (str): 字段类型 type (str): 字段类型
size (None | int | Tuple[int]): 字段长度描述, 如varchar(255), decimal(10,2) size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2)
Returns: Returns:
str: 类型定义 str: 类型定义
@ -97,7 +98,7 @@ class Convertor(ABC):
pass pass
@abstractmethod @abstractmethod
def gen_index(self, table_ddl: Dict) -> str: def gen_index(self, ddl: Dict) -> str:
"""生成索引定义 """生成索引定义
Args: Args:
@ -133,6 +134,55 @@ class Convertor(ABC):
""" """
pass pass
@staticmethod
def inserts(table_name: str, script_content: str) -> Generator:
PREFIX = f"INSERT INTO `{table_name}`"
# 收集 `table_name` 对应的 insert 语句
for line in script_content.split("\n"):
if line.startswith(PREFIX):
head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
head = head.strip().replace("`", "").lower()
tail = tail.strip().replace(r"\"", '"')
# tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'")
yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
@staticmethod
def index(ddl: Dict) -> Generator:
"""生成索引定义
Args:
ddl (Dict): 表DDL
Yields:
Generator[str]: create index 语句
"""
def generate_columns(columns):
keys = [
f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}"
for col in columns[0]
]
return ", ".join(keys)
for no, index in enumerate(ddl["index"], 1):
columns = generate_columns(index["columns"])
table_name = ddl["table_name"].lower()
yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})"
@staticmethod
def filed_comments(table_sql: str) -> Generator:
for line in table_sql.split("\n"):
match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
if match:
field = match.group(1)
comment_string = match.group(2).replace("\\n", "\n")
yield field, comment_string
def table_comment(self, table_sql: str) -> str:
match = re.search(r"COMMENT \= '([^']+)';", table_sql)
return match.group(1) if match else None
def print(self): def print(self):
"""打印转换后的sql脚本到终端""" """打印转换后的sql脚本到终端"""
print( print(
@ -192,7 +242,7 @@ class PostgreSQLConvertor(Convertor):
def __init__(self, src): def __init__(self, src):
super().__init__(src, "PostgreSQL") super().__init__(src, "PostgreSQL")
def translate_type(self, type, size): def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
"""类型转换""" """类型转换"""
type = type.lower() type = type.lower()
@ -234,27 +284,30 @@ class PostgreSQLConvertor(Convertor):
table_name = ddl["table_name"].lower() table_name = ddl["table_name"].lower()
columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
filed_def_list = ",\n ".join(columns)
script = f"""-- ---------------------------- script = f"""-- ----------------------------
-- Table structure for {table_name} -- Table structure for {table_name}
-- ---------------------------- -- ----------------------------
DROP TABLE IF EXISTS {table_name}; DROP TABLE IF EXISTS {table_name};
CREATE TABLE {table_name} ( CREATE TABLE {table_name} (
{',\n '.join(columns)} {filed_def_list}
);""" );"""
return script return script
def gen_comment(self, table_sql, table_name) -> str: def gen_index(self, ddl: Dict) -> str:
return "\n".join(f"{script};" for script in self.index(ddl))
def gen_comment(self, table_sql: str, table_name: str) -> str:
"""生成字段及表的注释""" """生成字段及表的注释"""
script = "" script = ""
for line in table_sql.split("\n"): for field, comment_string in self.filed_comments(table_sql):
match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) script += (
if match: f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n" )
match = re.search(r"COMMENT \= '([^']+)';", table_sql) table_comment = self.table_comment(table_sql)
table_comment = match.group(1) if match else None
if table_comment: if table_comment:
script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
@ -264,53 +317,21 @@ CREATE TABLE {table_name} (
"""生成主键定义""" """生成主键定义"""
return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
def gen_index(self, ddl) -> str: def gen_insert(self, table_name: str) -> str:
"""生成 index"""
def generate_columns(columns):
keys = [
f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
for col in columns[0]
]
return ", ".join(keys)
script = ""
for no, index in enumerate(ddl["index"], 1):
columns = generate_columns(index["columns"])
table_name = ddl["table_name"].lower()
script += (
f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
)
return script
def gen_insert(self, table_name) -> str:
"""生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence""" """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
PREFIX = f"INSERT INTO `{table_name}`" inserts = list(Convertor.inserts(table_name, self.content))
# 收集 `table_name` 对应的 insert 语句
inserts = []
for line in self.content.split("\n"):
if line.startswith(PREFIX):
head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
head = head.strip().replace("`", "").lower()
tail = tail.strip().replace(r"\"", '"')
script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
# bit(1)数据转换
script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
inserts.append(script)
## 生成 insert 脚本 ## 生成 insert 脚本
script = "" script = ""
last_id = 0 last_id = 0
if inserts: if inserts:
inserts_lines = "\n".join(inserts)
script += f"""\n\n-- ---------------------------- script += f"""\n\n-- ----------------------------
-- Records of {table_name.lower()} -- Records of {table_name.lower()}
-- ---------------------------- -- ----------------------------
-- @formatter:off -- @formatter:off
BEGIN; BEGIN;
{'\n'.join(inserts)} {inserts_lines}
COMMIT; COMMIT;
-- @formatter:on""" -- @formatter:on"""
match = re.search(r"VALUES \((\d+),", inserts[-1]) match = re.search(r"VALUES \((\d+),", inserts[-1])
@ -332,7 +353,7 @@ class OracleConvertor(Convertor):
def __init__(self, src): def __init__(self, src):
super().__init__(src, "Oracle") super().__init__(src, "Oracle")
def translate_type(self, type, size: None | int | Tuple[int]): def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
"""类型转换""" """类型转换"""
type = type.lower() type = type.lower()
@ -369,15 +390,19 @@ class OracleConvertor(Convertor):
full_type = self.translate_type(type, col["size"]) full_type = self.translate_type(type, col["size"])
nullable = "NULL" if col["nullable"] else "NOT NULL" nullable = "NULL" if col["nullable"] else "NOT NULL"
default = f"DEFAULT {col['default']}" if col["default"] is not None else "" default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
return f"{'\"size\"' if name == "size" else name } {full_type} {default} {nullable}" # Oracle 中 size 不能作为字段名
field_name = '"size"' if name == "size" else name
# Oracle DEFAULT 定义在 NULLABLE 之前
return f"{field_name} {full_type} {default} {nullable}"
table_name = ddl["table_name"].lower() table_name = ddl["table_name"].lower()
columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]] columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
field_def_list = ",\n ".join(columns)
script = f"""-- ---------------------------- script = f"""-- ----------------------------
-- Table structure for {table_name} -- Table structure for {table_name}
-- ---------------------------- -- ----------------------------
CREATE TABLE {ddl['table_name'].lower()} ( CREATE TABLE {table_name} (
{',\n '.join(columns)} {field_def_list}
);""" );"""
# oracle INSERT '' 不能通过 NOT NULL 校验 # oracle INSERT '' 不能通过 NOT NULL 校验
@ -385,72 +410,51 @@ CREATE TABLE {ddl['table_name'].lower()} (
return script return script
def gen_comment(self, table_sql, table_name) -> str: def gen_index(self, ddl: Dict) -> str:
script = "" return "\n".join(f"{script};" for script in self.index(ddl))
for line in table_sql.split("\n"):
match = re.search(r"`([^`]+)`.* COMMENT '([^']+)'", line)
if match:
script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
match = re.search(r"COMMENT \= '([^']+)';", table_sql) def gen_comment(self, table_sql: str, table_name: str) -> str:
table_comment = match.group(1) if match else None script = ""
for field, comment_string in self.filed_comments(table_sql):
script += (
f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
)
table_comment = self.table_comment(table_sql)
if table_comment: if table_comment:
script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';" script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
return script return script
def gen_pk(self, table_name) -> str: def gen_pk(self, table_name: str) -> str:
"""生成主键定义""" """生成主键定义"""
return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
def gen_index(self, table_ddl) -> str: def gen_index(self, ddl: Dict) -> str:
"""生成 INDEX 定义""" return "\n".join(f"{script};" for script in self.index(ddl))
def generate_columns(columns): def gen_insert(self, table_name: str) -> str:
keys = [
f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
for col in columns[0]
]
return ", ".join(keys)
script = ""
for no, index in enumerate(table_ddl["index"], 1):
columns = generate_columns(index["columns"])
table_name = table_ddl["table_name"].lower()
script += (
f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
)
return script
def gen_insert(self, table_name) -> str:
"""拷贝 INSERT 语句""" """拷贝 INSERT 语句"""
PREFIX = f"INSERT INTO `{table_name}`"
inserts = [] inserts = []
for line in self.content.split("\n"): for insert_script in Convertor.inserts(table_name, self.content):
if line.startswith(PREFIX): # 对日期数据添加 TO_DATE 转换
head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) insert_script = re.sub(
head = head.strip().replace("`", "").lower() r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
tail = tail.strip().replace(r"\"", '"') r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" insert_script,
# bit(1)数据转换 )
script = script.replace("b'0'", "'0'").replace("b'1'", "'1'") inserts.append(insert_script)
# 对日期数据添加 TO_DATE 转换
script = re.sub(
r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
script,
)
inserts.append(script)
## 生成 insert 脚本 ## 生成 insert 脚本
script = "" script = ""
last_id = 0 last_id = 0
if inserts: if inserts:
inserts_lines = "\n".join(inserts)
script += f"""\n\n-- ---------------------------- script += f"""\n\n-- ----------------------------
-- Records of {table_name.lower()} -- Records of {table_name.lower()}
-- ---------------------------- -- ----------------------------
-- @formatter:off -- @formatter:off
{'\n'.join(inserts)} {inserts_lines}
COMMIT; COMMIT;
-- @formatter:on""" -- @formatter:on"""
match = re.search(r"VALUES \((\d+),", inserts[-1]) match = re.search(r"VALUES \((\d+),", inserts[-1])
@ -476,7 +480,7 @@ class SQLServerConvertor(Convertor):
def __init__(self, src): def __init__(self, src):
super().__init__(src, "Microsoft SQL Server") super().__init__(src, "Microsoft SQL Server")
def translate_type(self, type, size): def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
"""类型转换""" """类型转换"""
type = type.lower() type = type.lower()
@ -507,7 +511,7 @@ class SQLServerConvertor(Convertor):
def _generate_column(col): def _generate_column(col):
name = col["name"].lower() name = col["name"].lower()
if name == 'id': if name == "id":
return "id bigint NOT NULL PRIMARY KEY IDENTITY" return "id bigint NOT NULL PRIMARY KEY IDENTITY"
if name == "deleted": if name == "deleted":
return "deleted bit DEFAULT 0 NOT NULL" return "deleted bit DEFAULT 0 NOT NULL"
@ -520,35 +524,34 @@ class SQLServerConvertor(Convertor):
table_name = ddl["table_name"].lower() table_name = ddl["table_name"].lower()
columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
filed_def_list = ",\n ".join(columns)
script = f"""-- ---------------------------- script = f"""-- ----------------------------
-- Table structure for {table_name} -- Table structure for {table_name}
-- ---------------------------- -- ----------------------------
DROP TABLE IF EXISTS {table_name}; DROP TABLE IF EXISTS {table_name};
CREATE TABLE {table_name} ( CREATE TABLE {table_name} (
{',\n '.join(columns)} {filed_def_list}
) )
GO""" GO"""
return script return script
def gen_comment(self, table_sql, table_name) -> str: def gen_comment(self, table_sql: str, table_name: str) -> str:
"""生成字段及表的注释""" """生成字段及表的注释"""
script = "" script = ""
for line in table_sql.split("\n"):
match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) for field, comment_string in self.filed_comments(table_sql):
if match: script += f"""EXEC sp_addextendedproperty
script += f"""EXEC sp_addextendedproperty 'MS_Description', N'{comment_string}',
'MS_Description', N'{match.group(2).replace('\\n', '\n')}',
'SCHEMA', N'dbo', 'SCHEMA', N'dbo',
'TABLE', N'{table_name}', 'TABLE', N'{table_name}',
'COLUMN', N'{match.group(1)}' 'COLUMN', N'{field}'
GO GO
""" """
match = re.search(r"COMMENT \= '([^']+)';", table_sql) table_comment = self.table_comment(table_sql)
table_comment = match.group(1) if match else None
if table_comment: if table_comment:
script += f"""EXEC sp_addextendedproperty script += f"""EXEC sp_addextendedproperty
'MS_Description', N'{table_comment}', 'MS_Description', N'{table_comment}',
@ -557,55 +560,34 @@ GO
GO GO
""" """
return script return script
def gen_pk(self, table_name) -> str: def gen_pk(self, table_name: str) -> str:
"""生成主键定义""" """生成主键定义"""
return "" return ""
def gen_index(self, ddl) -> str: def gen_index(self, ddl: Dict) -> str:
"""生成 index""" """生成 index"""
return "\n".join(f"{script}\nGO" for script in self.index(ddl))
def generate_columns(columns): def gen_insert(self, table_name: str) -> str:
keys = [
f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
for col in columns[0]
]
return ", ".join(keys)
script = ""
for no, index in enumerate(ddl["index"], 1):
columns = generate_columns(index["columns"])
table_name = ddl["table_name"].lower()
script += f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})\nGO\n"
return script
def gen_insert(self, table_name) -> str:
"""生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence""" """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
PREFIX = f"INSERT INTO `{table_name}`"
# 收集 `table_name` 对应的 insert 语句 # 收集 `table_name` 对应的 insert 语句
inserts = [] inserts = []
for line in self.content.split("\n"): for insert_script in Convertor.inserts(table_name, self.content):
if line.startswith(PREFIX): # SQLServer: 字符串前加Nhack是否存在替换字符串内容的风险
head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) insert_script = insert_script.replace(", '", ", N'").replace(
head = head.strip().replace("`", "").lower() "VALUES ('", "VALUES (N')"
tail = tail.strip().replace(r"\"", '"') )
# SQLServer: 字符串前加Nhack是否存在替换字符串内容的风险 # 删除 insert 的结尾分号
tail = tail.replace(", '", ", N'").replace("VALUES ('", "VALUES (N')") insert_script = re.sub(";$", r"\nGO", insert_script)
script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" inserts.append(insert_script)
# bit(1)数据转换
script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
# 删除 insert 的结尾分号
script = re.sub(";$", r"\nGO", script)
inserts.append(script)
## 生成 insert 脚本 ## 生成 insert 脚本
script = "" script = ""
if inserts: if inserts:
inserts_lines = "\n".join(inserts)
script += f"""\n\n-- ---------------------------- script += f"""\n\n-- ----------------------------
-- Records of {table_name.lower()} -- Records of {table_name.lower()}
-- ---------------------------- -- ----------------------------
@ -614,7 +596,7 @@ BEGIN TRANSACTION
GO GO
SET IDENTITY_INSERT {table_name.lower()} ON SET IDENTITY_INSERT {table_name.lower()} ON
GO GO
{'\n'.join(inserts)} {inserts_lines}
SET IDENTITY_INSERT {table_name.lower()} OFF SET IDENTITY_INSERT {table_name.lower()} OFF
GO GO
COMMIT COMMIT
@ -625,10 +607,26 @@ GO
def main(): def main():
sql_file = pathlib.Path('../mysql/ruoyi-vue-pro.sql').resolve().as_posix() parser = argparse.ArgumentParser(description="芋道系统数据库转换工具")
# convertor = PostgreSQLConvertor(sql_file) parser.add_argument(
# convertor = OracleConvertor(sql_file) "type",
convertor = SQLServerConvertor(sql_file) type=str,
help="目标数据库类型",
choices=["postgres", "oracle", "sqlserver"],
)
args = parser.parse_args()
sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix()
convertor = None
if args.type == "postgres":
convertor = PostgreSQLConvertor(sql_file)
elif args.type == "oracle":
convertor = OracleConvertor(sql_file)
elif args.type == "sqlserver":
convertor = SQLServerConvertor(sql_file)
else:
raise NotImplementedError(f"不支持目标数据库类型: {args.type}")
convertor.print() convertor.print()