scripts/generate_config_jat.py (60 lines of code) (raw):
from transformers import AutoTokenizer, CLIPImageProcessor
from jat.configuration_jat import JatConfig
from jat.processing_jat import JatProcessor
# Small model
tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"])
config = JatConfig(
vocab_size=tokenizer.vocab_size,
max_position_embeddings=512,
hidden_size=768,
num_layers=12,
attention_types=[[["global", "local"], 6]],
num_heads=12,
max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation)
tokenizer_class=tokenizer.__class__.__name__,
)
image_processor = CLIPImageProcessor(
size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size}
)
tokenizer.model_max_length = config.max_position_embeddings
tokenizer.pad_token = tokenizer.eos_token
processor = JatProcessor(tokenizer=tokenizer, image_processor=image_processor)
config.push_to_hub("jat-project/jat-small")
processor.push_to_hub("jat-project/jat-small")
# Medium model
tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"])
config = JatConfig(
vocab_size=tokenizer.vocab_size,
max_position_embeddings=1024,
hidden_size=2048,
num_layers=24,
attention_types=[[["global", "local"], 12]],
num_heads=16,
max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation)
tokenizer_class=tokenizer.__class__.__name__,
)
image_processor = CLIPImageProcessor(
size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size}
)
tokenizer.model_max_length = config.max_position_embeddings
tokenizer.pad_token = tokenizer.eos_token
processor = JatProcessor(tokenizer=tokenizer, image_processor=image_processor)
config.push_to_hub("jat-project/jat-medium")
processor.push_to_hub("jat-project/jat-medium")
# Large model
tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"])
config = JatConfig(
vocab_size=tokenizer.vocab_size,
max_position_embeddings=2048,
hidden_size=2560,
num_layers=32,
attention_types=[[["global", "local"], 16]],
num_heads=20,
max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation)
tokenizer_class=tokenizer.__class__.__name__,
)
image_processor = CLIPImageProcessor(
size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size}
)
tokenizer.model_max_length = config.max_position_embeddings
tokenizer.pad_token = tokenizer.eos_token
processor = JatProcessor(tokenizer=tokenizer, image_processor=image_processor)
config.push_to_hub("jat-project/jat-large")
processor.push_to_hub("jat-project/jat-large")