def setup_rag_args()

in parlai/agents/rag/args.py [0:0]


def setup_rag_args(parser: ParlaiParser) -> ParlaiParser:
    group = parser.add_argument_group('RAG Model Args')
    # Standard RAG Agent Arguments
    group.add_argument(
        '--generation-model',
        type=str,
        default='bart',
        help='which generation model to use',
        choices=['transformer/generator', 'bart', 't5'],
    )
    group.add_argument(
        '--query-model',
        type=str,
        default='bert',
        help='Which query model to use for DPR.',
        choices=QUERY_MODEL_TYPES,
    )
    group.add_argument(
        '--rag-model-type',
        type=str,
        default='token',
        help='which rag model decoding to use.',
        choices=['token', 'sequence', 'turn'],
    )
    group.add_argument(
        '--thorough',
        type='bool',
        default=False,
        help='whether to use thorough decoding for rag sequence. ',
    )
    modified_group = parser.add_argument_group('Modified RAG Args')
    modified_group.add_argument(
        '--n-extra-positions',
        type=int,
        default=0,
        help='Specify > 0 to include extra positions in the encoder, in which '
        'retrieved knowledge will go. In this setup, knowledge is _appended_ '
        'instead of prepended.',
    )
    modified_group.add_argument(
        '--gold-knowledge-passage-key',
        type=str,
        default='checked_sentence',
        help='key in the observation dict that indicates the gold knowledge passage. '
        'Specify, along with --debug, to compute passage retrieval metrics at train/test time.',
    )
    modified_group.add_argument(
        '--gold-knowledge-title-key',
        type=str,
        default='title',
        help='key in the observation dict that indicates the gold knowledge passage title. '
        'Specify, along with --debug, to compute passage retrieval metrics at train/test time.',
    )
    retriever_group = parser.add_argument_group('RAG Retriever Args')
    retriever_group.add_argument(
        '--rag-retriever-query',
        type=str,
        default='full_history',
        choices=['one_turn', 'full_history'],
        help='What to use as the query for retrieval. `one_turn` retrieves only on the last turn '
        'of dialogue; `full_history` retrieves based on the full dialogue history.',
    )
    retriever_group.add_argument(
        '--rag-retriever-type',
        type=str,
        default=RetrieverType.DPR.value,
        choices=[r.value for r in RetrieverType],
        help='Which retriever to use',
    )
    retriever_group.add_argument(
        '--retriever-debug-index',
        type=str,
        default=None,
        choices=SMALL_INDEX_TYPES,
        help='Load specified small index, for debugging.',
    )
    retriever_group.add_argument(
        '--n-docs', type=int, default=5, help='How many documents to retrieve'
    )
    retriever_group.add_argument(
        '--min-doc-token-length',
        type=int,
        default=64,
        help='minimum amount of information to retain from document. '
        'Useful to define if encoder does not use a lot of BPE token context.',
    )
    retriever_group.add_argument(
        '--max-doc-token-length',
        type=int,
        default=256,
        help='maximum amount of information to retain from document. ',
    )
    retriever_group.add_argument(
        '--rag-query-truncate',
        type=int,
        default=512,
        help='Max token length of query for retrieval.',
    )
    retriever_group.add_argument(
        '--print-docs',
        type='bool',
        default=False,
        help='Whether to print docs; usually useful during interactive mode.',
    )
    dense_retriever_group = parser.add_argument_group(
        'RAG Dense Passage Retriever Args'
    )
    dense_retriever_group.add_argument(
        '--path-to-index',
        type=str,
        default=WIKIPEDIA_COMPRESSED_INDEX,
        help='path to FAISS Index.',
    )
    dense_retriever_group.add_argument(
        '--path-to-dense-embeddings',
        type=str,
        default=None,
        help='path to dense embeddings directory used to build index. '
        'Default None will assume embeddings and index are in the same directory.',
    )
    dense_retriever_group.add_argument(
        '--dpr-model-file', type=str, default=DPR_ZOO_MODEL, help='path to DPR Model.'
    )
    dense_retriever_group.add_argument(
        '--path-to-dpr-passages',
        type=str,
        default=WIKIPEDIA_ZOO_PASSAGES,
        help='Path to DPR passages, used to build index.',
    )
    dense_retriever_group.add_argument(
        '--retriever-embedding-size',
        type=int,
        default=768,
        help='Embedding size of dense retriever',
    )
    tfidf_retriever_group = parser.add_argument_group('RAG TFIDF Retriever Args')
    tfidf_retriever_group.add_argument(
        '--tfidf-max-doc-paragraphs',
        type=int,
        default=-1,
        help='If > 0, limit documents to this many paragraphs',
    )
    tfidf_retriever_group.add_argument(
        '--tfidf-model-path',
        type=str,
        default=TFIDF_ZOO_MODEL,
        help='Optionally override TFIDF model.',
    )
    dpr_poly_retriever_group = parser.add_argument_group('RAG DPR-POLY Retriever Args')
    dpr_poly_retriever_group.add_argument(
        '--dpr-num-docs',
        type=int,
        default=25,
        help='In two stage retrieval, how many DPR documents to retrieve',
    )
    dpr_poly_retriever_group.add_argument(
        '--poly-score-initial-lambda',
        type=float,
        default=0.5,
        help='In two stage retrieval, how much weight to give to the poly scores. '
        'Note: Learned parameter. Specify initial value here',
    )
    dpr_poly_retriever_group.add_argument(
        '--polyencoder-init-model',
        type=str,
        default='wikito',
        help='Which init model to initialize polyencoder with. Specify wikito or reddit to use '
        'models from the ParlAI zoo; otherwise, provide a path to a trained polyencoder',
    )
    poly_faiss_group = parser.add_argument_group('RAG PolyFAISS retriever args')
    poly_faiss_group.add_argument(
        '--poly-faiss-model-file',
        type=str,
        default=None,
        help='path to poly-encoder for use in poly-faiss retrieval.',
    )
    regret_group = parser.add_argument_group("RAG ReGReT args")
    regret_group.add_argument(
        '--regret',
        type='bool',
        default=False,
        help='Retrieve, Generate, Retrieve, Tune. '
        'Retrieve, generate, then retrieve again, and finally tune (refine).',
    )
    regret_group.add_argument(
        '--regret-intermediate-maxlen',
        type=int,
        default=32,
        help='Maximum length in intermediate regret generation',
    )
    regret_group.add_argument(
        '--regret-model-file',
        type=str,
        default=None,
        help='Path to model for initial round of retrieval. ',
    )
    regret_group.add_argument(
        '--regret-dict-file',
        type=str,
        default=None,
        help='Path to dict file for model for initial round of retrieval. ',
    )
    regret_group.add_argument(
        '--regret-override-index',
        type='bool',
        default=False,
        help='Overrides the index used with the ReGReT model, if using separate models. '
        'I.e., the initial round of retrieval uses the same index as specified for the '
        'second round of retrieval',
    )
    indexer_group = parser.add_argument_group("RAG Indexer Args")
    indexer_group.add_argument(
        '--indexer-type',
        type=str,
        default='compressed',
        choices=['exact', 'compressed'],
        help='Granularity of RAG Indexer. Choose compressed to save on RAM costs, at the '
        'possible expense of accuracy.',
    )
    indexer_group.add_argument(
        '--indexer-buffer-size',
        type=int,
        default=65536,
        help='buffer size for adding vectors to the index',
    )
    indexer_group.add_argument(
        '--compressed-indexer-factory',
        type=str,
        default='IVF4096_HNSW128,PQ128',
        help='If specified, builds compressed indexer from a FAISS Index Factory. '
        'see https://github.com/facebookresearch/faiss/wiki/The-index-factory for details',
    )
    indexer_group.add_argument(
        '--compressed-indexer-gpu-train',
        type='bool',
        default=False,
        hidden=True,
        help='Set False to not train compressed indexer on the gpu.',
    )
    indexer_group.add_argument(
        '--compressed-indexer-nprobe',
        type=int,
        default=64,
        help='How many centroids to search in compressed indexer. See '
        'https://github.com/facebookresearch/faiss/wiki/Faiss-indexes#cell-probe-methods-indexivf-indexes '
        'for details',
    )
    # See https://github.com/facebookresearch/faiss/wiki/Faiss-indexes#indexhnsw-variants for details
    indexer_group.add_argument(
        '--hnsw-indexer-store-n',
        type=int,
        default=128,
        hidden=True,
        help='Granularity of  DenseHNSWIndexer. Higher == more accurate, more RAM',
    )
    indexer_group.add_argument(
        '--hnsw-ef-search',
        type=int,
        default=128,
        hidden=True,
        help='Depth of exploration of search for HNSW.',
    )
    indexer_group.add_argument(
        '--hnsw-ef-construction',
        type=int,
        default=200,
        hidden=True,
        help='Depth of exploration at add time for HNSW',
    )
    return parser