import json import re import time
import argparse import sqlparse from jinja2 import Template class SQLParseHelper: """ 解析SQL """ def __init__(self, sql_str=None, split_str=';'): self.sql_str = sql_str self.split_str = split_str self.sql_list = [] if self.split_str == ';': self.sql_list = list(filter( lambda y: y != '', sqlparse.split(self.sql_str) )) else: self.sql_list = list( filter( lambda y: y != '', map(lambda x: x.strip(), self.sql_str.split(self.split_str)) )) def get_length(self): """ sql文件中的SQL数量 :return: """ return len(self.sql_list) def get_format(self, sql): """ 根据选项格式化sql。 可用选项记录在“ SQL语句格式”中。 除格式化选项外,该函数还接受关键字“ encoding”,该关键字确定语句的编码。 返回值: 格式化的SQL语句为字符串。 :return: """ format_sql = sqlparse.format(sql, reindent=True, keyword_case='upper', strip_comments=True) return format_sql def get_init(self, sql): tokens = [] parsed = sqlparse.parse(sql) stmt = parsed[0].tokens for token in stmt: tokens.append([token.ttype, token.value]) sql = sqlparse.sql.Statement(stmt) sql_type = sql.get_type() return sql_type, tokens def get_type(self, tokens): return tokens[0][0][1] if len(tokens[0][0]) > 1 else 'Others' def sql_statement(self, tokens): return tokens[0][1] if len(tokens[0]) > 1 else 'Others' def get_table(self, tokens): """ 如何从tokens中获取表名? :param tokens: :return: """ tables = [] try: sql_words = list(map(lambda x: x[1], tokens)) index_num = sql_words.index('TABLE') table = sql_words[index_num + 2] except Exception as e: print(str(e)) else: tables.append(table) return tables class PTHelper: def __init__(self, sql): self.sql = sql def alter_sql_parse(self): """ 对同一张表按照原来的语句顺序用逗号连接,所有双引号替换为单引号 [ {"table_name": "frm_user", "alter_cmd": "DROP INDEX account_idx,ADD UNIQUE INDEX uniq_account(account)" "sql_from":[{ "original_sql": "alter table frm_user drop index account_idx;", "format_sql": "ALTER TABLE frm_user\nDROP INDEX account_idx;", "sql_type": "DDL", "sql_statement": "ALTER" }, { "original_sql": "alter table frm_user add unique index uniq_account(account);", "format_sql": "ALTER TABLE frm_user ADD UNIQUE INDEX uniq_account(account);", "sql_type": "DDL", "sql_statement": "ALTER" }] ] :return: """ self.result = [] for _sql in self.sql["data"]: sql_str = _sql["format_sql"].replace('`', '') match_obj = re.match( r'ALTER TABLE(.*)[\n]?(ADD|ALTER|CHANGE|CHARACTER|CONVERT|DISABLE|ENABLE|DROP|FORCE|LOCK|MODIFY|ORDER|RENAME|WITHOUT|WITH)[\n]?(.*);', sql_str) if match_obj: table_name = match_obj.group(1).strip() alter_cmd_key = match_obj.group(2).strip() alter_cmd_option = match_obj.group(3).strip() self.result.append( { "table_name": table_name.strip(), "alter_cmd": "{} {}".format(alter_cmd_key.strip(), alter_cmd_option.strip().replace('"', "'")), "sql_from": _sql } ) else: print("不能正常匹配") self.result.append( { "table_name": "", "alter_cmd": "", "sql_from": _sql } ) return self.result def sql_to_pt(self): """ 组成pt脚本 :return: """ table_names = set(map(lambda x: x["table_name"], self.result)) out = [] for table in table_names: info = list(filter(lambda x: x["table_name"] == table, self.result)) out.append( { "table_name": table, "alter_cmd": ','.join(list(map(lambda x: x["alter_cmd"], info))), "sql_from": info } ) return out class GetScripts: def __init__(self, sql_list, **kwargs): self.render_data = {"sql_list": sql_list, "time_string": time.strftime('%Y%m%d%H%M%S', time.localtime(time.time())), "host": kwargs['host'], "port": kwargs['port'], "dbname": kwargs['dbname'], "user": kwargs['user'], "password": kwargs['password'], } def render_template(self): template_data = """#!/bin/bash # pt-osc scripts # {{ time_string }} # Solar 混合云管理平台 host={{ host }} port={{ port }} dbname={{ dbname }} user={{ user }} password={{ password }} {% for sql in sql_list %} table={{ sql.table_name }} pt-online-schema-change --user=${user} --port=${port} --host=${host} --password=${password} --alter="{{ sql.alter_cmd }}" D=${dbname},t=${table} --noversion-check --execute --charset=utf8 {% endfor %} """ template = Template(template_data) return template.render(**self.render_data) def maker(self): data = self.render_template() return data class SQLCheck: """ 检测三个条件: 1. 检查语句中是否包含数据库名,如果包含则Fail; 2. 检查语句结尾是否为分号,如果是则True; 3. 判断SQL是否都为DDL Alter 语句,如果是则True; [ { "original_sql": "ALTER TABLE `aia_use_app_tag_log` ADD INDEX `idx_app_use_time` (`use_time`)", "format_sql": "ALTER TABLE `aia_use_app_tag_log` ADD INDEX `idx_app_use_time` (`use_time`)", "sql_type": "DDL", "sql_statement": "ALTER", "check_alter": "Pass", "check_database": "Pass", "check_semicolon": "Fail", "check_comma": "Fail", } ] """ def __init__(self, in_data_list): self.in_data_list = in_data_list def check_database(self, _data, pass_num, fail_num): """ 检查语句中是否包含数据库名,如果包含则Fail 语句中不可以包含. """ if _data["format_sql"].find('.') < 0: _data["check_database"] = "Pass" pass_num += 1 else: _data["check_database"] = "Fail" fail_num += 1 out_data = _data return pass_num, fail_num, out_data def check_semicolon(self, _data, pass_num, fail_num): """ 检查语句结尾是否为分号,如果是则True """ if _data["format_sql"].find(';') >= 0: _data["check_semicolon"] = "Pass" pass_num += 1 else: _data["check_semicolon"] = "Fail" fail_num += 1 out_data = _data return pass_num, fail_num, out_data def check_alter(self, _data, pass_num, fail_num): """ 判断SQL是否都为DDL Alter 语句 """ if _data["sql_type"] == "DDL" and _data["sql_statement"] == "ALTER": _data["check_alter"] = "Pass" pass_num += 1 else: _data["check_alter"] = "Fail" fail_num += 1 out_data = _data return pass_num, fail_num, out_data def main(self): out_data_list = [] pass_num = 0 fail_num = 0 for _data in self.in_data_list: pass_num, fail_num, out_data = self.check_alter(_data, pass_num, fail_num) pass_num, fail_num, out_data = self.check_database(_data, pass_num, fail_num) pass_num, fail_num, out_data = self.check_semicolon(_data, pass_num, fail_num) out_data_list.append(out_data) return { "pass_num": pass_num, "fail_num": fail_num, "data": out_data_list } def start_up(**kwargs): api = SQLParseHelper(sql_str=kwargs["sql_str"], split_str=";") sql_num = api.get_length() data = [] for sql in api.sql_list: format_sql = api.get_format(sql) result = api.get_init(format_sql) data.append({ "original_sql": sql, "format_sql": format_sql.strip(), "sql_type": api.get_type(tokens=result[1]), "sql_statement": api.sql_statement(tokens=result[1]), }) new_result = { "num": sql_num, "data": data } check_api = SQLCheck(new_result["data"]) check_result = check_api.main() if check_result["fail_num"] > 0: print("检测不通过") print(json.dumps( list(filter( lambda x: x["check_alter"] == 'Fail' or x["check_database"] == 'Fail' or x[ "check_semicolon"] == 'Fail' or x["check_comma"] == "Fail", check_result["data"] )) , indent=2, ensure_ascii=False)) exit() else: pass pt_api = PTHelper(check_result) pt_api.alter_sql_parse() pt_sql_list = pt_api.sql_to_pt() scripts_api = GetScripts(pt_sql_list, **kwargs) data = scripts_api.maker() return data if __name__ == "__main__": parser = argparse.ArgumentParser(description='''MySQL Alter语句自动转 PT脚本 小工具 支持的Alter类型:ADD|ALTER|CHANGE|CHARACTER|CONVERT|DISABLE|ENABLE|DROP|FORCE|LOCK|MODIFY|ORDER|RENAME|WITHOUT|WITH ; 不支持带库名; 不支持一条sql多个变。 输出结果 #!/bin/bash # pt-osc scripts # 20200522190824 # Solar 混合云管理平台 host=localhost port=3306 dbname=test01 user=root password=lsdkjkdjfkdf table=frm_user pt-online-schema-change --user=${user} --port=${port} --host=${host} --password=${password} --alter="DROP INDEX account_idx,ADD UNIQUE INDEX uniq_account(account)" D=${dbname},t=${table} --no-version-check --execute --charset=utf8 table=aia_present_point_exchange_user pt-online-schema-change --user=${user} --port=${port} --host=${host} --password=${password} --alter="ADD COLUMN end_date DATETIME NULL COMMENT '有效期结束' AFTER start_date" D=${dbname},t=${table} --no-version-check --execute --charset=utf8 " Example: python3 get_sql2pt.py --InFile demo/mysql_alter_demo.sql --OutFile demo/mysql_alter_pt.sh --Host localhost --Port 3306 --DBName test01 --User root --PassWord lsdkjkdjfkdf ''', formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("--InFile", help="MySQL Alter SQL源文件 必要参数") parser.add_argument("--OutFile", default='sql2pt-{}'.format(time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))), help="转换后的脚本名 非参数") parser.add_argument("--Host", help="MySQL服务器连接地址") parser.add_argument("--Port", help="MySQL服务器监听端口") parser.add_argument("--DBName", help="MySQL访问的数据库名") parser.add_argument("--User", help="MySQL服务器登陆用户名") parser.add_argument("--PassWord", help="MySQL服务器登陆用户密码") args = parser.parse_args() if args.InFile and args.Host and args.Port and args.DBName and args.User and args.PassWord: sql_str = open(args.InFile, 'r', encoding='utf-8').read().replace('#', '') params = { "sql_str": sql_str, "out_file": args.OutFile, "host": args.Host, "port": args.Port, "dbname": args.DBName, "user": args.User, "password": args.PassWord } result = start_up(**params) with open(args.OutFile, 'w', encoding='utf-8') as f: f.write(result)
|