Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 02a6984

Browse files
authored
Update run_generation_gpu_woq.py (#1454)
1 parent d6e6e9f commit 02a6984

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@
103103
type=int,
104104
help="Calibration dataset max or padding max length for AutoRound.",
105105
)
106+
parser.add_argument(
107+
"--lr",
108+
type=float,
109+
default=0.0025,
110+
help="learning rate, if None, it will be set to 1.0/iters automatically",
111+
)
112+
parser.add_argument(
113+
"--minmax_lr",
114+
type=float,
115+
default=0.0025,
116+
help="minmax learning rate, if None,it will beset to be the same with lr",
117+
)
118+
parser.add_argument(
119+
"--use_quant_input",
120+
action="store_true",
121+
help="whether to use the output of quantized block to tune the next block",
122+
)
106123
# =======================================
107124
args = parser.parse_args()
108125
torch_dtype = convert_dtype_str2torch(args.compute_dtype)
@@ -162,6 +179,9 @@
162179
calib_iters=args.calib_iters,
163180
calib_len=args.calib_len,
164181
nsamples=args.nsamples,
182+
lr=args.lr,
183+
minmax_lr=args.minmax_lr,
184+
use_quant_input=args.use_quant_input,
165185
)
166186
elif args.woq_algo.lower() == "rtn":
167187
quantization_config = RtnConfig(

0 commit comments

Comments
 (0)