@@ -106,17 +106,53 @@ pip install omegavit
106106## Quick Start
107107
108108``` python
109+ import sys
110+ from omegavit.main import create_advanced_vit, train_step
109111import torch
110- from omegavit import create_advanced_vit
112+ from loguru import logger
113+
114+ def main ():
115+ """ Main training function."""
116+ logger.info(" Starting training setup" )
117+
118+ # Setup
119+ device = torch.device(
120+ " cuda" if torch.cuda.is_available() else " cpu"
121+ )
122+ model = create_advanced_vit().to(device)
123+ optimizer = torch.optim.AdamW(
124+ model.parameters(), lr = 1e-4 , weight_decay = 0.05
125+ )
126+
127+ # Example input for testing
128+ batch_size = 8
129+ example_input = torch.randn(batch_size, 3 , 224 , 224 ).to(device)
130+ example_labels = torch.randint(0 , 1000 , (batch_size,)).to(device)
131+
132+ logger.info(" Running forward pass with example input" )
133+ output = model(example_input)
134+ logger.info(f " Output shape: { output.shape} " )
135+
136+ # Example training step
137+ loss = train_step(
138+ model, optimizer, (example_input, example_labels), device
139+ )
140+ logger.info(f " Example training step loss: { loss:.4f } " )
141+
142+
143+ if __name__ == " __main__" :
144+ # Configure logger
145+ logger.remove()
146+ logger.add(
147+ " advanced_vit.log" ,
148+ rotation = " 500 MB" ,
149+ level = " DEBUG" ,
150+ format = " {time: YYYY-MM-DD HH:mm:ss} | {level} | {message} " ,
151+ )
152+ logger.add(sys.stdout, level = " INFO" )
153+
154+ main()
111155
112- # Create model
113- model = create_advanced_vit(num_classes = 1000 )
114-
115- # Example forward pass
116- batch_size = 8
117- x = torch.randn(batch_size, 3 , 224 , 224 )
118- output = model(x)
119- print (f " Output shape: { output.shape} " ) # [8, 1000]
120156```
121157
122158## Model Configurations
0 commit comments