123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- # -*- coding: utf-8 -*-
- from datetime import datetime, date
- import sys
- class BatchInsertUtil:
- __env = None # odoo的运行环境
- __cr = None # odoo的数据库环境
- __insert_model_name = '' # 要插入的模型全称<str> eg:'module_name.model_name'
- __insert_model_fields = [] # 要插入的字段名称<list> eg:['field_name_1', 'field_name_2', ...]
- __insert_values = [] # 要插入的值<list<dict>>
- __insert_values_keys = [] # 要插入的值的键值<list>,要按顺序与insert_model_fields保持一致
- __foreign_key_field_dict = {} # 与父表关联的外键字段名称以及值<dict> eg:{'foreign_key': foreign_key_value}
- __is_having_index = False # 是否需要添加序号<boolean>
- __batch_count = 0 # 分小批次插入时,每次插入的条数
- __is_returnable = False # 是否有返回值
- def __init__(self, env, insert_model_name, insert_values, insert_model_fields=None, insert_values_keys=None,
- foreign_key_field_dict=None, is_having_index=False, batch_count=0, is_returnable=False):
- self.__env = env
- self.__cr = env.cr
- self.__insert_model_name = insert_model_name
- self.__insert_values = insert_values
- self._init_insert_model_fields(insert_model_fields, insert_values)
- self.__insert_values_keys = insert_values_keys if insert_values_keys else self.__insert_model_fields
- self.__foreign_key_field_dict = foreign_key_field_dict
- self.__is_having_index = is_having_index
- self.__batch_count = batch_count
- self.__is_returnable = is_returnable
- def _init_insert_model_fields(self, insert_model_fields, insert_values):
- if insert_model_fields:
- self.__insert_model_fields = insert_model_fields
- else:
- if insert_values:
- self.__insert_model_fields = [key for key in insert_values[0]]
- def batch_insert(self):
- """
- 初始化后, 供外部调用的 批次插入的方法
- :return:
- """
- ids = []
- self.__append_other_fields()
- if not self.__batch_count or self.__batch_count > len(self.__insert_values):
- data_sql = self.__get_data_sql()
- ids = self.__batch_insert_data(data_sql)
- else:
- batch_values_list = self.__get_batch_values_list()
- for values in batch_values_list:
- self.__insert_values = values
- data_sql = self.__get_data_sql()
- ids.extend(self.__batch_insert_data(data_sql))
- if self.__is_returnable:
- result = self.__env[self.__insert_model_name].search([('id', 'in', ids)]) if ids else []
- return result
- def __append_other_fields(self):
- """
- 用于添加其他字段,例如 外键字段、序号等等,有待扩展 #TODO
- :return:
- """
- if self.__foreign_key_field_dict:
- for key in self.__foreign_key_field_dict:
- self.__insert_model_fields.append(key)
- self.__insert_values_keys.append(key)
- for dic in self.__insert_values:
- dic[key] = self.__foreign_key_field_dict[key]
- if self.__is_having_index:
- self.__insert_model_fields.append('show_index')
- self.__insert_values_keys('show_index')
- temp_num = 1
- for value in self.__insert_values:
- value['show_index'] = temp_num
- temp_num += 1
- def __get_data_sql(self):
- """
- 获取 类似 '(xxx,xxx,xxx),(xxx,xxx,xxx)' 的字符串
- """
- insert_values = self.__insert_values
- insert_values_keys = self.__insert_values_keys
- string_data = []
- for dic in insert_values:
- value_list = []
- for key in insert_values_keys:
- value = dic[key]
- if isinstance(value, tuple): # many2one类型处理
- value_list.append(value[0])
- elif isinstance(value, str):
- value_list.append("'{}'".format(value))
- elif isinstance(value, date):
- value_list.append("'{}'".format(value.strftime('%Y-%m-%d')))
- elif isinstance(value, datetime):
- value_list.append("'{}'".format(value.strftime('%Y-%m-%d %H:%M:%S')))
- elif not dic[key]: # 字段为空处理
- value_list.append('null')
- # TODO 应该还存在其他类型,有待扩展
- else:
- value_list.append(dic[key])
- string_data.append('({})'.format(','.join([str(value) for value in value_list])))
- data_sql = ','.join(string_data)
- return data_sql
- def __batch_insert_data(self, data_sql):
- """
- 批次插入方法
- """
- insert_fields = self.__insert_model_fields
- if not insert_fields:
- return
- insert_model_name = self.__insert_model_name
- return_sql = 'returning id' if self.__is_returnable else ''
- table_name = insert_model_name.replace('.', '_')
- sql = '''
- insert into
- {table_name}
- {fields}
- values
- {data_sql}
- {return_sql}
- '''.format(table_name=table_name, fields='({})'.format(','.join(insert_fields)), data_sql=data_sql,
- return_sql=return_sql)
- self.__cr.execute(sql)
- result = [value['id'] for value in self.__cr.dictfetchall()] if self.__is_returnable else []
- self.__env.clear() # 清除缓存
- return result
- def __get_batch_values_list(self):
- """
- 分批插入时, 获取 分批后的 values 的list 集合
- :return: <list<list<dict>>>
- """
- batch_count = self.__batch_count
- insert_values = self.__insert_values
- result = []
- count = 0
- while True:
- result.append(insert_values[count:count + batch_count])
- count += batch_count
- if count + batch_count > len(insert_values):
- if insert_values[count:]:
- result.append(insert_values[count:])
- break
- return result
|