function getHyperParams()

in pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua [104:287]


function getHyperParams(config, cmd, checkpoint_path, fine_tune_checkpoint,
                        eval_result_path, export_dir,  gpuRequired,
                        cpuRequired, memRequired, cluster, continue_train,
                        distribute_strategy, with_evaluator, eval_method,
                        edit_config_json, selected_cols,
                        model_dir, hpo_param_path, hpo_metric_save_path,
                        saved_model_dir, all_cols, all_col_types,
                        reserved_cols, output_cols, model_outputs,
                        input_table, output_table, tables, query_table,
                        doc_table, knn_distance, knn_num_neighbours,
                        knn_feature_dims, knn_index_type, knn_feature_delimiter,
                        knn_nlist, knn_nprobe, knn_compress_dim, train_tables,
                        eval_tables, boundary_table, batch_size, profiling_file,
                        mask_feature_name, extra_params)
  hyperParameters = ""
  if cmd == "predict" then
    if cluster == nil or cluster == '' then
      error('cluster must be set')
    end
    if saved_model_dir == nil or saved_model_dir == '' then
      error('saved_model_dir must be set')
      checkOss(saved_model_dir)
    end
    hyperParameters = " --cmd=" .. cmd
    hyperParameters = hyperParameters .. " --saved_model_dir=" .. saved_model_dir
    hyperParameters = hyperParameters .. " --all_cols=" .. all_cols ..
                     " --all_col_types=" .. all_col_types
    if selected_cols ~= nil and selected_cols ~= '' then
      hyperParameters = hyperParameters .. " --selected_cols=" .. selected_cols
    end
    if reserved_cols ~= nil and string.len(reserved_cols) > 0 then
      hyperParameters = hyperParameters .. " --reserved_cols=" .. reserved_cols
    end
    hyperParameters = hyperParameters .. " --batch_size=" .. batch_size
    if profiling_file ~= nil and profiling_file ~= '' then
      checkOss(profiling_file)
      hyperParameters = hyperParameters .. " --profiling_file=" .. profiling_file
    end
    
    
    if model_outputs ~= nil and model_outputs ~= "" then
      hyperParameters = hyperParameters .. " --output_cols='" .. model_outputs .. "'"
    else
      hyperParameters = hyperParameters .. " --output_cols='" .. output_cols .. "'"
    end
    checkTable(input_table)
    checkTable(output_table)

    if extra_params ~= nil and extra_params ~= '' then
      hyperParameters = hyperParameters .. " " .. extra_params
    end
    return hyperParameters, cluster, tables, output_table
  end

  if cmd == "vector_retrieve" then
    if cluster == nil or cluster == '' then
      error('cluster must be set')
    end
    checkTable(query_table)
    checkTable(doc_table)
    checkTable(output_table)
    hyperParameters = " --cmd=" .. cmd
    hyperParameters = hyperParameters .. " --batch_size=" .. batch_size
    hyperParameters = hyperParameters .. " --knn_distance=" .. knn_distance
    if knn_num_neighbours ~= nil and knn_num_neighbours ~= '' then
      hyperParameters = hyperParameters .. ' --knn_num_neighbours=' .. knn_num_neighbours
    end
    if knn_feature_dims ~= nil and knn_feature_dims ~= '' then
      hyperParameters = hyperParameters .. ' --knn_feature_dims=' .. knn_feature_dims
    end
    hyperParameters = hyperParameters .. " --knn_index_type=" .. knn_index_type
    hyperParameters = hyperParameters .. " --knn_feature_delimiter=" .. knn_feature_delimiter
    if knn_nlist ~= nil and knn_nlist ~= '' then
      hyperParameters = hyperParameters .. ' --knn_nlist=' .. knn_nlist
    end
    if knn_nprobe ~= nil and knn_nprobe ~= '' then
      hyperParameters = hyperParameters .. ' --knn_nprobe=' .. knn_nprobe
    end
    if knn_compress_dim ~= nil and knn_compress_dim ~= '' then
      hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim
    end
    if extra_params ~= nil and extra_params ~= '' then
      hyperParameters = hyperParameters .. " " .. extra_params
    end
    return hyperParameters, cluster, tables, output_table
  end

  if cmd ~= "custom" then
    checkConfig(config)
  end

  hyperParameters = "--config='" .. config .. "'"

  if selected_cols ~= nil and selected_cols ~= '' then
    hyperParameters = hyperParameters .. ' --selected_cols=' .. selected_cols
  end

  hyperParameters = string.format('%s --cmd=%s', hyperParameters, cmd)

  if cmd == 'evaluate' then
    hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path
    hyperParameters = hyperParameters .. " --all_cols=" .. all_cols ..
                     " --all_col_types=" .. all_col_types
    hyperParameters = hyperParameters .. " --eval_result_path=" .. eval_result_path
    hyperParameters = hyperParameters .. " --mask_feature_name=" .. mask_feature_name
    hyperParameters = hyperParameters .. " --distribute_strategy=" .. distribute_strategy
    if eval_tables ~= "" and eval_tables ~= nil then
      hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables
    end
  elseif cmd == 'export' or cmd == 'export_checkpoint' then
    hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path
    hyperParameters = hyperParameters .. " --export_dir=" .. export_dir
  elseif cmd == 'train' then
    hyperParameters = hyperParameters .. " --all_cols=" .. all_cols ..
                     " --all_col_types=" .. all_col_types
    hyperParameters = hyperParameters .. " --continue_train=" .. continue_train
    hyperParameters = hyperParameters .. " --distribute_strategy=" .. distribute_strategy
    if with_evaluator ~= "" and tonumber(with_evaluator) ~= 0 then
      hyperParameters = hyperParameters .. " --with_evaluator"
    end
    if fine_tune_checkpoint ~= nil and fine_tune_checkpoint ~= '' then
      hyperParameters = hyperParameters .. " --fine_tune_checkpoint=" .. fine_tune_checkpoint
    end
    if eval_method ~= 'none' and eval_method ~= 'separate' and eval_method ~= 'master' then
      error('invalid eval_method ' .. eval_method)
    end
    if eval_method ~= "" then
      hyperParameters = hyperParameters .. " --eval_method=" .. eval_method
    end

    
    if train_tables ~= "" and train_tables ~= nil then
      hyperParameters = hyperParameters .. " --train_tables " .. train_tables
    end
    if eval_tables ~= "" and eval_tables ~= nil then
      hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables
    end
    if boundary_table ~= "" and boundary_table ~= nil then
      hyperParameters = hyperParameters .. " --boundary_table " .. boundary_table
    end

    if hpo_param_path ~= "" and hpo_param_path ~= nil then
      hyperParameters = hyperParameters .. " --hpo_param_path=" .. hpo_param_path
      if hpo_metric_save_path == nil then
        error('hpo_metric_save_path must be set')
      end
      hyperParameters = hyperParameters .. " --hpo_metric_save_path=" .. hpo_metric_save_path
    end
  end

  if edit_config_json ~= "" and edit_config_json ~= nil then
    hyperParameters = hyperParameters ..
        string.format(" --edit_config_json='%s'", edit_config_json)
  end

  if model_dir ~= "" and model_dir ~= nil then
    checkOss(model_dir)
    hyperParameters = hyperParameters .. " --model_dir=" .. model_dir
  end

  check_run_mode(cluster, gpuRequired)
  if gpuRequired ~= "" then
    num_gpus_per_worker = math.max(math.ceil(tonumber(gpuRequired)/100), 0)
    cluster = string.format('{"worker":{"count":1, "gpu":%s, "cpu":%s, "memory":%s}}',
                     gpuRequired, cpuRequired, memRequired)
  elseif cluster ~= "" then
    gpus_str = string.match(cluster, '"gpu"%s*:%s*(%d+)')
    if gpus_str ~= nil then
      num_gpus_per_worker = math.max(math.ceil(tonumber(gpus_str)/100), 0)
    else
      num_gpus_per_worker = 1
    end
  else
    num_gpus_per_worker = 1
  end
  hyperParameters = string.format("%s --num_gpus_per_worker=%s ", hyperParameters,
                                  num_gpus_per_worker)

  if extra_params ~= nil and extra_params ~= '' then
    hyperParameters = hyperParameters .. " " .. extra_params
  end

  return hyperParameters, cluster, tables, output_table
end