function parseTable()

in pai_jobs/easy_rec_flow/easy_rec.lua [349:611]


function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols,
                     reservedCols, lifecycle, outputCol, tables,
                     trainTables, evalTables, boundaryTable, queryTable, docTable)
  
  
  if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export'
     and cmd ~= 'export_checkpoint'
     and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then
    error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve')
  end

  
  if cmd == 'export' or cmd == 'custom' or cmd == 'export_checkpoint' then
    return "", "", "", "", "select 1;", "select 1;", tables
  end

  
  if cmd == 'train' and (tables == nil or tables == '') and (trainTables == nil or trainTables == '') then
    return "", "", "", "", "select 1;", "select 1;", tables
  end

  
  all_tables  = {}
  table_id = 0
  if tables ~= nil and tables ~= ''
  then
    tmpTables = split(tables, ',')
    for k=1, table.getn(tmpTables) do
      v = tmpTables[k]
      if all_tables[v] == nil then
        all_tables[v] = table_id
        table_id = table_id + 1
      end
    end
    if inputTable == nil or inputTable == ''
    then
      inputTable = tmpTables[1]
    end
  end

  if cmd == 'vector_retrieve' then
    inputTable = queryTable
    all_tables[queryTable] = table_id
    table_id = table_id + 1
    all_tables[docTable] = table_id
    table_id = table_id + 1
  end

  if cmd == 'train' then
    
    if trainTables ~= '' and trainTables ~= nil then
      tmpTables = split(trainTables, ',')
      for k=1, table.getn(tmpTables) do
        v = tmpTables[k]
        if all_tables[v] == nil then
          all_tables[v] = table_id
          table_id = table_id + 1
        end
      end
      inputTable = tmpTables[1]

      tmpTables = split(evalTables, ',')
      for k=1, table.getn(tmpTables) do
        v = tmpTables[k]
        if all_tables[v] == nil then
          all_tables[v] = table_id
          table_id = table_id + 1
        end
      end
    end
    if boundaryTable ~= nil and boundaryTable ~= '' then
      if all_tables[boundaryTable] == nil then
        all_tables[boundaryTable] = table_id
        table_id = table_id + 1
      end
    end
  end

  if cmd == 'evaluate' then
    
    if evalTables ~= nil and evalTables ~= ''
    then
      tmpTables = split(evalTables, ',')
      for k=1, table.getn(tmpTables) do
        v = tmpTables[k]
        if all_tables[v] == nil then
          all_tables[v] = table_id
          table_id = table_id + 1
        end
      end
      inputTable = tmpTables[1]
    end
  end

  if cmd == 'predict' then
    
    if inputTable ~= nil and inputTable ~= ''
    then
      tmpTables = split(inputTable, ',')
      for k=1, table.getn(tmpTables) do
        v = tmpTables[k]
        if all_tables[v] == nil then
          all_tables[v] = table_id
          table_id = table_id + 1
        end
      end
    else
      
      
      if tables ~= '' and tables ~= nil then
        inputTable = split(tables, ',')[1]
      else
        error('either inputTable or tables must be set')
      end
    end
  end

  
  tables = {}
  for k,v in pairs(all_tables) do
    
    tables[v+1] = k
    
  end

  if inputTable == nil or inputTable == '' then
    error('inputTable is not defined')
  end

  tables = join(tables, ',')

  if cmd == 'vector_retrieve' then
    if outputTable == nil or outputTable == '' then
      error("outputTable is not set")
    end

    proj1, table1, partition1 = splitTableParam(outputTable)
    out_table_name = proj1 .. "." .. table1
    create_sql = ''
    add_partition_sql = ''
    if partition1 ~= nil and string.len(partition1) ~= 0 then
      local partition_names, parition_values = parseParitionSpec(partition1)
      create_partition_str = genCreatePartitionStr(partition_names)
      create_sql = string.format("create table if not exists %s (query BIGINT, doc BIGINT, distance DOUBLE) partitioned by %s lifecycle %s;", out_table_name, create_partition_str, lifecycle)
      add_partition_sql = genAddPartitionStr(partition_names, parition_values)
      add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql)
    else
      create_sql = string.format("create table %s (query BIGINT, doc BIGINT, distance DOUBLE) lifecycle %s;", out_table_name, lifecycle)
      add_partition_sql = string.format("desc %s;",  out_table_name)
    end

    return "", "", "", "", create_sql, add_partition_sql, tables
  end

  
  proj0, table0, partition0 = splitTableParam(inputTable)
  input_col_types, input_cols = getInputTableColTypes(proj0 .. "." .. table0)

  if (excludedCols ~= nil and excludedCols ~= '') and
     (selectedCols ~= nil and selectedCols ~= '') then
    error('selected_cols and excluded_cols should not be set')
  end

  ex_cols_map = {}
  if excludedCols ~= '' and excludedCols ~= nil then
    ex_cols_lst = split(excludedCols, ',')
    for i=1, table.getn(ex_cols_lst) do
      ex_cols_map[ex_cols_lst[i]] = 1
    end
  end

  
  selected_cols = {}
  
  all_cols = {}
  all_col_types = {}
  all_cols_map = {}
  if selectedCols ~= '' and selectedCols ~= nil then
    tmp_cols = split(selectedCols, ",")
  else
    tmp_cols = input_cols
  end

  for i=1, table.getn(tmp_cols) do
    if input_col_types[tmp_cols[i]] == nil then
      error(string.format("column %s is not in input table", tmp_cols[i]))
      return
    elseif ex_cols_map[tmp_cols[i]] == nil then
      
      if input_col_types[tmp_cols[i]] ~= nil and all_cols_map[tmp_cols[i]] == nil then
        table.insert(all_cols, tmp_cols[i])
        table.insert(all_col_types, input_col_types[tmp_cols[i]])
        table.insert(selected_cols, tmp_cols[i])
        all_cols_map[tmp_cols[i]] = 1
      end
    end
  end

  if cmd == 'train' or cmd == 'evaluate' then
    return join(all_cols, ","), join(all_col_types, ","),
         join(selected_cols, ","), '',
         "select 1;", "select 1;", tables
  end

  
  
  
  reserved_cols = {}
  reserved_col_types = {}
  if reservedCols ~= nil and reservedCols ~= '' then
    if reservedCols == 'ALL_COLUMNS' then
      tmp_cols = input_cols
    else
      tmp_cols = split(reservedCols, ',')
    end
    for i=1, table.getn(tmp_cols) do
      if input_col_types[tmp_cols[i]] ~= nil then
        table.insert(reserved_cols, tmp_cols[i])
        table.insert(reserved_col_types, input_col_types[tmp_cols[i]])
        if all_cols_map[tmp_cols[i]] == nil then
          table.insert(all_cols, tmp_cols[i])
          table.insert(all_col_types, input_col_types[tmp_cols[i]])
          all_cols_map[tmp_cols[i]] = 1
        end
      else
        error("invalid reserved_col: " .. tmp_cols[i] .. " available: " .. join(input_cols, ","))
      end
    end
  else
    table.insert(reserved_cols, selected_cols[0])
  end

  
  sql_col_desc = {}
  for i=1, table.getn(reserved_cols) do
    table.insert(sql_col_desc, reserved_cols[i] .. " " .. reserved_col_types[i])
  end
  table.insert(sql_col_desc, outputCol)
  sql_col_desc = join(sql_col_desc, ",")

  if outputTable == nil or outputTable == '' then
    error("outputTable is not set")
  end

  proj1, table1, partition1 = splitTableParam(outputTable)
  out_table_name = proj1 .. "." .. table1
  create_sql = ''
  add_partition_sql = ''
  if partition1 ~= nil and string.len(partition1) ~= 0 then
    local partition_names, parition_values = parseParitionSpec(partition1)
    create_partition_str = genCreatePartitionStr(partition_names)
    create_sql = string.format("create table if not exists %s (%s) partitioned by %s lifecycle %s;", out_table_name, sql_col_desc, create_partition_str, lifecycle)
    add_partition_sql = genAddPartitionStr(partition_names, parition_values)
    add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql)
  else
    create_sql = string.format("create table %s (%s) lifecycle %s;", out_table_name, sql_col_desc, lifecycle)
    add_partition_sql = string.format("desc %s;",  out_table_name)
  end

  return join(all_cols, ","), join(all_col_types, ","),
         join(selected_cols, ","), join(reserved_cols, ","),
         create_sql, add_partition_sql, tables
end