tf.SegmentEmbeddingLayer module
tf.SegmentEmbeddingLayer module¶
- class tf.SegmentEmbeddingLayer.SegmentEmbeddingLayer(*args: Any, **kwargs: Any)¶
Bases:
modelzoo.common.layers.tf.BaseLayer.BaseLayer
Segment embedding layer. Adds segment information. For example, to which sentence the token belongs when an input sequence contains multiple sentences, such as two in the case of BERT model, to the token embedding provided as input.
- Parameters
num_segments (int) – Number of encoded segments.
embeddings_regularizer (callable) – Embeddings regularizer.
- build(input_shape)¶
- call(inputs, segment_ids)¶
Add segment embedding to inputs.
- Parameters
inputs – Tensor of input embeddings.
segment_ids – Segment IDs.