# Licensed under a 3-clause BSD style license - see LICENSE.rst import matplotlib.transforms as mtransforms import numpy as np from matplotlib import rcParams from matplotlib.text import Text from .frame import RectangularFrame class AxisLabels(Text): def __init__(self, frame, minpad=1, *args, **kwargs): # Use rcParams if the following parameters were not specified explicitly if "weight" not in kwargs: kwargs["weight"] = rcParams["axes.labelweight"] if "size" not in kwargs: kwargs["size"] = rcParams["axes.labelsize"] if "color" not in kwargs: kwargs["color"] = rcParams["axes.labelcolor"] self._frame = frame super().__init__(*args, **kwargs) self.set_clip_on(True) self.set_visible_axes("all") self.set_ha("center") self.set_va("center") self._minpad = minpad self._visibility_rule = "labels" def get_minpad(self, axis): try: return self._minpad[axis] except TypeError: return self._minpad def set_visible_axes(self, visible_axes): self._visible_axes = visible_axes def get_visible_axes(self): if self._visible_axes == "all": return list(self._frame.keys()) else: return [x for x in self._visible_axes if x in self._frame or x == "#"] def set_minpad(self, minpad): self._minpad = minpad def set_visibility_rule(self, value): allowed = ["always", "labels", "ticks"] if value not in allowed: raise ValueError( f"Axis label visibility rule must be one of{' / '.join(allowed)}" ) self._visibility_rule = value def get_visibility_rule(self): return self._visibility_rule def draw( self, renderer, bboxes, ticklabels_bbox, coord_ticklabels_bbox, ticks_locs, visible_ticks, ): if not self.get_visible(): return text_size = renderer.points_to_pixels(self.get_size()) # Flatten the bboxes for all coords and all axes ticklabels_bbox_list = [] for bbcoord in ticklabels_bbox.values(): for bbaxis in bbcoord.values(): ticklabels_bbox_list += bbaxis for axis in self.get_visible_axes(): if axis == "#": continue if self.get_visibility_rule() == "ticks": if not ticks_locs[axis]: continue elif self.get_visibility_rule() == "labels": if not coord_ticklabels_bbox: continue padding = text_size * self.get_minpad(axis) # Find position of the axis label. For now we pick the mid-point # along the path but in future we could allow this to be a # parameter. x, y, normal_angle = self._frame[axis]._halfway_x_y_angle() label_angle = (normal_angle - 90.0) % 360.0 if 135 < label_angle < 225: label_angle += 180 self.set_rotation(label_angle) # Find label position by looking at the bounding box of ticks' # labels and the image. It sets the default padding at 1 times the # axis label font size which can also be changed by setting # the minpad parameter. if isinstance(self._frame, RectangularFrame): if ( len(ticklabels_bbox_list) > 0 and ticklabels_bbox_list[0] is not None ): coord_ticklabels_bbox[axis] = [ mtransforms.Bbox.union(ticklabels_bbox_list) ] else: coord_ticklabels_bbox[axis] = [None] visible = ( axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None ) if axis == "l": if visible: x = coord_ticklabels_bbox[axis][0].xmin x = x - padding elif axis == "r": if visible: x = coord_ticklabels_bbox[axis][0].x1 x = x + padding elif axis == "b": if visible: y = coord_ticklabels_bbox[axis][0].ymin y = y - padding elif axis == "t": if visible: y = coord_ticklabels_bbox[axis][0].y1 y = y + padding else: # arbitrary axis x = x + np.cos(np.radians(normal_angle)) * (padding + text_size * 1.5) y = y + np.sin(np.radians(normal_angle)) * (padding + text_size * 1.5) self.set_position((x, y)) super().draw(renderer) bb = super().get_window_extent(renderer) bboxes.append(bb)