250. Count Univalue Subtrees

Given the root of a binary tree, return the number of uni-value subtrees. A uni-value subtree means all nodes of the subtree have the same value.

Example 1:

Input: root = [5,1,5,5,5,null,5] Output: 4 Example 2: Input: root = [] Output: 0

  • code
class Solution:
    def countUnivalSubtrees(self, root):
        if root is None: return 0
        self.count = 0
        self.is_uni(root)
        return self.count

    def is_uni(self, node):

        # base case - if the node has no children this is a univalue subtree
        if node.left is None and node.right is None:

            # found a univalue subtree - increment
            self.count += 1
            return True

        is_uni = True

        # check if all of the node's children are univalue subtrees and if they have the same value
        # also recursively call is_uni for children
        if node.left is not None:
            is_uni = self.is_uni(node.left) and is_uni and node.left.val == node.val

        if node.right is not None:
            is_uni = self.is_uni(node.right) and is_uni and node.right.val == node.val

        # increment self.res and return whether a univalue tree exists here
        self.count += is_uni
        return is_uni
  • code
class Solution:
    def countUnivalSubtrees(self, root: Optional[TreeNode]) -> int:
        self.res = 0
        def checkSameValue(root):
            if not root:
                return False
            
            if not root.left and not root.right:
                self.res += 1
                return True

            leftSub = checkSameValue(root.left)
            rightSub = checkSameValue(root.right)
            if leftSub and rightSub and root.val == root.left.val == root.right.val:
                self.res += 1
                return True
            if leftSub and not root.right and root.val == root.left.val:
                self.res += 1
                return True
            if rightSub and not root.left and root.val == root.right.val:
                self.res += 1
                return True
            
            return False
        
        checkSameValue(root)
        
        return self.res