SHOW:
|
|
- or go back to the newest paste.
1 | #!/bin/bash | |
2 | # Copyright (c) Microsoft Corporation. | |
3 | # SPDX-License-Identifier: Apache-2.0 | |
4 | ||
5 | # DeepSpeed Team | |
6 | ACTOR_MODEL_PATH=$1 | |
7 | CRITIC_MODEL_PATH=$2 | |
8 | ACTOR_ZERO_STAGE=$3 | |
9 | CRITIC_ZERO_STAGE=$4 | |
10 | OUTPUT=$5 | |
11 | - | if [ "$OUTPUT" == "" ]; then |
11 | + | |
12 | - | OUTPUT=/nobackup/jirayu/models/opt-1.3b-ppo |
12 | + | |
13 | fi | |
14 | if [ "$CRITIC_ZERO_STAGE" == "" ]; then | |
15 | CRITIC_ZERO_STAGE=2 | |
16 | fi | |
17 | ||
18 | # if actor and critic model names are not provided, then use the publicly available AdamG012/chat-opt-1.3b-sft-deepspeed and AdamG012/chat-opt-350m-reward-deepspeed | |
19 | ||
20 | mkdir -p $OUTPUT | |
21 | ||
22 | - | if [ "$ACTOR_MODEL_PATH" == "" ]; then |
22 | + | |
23 | - | ACTOR_MODEL_PATH=/nobackup/jirayu/models/opt-1.3b-sft |
23 | + | |
24 | Actor_Lr=1e-5 | |
25 | - | if [ "$CRITIC_MODEL_PATH" == "" ]; then |
25 | + | |
26 | - | CRITIC_MODEL_PATH=/nobackup/jirayu/models/opt-350m-rm |
26 | + | |
27 | deepspeed --master_port 12346 main.py \ | |
28 | --data_path stanfordnlp/SHP \ | |
29 | --data_split 0,0,10 \ | |
30 | --actor_model_name_or_path $ACTOR_MODEL_PATH \ | |
31 | --critic_model_name_or_path $CRITIC_MODEL_PATH \ | |
32 | --num_padding_at_beginning 1 \ | |
33 | --per_device_generation_batch_size 8 \ | |
34 | --per_device_training_batch_size 8 \ | |
35 | --generation_batches 1 \ | |
36 | --ppo_epochs 1 \ | |
37 | --max_answer_seq_len 256 \ | |
38 | --max_prompt_seq_len 256 \ | |
39 | --actor_learning_rate ${Actor_Lr} \ | |
40 | --critic_learning_rate ${Critic_Lr} \ | |
41 | --num_train_epochs 1 \ | |
42 | --lr_scheduler_type cosine \ | |
43 | --gradient_accumulation_steps 1 \ | |
44 | --disable_actor_dropout \ | |
45 | --num_warmup_steps 100 \ | |
46 | --deepspeed --seed 1234 \ | |
47 | --enable_hybrid_engine \ | |
48 | --actor_zero_stage $ACTOR_ZERO_STAGE \ | |
49 | --critic_zero_stage $CRITIC_ZERO_STAGE \ | |
50 | --enable_ema \ | |
51 | --output_dir $OUTPUT \ | |
52 | --enable_tensorboard \ | |
53 | --tensorboard_path $OUTPUT \ | |
54 | | tee $OUTPUT/training.log |