in pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua [371:633]
function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols,
reservedCols, lifecycle, outputCol, tables,
trainTables, evalTables, boundaryTable, queryTable, docTable)
if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export'
and cmd ~= 'export_checkpoint'
and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then
error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve')
end
if cmd == 'export' or cmd == 'custom' or cmd == 'export_checkpoint' then
return "", "", "", "", "select 1;", "select 1;", tables
end
if cmd == 'train' and (tables == nil or tables == '') and (trainTables == nil or trainTables == '') then
return "", "", "", "", "select 1;", "select 1;", tables
end
all_tables = {}
table_id = 0
if tables ~= nil and tables ~= ''
then
tmpTables = split(tables, ',')
for k=1, table.getn(tmpTables) do
v = tmpTables[k]
if all_tables[v] == nil then
all_tables[v] = table_id
table_id = table_id + 1
end
end
if inputTable == nil or inputTable == ''
then
inputTable = tmpTables[1]
end
end
if cmd == 'vector_retrieve' then
inputTable = queryTable
all_tables[queryTable] = table_id
table_id = table_id + 1
all_tables[docTable] = table_id
table_id = table_id + 1
end
if cmd == 'train' then
if trainTables ~= '' and trainTables ~= nil then
tmpTables = split(trainTables, ',')
for k=1, table.getn(tmpTables) do
v = tmpTables[k]
if all_tables[v] == nil then
all_tables[v] = table_id
table_id = table_id + 1
end
end
inputTable = tmpTables[1]
tmpTables = split(evalTables, ',')
for k=1, table.getn(tmpTables) do
v = tmpTables[k]
if all_tables[v] == nil then
all_tables[v] = table_id
table_id = table_id + 1
end
end
end
if boundaryTable ~= nil and boundaryTable ~= '' then
if all_tables[boundaryTable] == nil then
all_tables[boundaryTable] = table_id
table_id = table_id + 1
end
end
end
if cmd == 'evaluate' then
if evalTables ~= nil and evalTables ~= ''
then
tmpTables = split(evalTables, ',')
for k=1, table.getn(tmpTables) do
v = tmpTables[k]
if all_tables[v] == nil then
all_tables[v] = table_id
table_id = table_id + 1
end
end
inputTable = tmpTables[1]
end
end
if cmd == 'predict' then
if inputTable ~= nil and inputTable ~= ''
then
tmpTables = split(inputTable, ',')
for k=1, table.getn(tmpTables) do
v = tmpTables[k]
if all_tables[v] == nil then
all_tables[v] = table_id
table_id = table_id + 1
end
end
else
if tables ~= '' and tables ~= nil then
inputTable = split(tables, ',')[1]
else
error('either inputTable or tables must be set')
end
end
end
tables = {}
for k,v in pairs(all_tables) do
tables[v+1] = k
end
if inputTable == nil or inputTable == '' then
error('inputTable is not defined')
end
tables = join(tables, ',')
if cmd == 'vector_retrieve' then
if outputTable == nil or outputTable == '' then
error("outputTable is not set")
end
proj1, table1, partition1 = splitTableParam(outputTable)
out_table_name = proj1 .. "." .. table1
create_sql = ''
add_partition_sql = ''
if partition1 ~= nil and string.len(partition1) ~= 0 then
local partition_names, parition_values = parseParitionSpec(partition1)
create_partition_str = genCreatePartitionStr(partition_names)
create_sql = string.format("create table if not exists %s (query BIGINT, doc BIGINT, distance DOUBLE) partitioned by %s lifecycle %s;", out_table_name, create_partition_str, lifecycle)
add_partition_sql = genAddPartitionStr(partition_names, parition_values)
add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql)
else
create_sql = string.format("create table %s (query BIGINT, doc BIGINT, distance DOUBLE) lifecycle %s;", out_table_name, lifecycle)
add_partition_sql = string.format("desc %s;", out_table_name)
end
return "", "", "", "", create_sql, add_partition_sql, tables
end
proj0, table0, partition0 = splitTableParam(inputTable)
input_col_types, input_cols = getInputTableColTypes(proj0 .. "." .. table0)
if (excludedCols ~= nil and excludedCols ~= '') and
(selectedCols ~= nil and selectedCols ~= '') then
error('selected_cols and excluded_cols should not be set')
end
ex_cols_map = {}
if excludedCols ~= '' and excludedCols ~= nil then
ex_cols_lst = split(excludedCols, ',')
for i=1, table.getn(ex_cols_lst) do
ex_cols_map[ex_cols_lst[i]] = 1
end
end
selected_cols = {}
all_cols = {}
all_col_types = {}
all_cols_map = {}
if selectedCols ~= '' and selectedCols ~= nil then
tmp_cols = split(selectedCols, ",")
else
tmp_cols = input_cols
end
for i=1, table.getn(tmp_cols) do
if input_col_types[tmp_cols[i]] == nil then
error(string.format("column %s is not in input table", tmp_cols[i]))
return
elseif ex_cols_map[tmp_cols[i]] == nil then
if input_col_types[tmp_cols[i]] ~= nil and all_cols_map[tmp_cols[i]] == nil then
table.insert(all_cols, tmp_cols[i])
table.insert(all_col_types, input_col_types[tmp_cols[i]])
table.insert(selected_cols, tmp_cols[i])
all_cols_map[tmp_cols[i]] = 1
end
end
end
if cmd == 'train' or cmd == 'evaluate' then
return join(all_cols, ","), join(all_col_types, ","),
join(selected_cols, ","), '',
"select 1;", "select 1;", tables
end
reserved_cols = {}
reserved_col_types = {}
if reservedCols ~= nil and reservedCols ~= '' then
if reservedCols == 'ALL_COLUMNS' then
tmp_cols = input_cols
else
tmp_cols = split(reservedCols, ',')
end
for i=1, table.getn(tmp_cols) do
if input_col_types[tmp_cols[i]] ~= nil then
table.insert(reserved_cols, tmp_cols[i])
table.insert(reserved_col_types, input_col_types[tmp_cols[i]])
if all_cols_map[tmp_cols[i]] == nil then
table.insert(all_cols, tmp_cols[i])
table.insert(all_col_types, input_col_types[tmp_cols[i]])
all_cols_map[tmp_cols[i]] = 1
end
else
error("invalid reserved_col: " .. tmp_cols[i] .. " available: " .. join(input_cols, ","))
end
end
else
table.insert(reserved_cols, selected_cols[0])
end
sql_col_desc = {}
for i=1, table.getn(reserved_cols) do
table.insert(sql_col_desc, reserved_cols[i] .. " " .. reserved_col_types[i])
end
table.insert(sql_col_desc, outputCol)
sql_col_desc = join(sql_col_desc, ",")
if outputTable == nil or outputTable == '' then
error("outputTable is not set")
end
proj1, table1, partition1 = splitTableParam(outputTable)
out_table_name = proj1 .. "." .. table1
create_sql = ''
add_partition_sql = ''
if partition1 ~= nil and string.len(partition1) ~= 0 then
local partition_names, parition_values = parseParitionSpec(partition1)
create_partition_str = genCreatePartitionStr(partition_names)
create_sql = string.format("create table if not exists %s (%s) partitioned by %s lifecycle %s;", out_table_name, sql_col_desc, create_partition_str, lifecycle)
add_partition_sql = genAddPartitionStr(partition_names, parition_values)
add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql)
else
create_sql = string.format("create table %s (%s) lifecycle %s;", out_table_name, sql_col_desc, lifecycle)
add_partition_sql = string.format("desc %s;", out_table_name)
end
return join(all_cols, ","), join(all_col_types, ","),
join(selected_cols, ","), join(reserved_cols, ","),
create_sql, add_partition_sql, tables
end