tf.layers.SegmentEmbeddingLayer module

tf.layers.SegmentEmbeddingLayer module

class tf.layers.SegmentEmbeddingLayer.SegmentEmbeddingLayer(*args: Any, **kwargs: Any)

Bases: modelzoo.common.tf.layers.BaseLayer.BaseLayer

Segment embedding layer. Adds segment information. For example, to which sentence the token belongs when an input sequence contains multiple sentences, such as two in the case of BERT model, to the token embedding provided as input.

Parameters
  • num_segments (int) – Number of encoded segments.

  • embeddings_regularizer (callable) – Embeddings regularizer.

build(input_shape)
call(inputs, segment_ids)

Add segment embedding to inputs.

Parameters
  • inputs – Tensor of input embeddings.

  • segment_ids – Segment IDs.