from transformers import (
    AutoTokenizer,
    Gemma4Config,
    Gemma4ForConditionalGeneration,
    Gemma4TextConfig,
    Gemma4ForCausalLM,
    Gemma4VisionConfig,
    Gemma4AudioConfig,
)


def generate_vlm_model(output_dir="./tiny-random-gemma4"):
    model_tr = Gemma4ForConditionalGeneration.from_pretrained("google/gemma-4-E2B-it")
    config = model_tr.config

    config.audio_config.hidden_size = 8
    config.audio_config.num_attention_heads = 2
    config.audio_config.num_hidden_layers = 1
    config.audio_config.output_proj_dims = 8

    config.text_config.global_head_dim = 4
    config.text_config.head_dim = 4
    config.text_config.hidden_size = 8
    config.text_config.hidden_size_per_layer_input = 1
    config.text_config.intermediate_size = 32
    config.text_config.num_attention_heads = 2
    config.text_config.num_hidden_layers = 3
    config.text_config.layer_types = ["sliding_attention", "full_attention", "full_attention"]
    config.text_config.num_kv_shared_layers = 1
    config.text_config.dtype = "float32"

    config.vision_config.default_output_length = 70
    config.vision_config.head_dim = 4
    config.vision_config.hidden_size = 8
    config.vision_config.intermediate_size = 32
    config.vision_config.num_attention_heads = 2
    config.vision_config.num_hidden_layers = 1
    config.vision_config.num_key_value_heads = 2
    config.vision_config.patch_size = 2

    model = Gemma4ForConditionalGeneration(config)
    model.eval()

    model.save_pretrained(output_dir)

    # Copy tokenizer from google/gemma-4-E2B-it
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-E2B-it")
    tokenizer.save_pretrained(output_dir)

    # Estimate safetensors size
    import os

    safetensors_path = os.path.join(output_dir, "model.safetensors")
    if os.path.exists(safetensors_path):
        size_mb = os.path.getsize(safetensors_path) / (1024 * 1024)
        print(f"  model.safetensors size: {size_mb:.1f} MB")

    print(f"  VLM model saved to {output_dir}")
    return model


if __name__ == "__main__":
    generate_vlm_model()
Downloads last month
483
Safetensors
Model size
3.1M params
Tensor type
F32
·
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support