Skip to content

Commit dfbe32a

Browse files
authored
Merge pull request #1318 from gzrp/dev-postgresql
2 parents ef5715e + 21e0373 commit dfbe32a

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from abc import ABC, abstractmethod
21+
from singa import model
22+
23+
24+
class BaseTuner(ABC):
25+
"""
26+
BaseTuner: the base class of all tuner,all PEFT methods must inherit this class and implement the inject method.
27+
"""
28+
def __init__(self, config):
29+
r"""
30+
Args:
31+
config: object of the PeftConfig class or its subclasses
32+
"""
33+
self.config = config
34+
35+
@abstractmethod
36+
def inject(self, base_model: model.Model) -> model.Model:
37+
r"""
38+
all PEFT methods must implement the inject method, inject the peft method into the base model.
39+
Args:
40+
base_model: the base model
41+
42+
Returns: the base model with inject method
43+
"""
44+
raise NotImplementedError
45+
46+
@abstractmethod
47+
def merge_weights(self, base_model: model.Model, mode: bool = True) -> model.Model:
48+
r"""
49+
all PEFT methods must implement the merge_weights method. After model training, weights need to be combined to speed up inference
50+
Args:
51+
base_model: the base model with inject method
52+
mode: merge parameters or not, default True
53+
54+
Returns: the model with inject method after combining weights
55+
"""
56+
raise NotImplementedError
57+
58+
@staticmethod
59+
def freeze_base_parameters(base_model: model.Model):
60+
r"""
61+
freeze the weights of the base model
62+
Args:
63+
base_model: the base model
64+
"""
65+
params = base_model.get_params()
66+
for k, v in params.items():
67+
v.requires_grad = False
68+
v.stores_grad = False
69+
70+

0 commit comments

Comments
 (0)