register_kl

paddle.distribution. register_kl ( cls_p, cls_q ) [source]

Decorator for register a KL divergence implemention function.

The kl_divergence(p, q) function will search concrete implemention functions registered by register_kl, according to multi-dispatch pattern. If an implemention function is found, it will return the result, otherwise, it will raise NotImplementError exception. Users can register implemention function by the decorator.

Parameters
  • cls_p (Distribution) – The Distribution type of Instance p. Subclass derived from Distribution.

  • cls_q (Distribution) – The Distribution type of Instance q. Subclass derived from Distribution.

Examples

>>> import paddle

>>> @paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
>>> def kl_beta_beta():
...     pass # insert implementation here