Basic statistics concepts for machine learning in Scala spark

Basic statistics concepts for machine learning in Scala spark

Before applying some distribution algorithm or probability density function or  probability mass function, we need to understand some basic concepts of statistics these concepts might be though in our school ,we shall start by brushing up the concepts and implement those in Scala spark,Just for an overview i will be covering Mean, Median & Mode also Variation and Standard Deviation.

1) Mean

This is probably simplest concept of all, Average value in a set of Discrete Numerical Value  the formulae is simple (sum of elements)/number of Elements. so lets Begin with an example how to calculate the mean in spark using Scala.  

Follow this tutorial to get Started with Apache spark and Scala will be using this as an reference through

Now Lets Download Sample Data from here [Annual enterprise survey: 2018 financial year (provisional) – size bands CSV]

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object Main extends App {
  print("hello world");

  val spark = SparkSession
    .builder()
    .appName("test")
    .config("spark.master", "local")
    .getOrCreate();


  var data = spark.read.format("csv").
    option("header", true).load("/<parth to downlaoded file>/<name>.csv").toDF();

  data = data
    .withColumn("rn",row_number()
      .over(Window.orderBy("year"))).toDF();

  data = data.filter(data("rn") > 2).toDF();

  data.select(mean("value") as "mean").show()

}

So i am creating stand alone spark session and assigning to variable and creating an data frame from CSV and creating row number by using Window imported from org.apache.spark.sql.expressions.Window and removing unwanted records from csv. and generating mean at the last line data.select(mean("value") as "mean").show()

2) Median

Is the Middle Value in an Sorted Set for example [1,2,3,4,5]  Value 3 is Median, so it is easy when we have odd values in list.  in case if have even number in set i.e [1,2,3,4] so  middle value will between 2,3 so the median will be 2+3/2.

 import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object Main extends App {
  print("hello world");

  val spark = SparkSession
    .builder()
    .appName("test")
    .config("spark.master", "local")
    .getOrCreate();


  var data = spark.read.format("csv").
    option("header", true).load("/<parth to downlaoded file>/<name>.csv").toDF();

  data = data
    .withColumn("rn",row_number()
      .over(Window.orderBy("year"))).toDF();

  data = data.filter(data("rn") > 2).toDF();
 var median = data
    .withColumn("units",data("value").cast(FloatType))
    .select("units")
    .stat.approxQuantile("units",Array(0.5),0.0001)
    .head;

  println("==================================");

  println(median);

  println("===================================")
  }

The above code will fetch median value. as the the value is from CVS data , type will be String unlike like sql function stat function will not cast data, so we have to manually cast the data.

3) Mode

The number which appears most frequently in an set is called as mode for example [1,1,2,5,6,5] so the occurrence set will be   [1:2,2:1,5:2,6:1]  and the mode value will be 5,so lets see an sample code to find mode in Apache Spark.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType

object Main extends App {
  print("hello world");

  val spark = SparkSession
    .builder()
    .appName("test")
    .config("spark.master", "local")
    .getOrCreate();


  var data = spark.read.format("csv").
    option("header", true).load("/home/ashrith/data/test.csv").toDF();

  data = data
    .withColumn("rn",row_number()
      .over(Window.orderBy("year"))).toDF();

  data = data.filter(data("rn") > 2).toDF();

  data = data.filter(data("value") !== "C").groupBy("value").count();

  data.orderBy(data("count").desc).first().get(0);

  /*var occurance = data.orderBy(data("count").desc).toDF();
  occurance.first().get();*/


}

The above snippet is find the mode value form the data, so in spark it is not straight forward first we have group the occurrence and generate count and the order it by descending order and fetch the first row.

4) Variance

Is a measure that how far the data set is spread out, to calculate the variance

  1. Find the mean value of the data set.
  2. Subtract the Mean for each number in the data set and square the difference
  3. Then Work out the Average of the second step.

To Code find Variance in spark using Scala

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object Main extends App {
  print("hello world");

  val spark = SparkSession
    .builder()
    .appName("test")
    .config("spark.master", "local")
    .getOrCreate();


  var data = spark.read.format("csv").
    option("header", true).load("/home/ashrith/data/test.csv").toDF();

  data = data
    .withColumn("rn",row_number()
      .over(Window.orderBy("year"))).toDF();

  data = data.filter(data("rn") > 2).toDF();

  data.filter(data("value") !== "C").agg(variance(data("value"))).show();

}

Standard Deviation

a quantity expressing by how much the members of a group differ from the mean value for the group. this is very useful in finding an outliers histogram, outliers are the abnormal distance from the group, the occurrence of these numbers are uncommon. so if you are take average household income in an regions, billionaire are called as outliers.

object Main extends App {
  print("hello world");

  val spark = SparkSession
    .builder()
    .appName("test")
    .config("spark.master", "local")
    .getOrCreate();


  var data = spark.read.format("csv").
    option("header", true).load("/home/ashrith/data/test.csv").toDF();

  data = data
    .withColumn("rn",row_number()
      .over(Window.orderBy("year"))).toDF();

  data = data.filter(data("rn") > 2).toDF();

  data.filter(data("value") !== "C").agg(stddev(data("value"))).show();


}

Also Refer

Apache Spark Imputer Usage In Scala
This Tutorial explain what is Spark imputer, implement the Imputer and basic terminologies used while using the imputer.And strategies available in spark imputer.
Apache Spark Data Frame:Basic Data manipulation using scala
Overview of this tutorial * Replace the data with new value in Data Frame * Filter the row values with basic conditions in Data Frame * Type Casting the Column Value in Data Frame To start Apache spark and read data from csv follow this post[/2018/08/06/apache-spark-with-scala-basic-tutorial/]…