Skip to content

Commit b6cbb30

Browse files
authored
Introducing CamemBertForZeroShotClassification annotator (#14354)
* [SPARKNLP-856] Introducing CamemBertForZeroShotClassification * [SPARKNLP-856] Adding notebook examples for CamemBertForZeroShotClassification * [SPARKNLP-856] Adding CamemBertForZeroShotClassification to ResourceDownloader
1 parent 6b7eb4b commit b6cbb30

File tree

13 files changed

+6556
-49
lines changed

13 files changed

+6556
-49
lines changed

examples/python/transformers/HuggingFace_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb

Lines changed: 2979 additions & 0 deletions
Large diffs are not rendered by default.

examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb

Lines changed: 2625 additions & 0 deletions
Large diffs are not rendered by default.

python/sparknlp/annotator/classifier_dl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@
5151
from sparknlp.annotator.classifier_dl.deberta_for_zero_shot_classification import *
5252
from sparknlp.annotator.classifier_dl.mpnet_for_sequence_classification import *
5353
from sparknlp.annotator.classifier_dl.mpnet_for_question_answering import *
54-
from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import *
54+
from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import *
55+
from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import *
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains classes for CamemBertForSequenceClassification."""
15+
16+
from sparknlp.common import *
17+
18+
19+
class CamemBertForZeroShotClassification(AnnotatorModel,
20+
HasCaseSensitiveProperties,
21+
HasBatchedAnnotate,
22+
HasClassifierActivationProperties,
23+
HasCandidateLabelsProperties,
24+
HasEngine,
25+
HasMaxSentenceLengthLimit):
26+
"""CamemBertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language
27+
inference) tasks. Equivalent of `DeBertaForSequenceClassification` models, but these models don't require a hardcoded
28+
number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more
29+
flexible.
30+
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
31+
pair and passed to the pretrained model.
32+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
33+
object:
34+
>>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\
35+
... .setInputCols(["token", "document"]) \\
36+
... .setOutputCol("label")
37+
The default model is ``"camembert_zero_shot_classifier_xnli_onnx"``, if no name is
38+
provided.
39+
For available pretrained models please see the `Models Hub
40+
<https://sparknlp.orgtask=Text+Classification>`__.
41+
To see which models are compatible and how to import them see
42+
`Import Transformers into Spark NLP 🚀
43+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
44+
====================== ======================
45+
Input Annotation types Output Annotation type
46+
====================== ======================
47+
``DOCUMENT, TOKEN`` ``CATEGORY``
48+
====================== ======================
49+
Parameters
50+
----------
51+
batchSize
52+
Batch size. Large values allows faster processing but requires more
53+
memory, by default 8
54+
caseSensitive
55+
Whether to ignore case in tokens for embeddings matching, by default
56+
True
57+
configProtoBytes
58+
ConfigProto from tensorflow, serialized into byte array.
59+
maxSentenceLength
60+
Max sentence length to process, by default 128
61+
coalesceSentences
62+
Instead of 1 class per sentence (if inputCols is `sentence`) output 1
63+
class per document by averaging probabilities in all sentences, by
64+
default False
65+
activation
66+
Whether to calculate logits via Softmax or Sigmoid, by default
67+
`"softmax"`.
68+
Examples
69+
--------
70+
>>> import sparknlp
71+
>>> from sparknlp.base import *
72+
>>> from sparknlp.annotator import *
73+
>>> from pyspark.ml import Pipeline
74+
>>> documentAssembler = DocumentAssembler() \\
75+
... .setInputCol("text") \\
76+
... .setOutputCol("document")
77+
>>> tokenizer = Tokenizer() \\
78+
... .setInputCols(["document"]) \\
79+
... .setOutputCol("token")
80+
>>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\
81+
... .setInputCols(["token", "document"]) \\
82+
... .setOutputCol("multi_class") \\
83+
... .setCaseSensitive(True)
84+
... .setCandidateLabels(["sport", "politique", "science"])
85+
>>> pipeline = Pipeline().setStages([
86+
... documentAssembler,
87+
... tokenizer,
88+
... sequenceClassifier
89+
... ])
90+
>>> data = spark.createDataFrame([["L'équipe de France joue aujourd'hui au Parc des Princes"]]).toDF("text")
91+
>>> result = pipeline.fit(data).transform(data)
92+
>>> result.select("class.result").show(truncate=False)
93+
+------+
94+
|result|
95+
+------+
96+
|[sport]|
97+
+------+
98+
"""
99+
name = "CamemBertForZeroShotClassification"
100+
101+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN]
102+
103+
outputAnnotatorType = AnnotatorType.CATEGORY
104+
105+
configProtoBytes = Param(Params._dummy(),
106+
"configProtoBytes",
107+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
108+
TypeConverters.toListInt)
109+
110+
coalesceSentences = Param(Params._dummy(), "coalesceSentences",
111+
"Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.",
112+
TypeConverters.toBoolean)
113+
114+
def getClasses(self):
115+
"""
116+
Returns labels used to train this model
117+
"""
118+
return self._call_java("getClasses")
119+
120+
def setConfigProtoBytes(self, b):
121+
"""Sets configProto from tensorflow, serialized into byte array.
122+
123+
Parameters
124+
----------
125+
b : List[int]
126+
ConfigProto from tensorflow, serialized into byte array
127+
"""
128+
return self._set(configProtoBytes=b)
129+
130+
def setCoalesceSentences(self, value):
131+
"""Instead of 1 class per sentence (if inputCols is '''sentence''') output 1
132+
class per document by averaging probabilities in all sentences, by default True.
133+
134+
Due to max sequence length limit in almost all transformer models such as BERT
135+
(512 tokens), this parameter helps feeding all the sentences into the model and
136+
averaging all the probabilities for the entire document instead of probabilities
137+
per sentence.
138+
139+
Parameters
140+
----------
141+
value : bool
142+
If the output of all sentences will be averaged to one output
143+
"""
144+
return self._set(coalesceSentences=value)
145+
146+
@keyword_only
147+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.CamemBertForZeroShotClassification",
148+
java_model=None):
149+
super(CamemBertForZeroShotClassification, self).__init__(
150+
classname=classname,
151+
java_model=java_model
152+
)
153+
self._setDefault(
154+
batchSize=8,
155+
maxSentenceLength=128,
156+
caseSensitive=True,
157+
coalesceSentences=False,
158+
activation="softmax"
159+
)
160+
161+
@staticmethod
162+
def loadSavedModel(folder, spark_session):
163+
"""Loads a locally saved model.
164+
165+
Parameters
166+
----------
167+
folder : str
168+
Folder of the saved model
169+
spark_session : pyspark.sql.SparkSession
170+
The current SparkSession
171+
172+
Returns
173+
-------
174+
CamemBertForZeroShotClassification
175+
The restored model
176+
"""
177+
from sparknlp.internal import _CamemBertForZeroShotClassificationLoader
178+
jModel = _CamemBertForZeroShotClassificationLoader(folder, spark_session._jsparkSession)._java_obj
179+
return CamemBertForZeroShotClassification(java_model=jModel)
180+
181+
@staticmethod
182+
def pretrained(name="camembert_zero_shot_classifier_xnli_onnx", lang="fr", remote_loc=None):
183+
"""Downloads and loads a pretrained model.
184+
185+
Parameters
186+
----------
187+
name : str, optional
188+
Name of the pretrained model, by default
189+
"camembert_zero_shot_classifier_xnli_onnx"
190+
lang : str, optional
191+
Language of the pretrained model, by default "fr"
192+
remote_loc : str, optional
193+
Optional remote address of the resource, by default None. Will use
194+
Spark NLPs repositories otherwise.
195+
196+
Returns
197+
-------
198+
CamemBertForSequenceClassification
199+
The restored model
200+
"""
201+
from sparknlp.pretrained import ResourceDownloader
202+
return ResourceDownloader.downloadModel(CamemBertForZeroShotClassification, name, lang, remote_loc)

python/sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class DeBertaForZeroShotClassification(AnnotatorModel,
2121
HasBatchedAnnotate,
2222
HasClassifierActivationProperties,
2323
HasCandidateLabelsProperties,
24-
HasEngine):
24+
HasEngine,
25+
HasMaxSentenceLengthLimit):
2526
"""DeBertaForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language
2627
inference) tasks. Equivalent of `DeBertaForSequenceClassification` models, but these models don't require a hardcoded
2728
number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more
@@ -101,11 +102,6 @@ class per document by averaging probabilities in all sentences, by
101102

102103
outputAnnotatorType = AnnotatorType.CATEGORY
103104

104-
maxSentenceLength = Param(Params._dummy(),
105-
"maxSentenceLength",
106-
"Max sentence length to process",
107-
typeConverter=TypeConverters.toInt)
108-
109105
configProtoBytes = Param(Params._dummy(),
110106
"configProtoBytes",
111107
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -130,15 +126,6 @@ def setConfigProtoBytes(self, b):
130126
"""
131127
return self._set(configProtoBytes=b)
132128

133-
def setMaxSentenceLength(self, value):
134-
"""Sets max sentence length to process, by default 128.
135-
Parameters
136-
----------
137-
value : int
138-
Max sentence length to process
139-
"""
140-
return self._set(maxSentenceLength=value)
141-
142129
def setCoalesceSentences(self, value):
143130
"""Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging
144131
probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as DeBerta

python/sparknlp/internal/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,13 @@ def __init__(self, path, jspark):
798798
jspark,
799799
)
800800

801+
class _CamemBertForZeroShotClassificationLoader(ExtendedJavaWrapper):
802+
def __init__(self, path, jspark):
803+
super(_CamemBertForZeroShotClassificationLoader, self).__init__(
804+
"com.johnsnowlabs.nlp.annotators.classifier.dl.CamemBertForZeroShotClassification.loadSavedModel",
805+
path,
806+
jspark,
807+
)
801808

802809
class _RobertaQAToZeroShotNerLoader(ExtendedJavaWrapper):
803810
def __init__(self, path):
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2017-2022 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import unittest
16+
17+
import pytest
18+
19+
from sparknlp.annotator import *
20+
from sparknlp.base import *
21+
from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests
22+
from test.util import SparkContextForTest
23+
24+
25+
@pytest.mark.slow
26+
class CamemBertForZeroShotClassificationTestSpec(unittest.TestCase, HasMaxSentenceLengthTests):
27+
def setUp(self):
28+
self.text = "L'équipe de France joue aujourd'hui au Parc des Princes"
29+
self.data = SparkContextForTest.spark \
30+
.createDataFrame([[self.text]]).toDF("text")
31+
32+
self.tested_annotator = CamemBertForZeroShotClassification \
33+
.pretrained() \
34+
.setInputCols(["document", "token"]) \
35+
.setOutputCol("class")
36+
37+
def test_run(self):
38+
document_assembler = DocumentAssembler() \
39+
.setInputCol("text") \
40+
.setOutputCol("document")
41+
42+
tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")
43+
44+
doc_classifier = self.tested_annotator
45+
46+
pipeline = Pipeline(stages=[
47+
document_assembler,
48+
tokenizer,
49+
doc_classifier
50+
])
51+
52+
model = pipeline.fit(self.data)
53+
model.transform(self.data).show()
54+
55+
light_pipeline = LightPipeline(model)
56+
annotations_result = light_pipeline.fullAnnotate(self.text)
57+
print(annotations_result)

0 commit comments

Comments
 (0)