pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua (560 lines of code) (raw):

function split(str, delimiter) if str==nil or str=='' or delimiter==nil then return nil end local result = {} for match in (str..delimiter):gmatch("(.-)"..delimiter) do table.insert(result, match) end return result end function join(list, delimiter) return table.concat(list, delimiter) end function match_str_in_list(list, str_pattern) for idx=1,#(list) do if string.find(list[idx], str_pattern) ~= nil then return idx end end return nil end function CheckOssValid(host, bucket) if host == nil or string.len(host) == 0 or bucket == nil or string.len(bucket) == 0 then return false end return true end function ParseOssUri(oss_uri, default_host) if string.len(oss_uri) > 6 and string.find(oss_uri, "oss://") == 1 then _,_,_path,file = string.find(oss_uri,"oss://(.*/)(.*)") if _path == nil or string.len(_path) == 0 then error("invalid oss uri: "..oss_uri..", should end with '/'") end _,_,bucket_host,dir = string.find(_path, "(.-)(/.*)") if (string.find(bucket_host, "%.")) then _,_,bucket,host = string.find(bucket_host, "(.-)%.(.*)") else bucket = bucket_host host = default_host end if not CheckOssValid(host, bucket) then error("invalid oss uri: "..oss_uri..", oss host or bucket not found") end root_dir = bucket..dir return host, root_dir, file end error("invalid oss uri: "..oss_uri) end function getEntry(script_in, entryFile_in, config, cluster, res_project, version) if string.len(entryFile_in) == 0 then error('entryFile is not set') end if script_in ~= nil and string.len(script_in) > 0 then script = script_in entryFile = entryFile_in else script = "odps://" .. res_project .. "/resources/easy_rec_ext_" .. version .. "_res.tar.gz" entryFile = entryFile_in end return script, entryFile end function checkConfig(config) if config == nil or config == '' then error('config must be set') end s1, e1 = string.find(config, 'oss://') s2, e2 = string.find(config, 'http') if s1 == nil and s2 == nil then error("config path should be url or oss path") end end function checkTable(table) s1, e1 = string.find(table, "/tables/") s2, e2 = string.find(table, "odps://") if s1 == nil or s2 == nil then error(string.format("invalid odps table path: %s", table)) end end function checkOss(path) s1, e1 = string.find(path, "oss://") if s1 == nil then error(string.format("invalid oss path: %s", path)) end end function check_run_mode(cluster, gpuRequired) if (cluster ~=nil and cluster ~= "") and gpuRequired ~="" then error(string.format('cluster and gpuRequired should not be set at the same time. cluster: %s gpuRequired:%s', cluster, gpuRequired)) end end function getHyperParams(config, cmd, checkpoint_path, fine_tune_checkpoint, eval_result_path, export_dir, gpuRequired, cpuRequired, memRequired, cluster, continue_train, distribute_strategy, with_evaluator, eval_method, edit_config_json, selected_cols, model_dir, hpo_param_path, hpo_metric_save_path, saved_model_dir, all_cols, all_col_types, reserved_cols, output_cols, model_outputs, input_table, output_table, tables, query_table, doc_table, knn_distance, knn_num_neighbours, knn_feature_dims, knn_index_type, knn_feature_delimiter, knn_nlist, knn_nprobe, knn_compress_dim, train_tables, eval_tables, boundary_table, batch_size, profiling_file, mask_feature_name, extra_params) hyperParameters = "" if cmd == "predict" then if cluster == nil or cluster == '' then error('cluster must be set') end if saved_model_dir == nil or saved_model_dir == '' then error('saved_model_dir must be set') checkOss(saved_model_dir) end hyperParameters = " --cmd=" .. cmd hyperParameters = hyperParameters .. " --saved_model_dir=" .. saved_model_dir hyperParameters = hyperParameters .. " --all_cols=" .. all_cols .. " --all_col_types=" .. all_col_types if selected_cols ~= nil and selected_cols ~= '' then hyperParameters = hyperParameters .. " --selected_cols=" .. selected_cols end if reserved_cols ~= nil and string.len(reserved_cols) > 0 then hyperParameters = hyperParameters .. " --reserved_cols=" .. reserved_cols end hyperParameters = hyperParameters .. " --batch_size=" .. batch_size if profiling_file ~= nil and profiling_file ~= '' then checkOss(profiling_file) hyperParameters = hyperParameters .. " --profiling_file=" .. profiling_file end --support both 'probs float, embedding string' and 'probs, embedding' format --in easy_rec.python.inferece.predictor.predict_table if model_outputs ~= nil and model_outputs ~= "" then hyperParameters = hyperParameters .. " --output_cols='" .. model_outputs .. "'" else hyperParameters = hyperParameters .. " --output_cols='" .. output_cols .. "'" end checkTable(input_table) checkTable(output_table) if extra_params ~= nil and extra_params ~= '' then hyperParameters = hyperParameters .. " " .. extra_params end return hyperParameters, cluster, tables, output_table end if cmd == "vector_retrieve" then if cluster == nil or cluster == '' then error('cluster must be set') end checkTable(query_table) checkTable(doc_table) checkTable(output_table) hyperParameters = " --cmd=" .. cmd hyperParameters = hyperParameters .. " --batch_size=" .. batch_size hyperParameters = hyperParameters .. " --knn_distance=" .. knn_distance if knn_num_neighbours ~= nil and knn_num_neighbours ~= '' then hyperParameters = hyperParameters .. ' --knn_num_neighbours=' .. knn_num_neighbours end if knn_feature_dims ~= nil and knn_feature_dims ~= '' then hyperParameters = hyperParameters .. ' --knn_feature_dims=' .. knn_feature_dims end hyperParameters = hyperParameters .. " --knn_index_type=" .. knn_index_type hyperParameters = hyperParameters .. " --knn_feature_delimiter=" .. knn_feature_delimiter if knn_nlist ~= nil and knn_nlist ~= '' then hyperParameters = hyperParameters .. ' --knn_nlist=' .. knn_nlist end if knn_nprobe ~= nil and knn_nprobe ~= '' then hyperParameters = hyperParameters .. ' --knn_nprobe=' .. knn_nprobe end if knn_compress_dim ~= nil and knn_compress_dim ~= '' then hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim end if extra_params ~= nil and extra_params ~= '' then hyperParameters = hyperParameters .. " " .. extra_params end return hyperParameters, cluster, tables, output_table end if cmd ~= "custom" then checkConfig(config) end hyperParameters = "--config='" .. config .. "'" if selected_cols ~= nil and selected_cols ~= '' then hyperParameters = hyperParameters .. ' --selected_cols=' .. selected_cols end hyperParameters = string.format('%s --cmd=%s', hyperParameters, cmd) if cmd == 'evaluate' then hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path hyperParameters = hyperParameters .. " --all_cols=" .. all_cols .. " --all_col_types=" .. all_col_types hyperParameters = hyperParameters .. " --eval_result_path=" .. eval_result_path hyperParameters = hyperParameters .. " --mask_feature_name=" .. mask_feature_name hyperParameters = hyperParameters .. " --distribute_strategy=" .. distribute_strategy if eval_tables ~= "" and eval_tables ~= nil then hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables end elseif cmd == 'export' or cmd == 'export_checkpoint' then hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path hyperParameters = hyperParameters .. " --export_dir=" .. export_dir elseif cmd == 'train' then hyperParameters = hyperParameters .. " --all_cols=" .. all_cols .. " --all_col_types=" .. all_col_types hyperParameters = hyperParameters .. " --continue_train=" .. continue_train hyperParameters = hyperParameters .. " --distribute_strategy=" .. distribute_strategy if with_evaluator ~= "" and tonumber(with_evaluator) ~= 0 then hyperParameters = hyperParameters .. " --with_evaluator" end if fine_tune_checkpoint ~= nil and fine_tune_checkpoint ~= '' then hyperParameters = hyperParameters .. " --fine_tune_checkpoint=" .. fine_tune_checkpoint end if eval_method ~= 'none' and eval_method ~= 'separate' and eval_method ~= 'master' then error('invalid eval_method ' .. eval_method) end if eval_method ~= "" then hyperParameters = hyperParameters .. " --eval_method=" .. eval_method end -- tables used for train and evaluate if train_tables ~= "" and train_tables ~= nil then hyperParameters = hyperParameters .. " --train_tables " .. train_tables end if eval_tables ~= "" and eval_tables ~= nil then hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables end if boundary_table ~= "" and boundary_table ~= nil then hyperParameters = hyperParameters .. " --boundary_table " .. boundary_table end if hpo_param_path ~= "" and hpo_param_path ~= nil then hyperParameters = hyperParameters .. " --hpo_param_path=" .. hpo_param_path if hpo_metric_save_path == nil then error('hpo_metric_save_path must be set') end hyperParameters = hyperParameters .. " --hpo_metric_save_path=" .. hpo_metric_save_path end end if edit_config_json ~= "" and edit_config_json ~= nil then hyperParameters = hyperParameters .. string.format(" --edit_config_json='%s'", edit_config_json) end if model_dir ~= "" and model_dir ~= nil then checkOss(model_dir) hyperParameters = hyperParameters .. " --model_dir=" .. model_dir end check_run_mode(cluster, gpuRequired) if gpuRequired ~= "" then num_gpus_per_worker = math.max(math.ceil(tonumber(gpuRequired)/100), 0) cluster = string.format('{"worker":{"count":1, "gpu":%s, "cpu":%s, "memory":%s}}', gpuRequired, cpuRequired, memRequired) elseif cluster ~= "" then gpus_str = string.match(cluster, '"gpu"%s*:%s*(%d+)') if gpus_str ~= nil then num_gpus_per_worker = math.max(math.ceil(tonumber(gpus_str)/100), 0) else num_gpus_per_worker = 1 end else num_gpus_per_worker = 1 end hyperParameters = string.format("%s --num_gpus_per_worker=%s ", hyperParameters, num_gpus_per_worker) if extra_params ~= nil and extra_params ~= '' then hyperParameters = hyperParameters .. " " .. extra_params end return hyperParameters, cluster, tables, output_table end function splitTableParam(table_path) -- odps://xx_project/tables/table_name/pa=1/pb=2 -- split table name and partitions delimiter = '/' eles = split(table_path, delimiter) project_name = eles[3] table_name = eles[5] local partitions = {} for i=6, table.getn(eles) do table.insert(partitions, eles[i]) end partition_str = join(partitions, delimiter) return project_name, table_name, partition_str end function getInputTableColTypes(inputTable) -- to test: uncomment the following, and comment the rest --return {["a"] = "string", ["b"] = "int",["c"] = "string",["d"] = "int"}, {"a", "b", "c" } local all_input_cols = Builtin.GetAllColumnNames(inputTable, ",") local all_input_types = Builtin.GetColumnDataTypes(inputTable, ",") local col_list = split(all_input_cols, ',') local type_list = split(all_input_types, ',') local col_map = {} for i=1,table.getn(col_list) do col_map[col_list[i]] = type_list[i] end return col_map, col_list end function getOutputCols(col_type_map, reserved_columns, result_column) local res_cols = split(reserved_columns, ',') local sql = "(" if res_cols ~= nil then for i=1, table.getn(res_cols) do if col_type_map[res_cols[i]] == nil then error(string.format("column %s is not in input table", res_cols[i])) return else sql = sql .. res_cols[i] .. " " .. col_type_map[res_cols[i]] .. "," end end end sql = sql .. result_column .. " string)" return sql end function parseParitionSpec(partitions) local parition_names = {} local partition_values = {} local parts = split(partitions, "/") for i = 1, table.getn(parts) do local spec = split(parts[i], "=") if table.getn(spec) ~=2 then error("Partition Spec is not Right "..parts[i]) else table.insert(parition_names, i, spec[1]) table.insert(partition_values,i, spec[2]) end end return parition_names, partition_values end function genCreatePartitionStr(partition_names) local part_str = "(" for i = 1,#(partition_names) do part_str = part_str..partition_names[i].." string," end part_str = string.sub(part_str, 1, -2) return part_str..")" end function genAddPartitionStr(parition_names, partition_values) local part_str = "(" for i = 1, #(parition_names) do part_str= part_str..parition_names[i].."=\""..partition_values[i].."\"," end part_str = string.sub(part_str, 1, -2) return part_str..")" end function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, reservedCols, lifecycle, outputCol, tables, trainTables, evalTables, boundaryTable, queryTable, docTable) -- all_cols, all_col_types, selected_cols, reserved_cols, -- create_table_sql, add_partition_sql, tables parameter to runTF if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export' and cmd ~= 'export_checkpoint' and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve') end -- for export if cmd == 'export' or cmd == 'custom' or cmd == 'export_checkpoint' then return "", "", "", "", "select 1;", "select 1;", tables end -- for online train or train with oss input if cmd == 'train' and (tables == nil or tables == '') and (trainTables == nil or trainTables == '') then return "", "", "", "", "select 1;", "select 1;", tables end -- merge all tables into all_tables all_tables = {} table_id = 0 if tables ~= nil and tables ~= '' then tmpTables = split(tables, ',') for k=1, table.getn(tmpTables) do v = tmpTables[k] if all_tables[v] == nil then all_tables[v] = table_id table_id = table_id + 1 end end if inputTable == nil or inputTable == '' then inputTable = tmpTables[1] end end if cmd == 'vector_retrieve' then inputTable = queryTable all_tables[queryTable] = table_id table_id = table_id + 1 all_tables[docTable] = table_id table_id = table_id + 1 end if cmd == 'train' then -- merge train table and eval table into all_tables if trainTables ~= '' and trainTables ~= nil then tmpTables = split(trainTables, ',') for k=1, table.getn(tmpTables) do v = tmpTables[k] if all_tables[v] == nil then all_tables[v] = table_id table_id = table_id + 1 end end inputTable = tmpTables[1] tmpTables = split(evalTables, ',') for k=1, table.getn(tmpTables) do v = tmpTables[k] if all_tables[v] == nil then all_tables[v] = table_id table_id = table_id + 1 end end end if boundaryTable ~= nil and boundaryTable ~= '' then if all_tables[boundaryTable] == nil then all_tables[boundaryTable] = table_id table_id = table_id + 1 end end end if cmd == 'evaluate' then -- merge evalTables into tables if evalTables is set if evalTables ~= nil and evalTables ~= '' then tmpTables = split(evalTables, ',') for k=1, table.getn(tmpTables) do v = tmpTables[k] if all_tables[v] == nil then all_tables[v] = table_id table_id = table_id + 1 end end inputTable = tmpTables[1] end end if cmd == 'predict' then -- merge inputTable into all_tables if inputTable is set if inputTable ~= nil and inputTable ~= '' then tmpTables = split(inputTable, ',') for k=1, table.getn(tmpTables) do v = tmpTables[k] if all_tables[v] == nil then all_tables[v] = table_id table_id = table_id + 1 end end else -- if inputTable is not set but tables is set -- set inputTable to tables if tables ~= '' and tables ~= nil then inputTable = split(tables, ',')[1] else error('either inputTable or tables must be set') end end end -- merge all_tables into tables tables = {} for k,v in pairs(all_tables) do -- ensure order to be compatible tables[v+1] = k --table.insert(tables, k) end if inputTable == nil or inputTable == '' then error('inputTable is not defined') end tables = join(tables, ',') if cmd == 'vector_retrieve' then if outputTable == nil or outputTable == '' then error("outputTable is not set") end proj1, table1, partition1 = splitTableParam(outputTable) out_table_name = proj1 .. "." .. table1 create_sql = '' add_partition_sql = '' if partition1 ~= nil and string.len(partition1) ~= 0 then local partition_names, parition_values = parseParitionSpec(partition1) create_partition_str = genCreatePartitionStr(partition_names) create_sql = string.format("create table if not exists %s (query BIGINT, doc BIGINT, distance DOUBLE) partitioned by %s lifecycle %s;", out_table_name, create_partition_str, lifecycle) add_partition_sql = genAddPartitionStr(partition_names, parition_values) add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql) else create_sql = string.format("create table %s (query BIGINT, doc BIGINT, distance DOUBLE) lifecycle %s;", out_table_name, lifecycle) add_partition_sql = string.format("desc %s;", out_table_name) end return "", "", "", "", create_sql, add_partition_sql, tables end -- analyze selected_cols excluded_cols for train, evaluate and predict proj0, table0, partition0 = splitTableParam(inputTable) input_col_types, input_cols = getInputTableColTypes(proj0 .. "." .. table0) if (excludedCols ~= nil and excludedCols ~= '') and (selectedCols ~= nil and selectedCols ~= '') then error('selected_cols and excluded_cols should not be set') end ex_cols_map = {} if excludedCols ~= '' and excludedCols ~= nil then ex_cols_lst = split(excludedCols, ',') for i=1, table.getn(ex_cols_lst) do ex_cols_map[ex_cols_lst[i]] = 1 end end -- columns to be selected to input to the model selected_cols = {} -- all columns to read by TableRecordDataset all_cols = {} all_col_types = {} all_cols_map = {} if selectedCols ~= '' and selectedCols ~= nil then tmp_cols = split(selectedCols, ",") else tmp_cols = input_cols end for i=1, table.getn(tmp_cols) do if input_col_types[tmp_cols[i]] == nil then error(string.format("column %s is not in input table", tmp_cols[i])) return elseif ex_cols_map[tmp_cols[i]] == nil then -- not in excluded cols map if input_col_types[tmp_cols[i]] ~= nil and all_cols_map[tmp_cols[i]] == nil then table.insert(all_cols, tmp_cols[i]) table.insert(all_col_types, input_col_types[tmp_cols[i]]) table.insert(selected_cols, tmp_cols[i]) all_cols_map[tmp_cols[i]] = 1 end end end if cmd == 'train' or cmd == 'evaluate' then return join(all_cols, ","), join(all_col_types, ","), join(selected_cols, ","), '', "select 1;", "select 1;", tables end -- analyze reserved_cols for predict -- columns to be copied to output_table, may not be in selected columns -- could have overlapped columns with selected_cols and excluded_cols reserved_cols = {} reserved_col_types = {} if reservedCols ~= nil and reservedCols ~= '' then if reservedCols == 'ALL_COLUMNS' then tmp_cols = input_cols else tmp_cols = split(reservedCols, ',') end for i=1, table.getn(tmp_cols) do if input_col_types[tmp_cols[i]] ~= nil then table.insert(reserved_cols, tmp_cols[i]) table.insert(reserved_col_types, input_col_types[tmp_cols[i]]) if all_cols_map[tmp_cols[i]] == nil then table.insert(all_cols, tmp_cols[i]) table.insert(all_col_types, input_col_types[tmp_cols[i]]) all_cols_map[tmp_cols[i]] = 1 end else error("invalid reserved_col: " .. tmp_cols[i] .. " available: " .. join(input_cols, ",")) end end else table.insert(reserved_cols, selected_cols[0]) end -- build create output table sql and add partition sql for predict sql_col_desc = {} for i=1, table.getn(reserved_cols) do table.insert(sql_col_desc, reserved_cols[i] .. " " .. reserved_col_types[i]) end table.insert(sql_col_desc, outputCol) sql_col_desc = join(sql_col_desc, ",") if outputTable == nil or outputTable == '' then error("outputTable is not set") end proj1, table1, partition1 = splitTableParam(outputTable) out_table_name = proj1 .. "." .. table1 create_sql = '' add_partition_sql = '' if partition1 ~= nil and string.len(partition1) ~= 0 then local partition_names, parition_values = parseParitionSpec(partition1) create_partition_str = genCreatePartitionStr(partition_names) create_sql = string.format("create table if not exists %s (%s) partitioned by %s lifecycle %s;", out_table_name, sql_col_desc, create_partition_str, lifecycle) add_partition_sql = genAddPartitionStr(partition_names, parition_values) add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql) else create_sql = string.format("create table %s (%s) lifecycle %s;", out_table_name, sql_col_desc, lifecycle) add_partition_sql = string.format("desc %s;", out_table_name) end return join(all_cols, ","), join(all_col_types, ","), join(selected_cols, ","), join(reserved_cols, ","), create_sql, add_partition_sql, tables end function test_create_table() input_table = "odps://pai_rec_dev/tables/test_longonehot_4deepfm_20/part=1" output_table = "odps://pai_rec_dev/tables/test_longonehot_4deepfm_20_out/part=1" selectedCols = "a,b,c" excludedCols = "" reservedCols = "a" lifecycle=1 outputCol = "score double" all_cols, all_cols_types, selected_cols, reserved_cols, create_sql, add_partition_sql = createTable(input_table, output_table, selectedCols, excludedCols, reservedCols, lifecycle, outputCol) print(create_sql) print(add_partition_sql) print(string.format('all_cols = %s', all_cols)) print(string.format('selected_cols = %s', selected_cols)) print(string.format('reserved_cols = %s', reserved_cols)) end --test_create_table()