1- from typing import Callable , Dict , Iterable , List , Optional , Set , Tuple , cast
1+ from typing import Callable , Dict , Iterable , List , Optional , Tuple , cast
22
33import srsly
44import torch
5- from partial_tagger .data import Alignments , LabelSet
6- from partial_tagger .training import compute_partially_supervised_loss
7- from partial_tagger .utils import create_tag
5+ from partial_tagger .training import compute_partially_supervised_loss , create_tag_bitmap
6+ from sequence_label import LabelSet , SequenceLabel
87from spacy import util
98from spacy .errors import Errors
109from spacy .language import Language
@@ -39,7 +38,9 @@ def __init__(
3938
4039 @property
4140 def label_set (self ) -> LabelSet :
42- return LabelSet (set (self .cfg ["labels" ]))
41+ return LabelSet (
42+ labels = set (self .cfg ["labels" ]), padding_index = self .padding_index
43+ )
4344
4445 def predict (self , docs : List [Doc ]) -> Floats2d :
4546 (_ , tag_indices ) = self .model .predict (docs )
@@ -50,16 +51,14 @@ def set_annotations(
5051 docs : List [Doc ],
5152 tag_indices : Floats2d ,
5253 ) -> None :
53- alignments = Alignments (tuple (doc .user_data ["alignment" ] for doc in docs ))
54- tags_batch = alignments .create_char_based_tags (
55- tag_indices .tolist (),
56- label_set = self .label_set ,
57- padding_index = self .padding_index ,
54+ labels = self .label_set .decode (
55+ tag_indices = tag_indices .tolist (),
56+ alignments = tuple (doc .user_data ["alignment" ] for doc in docs ),
5857 )
5958
60- for doc , tags in zip (docs , tags_batch ):
59+ for doc , label in zip (docs , labels ):
6160 ents = []
62- for tag in tags :
61+ for tag in label . tags :
6362 span = doc .char_span (tag .start , tag .start + tag .length , tag .label )
6463 if span :
6564 ents .append (span )
@@ -89,19 +88,17 @@ def update(
8988 losses [self .name ] += loss
9089 return losses
9190
92- def initialize (
93- self , get_examples : Callable , * , nlp : Language , labels : Optional [dict ] = None
94- ) -> None :
91+ def initialize (self , get_examples : Callable , * , nlp : Language ) -> None :
9592 X_small : List [Doc ] = []
96- label : Set [str ] = set ()
93+ labels : List [str ] = []
9794 for example in get_examples ():
9895 if len (X_small ) < 10 :
9996 X_small .append (example .x )
10097 for entity in example .y .ents :
101- if entity .label_ not in label :
102- label . add (entity .label_ )
98+ if entity .label_ not in labels :
99+ labels . append (entity .label_ )
103100
104- self .cfg ["labels" ] = list ( label )
101+ self .cfg ["labels" ] = labels
105102
106103 self .model .initialize (
107104 X = X_small ,
@@ -113,23 +110,32 @@ def get_loss(
113110 ) -> Tuple [float , Floats4d ]:
114111 scores_pt = xp2torch (scores , requires_grad = True )
115112
116- char_based_tags = []
117- temp = []
113+ labels = []
114+ alignments = []
118115 lengths = []
119116 for example in examples :
120- tags = tuple (
121- create_tag (ent .start_char , len (ent .text ), ent .label_ )
122- for ent in example .y .ents
117+ labels .append (
118+ SequenceLabel .from_dict (
119+ tags = [
120+ {
121+ "start" : ent .start_char ,
122+ "end" : ent .end_char ,
123+ "label" : ent .label_ ,
124+ }
125+ for ent in example .y .ents
126+ ],
127+ size = len (example .y .text ),
128+ )
123129 )
124- char_based_tags .append (tags )
125130
126131 alignment = example .x .user_data ["alignment" ]
127- lengths .append (alignment . num_tokens )
128- temp .append (alignment )
132+ alignments .append (alignment )
133+ lengths .append (alignment . target_size )
129134
130- alignments = Alignments (tuple (temp ))
131- tag_bitmap = torch .tensor (
132- alignments .get_tag_bitmap (char_based_tags , self .label_set ),
135+ tag_bitmap = create_tag_bitmap (
136+ label_set = self .label_set ,
137+ labels = tuple (labels ),
138+ alignments = tuple (alignments ),
133139 device = scores_pt .device ,
134140 )
135141
@@ -140,7 +146,7 @@ def get_loss(
140146 )
141147
142148 loss = compute_partially_supervised_loss (
143- scores_pt , tag_bitmap , mask , self .label_set .get_outside_index ()
149+ scores_pt , tag_bitmap , mask , self .label_set .outside_index
144150 )
145151
146152 (grad ,) = torch .autograd .grad (loss , scores_pt )
0 commit comments