r,c=map(int,input().split());m=[[*map(int,input().split())]for _ in'.'*r];print(2*sum(m[i][j]>0for i in range(r)for j in range(c))+sum(m[i][0]+m[i][-1]+sum(abs(m[i][j+1]-m[i][j])for j in range(c-1))for i in range(r))+sum(m[0][i]+m[-1][i]+sum(abs(m[j+1][i]-m[j][i])for j in range(r-1))for i in range(c)))